Source code for holisticai.bias.mitigation.inprocessing.fairlet_clustering.transformer

from __future__ import annotations

from typing import Optional, Union

import numpy as np
from holisticai.bias.mitigation.commons.fairlet_clustering.decompositions import (
    DecompositionMixin,
    ScalableFairletDecomposition,
    VanillaFairletDecomposition,
)
from holisticai.bias.mitigation.inprocessing.fairlet_clustering.algorithm import (
    FairletClusteringAlgorithm,
)
from holisticai.utils.models.cluster import KCenters, KMedoids
from holisticai.utils.transformers.bias import BMInprocessing as BMImp
from sklearn.base import BaseEstimator

DECOMPOSITION_CATALOG = {
    "Scalable": ScalableFairletDecomposition,
    "Vanilla": VanillaFairletDecomposition,
}
CLUSTERING_CATALOG = {"KCenters": KCenters, "KMedoids": KMedoids}


[docs] class FairletClustering(BaseEstimator, BMImp): """Fairlet Clustering [1]_ inprocessing bias mitigation works in two steps: (1) The pointset is partitioned into subsets called fairlets that satisfy\ the fairness requirement and approximately preserve the k-median objective. (2) Fairlets are merged into k clusters by one of the existing k-median algorithms. Parameters ---------- n_clusters : int The number of clusters to form as well as the number of centroids to generate. decomposition : str Fairlet decomposition strategy, available: Vanilla, Scalable clustering_model : str specified lambda parameter p : int fairlet decomposition parameter for Vanilla and Scalable strategy q : int fairlet decomposition parameter for Vanilla and Scalable strategy seed : int Random seed. Examples -------- >>> from holisticai.bias.mitigation import FairletClustering >>> mitigator = FairletClustering(**params) >>> mitigator.fit(train_data, group_a, group_b) >>> train_data_transformed = mitigator.predict(train_data) References --------- .. [1] Backurs, Arturs, et al. "Scalable fair clustering." International Conference on\ Machine Learning. PMLR, 2019. """ def __init__( self, n_clusters: Optional[int], decomposition: Union[str, DecompositionMixin] = "Vanilla", clustering_model: Optional[str] = "KCenters", p: Optional[str] = 1, q: Optional[float] = 3, seed: Optional[int] = None, ): if decomposition in ["Scalable", "Vanilla"]: self.decomposition = DECOMPOSITION_CATALOG[decomposition](p=p, q=q) self.clustering_model = CLUSTERING_CATALOG[clustering_model](n_clusters=n_clusters) # Constant parameters self.algorithm = FairletClusteringAlgorithm( decomposition=self.decomposition, clustering_model=self.clustering_model ) self.p = p self.q = q self.n_clusters = n_clusters self.seed = seed
[docs] def fit( self, X: np.ndarray, group_a: np.ndarray, group_b: np.ndarray, ): """ Fit the model Description ----------- Learn a fair cluster. Parameters ---------- X : numpy array input matrix group_a : numpy array binary mask vector group_b : numpy array binary mask vector Returns ------- self """ params = self._load_data(X=X, group_a=group_a, group_b=group_b) X = params["X"] group_a = params["group_a"].astype("int32") group_b = params["group_b"].astype("int32") np.random.seed(self.seed) self.algorithm.fit(X, group_a=group_a, group_b=group_b) return self
@property def cluster_centers_(self): return self.algorithm.cluster_centers_ @property def labels_(self): return self.algorithm.labels
[docs] def predict(self, X: np.ndarray): """ Prediction Description ---------- Predict cluster for the given samples. Parameters ---------- X : pandas.DataFrame or numpy array Test samples. Returns ------- numpy.ndarray Predicted output per sample. """ params = self._load_data(X=X) X = params["X"] return self.algorithm.predict(X)
[docs] def fit_predict(self, X: np.ndarray, group_a: np.ndarray, group_b: np.ndarray): """ Prediction Description ---------- Fit and Predict the cluster for the given samples. Parameters ---------- X : pandas.DataFrame or numpy array Test samples. group_a : numpy array binary mask vector group_b : numpy array binary mask vector Returns ------- numpy.ndarray Predicted cluster per sample. """ self.fit(X, group_a, group_b) return self.labels_