diff --git a/CHANGELOG.md b/CHANGELOG.md index cdaeb24..fe99b06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ - Remove agora implementation and tests. ([#73](https://github.com/polis-community/red-dwarf/issues/74)) - Migrate from reference HDBSCAN module (in `scikit-learn`) to full-featured HDBSCAN* package. - Add dependency groups to avoid installing everything. ([#11](https://github.com/polis-community/red-dwarf/issues/11)) +- Add simple reducer registry for adding more dimensional reduction algos at runtime. ### Fixes diff --git a/docs/api_reference.md b/docs/api_reference.md index 0b18b9f..89e9397 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -26,14 +26,27 @@ use in Scikit-Learn workflows, pipelines, and APIs. options: show_root_heading: true +### ::: reddwarf.sklearn.cluster.BestPolisKMeans + options: + show_root_heading: true + docstring_style: numpy + ### ::: reddwarf.sklearn.model_selection.GridSearchNonCV options: show_root_heading: true +### ::: reddwarf.sklearn.pipeline.PatchedPipeline + options: + show_root_heading: true + ### ::: reddwarf.sklearn.transformers.SparsityAwareScaler options: show_root_heading: true +### ::: reddwarf.sklearn.transformers.SparsityAwareCapturer + options: + show_root_heading: true + ## `reddwarf.utils.matrix` ### ::: reddwarf.utils.matrix.generate_raw_matrix @@ -58,15 +71,30 @@ use in Scikit-Learn workflows, pipelines, and APIs. options: show_root_heading: true +## `reddwarf.utils.reducer.registry` + +### ::: reddwarf.utils.reducer.registry.register_reducer + options: + show_root_heading: true + +### ::: reddwarf.utils.reducer.registry.get_reducer + options: + show_root_heading: true + +### ::: reddwarf.utils.reducer.registry.list_reducers + options: + show_root_heading: true + ## `reddwarf.utils.clusterer` ### ::: reddwarf.utils.clusterer.base.run_clusterer options: show_root_heading: true -### ::: reddwarf.utils.clusterer.kmeans.find_best_kmeans +### ::: reddwarf.sklearn.cluster.BestPolisKMeans options: show_root_heading: true + docstring_style: numpy ## `reddwarf.utils.consensus` diff --git a/reddwarf/implementations/base.py b/reddwarf/implementations/base.py index 6f38e6a..16f226d 100644 --- a/reddwarf/implementations/base.py +++ b/reddwarf/implementations/base.py @@ -11,8 +11,8 @@ simple_filter_matrix, get_clusterable_participant_ids, ) -from reddwarf.utils.reducer.base import ReducerType, ReducerModel, run_reducer -from reddwarf.utils.clusterer.base import ClustererType, ClustererModel +from reddwarf.utils.reducer.base import ReducerModel, run_reducer +from reddwarf.utils.clusterer.base import ClustererModel from reddwarf.utils.stats import ( calculate_comment_statistics_dataframes, populate_priority_calculations_into_statements_df, @@ -54,9 +54,9 @@ class PolisClusteringResult: def run_pipeline( votes: list[dict], - reducer: ReducerType = "pca", + reducer: str = "pca", reducer_kwargs: dict = {}, - clusterer: ClustererType = "kmeans", + clusterer: str = "kmeans", clusterer_kwargs: dict = {}, mod_out_statement_ids: list[int] = [], meta_statement_ids: list[int] = [], @@ -79,9 +79,9 @@ def run_pipeline( Args: votes (list[dict]): Raw list of vote dicts, with keys for "participant_id", "statement_id", "vote" and "modified" - reducer (ReducerType): Selects the type of reducer model to use. + reducer (str): Selects the type of reducer model to use. reducer_kwargs (dict): Extra params to pass to reducer model during initialization. - clusterer (ClustererType): Selects the type of clusterer model to use. + clusterer (str): Selects the type of clusterer model to use. clusterer_kwargs (dict): Extra params to pass to clusterer model during initialization. mod_out_statement_ids (list[int]): List of statement IDs to moderate/zero out meta_statement_ids (list[int]): List of meta statement IDs diff --git a/reddwarf/sklearn/cluster.py b/reddwarf/sklearn/cluster.py index 23b2de8..eef9dd7 100644 --- a/reddwarf/sklearn/cluster.py +++ b/reddwarf/sklearn/cluster.py @@ -1,9 +1,34 @@ -from typing import Optional +from typing import List, Optional + import numpy as np -from numpy.typing import ArrayLike +from numpy.typing import ArrayLike, NDArray +from sklearn.base import BaseEstimator, TransformerMixin from sklearn.cluster import KMeans, kmeans_plusplus +from sklearn.metrics import silhouette_score from sklearn.utils.validation import check_random_state, check_array +from reddwarf.sklearn.model_selection import GridSearchNonCV + + +def _to_range(r) -> range: + """ + Creates an inclusive range from a list, tuple, or int. + + Examples: + _to_range(2) # [2] + _to_range([2, 5]) # [2, 3, 4, 5] + _to_range((2, 5)) # [2, 3, 4, 5] + """ + if isinstance(r, int): + start = end = r + elif isinstance(r, (tuple, list)) and len(r) == 2: + start, end = r + else: + raise ValueError("Expected int or a 2-element tuple/list") + + return range(start, end + 1) # inclusive + + class PolisKMeans(KMeans): """ A modified version of scikit-learn's KMeans that allows partial initialization @@ -76,7 +101,7 @@ def __init__( self.init_centers = init_centers self.init_centers_used_ = None - def _generate_centers(self, X, x_squared_norms, n_to_generate, random_state): + def _generate_centers(self, X, x_squared_norms, n_to_generate, random_state) -> np.ndarray: if not isinstance(self._init_strategy, str): raise ValueError("Internal error: _strategy must be a string.") @@ -122,30 +147,37 @@ def fit(self, X, y=None, sample_weight=None): X, x_squared_norms, self.n_clusters, random_state ) - # Override the init param passed to sklearn with actual centers. - # We take control of the initialization strategy (`k-means++`, `random`, - # `polis`, etc) in our own code. + # Override the init param with our computed centers super().set_params(init=self.init_centers_used_) - return super().fit(X, y=y, sample_weight=sample_weight) -from sklearn.base import BaseEstimator, TransformerMixin + class PolisKMeansDownsampler(BaseEstimator, TransformerMixin): """ - A transformer that fits `PolisKMeans` and returns the cluster centers as the - downsampled dataset. + A transformer that fits PolisKMeans and returns cluster centers as downsampled data. - This will support mimicking "base clusters" from the Polis platform. + This supports mimicking "base clusters" from the Polis platform and enables + use in sklearn pipelines where intermediate steps implement both fit and transform. - This enables use in sklearn pipelines, where intermediate steps - are expected to implement both `fit` and `transform`. + Parameters + ---------- + n_clusters : int, default=100 + Number of clusters to form + random_state : int, RandomState instance or None, default=None + Random state for reproducible results + init : {'k-means++', 'random', 'polis'}, default='k-means++' + Initialization strategy + init_centers : array-like of shape (n_clusters, n_features), optional + Initial cluster centers """ - def __init__(self, - n_clusters=100, - random_state=None, - init="k-means++", - init_centers=None, + + def __init__( + self, + n_clusters: int = 100, + random_state: Optional[int] = None, + init: str = "k-means++", + init_centers: Optional[ArrayLike] = None, ): self.n_clusters = n_clusters self.random_state = random_state @@ -153,7 +185,7 @@ def __init__(self, self.init_centers = init_centers self.kmeans_ = None - def fit(self, X, y=None): + def fit(self, X, y=None) -> 'PolisKMeansDownsampler': self.kmeans_ = PolisKMeans( n_clusters=self.n_clusters, random_state=self.random_state, @@ -163,5 +195,82 @@ def fit(self, X, y=None): self.kmeans_.fit(X) return self - def transform(self, X, y=None): - return self.kmeans_.cluster_centers_ if self.kmeans_ else None \ No newline at end of file + def transform(self, X, y=None) -> Optional[np.ndarray]: + return self.kmeans_.cluster_centers_ if self.kmeans_ else None + + +class BestPolisKMeans(BaseEstimator): + """ + A clusterer that automatically finds optimal k-means clustering using silhouette scores. + + This class provides a scikit-learn-like interface while handling k-selection + internally using grid search and silhouette scoring. + + Parameters + ---------- + k_bounds : list of int, default=[2, 5] + Range of k values to search [min_k, max_k] + init : {'k-means++', 'random', 'polis'}, default='polis' + Initialization strategy + init_centers : array-like, optional + Initial cluster centers + random_state : int, optional + Random state for reproducible results + + Attributes + ---------- + best_estimator_ : PolisKMeans + The best fitted estimator + best_k_ : int + The optimal number of clusters found + best_score_ : float + The best silhouette score achieved + """ + + def __init__( + self, + k_bounds: Optional[List[int]] = None, + init: str = "polis", + init_centers: Optional[ArrayLike] = None, + random_state: Optional[int] = None, + ): + self.k_bounds = k_bounds or [2, 5] + self.init = init + self.init_centers = init_centers + self.random_state = random_state + self.best_estimator_ = None + self.best_k_ = None + self.best_score_ = None + + def fit(self, X: NDArray) -> 'BestPolisKMeans': + """Fit the clusterer and find optimal number of clusters using silhouette scores.""" + param_grid = { + "n_clusters": _to_range(self.k_bounds), + } + + def scoring_function(estimator, X_data): + labels = estimator.fit_predict(X_data) + return silhouette_score(X_data, labels) + + search = GridSearchNonCV( + param_grid=param_grid, + scoring=scoring_function, + estimator=PolisKMeans( + init=self.init, + init_centers=self.init_centers, + random_state=self.random_state, + ), + ) + + search.fit(X) + + self.best_k_ = search.best_params_['n_clusters'] + self.best_score_ = search.best_score_ + self.best_estimator_ = search.best_estimator_ + + return self + + def fit_predict(self, X: NDArray, y=None, **kwargs) -> Optional[np.ndarray]: + """Fit the clusterer and return cluster labels.""" + self.fit(X) + return self.labels_ diff --git a/reddwarf/utils/__init__.py b/reddwarf/utils/__init__.py index be8b36c..87b610c 100644 --- a/reddwarf/utils/__init__.py +++ b/reddwarf/utils/__init__.py @@ -1,4 +1,3 @@ -from reddwarf.utils.clusterer.kmeans import * from reddwarf.utils.matrix import * from reddwarf.utils.polismath import * from reddwarf.utils.stats import * diff --git a/reddwarf/utils/clusterer/__init__.py b/reddwarf/utils/clusterer/__init__.py index e69de29..6d5b59e 100644 --- a/reddwarf/utils/clusterer/__init__.py +++ b/reddwarf/utils/clusterer/__init__.py @@ -0,0 +1,5 @@ +def load_builtins(): + """ + Load the builtin clusterers into the clusterer registry. + """ + import reddwarf.utils.clusterer.builtins # noqa: F401 \ No newline at end of file diff --git a/reddwarf/utils/clusterer/base.py b/reddwarf/utils/clusterer/base.py index afef0df..5147589 100644 --- a/reddwarf/utils/clusterer/base.py +++ b/reddwarf/utils/clusterer/base.py @@ -1,51 +1,68 @@ -from typing import Literal, Optional, Union, TypeAlias, TYPE_CHECKING +from typing import Optional, Union, TypeAlias, TYPE_CHECKING, Any from numpy.typing import NDArray -from reddwarf.exceptions import try_import -from reddwarf.sklearn.cluster import PolisKMeans -from reddwarf.utils.clusterer.kmeans import find_best_kmeans +from reddwarf.utils.clusterer.registry import get_clusterer +from reddwarf.utils.clusterer import load_builtins if TYPE_CHECKING: from hdbscan import HDBSCAN from reddwarf.sklearn.cluster import PolisKMeans -ClustererModel: TypeAlias = Union["HDBSCAN", "PolisKMeans"] -ClustererType: TypeAlias = Literal["hdbscan", "kmeans"] +ClustererModel: TypeAlias = Union["HDBSCAN", "PolisKMeans", Any] + +# Load builtin clusterers +load_builtins() def run_clusterer( + clusterer: str, X_participants_clusterable: NDArray, - clusterer: ClustererType = "kmeans", - force_group_count=None, - max_group_count=5, - init_centers=None, - random_state=None, + force_group_count: Optional[int] = None, + max_group_count: int = 5, + init_centers: Optional[list] = None, + random_state: Optional[int] = None, **clusterer_kwargs, ) -> Optional[ClustererModel]: - match clusterer: - case "kmeans": - if force_group_count: - k_bounds = [force_group_count, force_group_count] - else: - k_bounds = [2, max_group_count] - - _, _, kmeans = find_best_kmeans( - X_to_cluster=X_participants_clusterable, - k_bounds=k_bounds, - # Force polis strategy of initiating cluster centers. See: PolisKMeans. - init="polis", - init_centers=init_centers, - random_state=random_state, - # TODO: Support passing in arbitrary clusterer_kwargs. - ) - - return kmeans - - case "hdbscan": - hdbscan = try_import("hdbscan", extra="alt-algos") - - hdb = hdbscan.HDBSCAN(**clusterer_kwargs) - hdb.fit(X_participants_clusterable) - - return hdb - case _: - raise NotImplementedError("clusterer type unknown") + """ + Run a clusterer on participant data using the registry system. + + Args: + clusterer: Name of the registered clusterer to use + X_participants_clusterable: Array of participant coordinates to cluster + force_group_count: Force a specific number of clusters (for k-means) + max_group_count: Maximum number of clusters to test (for k-means) + init_centers: Initial cluster center coordinates + random_state: Random state for reproducibility + **clusterer_kwargs: Additional parameters to pass to the clusterer + + Returns: + Fitted clusterer model or None if clustering fails + """ + # Handle k-means specific parameters + if clusterer == "kmeans": + if force_group_count: + k_bounds = [force_group_count, force_group_count] + else: + k_bounds = [2, max_group_count] + + clusterer_kwargs.update({ + 'k_bounds': k_bounds, + 'init_centers': init_centers, + 'random_state': random_state, + }) + + # Use the registry system for all clusterers + try: + clusterer_instance = get_clusterer(clusterer, **clusterer_kwargs) + + clusterer_instance.fit(X_participants_clusterable) + + # If the clusterer has a best_estimator_ (like BestPolisKMeans), then it's a meta-estimator pipeline. + # In this case, return the best estimator instead of the meta-estimator. + # This ensures we get the actual fitted estimator with all its attributes/params. + if hasattr(clusterer_instance, 'best_estimator_') and clusterer_instance.best_estimator_ is not None: + return clusterer_instance.best_estimator_ + + return clusterer_instance + + except ValueError as e: + raise NotImplementedError(f"Clusterer '{clusterer}' not registered: {e}") diff --git a/reddwarf/utils/clusterer/builtins.py b/reddwarf/utils/clusterer/builtins.py new file mode 100644 index 0000000..c6802fb --- /dev/null +++ b/reddwarf/utils/clusterer/builtins.py @@ -0,0 +1,42 @@ +from typing import TYPE_CHECKING, cast + +from reddwarf.exceptions import try_import +from reddwarf.sklearn.cluster import BestPolisKMeans +from .registry import register_clusterer + +if TYPE_CHECKING: + import hdbscan as hdbscan_module + + +@register_clusterer("kmeans") +def make_kmeans(**kwargs): + """ + Create a PolisBestKMeans clusterer that automatically finds optimal k. + + This is the default k-means implementation that uses silhouette scores + to determine the optimal number of clusters. + """ + defaults: dict = dict( + k_bounds=[2, 5], + init="polis", + init_centers=None, + random_state=None, + ) + defaults.update(kwargs) + return BestPolisKMeans(**defaults) + + +@register_clusterer("hdbscan") +def make_hdbscan(**kwargs) -> "hdbscan_module.HDBSCAN": + """ + Create an HDBSCAN clusterer with default parameters. + """ + hdbscan = try_import("hdbscan", extra="alt-algos") + if TYPE_CHECKING: + hdbscan = cast("hdbscan_module", hdbscan) + + defaults: dict = dict( + min_cluster_size=5, + ) + defaults.update(kwargs) + return hdbscan.HDBSCAN(**defaults) diff --git a/reddwarf/utils/clusterer/kmeans.py b/reddwarf/utils/clusterer/kmeans.py deleted file mode 100644 index 71a8883..0000000 --- a/reddwarf/utils/clusterer/kmeans.py +++ /dev/null @@ -1,111 +0,0 @@ -from numpy.typing import NDArray -import pandas as pd -import numpy as np -from reddwarf.sklearn.model_selection import GridSearchNonCV -from reddwarf.sklearn.cluster import PolisKMeans -from sklearn.metrics import silhouette_score -from typing import List, Optional - -RangeLike = int | tuple[int, int] | list[int] - -def to_range(r: RangeLike) -> range: - """ - Creates an inclusive range from a list, tuple, or int. - - Examples: - ``` - to_range(2) # [2, 3] - to_range([2, 5]) # [2, 3, 4, 5] - to_range((2, 5)) # [2, 3, 4, 5] - ``` - """ - if isinstance(r, int): - start = end = r - elif isinstance(r, (tuple, list)) and len(r) == 2: - start, end = r - else: - raise ValueError("Expected int or a 2-element tuple/list") - - return range(start, end + 1) # inclusive - -# TODO: Start passing init_centers based on /math/pca2 endpoint data, -# and see how often we get the same clusters. -def run_kmeans( - dataframe: pd.DataFrame, - n_clusters: int = 2, - init="k-means++", - # TODO: Improve this type. 3d? - init_centers: Optional[List] = None, - random_state: Optional[int] = None, -) -> PolisKMeans: - """ - Runs K-Means clustering on a 2D DataFrame of xy points, for a specific K, - and returns labels for each row and cluster centers. Optionally accepts - guesses on cluster centers, and a random_state to reproducibility. - - Args: - dataframe (pd.DataFrame): A dataframe with two columns (assumed `x` and `y`). - n_clusters (int): How many clusters k to assume. - init (string): The cluster initialization strategy. See `PolisKMeans` docs. - init_centers (List): A list of xy coordinates to use as initial center guesses. See `PolisKMeans` docs. - random_state (int): Determines random number generation for centroid initialization. Use an int to make the randomness deterministic. - - Returns: - kmeans (PolisKMeans): The estimator object returned from PolisKMeans. - """ - # TODO: Set random_state to a value eventually, so calculation is deterministic. - kmeans = PolisKMeans( - n_clusters=n_clusters, - random_state=random_state, - init=init, - init_centers=init_centers, - ).fit(dataframe) - - return kmeans - -def find_best_kmeans( - X_to_cluster: NDArray, - k_bounds: RangeLike = [2, 5], - init="k-means++", - init_centers: Optional[List] = None, - random_state: Optional[int] = None, -) -> tuple[int, float, PolisKMeans | None]: - """ - Use silhouette scores to find the best number of clusters k to assume to fit the data. - - Args: - X_to_cluster (NDArray): A n-D numpy array. - k_bounds (RangeLike): An upper and low bound on n_clusters to test for. (Default: [2, 5]) - init_centers (List): A list of xy coordinates to use as initial center guesses. - random_state (int): Determines random number generation for centroid initialization. Use an int to make the randomness deterministic. - - Returns: - best_k (int): Ideal number of clusters. - best_silhouette_score (float): Silhouette score for this K value. - best_kmeans (PolisKMeans | None): The optimal fitted estimator returned from PolisKMeans. - """ - param_grid = { - "n_clusters": to_range(k_bounds), - } - - def scoring_function(estimator, X): - labels = estimator.fit_predict(X) - return silhouette_score(X, labels) - - search = GridSearchNonCV( - param_grid=param_grid, - scoring=scoring_function, - estimator=PolisKMeans( - init=init, # strategy - init_centers=init_centers, # guesses - random_state=random_state, - ), - ) - - search.fit(X_to_cluster) - - best_k = search.best_params_['n_clusters'] - best_silhouette_score = search.best_score_ - best_kmeans = search.best_estimator_ - - return best_k, best_silhouette_score, best_kmeans \ No newline at end of file diff --git a/reddwarf/utils/clusterer/registry.py b/reddwarf/utils/clusterer/registry.py new file mode 100644 index 0000000..00b336d --- /dev/null +++ b/reddwarf/utils/clusterer/registry.py @@ -0,0 +1,138 @@ +""" +Clusterer Registry Module + +This module provides a registry system for managing clustering algorithms. +It allows for dynamic registration and retrieval of clusterer factory functions, enabling +a plugin-like architecture for different clustering techniques. + +The registry supports: +- Decorator-based registration of clusterer factory functions +- Runtime retrieval of registered clusterers with parameter overrides +- Listing of all available clusterer names + +Example: + Register a custom clusterer factory function: + + >>> @register_clusterer('dbscan') + ... def make_dbscan(**kwargs): + ... # Import and configure the clusterer + ... defaults = dict(eps=0.5, min_samples=5) + ... defaults.update(kwargs) + ... return DBSCAN(**defaults) + + Retrieve and instantiate a clusterer: + + >>> clusterer = get_clusterer('dbscan', eps=0.3) + >>> print(clusterer.eps) + 0.3 + + List all available clusterers: + + >>> clusterers = list_clusterers() + >>> print(clusterers) + ['dbscan'] +""" + +from typing import Any, Callable, Dict, List + +# Global registry to store clusterer name -> factory function mappings +_CLUSTERER_REGISTRY: Dict[str, Callable[..., Any]] = {} + + +def register_clusterer(name: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """ + Decorator to register a clusterer factory function in the global registry. + + This decorator allows clusterer factory functions to be registered with a string name, + making them available for dynamic instantiation via get_clusterer(). The registered + function should accept keyword arguments and return a configured clusterer instance. + + Args: + name: The string identifier for the clusterer. Must be unique within + the registry. + + Returns: + A decorator function that registers the decorated factory function and returns + it unchanged. + + Raises: + No explicit exceptions, but will overwrite existing registrations + with the same name without warning. + + Example: + >>> @register_clusterer('dbscan') + ... def make_dbscan(**kwargs): + ... from sklearn.cluster import DBSCAN + ... defaults = dict(eps=0.5, min_samples=5) + ... defaults.update(kwargs) + ... return DBSCAN(**defaults) + + >>> # The clusterer factory is now available in the registry + >>> 'dbscan' in _CLUSTERER_REGISTRY + True + """ + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + _CLUSTERER_REGISTRY[name] = fn + return fn + return decorator + + +def get_clusterer(name: str, **overrides: Any) -> Any: + """ + Retrieve and instantiate a registered clusterer by name. + + This function looks up a clusterer factory function by its registered name and + calls it with the provided keyword arguments. This allows for dynamic creation + of clusterer instances with custom parameters. + + Args: + name: The string identifier of the clusterer to retrieve. + **overrides: Keyword arguments to pass to the clusterer factory function. + These will be merged with or override any default parameters + defined in the factory function. + + Returns: + An instance of the requested clusterer, created by calling the registered + factory function with the provided parameters. + + Raises: + ValueError: If the specified clusterer name is not found in the registry. + + Example: + >>> # Assuming 'dbscan' clusterer factory is registered + >>> clusterer = get_clusterer('dbscan', eps=0.3, min_samples=10) + >>> print(clusterer.eps) + 0.3 + + >>> # This will raise ValueError + >>> clusterer = get_clusterer('nonexistent') + ValueError: Clusterer 'nonexistent' not registered. + """ + if name not in _CLUSTERER_REGISTRY: + raise ValueError(f"Clusterer '{name}' not registered.") + return _CLUSTERER_REGISTRY[name](**overrides) + + +def list_clusterers() -> List[str]: + """ + Get a list of all registered clusterer names. + + This function returns the names of all currently registered clusterers, + which can be used to discover available clustering algorithms or for + validation purposes. + + Returns: + A list of strings containing the names of all registered clusterers. + The list will be empty if no clusterers have been registered. + + Example: + >>> # Assuming some clusterers are registered + >>> clusterers = list_clusterers() + >>> print(clusterers) + ['kmeans', 'hdbscan', 'dbscan'] + + >>> # Check if a specific clusterer is available + >>> if 'dbscan' in list_clusterers(): + ... clusterer = get_clusterer('dbscan') + """ + return list(_CLUSTERER_REGISTRY.keys()) \ No newline at end of file diff --git a/reddwarf/utils/reducer/__init__.py b/reddwarf/utils/reducer/__init__.py index e69de29..7390014 100644 --- a/reddwarf/utils/reducer/__init__.py +++ b/reddwarf/utils/reducer/__init__.py @@ -0,0 +1,5 @@ +def load_builtins(): + """ + Load the builtin reducers into the reducer registry. + """ + import reddwarf.utils.reducer.builtins # noqa: F401 \ No newline at end of file diff --git a/reddwarf/utils/reducer/base.py b/reddwarf/utils/reducer/base.py index 1b20421..17895f6 100644 --- a/reddwarf/utils/reducer/base.py +++ b/reddwarf/utils/reducer/base.py @@ -1,10 +1,11 @@ from numpy.typing import NDArray import numpy as np -from reddwarf.exceptions import try_import +from reddwarf.utils.reducer import load_builtins +from reddwarf.utils.reducer.registry import get_reducer from reddwarf.utils.matrix import generate_virtual_vote_matrix from reddwarf.sklearn.transformers import SparsityAwareCapturer, SparsityAwareScaler from reddwarf.sklearn.pipeline import PatchedPipeline -from typing import Literal, Optional, Tuple, Union, TYPE_CHECKING, TypeAlias +from typing import Optional, Tuple, Union, TYPE_CHECKING, TypeAlias from sklearn.impute import SimpleImputer @@ -12,47 +13,12 @@ from pacmap import PaCMAP, LocalMAP from sklearn.decomposition import PCA -ReducerType: TypeAlias = Literal["pca", "pacmap", "localmap"] ReducerModel: TypeAlias = Union["PCA", "PaCMAP", "LocalMAP"] -def get_reducer( - reducer: ReducerType = "pca", - n_components: int = 2, - random_state: Optional[int] = None, - **reducer_kwargs, -) -> ReducerModel: - # Setting n_neighbors to None defaults to 10 below 10,000 samples, and - # slowly increases it according to a formula beyond that. - # See: https://github.com/YingfanWang/PaCMAP?tab=readme-ov-file#parameters - DEFAULT_N_NEIGHBORS = None - match reducer: - case "pacmap" | "localmap": - pacmap = try_import("pacmap", extra="alt-algos") - - # Override with default if not set - n_neighbors = reducer_kwargs.pop("n_neighbors", DEFAULT_N_NEIGHBORS) - - ReducerCls = pacmap.PaCMAP if reducer == "pacmap" else pacmap.LocalMAP - return ReducerCls( - n_components=n_components, - random_state=random_state, - n_neighbors=n_neighbors, # type:ignore - **reducer_kwargs, - ) - case "pca" | _: - from sklearn.decomposition import PCA - - return PCA( - n_components=n_components, - random_state=random_state, - **reducer_kwargs, - ) - - def run_reducer( vote_matrix: NDArray, - reducer: ReducerType = "pca", + reducer: str = "pca", n_components: int = 2, **reducer_kwargs, ) -> Tuple[NDArray, Optional[NDArray], ReducerModel]: @@ -72,21 +38,24 @@ def run_reducer( X_statements (Optional[NDArray]): A numpy array with n-d coordinates for each projected col/statement. reducer_model (ReducerModel): The fitted dimensional reduction sci-kit learn estimator. """ + load_builtins() + reducer_kwargs.update(n_components=n_components) match reducer: case "pca": pipeline = PatchedPipeline( [ ("capture", SparsityAwareCapturer()), ("impute", SimpleImputer(missing_values=np.nan, strategy="mean")), - ("reduce", get_reducer(reducer, n_components=n_components, **reducer_kwargs)), + ("reduce", get_reducer(reducer, **reducer_kwargs)), ("scale", SparsityAwareScaler(capture_step="capture")), ] ) - case "pacmap" | "localmap": + # Use this basic unscaled pipeline by default. + case "pacmap" | "localmap" | _: pipeline = PatchedPipeline( [ ("impute", SimpleImputer(missing_values=np.nan, strategy="mean")), - ("reduce", get_reducer(reducer, n_components=n_components, **reducer_kwargs)), + ("reduce", get_reducer(reducer, **reducer_kwargs)), ] ) diff --git a/reddwarf/utils/reducer/builtins.py b/reddwarf/utils/reducer/builtins.py new file mode 100644 index 0000000..cd02840 --- /dev/null +++ b/reddwarf/utils/reducer/builtins.py @@ -0,0 +1,40 @@ +from typing import TYPE_CHECKING, cast +from sklearn.decomposition import PCA + +from reddwarf.exceptions import try_import +from .registry import register_reducer + +if TYPE_CHECKING: + import pacmap as pacmap_module + + +# Setting n_neighbors to None defaults to 10 below 10,000 samples, and +# slowly increases it according to a formula beyond that. +# See: https://github.com/YingfanWang/PaCMAP?tab=readme-ov-file#parameters +DEFAULT_N_NEIGHBORS = None + +@register_reducer("pca") +def make_pca(**kwargs): + defaults: dict = dict(n_components=2, random_state=42) + defaults.update(kwargs) + return PCA(**defaults) + +@register_reducer("pacmap") +def make_pacmap(**kwargs) -> "pacmap_module.PaCMAP": + pacmap = try_import("pacmap", extra="alt-algos") + if TYPE_CHECKING: + pacmap = cast("pacmap_module", pacmap) + + defaults: dict = dict(n_components=2, n_neighbors=DEFAULT_N_NEIGHBORS) + defaults.update(kwargs) + return pacmap.PaCMAP(**defaults) + +@register_reducer("localmap") +def make_localmap(**kwargs) -> "pacmap_module.LocalMAP": + pacmap = try_import("pacmap", extra="alt-algos") + if TYPE_CHECKING: + pacmap = cast("pacmap_module", pacmap) + + defaults: dict = dict(n_components=2, n_neighbors=DEFAULT_N_NEIGHBORS) + defaults.update(kwargs) + return pacmap.LocalMAP(**defaults) \ No newline at end of file diff --git a/reddwarf/utils/reducer/registry.py b/reddwarf/utils/reducer/registry.py new file mode 100644 index 0000000..418af8d --- /dev/null +++ b/reddwarf/utils/reducer/registry.py @@ -0,0 +1,137 @@ +""" +Reducer Registry Module + +This module provides a registry system for managing dimensionality reduction algorithms. +It allows for dynamic registration and retrieval of reducer factory functions, enabling +a plugin-like architecture for different reduction techniques. + +The registry supports: +- Decorator-based registration of reducer factory functions +- Runtime retrieval of registered reducers with parameter overrides +- Listing of all available reducer names + +Example: + Register a custom reducer factory function: + + >>> @register_reducer('umap') + ... def make_umap(**kwargs): + ... # Import and configure the reducer + ... defaults = dict(n_components=2, n_neighbors=15) + ... defaults.update(kwargs) + ... return umap.UMAP(**defaults) + + Retrieve and instantiate a reducer: + + >>> reducer = get_reducer('umap', n_components=3) + >>> print(reducer.n_components) + 3 + + List all available reducers: + + >>> reducers = list_reducers() + >>> print(reducers) + ['umap'] +""" + +from typing import Any, Callable, Dict, List + +# Global registry to store reducer name -> factory function mappings +_REDUCER_REGISTRY: Dict[str, Callable[..., Any]] = {} + + +def register_reducer(name: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """ + Decorator to register a reducer factory function in the global registry. + + This decorator allows reducer factory functions to be registered with a string name, + making them available for dynamic instantiation via get_reducer(). The registered + function should accept keyword arguments and return a configured reducer instance. + + Args: + name: The string identifier for the reducer. Must be unique within + the registry. + + Returns: + A decorator function that registers the decorated factory function and returns + it unchanged. + + Raises: + No explicit exceptions, but will overwrite existing registrations + with the same name without warning. + + Example: + >>> @register_reducer('umap') + ... def make_umap(**kwargs): + ... defaults = dict(n_components=2, n_neighbors=15) + ... defaults.update(kwargs) + ... return umap.UMAP(**defaults) + + >>> # The reducer factory is now available in the registry + >>> 'umap' in _REDUCER_REGISTRY + True + """ + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + _REDUCER_REGISTRY[name] = fn + return fn + return decorator + + +def get_reducer(name: str, **overrides: Any) -> Any: + """ + Retrieve and instantiate a registered reducer by name. + + This function looks up a reducer factory function by its registered name and + calls it with the provided keyword arguments. This allows for dynamic creation + of reducer instances with custom parameters. + + Args: + name: The string identifier of the reducer to retrieve. + **overrides: Keyword arguments to pass to the reducer factory function. + These will be merged with or override any default parameters + defined in the factory function. + + Returns: + An instance of the requested reducer, created by calling the registered + factory function with the provided parameters. + + Raises: + ValueError: If the specified reducer name is not found in the registry. + + Example: + >>> # Assuming 'umap' reducer factory is registered + >>> reducer = get_reducer('umap', n_components=3, n_neighbors=20) + >>> print(reducer.n_components) + 3 + + >>> # This will raise ValueError + >>> reducer = get_reducer('nonexistent') + ValueError: Reducer 'nonexistent' not registered. + """ + if name not in _REDUCER_REGISTRY: + raise ValueError(f"Reducer '{name}' not registered.") + return _REDUCER_REGISTRY[name](**overrides) + + +def list_reducers() -> List[str]: + """ + Get a list of all registered reducer names. + + This function returns the names of all currently registered reducers, + which can be used to discover available reduction algorithms or for + validation purposes. + + Returns: + A list of strings containing the names of all registered reducers. + The list will be empty if no reducers have been registered. + + Example: + >>> # Assuming some reducers are registered + >>> reducers = list_reducers() + >>> print(reducers) + ['pca', 'tsne', 'umap'] + + >>> # Check if a specific reducer is available + >>> if 'umap' in list_reducers(): + ... reducer = get_reducer('umap') + """ + return list(_REDUCER_REGISTRY.keys()) diff --git a/tests/test_extras.py b/tests/test_extras.py index 4411a6a..3cb0846 100644 --- a/tests/test_extras.py +++ b/tests/test_extras.py @@ -25,7 +25,7 @@ def test_missing_extras_pca_works(monkeypatch): mp.setattr(builtins, "__import__", _fake_import) from reddwarf.utils.reducer.base import get_reducer - get_reducer(reducer="pca") + get_reducer(name="pca") def test_missing_extras_kmeans_works(monkeypatch): # Ensure clean re-import @@ -67,7 +67,7 @@ def test_missing_extras_pacmap_fails(monkeypatch): from reddwarf.utils.reducer.base import get_reducer with pytest.raises(ImportError) as exc_info: - get_reducer(reducer="pacmap") + get_reducer(name="pacmap") assert "Missing optional dependency 'pacmap'" in str(exc_info.value) assert "pip install red-dwarf[alt-algos]" in str(exc_info.value) @@ -79,7 +79,7 @@ def test_missing_extras_localmap_fails(monkeypatch): from reddwarf.utils.reducer.base import get_reducer with pytest.raises(ImportError) as exc_info: - get_reducer(reducer="localmap") + get_reducer(name="localmap") assert "Missing optional dependency 'pacmap'" in str(exc_info.value) assert "pip install red-dwarf[alt-algos]" in str(exc_info.value) diff --git a/tests/utils/clusterer/test_kmeans.py b/tests/utils/clusterer/test_kmeans.py index 66f8716..c7da7f1 100644 --- a/tests/utils/clusterer/test_kmeans.py +++ b/tests/utils/clusterer/test_kmeans.py @@ -1,11 +1,11 @@ import pytest -from reddwarf.utils.clusterer.kmeans import run_kmeans, find_best_kmeans +from reddwarf.sklearn.cluster import BestPolisKMeans, PolisKMeans from tests.fixtures import polis_convo_data from tests.helpers import transform_base_clusters_to_participant_coords import pandas as pd @pytest.mark.parametrize("polis_convo_data", ["small"], indirect=True) -def test_run_kmeans_real_data_reproducible(polis_convo_data): +def test_polis_kmeans_real_data_reproducible(polis_convo_data): fixture = polis_convo_data expected_cluster_centers = [group["center"] for group in fixture.math_data["group-clusters"]] @@ -21,11 +21,11 @@ def test_run_kmeans_real_data_reproducible(polis_convo_data): for item in projected_participants ]).set_index("participant_id") - calculated_kmeans = run_kmeans( - dataframe=projected_participants_df, - init_centers=expected_cluster_centers, + calculated_kmeans = PolisKMeans( n_clusters=cluster_count, - ) + init="k-means++", + init_centers=expected_cluster_centers, + ).fit(projected_participants_df) # Ensure same number of clusters assert len(calculated_kmeans.cluster_centers_) == len(expected_cluster_centers) @@ -37,7 +37,7 @@ def test_run_kmeans_real_data_reproducible(polis_convo_data): # NOTE: "small-no-meta" fixture doesn't work because wants to find 4 clusters, whereas real data from polismath says 3. # This is likely due to k-smoothing holding back the k value at 3 in polismath, and we're finding the real current one. @pytest.mark.parametrize("polis_convo_data", ["small-with-meta"], indirect=True) -def test_find_best_kmeans_real_data(polis_convo_data): +def test_best_polis_kmeans_real_data(polis_convo_data): fixture = polis_convo_data MAX_GROUP_COUNT = 5 @@ -54,15 +54,24 @@ def test_find_best_kmeans_real_data(polis_convo_data): for item in projected_participants ]).set_index("participant_id") - results = find_best_kmeans( - X_to_cluster=projected_participants_df, + from reddwarf.utils.clusterer.base import run_clusterer + + # Test using run_clusterer (which is what the pipeline actually uses) + clusterer_result = run_clusterer( + clusterer="kmeans", + X_participants_clusterable=projected_participants_df.values, k_bounds=[2, MAX_GROUP_COUNT], - # Pad center guesses to have enough values for testing up to max k groups. init_centers=expected_centers ) - _, _, optimal_kmeans = results # for documentation - calculated_centers = optimal_kmeans.cluster_centers_.tolist() if optimal_kmeans else [] + cluster_centers = getattr(clusterer_result, 'cluster_centers_', None) + calculated_centers = cluster_centers.tolist() if cluster_centers is not None else [] + + # Verify init_centers_used_ attribute is available (this was the original bug) + assert hasattr(clusterer_result, 'init_centers_used_') + init_centers_used = getattr(clusterer_result, 'init_centers_used_', None) + assert init_centers_used is not None + assert init_centers_used.shape[0] == len(calculated_centers) assert len(expected_centers) == len(calculated_centers) for i, _ in enumerate(expected_centers):