Source code for pyreco.edge_pruning

import numpy as np
import networkx as nx
import scipy.sparse as sp
from joblib import Parallel, delayed
from tqdm import tqdm
from pyreco.custom_models import RC
from pyreco.edge_selector import EdgeSelector
from typing import Union
import copy
from pyreco.graph_analyzer import GraphAnalyzer
from pyreco.node_analyzer import NodeAnalyzer
from pyreco.edge_analyzer import EdgeAnalyzer, available_extractors
from pyreco.metrics import available_metrics


def _evaluate_candidate_standalone(model, candidate, x_train, y_train, x_test, y_test,
                                   criterion, graph_analyzer):
    """
    Evaluate a single candidate edge removal, independent of 'EdgePruner'.
    Standalone (module-level) counterpart to
    'EdgePruner._evaluate_candidate_performance', used for parallel execution so that
    'joblib' workers do not need to pickle the whole 'EdgePruner' instance
    (whose history grows over the course of pruning) for every candidate.

    Parameters
    ----------
    model : RC
        Current (not yet pruned) reservoir computer model.
    candidate : tuple
        Edge (u, v) to tentatively remove.
    x_train : np.ndarray
        Training input data.
    y_train : np.ndarray
        Training target data.
    x_test : np.ndarray
        Validation input data.
    y_test : np.ndarray
        Validation target data.
    criterion : str
        Performance metric used to score the candidate.
    graph_analyzer : GraphAnalyzer
        Analyzer used to extract graph-level properties after pruning.

    Returns
    -------
    score : float
        Loss of the refit model with 'candidate' removed, evaluated on
        (x_test, y_test) using 'criterion'.
    model : RC
        Deep copy of 'model' with 'candidate' removed and refit.
    graph_props_after : dict
        Graph-level properties of the reservoir after removing 'candidate'.
    """
    _model = copy.deepcopy(model)
    _model.remove_reservoir_edges(edges=[candidate])
    _model.fit(x=x_train, y=y_train)
    _score = _model.evaluate(x=x_test, y=y_test, metrics=criterion)[0]
    _graph = _model.reservoir_layer.weights
    _graph_props_after = graph_analyzer.extract_properties(graph=_graph)
    return _score, _model, _graph_props_after


