"""
Some plotting capabilities
"""
import numpy as np
from matplotlib import pyplot as plt
from .metrics import r2
[docs]
def r2_scatter(y_true: np.ndarray, y_pred: np.ndarray, state_idx: int|tuple= None,
title:str = None, xlabel:str=None, ylabel:str = None):
# plots predictions against ground truth values as scatter plot
# lets the user choose the output state to plot (if there are multiple states). If not provided, all states will
# be plotted.
# collapses the data along the time dimension to show time-dependent targets
# expects arguments to be of 3D shape: [n_batch, n_timesteps, n_states]
if y_true.ndim != y_pred.ndim:
raise(ValueError('Inconsistent shapes! y_true and y_pred need to have the same shape!'))
if (state_idx is not None) and (np.max(state_idx) >= y_true.ndim):
raise(ValueError(f'Please select a valid state index, maximum being {y_true.ndim} for the given data'))
# select the states to show in case supplied by the user
if state_idx is not None:
y_true = y_true[:, :, state_idx]
y_pred = y_pred[:, :, state_idx]
# now flatten into a vector (along states and along time if necessary)
y_true = y_true.flatten()
y_pred = y_pred.flatten()
min_val = np.min(y_true) # np.min([np.min(y_true), np.min(y_pred)])
max_val = np.max(y_true) # np.max([np.max(y_true), np.max(y_pred)])
fig = plt.figure()
plt.plot([min_val, max_val], [min_val, max_val],
linestyle='solid',
color='gray',
label='perfect model')
plt.plot(y_true, y_pred,
linestyle='none',
marker='.',
markersize=5,
color='black',
label=rf'model predictions ($R^2=${r2(y_true, y_pred):.2f})')
if xlabel is None:
plt.xlabel('ground truth')
else:
plt.xlabel(xlabel)
if ylabel is None:
plt.ylabel('predictions')
else:
plt.ylabel(ylabel)
plt.legend()
if title is None:
plt.title(rf'')
else:
plt.title(title)
plt.tight_layout()
plt.show()