"""
Some plotting capabilities
"""
import numpy as np
from matplotlib import pyplot as plt
import networkx as nx
from .metrics import r2
[docs]
def r2_scatter(y_true: np.ndarray, y_pred: np.ndarray, state_idx: int|tuple= None,
title:str = None, xlabel:str=None, ylabel:str = None):
# plots predictions against ground truth values as scatter plot
# lets the user choose the output state to plot (if there are multiple states). If not provided, all states will
# be plotted.
# collapses the data along the time dimension to show time-dependent targets
# expects arguments to be of 3D shape: [n_batch, n_timesteps, n_states]
if y_true.ndim != y_pred.ndim:
raise(ValueError('Inconsistent shapes! y_true and y_pred need to have the same shape!'))
if (state_idx is not None) and (np.max(state_idx) >= y_true.ndim):
raise(ValueError(f'Please select a valid state index, maximum being {y_true.ndim} for the given data'))
# select the states to show in case supplied by the user
if state_idx is not None:
y_true = y_true[:, :, state_idx]
y_pred = y_pred[:, :, state_idx]
# now flatten into a vector (along states and along time if necessary)
y_true = y_true.flatten()
y_pred = y_pred.flatten()
min_val = np.min(y_true) # np.min([np.min(y_true), np.min(y_pred)])
max_val = np.max(y_true) # np.max([np.max(y_true), np.max(y_pred)])
fig = plt.figure()
plt.plot([min_val, max_val], [min_val, max_val],
linestyle='solid',
color='gray',
label='perfect model')
plt.plot(y_true, y_pred,
linestyle='none',
marker='.',
markersize=5,
color='black',
label=rf'model predictions ($R^2=${r2(y_true, y_pred):.2f})')
if xlabel is None:
plt.xlabel('ground truth')
else:
plt.xlabel(xlabel)
if ylabel is None:
plt.ylabel('predictions')
else:
plt.ylabel(ylabel)
plt.legend()
if title is None:
plt.title(rf'')
else:
plt.title(title)
plt.tight_layout()
plt.show()
[docs]
def visualize_reservoir_network_circle(G_Net, W_inp, W_out, n_inputs, n_outputs,
seed=None, save_path=None, Node_colors=None,
Edge_Weights=None):
"""
Draws a circular layout diagram for ESN/RC architecture.
Handles ANY valid shape of W_inp and W_out safely.
Allows custom colors.
"""
CWinp = Node_colors['CWinp']
CWres_inp = Node_colors['CWres_inp']
CWres_out = Node_colors['CWres_out']
CWres_both = Node_colors['CWres_both']
CWres_internal = Node_colors['CWres_internal']
CWout = Node_colors['CWout']
Winp = Node_colors['Winp']
Wout = Node_colors['Wout']
CWres = Node_colors['CWres']
# --------------------------------------------------
# Setup
# --------------------------------------------------
if seed is not None:
np.random.seed(seed)
G = nx.DiGraph()
# --------------------------------------------------
# Create Nodes
# --------------------------------------------------
# Input nodes
input_nodes = [f'I{i}' for i in range(n_inputs)]
for node in input_nodes:
G.add_node(node, type='input')
# Reservoir nodes
n_reservoir = G_Net.shape[0]
reservoir_nodes = [f'R{i}' for i in range(n_reservoir)]
for node in reservoir_nodes:
G.add_node(node, type='reservoir')
# Output nodes
output_nodes = [f'O{i}' for i in range(n_outputs)]
for node in output_nodes:
G.add_node(node, type='output')
# --------------------------------------------------
# Add Input → Reservoir Edges
# --------------------------------------------------
input_connections = set()
def add_single_input_vector(vec):
"""Handles W_inp as (N,) or (N,1) — single input only."""
for j in range(len(vec)):
w = vec[j]
G.add_edge('I0', f'R{j}', weight=w, type='input')
input_connections.add(f'R{j}')
if W_inp.ndim == 1:
add_single_input_vector(W_inp)
elif W_inp.ndim == 2 and W_inp.shape[1] == 1:
add_single_input_vector(W_inp[:, 0])
else:
# Standard multiple inputs: shape = (n_inputs, n_reservoir)
for i in range(W_inp.shape[0]):
for j in range(W_inp.shape[1]):
w = W_inp[i, j]
if abs(w) > 0.01:
G.add_edge(f'I{i}', f'R{j}', weight=w, type='input')
input_connections.add(f'R{j}')
# --------------------------------------------------
# Add Reservoir → Reservoir Edges
# --------------------------------------------------
for i in range(n_reservoir):
for j in range(n_reservoir):
w = G_Net[i, j]
if abs(w) > 0.01:
G.add_edge(f'R{i}', f'R{j}', weight=w, type='recurrent')
# --------------------------------------------------
# Add Reservoir → Output Edges
# --------------------------------------------------
output_connections = set()
if W_out.ndim == 1:
for i in range(len(W_out)):
if abs(W_out[i]) > 0.01:
G.add_edge(f'R{i}', 'O0', weight=W_out[i], type='output')
output_connections.add(f'R{i}')
else:
for i in range(W_out.shape[0]):
for j in range(W_out.shape[1]):
w = W_out[i, j]
if abs(w) > 0.01:
G.add_edge(f'R{i}', f'O{j}', weight=w, type='output')
output_connections.add(f'R{i}')
# --------------------------------------------------
# Reservoir Node Types (colors + labels)
# --------------------------------------------------
reservoir_colors = {}
reservoir_labels = {}
for node in reservoir_nodes:
has_in = node in input_connections
has_out = node in output_connections
if has_in and has_out:
reservoir_colors[node] = CWres_both
reservoir_labels[node] = 'I+O'
elif has_in:
reservoir_colors[node] = CWres_inp
reservoir_labels[node] = 'Inp'
elif has_out:
reservoir_colors[node] = CWres_out
reservoir_labels[node] = 'Out'
else:
reservoir_colors[node] = CWres_internal
reservoir_labels[node] = 'Int'
# --------------------------------------------------
# Node Positions
# --------------------------------------------------
pos = {}
# Reservoir nodes inside a circle
R = 2
for i, node in enumerate(reservoir_nodes):
theta = np.random.uniform(0, 2 * np.pi)
r = R * np.sqrt(np.random.uniform(0, 1))
pos[node] = (r * np.cos(theta), r * np.sin(theta))
# Input column aligned vertically
spacing = 1.5
mid = (n_inputs - 1) * spacing / 2
for k, node in enumerate(input_nodes):
pos[node] = (-3, k * spacing - mid)
# Output column aligned vertically
mid = (n_outputs - 1) * spacing / 2
for k, node in enumerate(output_nodes):
pos[node] = (3, k * spacing - mid)
# --------------------------------------------------
# Plotting
# --------------------------------------------------
plt.figure(figsize=(14, 10))
# Reservoir circle
plt.gca().add_patch(
plt.Circle((0, 0), R, color='Black', fill=False, linestyle='-', linewidth=1, alpha=0.5)
)
# Node colors list
node_colors = []
for node in G.nodes():
t = G.nodes[node]['type']
if t == 'input':
node_colors.append(Winp)
elif t == 'output':
node_colors.append(Wout)
else:
node_colors.append(reservoir_colors[node])
# Draw nodes
nx.draw_networkx_nodes(
G, pos,
node_size=150,
node_color=node_colors
)
# Draw labels
nx.draw_networkx_labels(
G,
{k: v for k, v in pos.items() if k.startswith('R')},
labels=reservoir_labels,
font_size=6, font_weight='bold',
)
# Edge colors
edge_colors = []
for _, _, d in G.edges(data=True):
if d['type'] == 'input':
edge_colors.append(CWinp)
elif d['type'] == 'recurrent':
edge_colors.append(CWres)
else:
edge_colors.append(CWout)
# Draw edges on TOP of nodes
nx.draw_networkx_edges(
G, pos, arrows=True, arrowsize=5,
edge_color=edge_colors,
width=Edge_Weights
)
# Layer labels
plt.text(-2.5, -1.5, "Input Layer", fontsize=14, ha='center', color='red')
plt.text(0, 2.1, "Reservoir Layer", fontsize=14, ha='center', color='blue')
plt.text(2.5, -1.5, "Output Layer", fontsize=14, ha='center', color='green')
plt.text(-3, -2.25, "W_inp", fontsize=14, ha='center', color='red')
plt.text(0, -2.3, "W_res", fontsize=14, ha='center', color='blue')
plt.text(3, -2.25, "W_out", fontsize=14, ha='center', color='green')
# Legend fixed (inside figure)
legend_elements = [
plt.Line2D([0], [0], marker='o', markerfacecolor='orange', color='w', markersize=10, label='Input + Output'),
plt.Line2D([0], [0], marker='o', markerfacecolor='lightcoral', color='w', markersize=10, label='Input Only'),
plt.Line2D([0], [0], marker='o', markerfacecolor='lightgreen', color='w', markersize=10, label='Output Only'),
plt.Line2D([0], [0], marker='o', markerfacecolor='lightblue', color='w', markersize=10, label='Internal Only'),
]
plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1, 1), fontsize=10)
plt.title("Python Reservoir Computing Architecture - PyReCo Architecture")
plt.axis('off')
plt.tight_layout()
# Save figure if filename is provided
if save_path is not None:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()