"""
Capabilities to prune an existing RC model, i.e. try to cut reservoir nodes and improve
performance while reducing the reservoir size
"""
import numpy as np
from pyreco.custom_models import RC
from pyreco.node_selector import NodeSelector
import networkx as nx
import math
from typing import Union
import copy
from pyreco.graph_analyzer import GraphAnalyzer
from pyreco.node_analyzer import NodeAnalyzer
[docs]
class NetworkPruner:
# implements a pruning object for pyreco objects.
def __init__(
self,
target_score: float = None,
stop_at_minimum: bool = True,
min_num_nodes: int = 3,
patience: int = 0,
candidate_fraction: float = 0.1,
remove_isolated_nodes: bool = False,
criterion: str = "mse",
metrics: Union[list, str] = ["mse"],
maintain_spectral_radius: bool = False,
node_props_extractor=None,
graph_props_extractor=None,
return_best_model: bool = True,
graph_analyzer: GraphAnalyzer = None,
node_analyzer: NodeAnalyzer = None,
):
"""
Initializer for the pruning class.
Parameters:
- target_score (float): The test set score that the user aims at. Pruning stops
once this score is reached.
- stop_at_minimum (bool): Whether to stop at the local minimum of the test set
score. When set to False, pruning continues until the minimal number of nodes
in <min_num_nodes>.
- min_num_nodes (int): Stop pruning when arriving at this number of nodes.
Conflicts if stop_at_minimum is set to True but also a min_num_nodes is given.
- patience (int): We allow a patience, i.e. keep pruning after we reached a
(local) minimum of the test set score. Depends on the size of the original
reservoir network, defaults to 10% of initial reservoir nodes.
- candidate_fraction (float): number of randomly chosen reservoir nodes during
every pruning iteration that is a candidate for pruning. Refers to the fraction of nodes w.r.t. current number of nodes during pruning iteration.
- remove_isolated_nodes (bool): Whether to remove isolated nodes during pruning.
- criterion (str): The criterion to be used for steering the node pruning. Default is "mse".
- metrics (list or str): The metrics to be used for evaluating the pruned model. Default is ["mse"].
- maintain_spectral_radius (bool): Whether to maintain the spectral radius of the reservoir layer during pruning.
"""
# Sanity checks for the input parameter types and values
if target_score is not None and not isinstance(target_score, float):
raise TypeError("target_score must be a float")
if not isinstance(stop_at_minimum, bool):
raise TypeError("stop_at_minimum must be a boolean")
if not isinstance(min_num_nodes, int):
raise TypeError("min_num_nodes must be an integer")
if min_num_nodes <= 2:
raise ValueError("min_num_nodes must be larger than 2")
if patience is not None and not isinstance(patience, int):
raise TypeError("patience must be an integer")
if not isinstance(candidate_fraction, float):
raise TypeError("candidate_fraction must be a float in (0, 1]")
if candidate_fraction <= 0 or candidate_fraction > 1:
raise ValueError("candidate_fraction must be a float in (0, 1]")
if not isinstance(criterion, str):
raise TypeError("criterion must be a string")
if graph_analyzer is not None and not isinstance(graph_analyzer, GraphAnalyzer):
raise TypeError("graph_analyzer must be an instance of GraphAnalyzer")
if graph_analyzer is None:
graph_analyzer = GraphAnalyzer()
if node_analyzer is not None and not isinstance(node_analyzer, NodeAnalyzer):
raise TypeError("node_analyzer must be an instance of NodeAnalyzer")
if node_analyzer is None:
node_analyzer = NodeAnalyzer()
# Assigning the parameters to instance variables
if target_score is None:
self.target_score = 0.0
else:
self.target_score = target_score
self.criterion = criterion
self.stop_at_minimum = stop_at_minimum
self.min_num_nodes = min_num_nodes
self.patience = patience
self.candidate_fraction = candidate_fraction
self.metrics = metrics
self.return_best_model = return_best_model
self.graph_analyzer = graph_analyzer
self.node_analyzer = node_analyzer
# TODO not implemented yet
self.remove_isolated_nodes = remove_isolated_nodes
self.maintain_spectral_radius = maintain_spectral_radius
# store the history of the pruning process in a nested dictionary
self.history = {}
# initialize attributes that will be used during pruning (and changed during the process)
# needs to be attributes as the history updates depend on them
self._curr_loss = None
self._curr_num_nodes = None
self._curr_loss_history = []
self._idx_prune = None
self._patience_counter = 0
self._curr_metrics = None
[docs]
def prune(self, model: RC, data_train: tuple, data_val: tuple):
"""
Prune a given model by removing nodes.
Parameters:
- model (RC): The reservoir computer model to prune.
- data_train (tuple): Training data.
- data_val (tuple): Validation data.
"""
# Sanity checks for the input parameter types and values
if not isinstance(model, RC):
raise TypeError("model must be an instance of RC")
if not isinstance(data_train, tuple) or not isinstance(data_val, tuple):
raise TypeError("data_train and data_val must be tuples")
if len(data_train) != 2 or len(data_val) != 2:
raise ValueError("data_train and data_val must have 2 elements each")
for idx, elem in enumerate(data_train):
if not isinstance(elem, list):
if not isinstance(elem, np.ndarray):
raise TypeError(f"data_train[{idx}] must be a list or numpy array")
for idx, elem in enumerate(data_val):
if not isinstance(elem, list):
if not isinstance(elem, np.ndarray):
raise TypeError(f"data_val[{idx}] must be a list or numpy array")
if len(data_train[0]) != len(data_train[1]):
raise ValueError(
"data_train[0] and data_train[1] must have the same length, "
"i.e. same number of samples"
)
if len(data_val[0]) != len(data_val[1]):
raise ValueError(
"data_val[0] and data_val[1] must have the same length, "
"i.e. same number of samples"
)
# obtain training and testing data
x_test, y_test = data_val[0], data_val[1]
x_train, y_train = data_train[0], data_train[1]
# Assigning the parameters to instance variables that can not be set
# in the initializer, as they depend on the model and data
self._curr_num_nodes = model.reservoir_layer.nodes
# initialize the quantities that affect the stop condition
self._curr_loss = model.evaluate(x=x_test, y=y_test, metrics=self.criterion)[0]
self._curr_loss_history = [self._curr_loss]
# initialize quantities that we track for the pruning history
# these do not affect the pruning process
self._curr_metrics = model.evaluate(x=x_test, y=y_test, metrics=self.metrics)
# storing all pruned models during the pruning iteration
# allows to recover models from previous iterations, e.g. when the best model is not the last one in the iteration (positive patience value)
_pruned_models = [copy.deepcopy(model)]
# initialize the pruning iterator
self._iter_count = 0
# Store all relevant information during pruning inside self.history
# self._update_pruning_history(model=model)
self.add_val_to_history(["loss"], self._curr_loss)
self.add_val_to_history(["metrics"], self._curr_metrics)
self.add_val_to_history(["num_nodes"], self._curr_num_nodes)
self.add_val_to_history(["iteration"], self._iter_count)
_graph = model.reservoir_layer.weights
_graph_props = self.graph_analyzer.extract_properties(graph=_graph)
self.add_dict_to_history(["graph_props"], _graph_props)
while True: # self._curr_num_nodes>self.min_num_nodes:
print(f"pruning iteration {self._iter_count}")
print(
f"current reservoir size: {self._curr_num_nodes}, current loss: {self._curr_loss:.8f}"
)
# propose a list of nodes to prune using a random uniform distribution. If the user specified a candidate_fraction of 1.0, we will try out all nodes
_num_nodes_to_prune = math.ceil(
self.candidate_fraction * self._curr_num_nodes
)
selector = NodeSelector(
total_nodes=self._curr_num_nodes, strategy="random_uniform_wo_repl"
)
# obtain nodes that are proposed for pruning
_curr_candidate_nodes = selector.select_nodes(num=_num_nodes_to_prune)
print(
f"propose {_num_nodes_to_prune}/{self._curr_num_nodes} nodes for pruning"
)
# track the performance of the RC with the candidate nodes removed
_candidate_scores = []
_candidate_models = []
_cand_node_props = [] # properties of the to-be removed node
_cand_node_input_receiving = (
[]
) # whether the node is connected to input layer
_cand_node_output_sending = (
[]
) # whether the node is connected to output layer
_cand_graph_props_before = [] # properties of the graph before pruning
_cand_graph_props_after = [] # properties of the graph after pruning
# iteratate over the candidate nodes: delete one-by-one, measure performance,
# and also track node/graph-level properties
for node in _curr_candidate_nodes:
# get a copy of the original model to try out the deletion
_model = copy.deepcopy(model)
# extract information about the node that we will prune,
# and about the graph before we prune it
_graph = _model.reservoir_layer.weights
_node_props = self.node_analyzer.extract_properties(
graph=_graph, node=node
)
_graph_props = self.graph_analyzer.extract_properties(graph=_graph)
# check for links to input and read-out layer of the current node
_is_input_receiving = (
node in _model.reservoir_layer.input_receiving_nodes
)
_is_output_sending = node in _model.readout_layer.readout_nodes
_cand_node_props.append(_node_props)
_cand_graph_props_before.append(_graph_props)
_cand_node_input_receiving.append(_is_input_receiving)
_cand_node_output_sending.append(_is_output_sending)
# remove current candidate node
_model.remove_reservoir_nodes(nodes=[node])
# TODO: remove isolated nodes using utility function from utils_networks
# if self.remove_isolated_nodes:
# iso_nodes = ...
# _model.remove_reservoir_nodes(nodes=[iso_nodes])
# TODO: maintain the spectral radius of the reservoir layer
# if self.maintain_spectral_radius:
# spec_rad = model.reservoir_layer.spectral_radius
# _model.set_spec_rad(spec_rad)
# pruning requires re-fitting the model
_model.fit(x=x_train, y=y_train)
# evaluate the pruned model
_score = _model.evaluate(x=x_test, y=y_test, metrics=self.criterion)[0]
# extract graph properties after pruning
_graph = _model.reservoir_layer.weights
_graph_props = self.graph_analyzer.extract_properties(graph=_graph)
_cand_graph_props_after.append(_graph_props)
print(
f"deletion of candidate node {node}. loss: \t{_score:.6f} ({(self._curr_loss-_score)/self._curr_loss:+.3%})"
)
# store the relevant candidate information
_candidate_scores.append(_score)
_candidate_models.append(_model)
# delete temporary variables (just for safety)
del (
_model,
_score,
_graph,
_graph_props,
_node_props,
)
# store the candidate properties in the history object
self.add_val_to_history(
["candidate_scores"],
_candidate_scores,
)
self.add_val_to_history(
["candidate_nodes"],
_curr_candidate_nodes,
)
self.add_val_to_history(
["candidate_node_props"],
dictlist_to_dict(_cand_node_props),
)
self.add_val_to_history(
["candidate_graph_props_before"],
dictlist_to_dict(_cand_graph_props_before),
)
self.add_val_to_history(
["candidate_graph_props_after"],
dictlist_to_dict(_cand_graph_props_after),
)
# after trying out all candidate nodes, we need to select the node to prune,
# i.e. the one that has the smallest loss among all candidate nodes
idx_prune = np.argmin(_candidate_scores)
self._curr_idx_prune = idx_prune # just for history logging
# update the termination relevant quantities,
# assuming that we will prune that node
self._curr_loss = _candidate_scores[idx_prune]
self._curr_num_nodes = _candidate_models[idx_prune].reservoir_layer.nodes
self._curr_loss_history.append(self._curr_loss)
# check if we should actually prune the node, or if that would violate the termination criteria (no optimal design by now to do it here though)
if not self._keep_pruning():
# exit the pruning loop
break
print(f"pruning node {idx_prune}, resulting in loss {self._curr_loss:.6f}")
print(
f"loss improvement by {((self._curr_loss_history[-2]-self._curr_loss)/self._curr_loss_history[-2]):+.3%}\n"
)
# prune the node that gives us the least performance drop. as we have already
# pruned the node and stored the model, we only need to update the model.
# Saves at least one training run and all the pruning logic
model = _candidate_models[idx_prune]
# store the model for later use
_pruned_models.append(copy.deepcopy(model))
# compute things that are required for the history, but not for the
# pruning loop termination criteria
self._curr_metrics = model.evaluate(
x=x_test, y=y_test, metrics=self.metrics
)
# self._update_pruning_history
self.add_val_to_history(["loss"], self._curr_loss)
self.add_val_to_history(["metrics"], self._curr_metrics)
self.add_val_to_history(["num_nodes"], self._curr_num_nodes)
self.add_val_to_history(["idx_prune"], self._curr_idx_prune)
self.add_val_to_history(["iteration"], self._iter_count)
self.add_val_to_history(
["del_node_props", "input_receiving_node"],
_cand_node_input_receiving[idx_prune],
)
self.add_val_to_history(
["del_node_props", "output_sending_node"],
_cand_node_output_sending[idx_prune],
)
self.add_dict_to_history(["del_node_props"], _cand_node_props[idx_prune])
self.add_dict_to_history(
["graph_props"], _cand_graph_props_after[idx_prune]
)
# update counter
self._iter_count += 1
# in case we have a non-zero patience, we need to return the best model
# instead of the last one (i.e. when a positive patience value was given)
if self.return_best_model:
idx_best = np.argmin(self._curr_loss_history[:-1])
model = copy.deepcopy(_pruned_models[idx_best])
print(f"returning model {idx_best} as the best, i.e. with lowest loss")
# we should fit the final model, and evaluate it
model.fit(x=x_train, y=y_train)
final_loss = model.evaluate(x=x_test, y=y_test, metrics=self.criterion)[0]
final_metrics = model.evaluate(x=x_test, y=y_test, metrics=self.metrics)
print(
f"\ninitial loss{self._curr_loss_history[0]:.6f}, loss after pruning: {final_loss:.6f}"
)
print(f"final model has {model.reservoir_layer.nodes} nodes")
print(f"final model loss {self.criterion}: {final_loss:.6f}")
print(f"final model metrics ({self.metrics}): {final_metrics}")
return model, self.history
def _keep_pruning(self):
# Termination criteria for the pruning process
# Keep pruning as long as all of the following conditions are met:
# 1. The current score is below the target score
# 2. The current number of nodes is above the minimum number of nodes
# 3. The current loss is smaller than the previous loss
if (
self._loss_not_met()
and self._num_nodes_not_met()
and self._at_minimum_not_met()
):
return True
else:
return False
def _loss_not_met(self):
# Checks the stopping condition based on the current and target loss
# returns True if the criterion is not met, i.e. we should continue pruning
if self._curr_loss >= self.target_score:
print(
f"Loss {self._curr_loss:.6f} is larger target score {self.target_score:.6f}. Continuing pruning"
)
return True
else:
print(
f"Loss {self._curr_loss:.6f} is smaller target score {self.target_score:.6f}. Terminating pruning"
)
return False
def _num_nodes_not_met(self):
# checks if the number of nodes is above the minimum number of nodes
# returns True if number of nodes is above minimum, i.e. we should continue pruning
if self._curr_num_nodes > self.min_num_nodes:
print(
f"Number of nodes {self._curr_num_nodes} is larger than minimum number of nodes {self.min_num_nodes}. Continuing pruning"
)
return True
else:
print(
f"Number of nodes {self._curr_num_nodes} is smaller/equal minimum number of nodes {self.min_num_nodes}. Terminating pruning"
)
return False
def _at_minimum_not_met(self):
# checks if the loss is at a minimum,
# considering also patience.
# returns True if loss is not at minimum, i.e. we should continue pruning
if len(self._curr_loss_history) < 2:
# we are just at the start of pruning, cannot
# check for a minimum.
return True
if self.stop_at_minimum:
if self._curr_loss_history[-2] > self._curr_loss_history[-1]:
print(
f"Loss decreased from {self._curr_loss_history[-2]:.6f} to {self._curr_loss_history[-1]:.6f}. Continuing pruning"
)
self._patience_counter = 0
return True
else: # current loss is larger than previous
self._patience_counter += 1
if self._patience_counter < self.patience:
print(
f"Loss increased, but {self._patience_counter} < {self.patience} Continuing pruning"
)
return True
else:
# TODO: we need to recover the model that had the best score!
print(
f"Loss increased for {self.patience} consecutive iterations. Terminating pruning"
)
return False
else:
return True
[docs]
def add_val_to_history(self, keys, value):
"""
Add a value to history dictionary based on a list of keys.
Args:
keys (list): A list of keys specifying the path in the nested dictionary.
value: The value to add.
"""
if len(keys) == 1:
if keys[0] not in self.history:
self.history[keys[0]] = []
self.history[keys[0]].append(value)
elif len(keys) == 2:
if keys[0] not in self.history:
self.history[keys[0]] = {}
if keys[1] not in self.history[keys[0]]:
self.history[keys[0]][keys[1]] = []
self.history[keys[0]][keys[1]].append(value)
# for key in keys[:-1]:
# if key not in self.history:
# self.history[key] = {}
# self.history = self.history[key]
# if keys[-1] not in self.history:
# self.history[keys[-1]] = []
# self.history[keys[-1]].append(value)
[docs]
def add_dict_to_history(self, keys, value_dict):
"""
Add a dictionary to history dictionary based on a list of keys.
Args:
nested_dict (dict): The nested dictionary.
keys (list): A list of keys specifying the path in the nested dictionary.
value_dict (dict): The dictionary to add.
"""
for key in keys[:-1]:
if key not in self.history:
self.history[key] = {}
self.history = self.history[key]
if keys[-1] not in self.history:
self.history[keys[-1]] = {}
for k, v in value_dict.items():
if k not in self.history[keys[-1]]:
self.history[keys[-1]][k] = []
self.history[keys[-1]][k].append(v)
[docs]
def dictlist_to_dict(dict_list):
"""
Join dictionaries in a list into a common dictionary.
Args:
dict_list (list): A list of dictionaries to join.
Returns:
dict: A common dictionary containing all key-value pairs from the dictionaries in the list.
"""
common_dict = {}
for d in dict_list:
for key, value in d.items():
if key in common_dict:
if isinstance(common_dict[key], list):
common_dict[key].append(value)
else:
common_dict[key] = [common_dict[key], value]
else:
common_dict[key] = value
return common_dict
# def _update_pruning_history(self, model: RC):
# # this will keep track of all quantities that are relevant during the pruning iterations.
# # Pruning iteration
# # self.history["iteration"].append(self._curr_iter)
# if not self.history:
# # initialize the history object
# self.history["iteration"] = []
# self.history["loss"] = []
# self.history["metrics"] = []
# self.history["num_nodes"] = []
# # initialize the dicts for the graph and node properties with empty lists
# graph_keys = self.graph_analyzer.list_properties()
# node_keys = self.node_analyzer.list_properties()
# self.history["graph_props"] = {key: [] for key in graph_keys}
# # self.history["candidate_graph_props"] = {key: [] for key in graph_keys}
# self.history["del_node_props"] = {key: [] for key in graph_keys}
# # self.history["candidate_node_props"] = {key: [] for key in graph_keys}
# else:
# # store the most relevant information
# # we will extract properties from the reservoir network of the model
# graph = model.reservoir_layer.weights
# graph_props = self.graph_analyzer.extract_properties(graph)
# # # choose the node to extract properties from
# # node = int(self._curr_idx_prune)
# # node_props = self.node_analyzer.extract_properties(graph, node)
# # high-level properties
# self.add_val_to_history(
# ["num_nodes"],
# self._curr_num_nodes,
# )
# self.add_val_to_history(
# ["loss"],
# self._curr_loss,
# )
# self.add_val_to_history(
# ["metrics"],
# self._curr_metrics,
# )
# self.add_val_to_history(
# ["iteration"],
# self._iter_count,
# )
# self.add_val_to_history(
# ["graph_props"],
# graph_props,
# )
[docs]
def append_to_dict(dict1, dict2):
# appends entries in dict1 to existing dict 2
for key in list(dict1.keys()):
if key in list(dict2.keys()):
# print(f"appending {key} to existing dict")
dict2[key].append(dict1[key])
return dict2
if __name__ == "__main__":
# test the pruning
from pyreco.utils_data import sequence_to_sequence as seq_2_seq
from pyreco.custom_models import RC as RC
from pyreco.layers import InputLayer, ReadoutLayer
from pyreco.layers import RandomReservoirLayer
from pyreco.optimizers import RidgeSK
# get some data
X_train, X_test, y_train, y_test = seq_2_seq(
name="sine_pred", n_batch=20, n_states=2, n_time=150
)
input_shape = X_train.shape[1:]
output_shape = y_train.shape[1:]
# build a classical RC
model = RC()
model.add(InputLayer(input_shape=input_shape))
model.add(
RandomReservoirLayer(
nodes=50,
density=0.1,
activation="tanh",
leakage_rate=0.1,
fraction_input=0.5,
),
)
model.add(ReadoutLayer(output_shape, fraction_out=0.9))
# Compile the model
optim = RidgeSK(alpha=0.5)
model.compile(
optimizer=optim,
metrics=["mean_squared_error"],
)
# Train the model
model.fit(X_train, y_train)
print(f"score: \t\t\t{model.evaluate(x=X_test, y=y_test)[0]:.4f}")
# prune the model
pruner = NetworkPruner(
stop_at_minimum=False,
min_num_nodes=46,
patience=2,
candidate_fraction=0.9,
remove_isolated_nodes=False,
metrics=["mse"],
maintain_spectral_radius=False,
)
model_pruned, history = pruner.prune(
model=model, data_train=(X_train, y_train), data_val=(X_test, y_test)
)
import matplotlib.pyplot as plt
plt.figure()
plt.subplot(1, 2, 1)
plt.plot(history["num_nodes"], history["loss"], label="loss")
plt.xlabel("number of nodes")
plt.ylabel("loss")
plt.subplot(1, 2, 2)
for key in history["graph_props"].keys():
plt.plot(history["num_nodes"], history["graph_props"][key], label=key)
plt.xlabel("number of nodes")
plt.yscale("log")
plt.legend()
plt.show()