[docs] class EdgePruner: # implements a pruning object for pyreco objects. PRUNING_CRITERION = { 'performance': '_performance_pruning', 'structure': '_structural_pruning' } STOPPING_CRITERION = { 'patience': '_patience_stopping', 'min_nodes': '_min_num_nodes_stopping', 'min_edges': '_min_num_edges_stopping' } def __init__( self, edge_selection_strat: str = 'random_uniform_wo_repl', candidate_fraction: float = 0.1, pruning_criterion: str = 'performance', stopping_criterion: list = ['patience'], min_num_nodes: int = 3, min_num_edges: int = 2, patience: int = 0, performance_criterion: str = 'mse', structural_criterion: str = 'betweenness', metrics: Union[list, str] = ['mse'], return_best_model: bool = True, graph_analyzer: GraphAnalyzer = None, node_analyzer: NodeAnalyzer = None, edge_analyzer: EdgeAnalyzer = None, remove_isolated_nodes: bool = False, directed: bool = True, parallel: bool = False, ): """ Initializer for the edge pruning class. Parameters ---------- edge_selection_strat : str, optional Strategy used by :class:`EdgeSelector` to propose candidate edges for pruning during every iteration. Default is "random_uniform_wo_repl". candidate_fraction : float, optional Number of randomly chosen reservoir edges during every pruning iteration that is a candidate for pruning. Refers to the fraction of edges w.r.t. current number of edges during pruning iteration. Must be in (0, 1]. Default is 0.1. pruning_criterion : str, optional The criterion used to score and select among candidate edges. Must be a key of ``PRUNING_CRITERION``. Default is "performance". stopping_criterion : list of str, optional One or more criteria used to decide when to stop pruning. Pruning is stopped as soon as one criteria is fullfilled. Each entry must be a key of ``STOPPING_CRITERION``. Default is ["patience"]. min_num_nodes : int, optional Stop pruning when arriving at this number of nodes. Must be larger than 2. Default is 3. min_num_edges : int, optional Stop pruning when arriving at this number of edges. Default is 2. patience : int, optional We allow a patience, i.e. to keep pruning after we reached a (local) minimum of the test set loss. Default is 0. performance_criterion : str, optional Loss metric used to evaluate model performance, both for steering ``pruning_criterion='performance'`` and for tracking real loss (patience, history, ``return_best_model``) regardless of which ``pruning_criterion`` is active. Must be a key of :func:`pyreco.metrics.available_metrics`. Default is "mse". structural_criterion : str, optional Edge property used to score candidates when ``pruning_criterion='structure'``. Must be a key of :func:`pyreco.edge_analyzer.available_extractors` (e.g. "betweenness", "source_out_degree"). Unused for other pruning criteria. Default is "betweenness" Still needs to be implemented. metrics : list or str, optional Additional metrics to track throughout the pruning history, without influencing pruning decisions. Each entry must be a key of :func:`pyreco.metrics.available_metrics`. Default is ["mse"]. return_best_model : bool, optional Whether to return the model with the lowest recorded loss instead of the last model produced before stopping. Default is True. graph_analyzer : GraphAnalyzer, optional Analyzer used to extract graph-level properties. A default instance is created if not provided. node_analyzer : NodeAnalyzer, optional Analyzer used to extract node-level properties. A default instance is created if not provided. edge_analyzer : EdgeAnalyzer, optional Analyzer used to extract edge-level properties. A default instance is created if not provided. remove_isolated_nodes : bool, optional Whether to remove isolated nodes during pruning. Default is False. directed : bool, optional Whether the reservoir graph is treated as directed. Default is True. parallel : bool, optional Whether to parallelize candidate evaluation across CPU cores using ``joblib``. Default is False. Raises ------ TypeError If any parameter has an invalid type. ValueError If any parameter has an invalid value. NotImplementedError If ``pruning_criterion`` or an entry of ``stopping_criterion`` is not a recognized strategy. """ # Sanity checks for the input parameter types and values self._validate_init_params( edge_selection_strat, candidate_fraction, pruning_criterion, stopping_criterion, min_num_nodes, min_num_edges, patience, performance_criterion, structural_criterion, metrics, return_best_model, graph_analyzer, node_analyzer, edge_analyzer, remove_isolated_nodes, directed, parallel, ) # If not given create analyzer classes for properties to store if graph_analyzer is None: graph_analyzer = GraphAnalyzer() if node_analyzer is None: node_analyzer = NodeAnalyzer() if edge_analyzer is None: edge_analyzer = EdgeAnalyzer() # Assigning the parameters to instance variables # Parameters for pruning criterion self.performance_criterion = performance_criterion self.structural_criterion = structural_criterion self.pruning_criterion = pruning_criterion # Parameters for stopping criterion self.stopping_criterion = stopping_criterion self.min_num_nodes = min_num_nodes self.min_num_edges = min_num_edges self.patience = patience # Parameters for candidate selection self.candidate_fraction = candidate_fraction self.edge_selection_strat = edge_selection_strat self.directed = directed # Parameters for tracking metrics self.metrics = metrics self.graph_analyzer = graph_analyzer self.node_analyzer = node_analyzer self.edge_analyzer = edge_analyzer # Parameter bools for extra functionalities self.return_best_model = return_best_model self.remove_isolated_nodes = remove_isolated_nodes self.parallel = parallel # Initialize history dict to store the history of the pruning process in a # nested dictionary self.history = {} # Initialize attributes that will be used during pruning (and changed during # the process) # Needs to be attributes as the history updates depend on them self._curr_model = None self._curr_loss = None self._curr_num_nodes = None self._curr_num_edges = None self._curr_loss_history = [] self._best_loss = None self._idx_prune = None self._patience_counter = 0 self._curr_metrics = None
[docs] def prune(self, model: RC, data_train: tuple, data_val: tuple): """ Prune a given model by iteratively removing edges. Parameters ---------- model : RC Reservoir computer model to prune. data_train : tuple Tuple (x_train, y_train) of training data. data_val : tuple Tuple (x_val, y_val) of validation data, used to score candidates and to evaluate the stopping criterion. Returns ------- model : RC Pruned model (refit on ``data_train``). If ``return_best_model`` is True, this is the model with the lowest recorded loss during pruning, otherwise the last model produced before stopping. history : dict Nested dictionary recording, per pruning iteration, the model state before and after pruning, the evaluated candidates and their scores, and any nodes removed as a side effect. """ # Sanity checks for the input parameter types and values self._validate_pruning_params(model, data_train, data_val) # Obtain training and testing data x_test, y_test = data_val[0], data_val[1] x_train, y_train = data_train[0], data_train[1] # Assigning the parameters to instance variables that can not be set # in the initializer, as they depend on the model and data self._curr_num_nodes = model.reservoir_layer.nodes _graph = model.reservoir_layer.weights if isinstance(_graph, nx.Graph): edge_indices = list(_graph.edges()) elif isinstance(_graph, np.ndarray): rows, cols = np.where(_graph != 0) # where entries are not zero edge_indices = list(zip(rows, cols)) if not self.directed: edge_indices = [(r, c) for r, c in edge_indices if r < c] self._curr_num_edges = len(edge_indices) # Initialize the quantities that affect the stop condition self._curr_loss = model.evaluate(x=x_test, y=y_test, metrics=self.performance_criterion)[0] self._curr_loss_history = [self._curr_loss] # Initialize quantities that we track for the pruning history # These do not affect the pruning process self._curr_metrics = model.evaluate(x=x_test, y=y_test, metrics=self.metrics) # Storing all pruned models during the pruning iteration # Allows to recover models from previous iterations, e.g. when the best model # is not the last one in the iteration (positive patience value) _pruned_models = [copy.deepcopy(model)] _pruned_models_losses = [self._curr_loss] # Initialize the pruning iterator self._iter_count = 0 while True: # min_edges checked here (pre-pruning) so we avoid evaluating candidates # when we already know we would violate the constraint if 'min_edges' in self.stopping_criterion: if not self._min_num_edges_stopping(): break # Store initial model and properties in history self._history_update_before_pruning(model, _graph) print(f'Currently at pruning iteration {self._iter_count} ...') print( f'Current reservoir size: {self._curr_num_nodes} | ' f'Current loss: {self._curr_loss:.8f}' ) # Get candidates for pruning _curr_candidates = self._get_candidates(_graph) # Apply chosen pruning strategy on candidates to get scores for candidates _candidate_scores, _candidate_models, _cand_graph_props_after = \ self._apply_pruning_strategy(model, _curr_candidates, x_train, y_train, x_test, y_test) # Store candidates and properties evaluated during iteration in history self._history_update_candidate_iteration(_graph, _curr_candidates, _candidate_scores, _cand_graph_props_after) # Out of all candidates select candidate with best score (the one to prune) idx_prune, pruned_candidate = self._get_best_candidate(_curr_candidates, _candidate_scores) self._curr_idx_prune = idx_prune # Get model properties of selected candidate and update the termination # relevant quantities curr_model, curr_loss, curr_num_nodes, curr_num_edges = \ self._get_best_candidate_model_properties(idx_prune, _candidate_models, x_test, y_test) # TODO rethink if we should do the setting of these props above already and # rename function so we can use it more # and then also do current amount of edges self._curr_model = curr_model self._curr_loss = curr_loss self._curr_num_nodes = curr_num_nodes self._curr_num_edges = curr_num_edges self._curr_loss_history.append(self._curr_loss) # Check for isolated nodes and remove if no effect on performance removed_nodes = {} if self.remove_isolated_nodes: isolated_nodes = self._get_isolated_nodes(self._curr_model) self._curr_model, removed_nodes = self._remove_isolated_nodes( isolated_nodes, self._curr_model, x_train, y_train, x_test, y_test) self._curr_num_nodes = self._curr_model.reservoir_layer.nodes # Isolated readout node removal may have updated self._curr_loss above # (see _remove_isolated_nodes), keep appended history entry # in sync so patience/_curr_loss_history represents true values self._curr_loss_history[-1] = self._curr_loss # Check stopping criterion on to be pruned candidate model properties # If termination criteria would be violated by pruning candidate we stop # pruning if self._check_stopping_criterion(): # Exit pruning loop if stopping criterion condition is met # Remove partial history entry for this iteration before exiting to # return clean history del self.history[self._iter_count] break if isinstance(pruned_candidate, tuple): # Clean print of pruned_candidate pruned_candidate = (int(pruned_candidate[0]), int(pruned_candidate[1])) print(f'Pruning candidate {pruned_candidate}, ' f'resulting in loss {self._curr_loss:.6f}') prev_loss = self._curr_loss_history[-2] loss_improvement = (prev_loss-self._curr_loss) / prev_loss print( 'Loss improvement to previous iteration by ' f'{loss_improvement:+.3%}\n' ) # Prune edge that gives us the least performance drop # As we have already pruned edge and stored the model, we only need to # update the model # Get it from self._curr_model to not discard isolated-node removal # model = _candidate_models[idx_prune] model = self._curr_model _graph = model.reservoir_layer.weights # Store the model and loss for later use _pruned_models.append(copy.deepcopy(model)) _pruned_models_losses.append(self._curr_loss) # Compute things that are required for history, but not for # pruning loop termination criteria self._curr_metrics = model.evaluate( x=x_test, y=y_test, metrics=self.metrics ) # Store important data after pruning in history self._history_update_after_pruning(model, _graph, idx_prune, _curr_candidates, removed_nodes) # Update iteration counter self._iter_count += 1 # In case we have a non-zero patience, we might want to return the best model # instead of the last one (i.e. when a positive patience value was given) if self.return_best_model: idx_best = np.argmin(_pruned_models_losses) model = copy.deepcopy(_pruned_models[idx_best]) print(f'\nReturning model form iteration {idx_best-1} as the best') # Fit the final model, and evaluate it model.fit(x=x_train, y=y_train) final_loss = model.evaluate(x=x_test, y=y_test, metrics=self.performance_criterion)[0] final_metrics = model.evaluate(x=x_test, y=y_test, metrics=self.metrics) print( f'\nInitial loss: {self._curr_loss_history[0]:.6f}, ' f'loss after pruning: {final_loss:.6f}' ) print(f'Final model has {model.reservoir_layer.nodes} nodes') print(f'Final model loss {self.performance_criterion}: {final_loss:.6f}') print(f'Final model metrics ({self.metrics}): {final_metrics}') return model, self.history
# ######## PRUNING STEPS FUNCTIONS ########## def _history_update_before_pruning(self, model, graph): """ Record the state of the model at the start of a pruning iteration. Parameters ---------- model : RC Model as it stands before this iteration's candidate edges are evaluated. graph : nx.Graph or np.ndarray Reservoir's adjacency representation corresponding to ``model``. Returns ------- None Updates ``self.history[self._iter_count]['starting_model']`` in place. """ # Get graph properties of current model graph_props = self.graph_analyzer.extract_properties(graph=graph) # Store the properties of current model self.history[self._iter_count] = { 'starting_model': { 'weights': sp.csr_matrix(graph) if isinstance(graph, np.ndarray) else nx.to_scipy_sparse_array(graph), 'input_nodes': list(model.reservoir_layer.input_receiving_nodes), 'readout_nodes': list(model.readout_layer.readout_nodes), 'loss': self._curr_loss, 'num_nodes': self._curr_num_nodes, 'num_edges': self._curr_num_edges, 'metrics': self._curr_metrics, 'graph_props': graph_props, } } def _get_candidates(self, graph): """ Propose candidate edges for pruning in the current iteration. Parameters ---------- graph : nx.Graph or np.ndarray Reservoir's current adjacency representation. Returns ------- list of tuple Edge indices (u, v) proposed as pruning candidates, selected via ``self.edge_selection_strat`` and ``self.candidate_fraction``. """ selector = EdgeSelector( graph=graph, strategy=self.edge_selection_strat, directed=self.directed ) # Obtain edges for pruning _curr_candidates = selector.select_edges(fraction=self.candidate_fraction) print( f'Proposing {selector.num_select_edges}/{selector.num_total_edges} ' 'edges for pruning ...' ) return _curr_candidates def _apply_pruning_strategy(self, model, candidates, x_train, y_train, x_test, y_test): """ Dispatch candidate scoring to the configured pruning criterion. Looks up ``self.pruning_criterion`` in ``PRUNING_CRITERION`` and calls the corresponding method (e.g. ``_performance_pruning``). Parameters ---------- model : RC Current (not yet pruned) model. candidates : list of tuple Edge indices (u, v) to evaluate as pruning candidates. x_train : np.ndarray Training input data. y_train : np.ndarray Training target data. x_test : np.ndarray Validation input data. y_test : np.ndarray Validation target data. Returns ------- candidate_scores : list of float Score for each candidate (lower is better, i.e. a better pruning choice). candidate_models : list of RC Model resulting from removing each candidate edge. cand_graph_props_after : list of dict Graph-level properties of the reservoir after removing each candidate. """ method = getattr(self, self.PRUNING_CRITERION[self.pruning_criterion]) return method(model, candidates, x_train, y_train, x_test, y_test) def _history_update_candidate_iteration(self, graph, candidates, candidate_scores, cand_graph_props_after): """ Record all evaluated candidates and their scores for this iteration. Parameters ---------- graph : nx.Graph or np.ndarray Reservoir's weight matrix before pruning, used to extract edge-level properties of the candidates. candidates : list of tuple Edge indices (u, v) that were evaluated as pruning candidates. candidate_scores : list of float Score for each candidate, in the same order as ``candidates``. cand_graph_props_after : list of dict Graph-level properties of the reservoir after removing each candidate, in the same order as ``candidates``. Returns ------- None Updates ``self.history[self._iter_count]['candidates']`` in place. """ # Get scores, edge properties and potential graph properties of candidates cand_edge_props = self.edge_analyzer.extract_properties_batch(graph, candidates) self.history[self._iter_count]['candidates'] = { (int(c[0]), int(c[1])): { 'score': score, 'edge_props': edge_props, 'graph_props_after': graph_props, } for c, score, edge_props, graph_props in zip( candidates, candidate_scores, cand_edge_props, cand_graph_props_after ) } def _get_best_candidate(self, candidates, candidate_scores): """ Select the candidate edge with the best (lowest) score. Parameters ---------- candidates : list of tuple Edge indices (u, v) that were evaluated as pruning candidates. candidate_scores : list of float Score for each candidate, in the same order as ``candidates``. Returns ------- idx_prune : int Index, into ``candidates`` and ``candidate_scores``, of the selected candidate. pruned_candidate : tuple Edge (u, v) to be pruned. """ idx_prune = np.argmin(candidate_scores) pruned_candidate = candidates[idx_prune] # Just for history logging return idx_prune, pruned_candidate def _get_best_candidate_model_properties(self, candidate_idx, candidate_models, x_test, y_test): """ Extract the model and bookkeeping quantities for the selected candidate. The loss is computed here directly from the selected model, rather than reused from the candidate's selection score. The two coincide for ``pruning_criterion='performance'`` (the score is the loss), but not in general (e.g. a structural criterion's score is a graph property, not a loss). Computing it explicitly keeps ``self._curr_loss``/patience/history meaningful regardless of which criterion picked the candidate. Parameters ---------- candidate_idx : int Index of the selected candidate, as returned by ``_get_best_candidate``. candidate_models : list of RC Model resulting from removing each evaluated candidate edge. x_test : np.ndarray Validation input data. y_test : np.ndarray Validation target data. Returns ------- curr_model : RC Model resulting from removing the selected candidate edge. curr_loss : float Loss of ``curr_model`` on (x_test, y_test), evaluated using ``self.performance_criterion``. curr_num_nodes : int Number of reservoir nodes in ``curr_model``. curr_num_edges : int Number of reservoir edges in ``curr_model``. """ curr_model = candidate_models[candidate_idx] curr_loss = curr_model.evaluate(x=x_test, y=y_test, metrics=self.performance_criterion)[0] curr_num_nodes = curr_model.reservoir_layer.nodes curr_graph = curr_model.reservoir_layer.weights if isinstance(curr_graph, nx.Graph): edge_indices = list(curr_graph.edges()) elif isinstance(curr_graph, np.ndarray): rows, cols = np.where(curr_graph != 0) # Where entries are not zero edge_indices = list(zip(rows, cols)) if not self.directed: edge_indices = [(r, c) for r, c in edge_indices if r < c] curr_num_edges = len(edge_indices) return curr_model, curr_loss, curr_num_nodes, curr_num_edges def _get_isolated_nodes(self, model): """ Find nodes that are isolated (degree 0) in the reservoir graph. Parameters ---------- model : RC Model whose reservoir graph is inspected. Returns ------- list of dict One entry per isolated node, each with keys ``'id'`` (node id), ``'is_input'`` (whether the node receives input), and ``'is_readout'`` (whether the node feeds the readout layer). """ graph = model.reservoir_layer.weights input_nodes = set(model.reservoir_layer.input_receiving_nodes) readout_nodes = set(model.readout_layer.readout_nodes) isolated_nodes = [] if isinstance(graph, nx.Graph): for node in graph.nodes(): if graph.degree(node) == 0: isolated_nodes.append({ 'id': node, 'is_input': node in input_nodes, 'is_readout': node in readout_nodes, }) elif isinstance(graph, np.ndarray): for node in range(graph.shape[0]): if np.count_nonzero(graph[node, :]) == 0 and \ np.count_nonzero(graph[:, node]) == 0: isolated_nodes.append({ 'id': node, 'is_input': node in input_nodes, 'is_readout': node in readout_nodes, }) return isolated_nodes def _remove_isolated_nodes(self, isolated_nodes, model, x_train, y_train, x_test, y_test): """ Remove isolated nodes from the model where safe to do so. Non-input, non-readout isolated nodes are removed unconditionally. Isolated readout nodes are removed only if doing so does not increase the loss; input-receiving isolated nodes are never removed. Parameters ---------- isolated_nodes : list of dict Isolated nodes as returned by ``_get_isolated_nodes``. model : RC Model to remove isolated nodes from. x_train : np.ndarray Training input data. y_train : np.ndarray Training target data. x_test : np.ndarray Validation input data. y_test : np.ndarray Validation target data. Returns ------- model : RC Model with eligible isolated nodes removed. removed : dict Mapping of removed node id to a dict with keys ``'loss_after'`` (loss after removal, or None for unconditionally removed nodes) and ``'is_readout'``. """ # Dict to store isolated nodes removed = {} # Remove isolated nodes that are neither input-receiving nor readout nodes fully_isolated_node_ids = [ n['id'] for n in isolated_nodes if not n['is_input'] and not n['is_readout'] ] if fully_isolated_node_ids: print(f'Removing {len(fully_isolated_node_ids)} isolated non-input/readout ' f'nodes: {fully_isolated_node_ids}') model.remove_reservoir_nodes(nodes=fully_isolated_node_ids) # Store removed node in dict for tracking for node_id in fully_isolated_node_ids: removed[node_id] = {'loss_after': None, 'is_readout': False} # Remove isolated readout nodes if removal doesn't affect performance negatively isolated_readout_nodes = [ n for n in isolated_nodes if not n['is_input'] and n['is_readout'] ] # Track original IDs of all removed nodes to adjust indices after renumbering # Each removal shifts down the indices of all higher-numbered nodes by 1 removed_original_ids = list(fully_isolated_node_ids) # Check effect on performance performance for node in isolated_readout_nodes: # Compute how many previously removed nodes had a smaller original ID, # since each such removal shifted this node's current index down by 1 adjusted_id = node['id'] - sum(1 for r in removed_original_ids if r < node['id']) _model = copy.deepcopy(model) _model.remove_reservoir_nodes(nodes=[adjusted_id]) _model.fit(x=x_train, y=y_train) _score = _model.evaluate(x=x_test, y=y_test, metrics=self.performance_criterion)[0] if _score <= self._curr_loss: print(f'Removing isolated readout node {node["id"]:>3}: loss ' f'{self._curr_loss:.6f} -> {_score:.6f}') # Note that node has been removed removed_original_ids.append(node['id']) # Store removed node in dict for tracking removed[node['id']] = {'loss_after': _score, 'is_readout': True} model = _model self._curr_loss = _score else: print(f'Keeping isolated readout node {node["id"]:>3}: removal would ' f'increase loss {self._curr_loss:.6f} -> {_score:.6f}') return model, removed def _check_stopping_criterion(self): """ Check whether any configured stopping criterion has been met. Iterates over ``self.stopping_criterion`` and calls the corresponding method (e.g. ``_patience_stopping``) for each. ``'min_edges'`` is skipped here since it is checked separately, at the top of the pruning loop, before candidates are evaluated. Returns ------- bool True if pruning should stop (any criterion's "continue" check returned False), False if pruning should continue. """ for criterion in self.stopping_criterion: if criterion == 'min_edges': # Handled at the top of the pruning loop (pre-pruning check) continue method = getattr(self, self.STOPPING_CRITERION[criterion]) if not method(): # Criterion is met return True # Criterion is not met return False def _history_update_after_pruning(self, model, graph, idx_prune, candidates, removed_nodes): """ Record the outcome of a pruning iteration. Parameters ---------- model : RC Model after the selected candidate edge has been pruned. graph : nx.Graph or np.ndarray ``model``'s reservoir adjacency representation. idx_prune : int Index, into ``candidates``, of the edge that was pruned. candidates : list of tuple Edge indices (u, v) that were evaluated this iteration. removed_nodes : dict Isolated nodes removed this iteration, as returned by ``_remove_isolated_nodes``. Returns ------- None Updates ``self.history[self._iter_count]`` in place with the ``'winner'``, ``'removed_nodes'``, and ``'final_model'`` entries. """ # Get final graph properties and winner graph_props = self.graph_analyzer.extract_properties(graph=graph) winner = candidates[idx_prune] # Store winner, removed nodes and properties of winner model self.history[self._iter_count]['winner'] = (int(winner[0]), int(winner[1])) self.history[self._iter_count]['removed_nodes'] = removed_nodes self.history[self._iter_count]['final_model'] = { 'weights': sp.csr_matrix(graph) if isinstance(graph, np.ndarray) else nx.to_scipy_sparse_array(graph), 'input_nodes': list(model.reservoir_layer.input_receiving_nodes), 'readout_nodes': list(model.readout_layer.readout_nodes), 'loss': self._curr_loss, 'num_nodes': self._curr_num_nodes, 'num_edges': self._curr_num_edges, 'metrics': self._curr_metrics, 'graph_props': graph_props, } # ######## MODEL TRAINING FUNCTIONS ######### def _retrain_model(self, model, x_train, y_train): """ Fit a model on training data. Parameters ---------- model : RC Model to fit. x_train : np.ndarray Training input data. y_train : np.ndarray Training target data. Returns ------- Whatever ``model.fit`` returns. """ return model.fit(x=x_train, y=y_train) # ######## PRUNING CRITERIONS FUNCTIONS ######### def _performance_pruning(self, model, candidates, x_train, y_train, x_test, y_test): """ Score candidates by the model's loss after removing and refitting. For each candidate edge, removes it from a copy of ``model``, refits the copy, and evaluates it on the validation data. Runs serially or in parallel (across CPU cores via ``joblib``) depending on ``self.parallel``. Parameters ---------- model : RC Current (not yet pruned) model. candidates : list of tuple Edge indices (u, v) to evaluate as pruning candidates. x_train : np.ndarray Training input data. y_train : np.ndarray Training target data. x_test : np.ndarray Validation input data. y_test : np.ndarray Validation target data. Returns ------- candidate_scores : list of float Loss of the refit model for each candidate, in the same order as ``candidates``. candidate_models : list of RC Refit model resulting from removing each candidate edge. cand_graph_props_after : list of dict Graph-level properties of the reservoir after removing each candidate. """ if self.parallel: # Parallelizing performance evaluation of candidates # Kicking off parallelization of going through all candidates # tqdm shows process in in bar chart # n_jobs is number of jobs to run in parallel (-1 uses all CPU cores) # backend - loky is default parallel = Parallel(n_jobs=-1, backend='loky') # tqdm shows process in bar chart # Generator returns results in order that they're given # Call seperate evaluation function that doesn't pass the whole self object # (history gets larger with each iteration and therefore slows down when # it's spawned across multiple CPUs) results = parallel( delayed(_evaluate_candidate_standalone) (model, c, x_train, y_train, x_test, y_test, self.performance_criterion, self.graph_analyzer) for c in tqdm(candidates, desc='Evaluating candidates') ) _candidate_scores, _candidate_models, _cand_graph_props_after = \ zip(*results) _candidate_scores = list(_candidate_scores) _candidate_models = list(_candidate_models) _cand_graph_props_after = list(_cand_graph_props_after) else: # Go through candidates one by one in a single process # Initialize lists to track different metrics during pruning iteration _candidate_scores, _candidate_models, _cand_graph_props_after = [], [], [] for candidate in candidates: # Get model performance for removing candidate score, cand_model, props_after = \ self._evaluate_candidate_performance(model, candidate, x_train, y_train, x_test, y_test) # Collect score, model and properties of candiate _candidate_scores.append(score) _candidate_models.append(cand_model) _cand_graph_props_after.append(props_after) return _candidate_scores, _candidate_models, _cand_graph_props_after def _evaluate_candidate_performance(self, model, candidate, x_train, y_train, x_test, y_test): """ Evaluate a single candidate edge removal (serial path). Parameters ---------- model : RC Current (not yet pruned) model. candidate : tuple Edge (u, v) to tentatively remove. x_train : np.ndarray Training input data. y_train : np.ndarray Training target data. x_test : np.ndarray Validation input data. y_test : np.ndarray Validation target data. Returns ------- score : float Loss of the refit model with ``candidate`` removed, evaluated on (x_test, y_test) using ``self.performance_criterion``. model : RC Deep copy of ``model`` with ``candidate`` removed and refit. graph_props_after : dict Graph-level properties of the reservoir after removing ``candidate``. """ # Single candidate run (had to be broken down to this to enable parallelization) # Copy original model for candidate removal _model = copy.deepcopy(model) # Get info on candidate egde and graph before removal _graph = _model.reservoir_layer.weights # Remove candidate edge from reservoir _model.remove_reservoir_edges(edges=[candidate]) # Re-fit (retrain) pruned model _model.fit(x=x_train, y=y_train) # Evaluate pruned model regarding performance criterion _score = _model.evaluate(x=x_test, y=y_test, metrics=self.performance_criterion)[0] # Extract graph properties after pruning _graph = _model.reservoir_layer.weights _graph_props_after = self.graph_analyzer.extract_properties(graph=_graph) if not self.parallel: # Print candidate and score info if not parallelized # Format candidate tuple cleanly label = f"{int(candidate[0])}-{int(candidate[1])}" if \ isinstance(candidate, tuple) else int(candidate) print(f'Possible deletion of edge {label:<10} loss: {_score:.6f} ' f'({(self._curr_loss - _score) / self._curr_loss:+.3%})') # Return canidate score, model and properties return _score, _model, _graph_props_after def _shortest_path_pruning(self, model): """ Placeholder for a shortest-path-based pruning criterion. Not implemented yet. Parameters ---------- model : RC The current (not yet pruned) model. """ # possible other pruning strategies (neglecting for now) pass # ############# STOPPING CRITERIONS FUNCTIONS ############## def _patience_stopping(self): """ Check the patience-based stopping criterion. Patience counts iterations since the all-time best loss was achieved, not just since the last improvement over the previous iteration. Returns ------- bool True if we should continue pruning, False to stop. """ # Checks if the loss is at a minimum, considering also patience. # Patience counts steps since the all-time best loss was achieved, # not just since the last improvement over the previous step # Returns True if we should continue pruning, False to stop. if len(self._curr_loss_history) < 1: return True current_loss = self._curr_loss_history[-1] if self._best_loss is None or current_loss <= self._best_loss: # New best loss — reset patience print( f'Loss improved to new best {current_loss:.6f}. ' '\nContinuing pruning ...' ) self._best_loss = current_loss self._patience_counter = 0 return True else: # No improvement over best loss self._patience_counter += 1 if self._patience_counter <= self.patience: print( f'No improvement over best loss ({self._best_loss:.6f}), patience ' f'{self._patience_counter}/{self.patience}. ' '\nContinuing pruning ...' ) return True else: print( f'No improvement over best loss ({self._best_loss:.6f}) for ' f'{self.patience} consecutive iterations. \nTerminating pruning!' ) return False def _min_num_nodes_stopping(self): """ Check the minimum-number-of-nodes stopping criterion. Returns ------- bool True if the number of nodes is above ``self.min_num_nodes`` and we should continue pruning, False to stop. """ # Checks if the number of nodes is above the minimum number of nodes # Returns True if number of nodes is above minimum and we should continue # pruning if self._curr_num_nodes >= self.min_num_nodes: # When reaching min_num_nodes we should still continue pruning # pruning another edge doesn't mean that a node will be removed print( f'Number of nodes {self._curr_num_nodes} is larger than minimum number ' f'of nodes {self.min_num_nodes}. \nContinuing pruning ...' ) return True else: print( f'Number of nodes {self._curr_num_nodes} is smaller/equal minimum ' f'number of nodes {self.min_num_nodes}. \nTerminating pruning!' ) return False def _min_num_edges_stopping(self): """ Check the minimum-number-of-edges stopping criterion. Unlike the other stopping criteria, this is a pre-pruning check called at the top of the pruning loop, before candidates are evaluated. This avoids scoring candidates we already know we cannot prune. Returns ------- bool True if pruning one more edge would not violate ``self.min_num_edges`` and we should continue, False to stop. """ # Pre-pruning check: called at the top of the while loop before candidate # evaluation. # Stops if pruning one more edge would violate the minimum. if self._curr_num_edges <= self.min_num_edges: print(f'Number of edges {self._curr_num_edges} is minimum number of edges ' f'{self.min_num_edges}. \nTerminating pruning!') return False print(f'Number of edges {self._curr_num_edges} is larger than minimum number ' f'of edges {self.min_num_edges}. \nContinuing pruning ...') return True # ######## VALIDATION FUNCTIONS ######### def _validate_init_params( self, edge_selection_strat, candidate_fraction, pruning_criterion, stopping_criterion, min_num_nodes, min_num_edges, patience, performance_criterion, structural_criterion, metrics, return_best_model, graph_analyzer, node_analyzer, edge_analyzer, remove_isolated_nodes, directed, parallel, ): """ Validate the types and values of parameters passed to ``__init__``. Parameters ---------- edge_selection_strat : str See ``__init__``. candidate_fraction : float See ``__init__``. pruning_criterion : str See ``__init__``. stopping_criterion : list of str See ``__init__``. min_num_nodes : int See ``__init__``. min_num_edges : int See ``__init__``. patience : int See ``__init__``. performance_criterion : str See ``__init__``. structural_criterion : str See ``__init__``. metrics : list or str See ``__init__``. return_best_model : bool See ``__init__``. graph_analyzer : GraphAnalyzer or None See ``__init__``. node_analyzer : NodeAnalyzer or None See ``__init__``. edge_analyzer : EdgeAnalyzer or None See ``__init__``. remove_isolated_nodes : bool See ``__init__``. directed : bool See ``__init__``. parallel : bool See ``__init__``. Raises ------ TypeError If any parameter has an invalid type. ValueError If any parameter has an invalid value. NotImplementedError If ``pruning_criterion`` or an entry of ``stopping_criterion`` is not a recognized strategy. """ # Validate edge selection strategy from pyreco.edge_selector import EdgeSelector if not isinstance(edge_selection_strat, str): raise TypeError('edge_selection_strat must be a string') if edge_selection_strat not in EdgeSelector.STRATEGIES: raise NotImplementedError( f"Unknown edge selection strategy '{edge_selection_strat}'. " f'Available strategies: {list(EdgeSelector.STRATEGIES)}' ) # Validate candidate fraction if not isinstance(candidate_fraction, float): raise TypeError('candidate_fraction must be a float in (0, 1]') if candidate_fraction <= 0 or candidate_fraction > 1: raise ValueError('candidate_fraction must be a float in (0, 1]') # Validate pruning criterion if not isinstance(pruning_criterion, str): raise TypeError('pruning_criterion must be a string') if pruning_criterion not in self.PRUNING_CRITERION: raise NotImplementedError( f"Unknown strategy '{pruning_criterion}'. " f'Available strategies: {list(self.PRUNING_CRITERION)}' ) # Structural pruning not imlemented yet if pruning_criterion == 'structure': raise NotImplementedError( 'Structural pruning criterion is not implemented yet' ) # Validate stopping criterion if not isinstance(stopping_criterion, list): raise TypeError('stopping_criterion must be a list') if len(stopping_criterion) == 0: raise ValueError( 'stopping_criterion must contain at least one criterion, otherwise ' 'pruning has no way to stop and will run until it errors out on an ' 'edgeless graph' ) for sc in stopping_criterion: if sc not in self.STOPPING_CRITERION: raise NotImplementedError( f"Unknown strategy '{sc}'. " f'Available strategies: {list(self.STOPPING_CRITERION)}' ) # Validate min num of nodes if not isinstance(min_num_nodes, int): raise TypeError('min_num_nodes must be an integer') if min_num_nodes <= 2: raise ValueError('min_num_nodes must be larger than 2') # Validate min num of edges if not isinstance(min_num_edges, int): raise TypeError('min_num_edges must be an integer') if min_num_edges < 0: raise ValueError('min_num_edges must be larger than or equal to 0') # Validate patience if not isinstance(patience, int): raise TypeError('patience must be an integer') if patience < 0: raise ValueError('patience must be >= 0') # Validate performance criterion if not isinstance(performance_criterion, str): raise TypeError('performance_criterion must be a string') if performance_criterion not in available_metrics(): raise ValueError( f"Unknown metric '{performance_criterion}'. " f'Available metrics: {available_metrics()}' ) # Validate structural criterion if not isinstance(structural_criterion, str): raise TypeError('structural_criterion must be a string') if structural_criterion not in available_extractors(): raise ValueError( f"Unknown structural criterion '{structural_criterion}'. " f'Available structural criteria: {list(available_extractors())}' ) # Validate metrics if not isinstance(metrics, (list, str)): raise TypeError('metrics must be a list or a string') if isinstance(metrics, list) and not all(isinstance(m, str) for m in metrics): raise TypeError('metrics must be a list of strings') _metrics_list = metrics if isinstance(metrics, list) else [metrics] if not all(m in available_metrics() for m in _metrics_list): raise ValueError( f'Unknown metric in {metrics}. ' f'Available metrics: {available_metrics()}' ) # Validate return_best_model if not isinstance(return_best_model, bool): raise TypeError('return_best_model must be a boolean') # Validate graph analyzer if graph_analyzer is not None and not isinstance(graph_analyzer, GraphAnalyzer): raise TypeError('graph_analyzer must be an instance of GraphAnalyzer') # Validate node analyzer if node_analyzer is not None and not isinstance(node_analyzer, NodeAnalyzer): raise TypeError('node_analyzer must be an instance of NodeAnalyzer') # Validate edge analyzer if edge_analyzer is not None and not isinstance(edge_analyzer, EdgeAnalyzer): raise TypeError('edge_analyzer must be an instance of EdgeAnalyzer') # Validate remove isolated nodes if not isinstance(remove_isolated_nodes, bool): raise TypeError('remove_isolated_nodes must be a boolean') # Validate directed if not isinstance(directed, bool): raise TypeError('directed must be a boolean') # Validate parallel if not isinstance(parallel, bool): raise TypeError('parallel must be a boolean') def _validate_pruning_params(self, model, data_train, data_val): """ Validate the types and values of parameters passed to ``prune``. Parameters ---------- model : RC See ``prune``. data_train : tuple See ``prune``. data_val : tuple See ``prune``. Raises ------ TypeError If ``model`` is not an ``RC`` instance, ``data_train``/ ``data_val`` are not tuples, or their elements are not lists or numpy arrays. ValueError If ``data_train``/``data_val`` do not have exactly 2 elements, or if the inputs and targets within either do not have matching lengths. """ if not isinstance(model, RC): raise TypeError('model must be an instance of RC') if not isinstance(data_train, tuple) or not isinstance(data_val, tuple): raise TypeError('data_train and data_val must be tuples') if len(data_train) != 2 or len(data_val) != 2: raise ValueError('data_train and data_val must have 2 elements each') for idx, elem in enumerate(data_train): if not isinstance(elem, list): if not isinstance(elem, np.ndarray): raise TypeError(f'data_train[{idx}] must be a list or numpy array') for idx, elem in enumerate(data_val): if not isinstance(elem, list): if not isinstance(elem, np.ndarray): raise TypeError(f'data_val[{idx}] must be a list or numpy array') if len(data_train[0]) != len(data_train[1]): raise ValueError( 'data_train[0] and data_train[1] must have the same length, ' 'i.e. same number of samples' ) if len(data_val[0]) != len(data_val[1]): raise ValueError( 'data_val[0] and data_val[1] must have the same length, ' 'i.e. same number of samples' )
if __name__ == "__main__": # Test the pruning from pyreco.utils_data import sequence_to_sequence as seq_2_seq from pyreco.custom_models import RC as RC from pyreco.layers import InputLayer, ReadoutLayer from pyreco.layers import RandomReservoirLayer from pyreco.optimizers import RidgeSK # Get some data X_train, X_test, y_train, y_test = seq_2_seq( name='sine_pred', n_batch=20, n_states=2, n_time=150 ) input_shape = X_train.shape[1:] output_shape = y_train.shape[1:] # Build a classical RC model = RC() model.add(InputLayer(input_shape=input_shape)) model.add( RandomReservoirLayer( nodes=50, density=0.1, activation='tanh', leakage_rate=0.1, fraction_input=0.5, ), ) model.add(ReadoutLayer(output_shape, fraction_out=0.9)) # Compile the model optim = RidgeSK(alpha=0.5) model.compile( optimizer=optim, metrics=['mean_squared_error'], ) # Train the model model.fit(X_train, y_train) print(f'Ccore: \t\t\t{model.evaluate(x=X_test, y=y_test)[0]:.4f}') # Prune the model pruner = EdgePruner( # min_num_nodes=46, # stopping_criterion=['patience'], # stopping_criterion=['min_edges', 'patience'], stopping_criterion=['min_edges'], patience=2, min_num_edges=0, candidate_fraction=0.9, remove_isolated_nodes=True, metrics=["mse"], parallel=True ) model_pruned, history = pruner.prune( model=model, data_train=(X_train, y_train), data_val=(X_test, y_test) )