Source code for pyreco.edge_selector

import random
import networkx as nx
import numpy as np
import warnings


[docs] class EdgeSelector: ''' A class to select edges from a graph based on specific criteria, analogue to 'NodeSelector'. Parameters ---------- graph : nx.Graph or np.ndarray NetworkX graph or adjacency matrix to select edges from. strategy : str, optional Strategy used for edge selection. Default is "random_uniform_wo_repl". Attributes ---------- graph : nx.Graph or np.ndarray Input graph. directed : bool, optional Whether the graph is directed. Only used for np.ndarray inputs. Default is True. edge_indices : list of tuple List of all edge indices in the graph. num_total_edges : int Total number of edges in the graph. num_select_edges : int Number of edges to select. Set after calling select_edges. fraction : float Fraction of edges to select. Set after calling select_edges. strategy : callable Method corresponding to the selected strategy. selected_edges : list List of selected edges. Set after calling select_edges. Raises ------ TypeError If graph is not a nx.Graph or np.ndarray. ValueError If graph is None. NotImplementedError If the requested strategy is not implemented. ''' # Selection strategies implemented # TODO think about putting these in classes and have abstract base class to # define them STRATEGIES = { 'random_uniform_wo_repl': '_random_uniform_wo_repl', } def __init__( self, graph: nx.Graph | np.ndarray = None, strategy: str = 'random_uniform_wo_repl', directed: bool = True, ): # Sanity checks for passed graph and selection strategy self._validate_graph(graph, directed) self._validate_strategy(strategy) # Get all edges of the graph edge_indices = self._extract_edges(graph, directed) # Assign values to attributes self.graph = graph self.directed = directed self.edge_indices = edge_indices self.num_total_edges: int = len(self.edge_indices) self.num_select_edges: int = None self.fraction: float = None self.strategy = getattr(self, self.STRATEGIES[strategy]) self.selected_edges: list = []
[docs] def select_edges( self, fraction: float = None, num: int = None, ): ''' Select a subset of edges from the graph. Parameters ---------- fraction : float, optional Fraction of total edges to select. Must be in (0, 1]. Mutually exclusive with num. num : int, optional Exact number of edges to select. Must be a positive integer no greater than the total number of edges. Mutually exclusive with fraction. Returns ------- list of tuple List of selected edge indices as (source, target) tuples. Raises ------ ValueError If neither or both of fraction and num are provided. TypeError If num is not an integer, or fraction is not a float. ValueError If num is not in [1, num_total_edges], or fraction not in (0, 1]. ''' # Sanity checks for passed fraction and num of edges self._validate_selection_args(fraction, num) # Calculate number of edges to select or fraction # max(1, ...) to ensure at least one edge is always proposed, preventing # fraction * small edge counts from rounding down to 0 (enables pruning to 0) self.num_select_edges = num if num is not None else \ max(1, round(self.num_total_edges * fraction)) self.fraction = fraction if fraction is not None else num / self.num_total_edges self.selected_edges = self.strategy() return self.selected_edges
def _random_uniform_wo_repl(self): ''' Select num_select_edges edges by uniform random sampling without replacement. Returns ------- list of tuple Randomly sampled edge indices as (source, target) tuples. ''' return random.sample(self.edge_indices, self.num_select_edges) def _validate_graph(self, graph, directed): ''' Validate the graph passed to the class. Parameters ---------- graph : any Graph object to validate. directed : bool Whether graph is directed. Used to verify consistency with nx.Graph/nx.DiGraph types, and to warn about symmetric np.ndarray inputs. Raises ------ TypeError If graph is not a nx.Graph or np.ndarray. ValueError If graph is None. If directed=True but an undirected nx.Graph is passed. If directed=False but a directed nx.DiGraph is passed. Warns ----- UserWarning If directed=True and np.ndarray appears symmetric. If directed=False and np.ndarray appears asymmetric. ''' if graph is not None: if not isinstance(graph, nx.Graph) and not isinstance(graph, np.ndarray): raise TypeError('graph must be a networkx graph or np.ndarray') if isinstance(graph, nx.Graph): if directed and not isinstance(graph, nx.DiGraph): raise ValueError('directed=True but graph is undirected nx.Graph, ' 'use nx.DiGraph instead') if not directed and isinstance(graph, nx.DiGraph): raise ValueError('directed=False but graph is directed nx.DiGraph, ' 'use nx.Graph instead') # directed=True but matrix is symmetric if directed and isinstance(graph, np.ndarray) and np.array_equal(graph, graph.T): warnings.warn('directed graph appears symmetric, verify this is ' 'intended') # directed=False but matrix is asymmetric if not directed and isinstance(graph, np.ndarray) and not \ np.array_equal(graph, graph.T): warnings.warn('directed=False but graph appears asymmetric, verify this' ' is intended') else: raise ValueError('Graph must be provided') def _validate_strategy(self, strategy): ''' Validate the strategy passed to the class. Parameters ---------- strategy : str Strategy name to validate. Raises ------ NotImplementedError If strategy is not a key in STRATEGIES. ''' if strategy not in self.STRATEGIES: raise NotImplementedError( f"Unknown strategy '{strategy}'. " "Available strategies: {list(self.STRATEGIES)}" ) def _validate_selection_args(self, fraction, num): ''' Validate the arguments passed to select_edges. Parameters ---------- fraction : float or None Fraction of edges to select. num : int or None Exact number of edges to select. Raises ------ ValueError If neither or both of fraction and num are provided. If num is not in [1, num_total_edges]. If fraction is not in (0, 1]. TypeError If num is not an integer. If fraction is not a float. ''' # Check that either fraction or num is provided if fraction is None and num is None: raise ValueError('Provide either fraction or num, not neither') if fraction is not None and num is not None: raise ValueError('Provide either fraction or num, not both') # Check that selected num is an integer and not larger than total amount of # edges in graph if num is not None: if not isinstance(num, int): raise TypeError('num must be an integer') if not (0 < num <= self.num_total_edges): raise ValueError(f'num must be between 1 and {self.num_total_edges}') # Check that selected fraction is a float and larger than zero but not larger # than 1 if fraction is not None: if not isinstance(fraction, float): raise TypeError('fraction must be a float') if not (0.0 < fraction <= 1.0): raise ValueError('fraction must be in (0, 1]') def _extract_edges(self, graph: nx.Graph | np.ndarray, directed) -> list[tuple]: ''' Extract edge indices and shape information from a graph. Parameters ---------- graph : nx.Graph or np.ndarray Graph to extract edges from. Returns ------- edge_indices : list of tuple List of (source, target) tuples representing all edges. ''' # Get the edges in the graph and the shape of the graph # Prunable edges in graph if isinstance(graph, nx.Graph): edge_indices = list(graph.edges()) return edge_indices elif isinstance(graph, np.ndarray): rows, cols = np.where(graph != 0) # where entries are not zero edge_indices = list(zip(rows, cols)) if not directed: edge_indices = [(r, c) for r, c in edge_indices if r < c] return edge_indices
if __name__ == "__main__": # Create a sample graph G = nx.erdos_renyi_graph(10, 0.5, directed=True) print(G.edges()) # all edges print(G.number_of_edges()) # total edge count print(G.number_of_nodes()) # total node count # Graphs edges print(f'Possible edges: {G.edges()}') # Select random edges selector = EdgeSelector(strategy="random_uniform_wo_repl", graph=G) random_edges = selector.select_edges(num=4) print(f'Randomly selected edges: {random_edges}') # Create a sample graph G = nx.erdos_renyi_graph(10, 0.5, directed=True) # Graphs edges print(f'Possible edges: {G.edges()}') # Select random edges selector = EdgeSelector(strategy="random_uniform_wo_repl", graph=G) random_edges = selector.select_edges(fraction=0.5) print(f'Randomly selected edges: {random_edges}')