Source code for holisticai.bias.mitigation.preprocessing.fairlet_clustering.transformer
from __future__ import annotations
from typing import Optional, Union
import numpy as np
from holisticai.bias.mitigation.commons.fairlet_clustering.decompositions import (
DecompositionMixin,
ScalableFairletDecomposition,
VanillaFairletDecomposition,
)
from holisticai.utils.models.cluster import KCenters, KMedoids
from holisticai.utils.transformers.bias import BMPreprocessing as BMPre
from sklearn.base import BaseEstimator
from sklearn.metrics.pairwise import pairwise_distances_argmin
DECOMPOSITION_CATALOG = {
"Scalable": ScalableFairletDecomposition,
"Vanilla": VanillaFairletDecomposition,
}
CLUSTERING_CATALOG = {"KCenter": KCenters, "KMedoids": KMedoids}
[docs]
class FairletClusteringPreprocessing(BaseEstimator, BMPre):
"""
Fairlet decomposition [1]_ is a pre-processing approach that computes\
fair micro-clusters where fairness is guaranteed. They then use\
the fairlet centers as a newly transformed dataset from the original.\
This transformed fairlet-based dataset is then provided to vanilla\
clustering algorithms, and hence, we obtain approximately\
fair clustering outputs as a result of the fairlets themselves being fair.
Parameters
----------
decomposition : str, optional
Fairlet decomposition strategy, available: Vanilla, Scalable, MCF. Default is Vanilla.
p : int, optional
fairlet decomposition parameter for Vanilla and Scalable strategy. Default is 1.
q : int, optional
fairlet decomposition parameter for Vanilla and Scalable strategy. Default is 3.
seed : int, optional
Random seed. Default is None.
Examples
--------
>>> from holisticai.bias.mitigation import FairletClusteringPreprocessing
>>> mitigator = FairletClusteringPreprocessing()
>>> train_data_transformed = mitigator.fit_transform(train_data, group_a, group_b)
References
----------
.. [1] `Backurs, Arturs, et al. "Scalable fair clustering." International Conference on
Machine Learning. PMLR, 2019.`
"""
def __init__(
self,
decomposition: Union[str, DecompositionMixin] = "Vanilla",
p: Optional[str] = 1,
q: Optional[float] = 3,
seed: Optional[int] = None,
):
self.decomposition = DECOMPOSITION_CATALOG[decomposition](p=p, q=q)
self.p = p
self.q = q
self.seed = seed
[docs]
def fit_transform(
self,
X: np.ndarray,
group_a: np.ndarray,
group_b: np.ndarray,
sample_weight: Optional[np.ndarray] = None,
):
"""
Fits the model by learning a fair cluster.
Parameters
----------
X : matrix-like
input matrix
group_a : array-like
binary mask vector
group_b : array-like
binary mask vector
sample_weight : array-like, optional
Samples weights vector. Default is None.
Returns
-------
matrix
Transformed matrix
"""
params = self._load_data(X=X, sample_weight=sample_weight, group_a=group_a, group_b=group_b)
x = params["X"]
sample_weight = params["sample_weight"]
group_a = params["group_a"].astype("int32")
group_b = params["group_b"].astype("int32")
np.random.seed(self.seed)
fairlets, fairlet_centers, fairlet_costs = self.decomposition.fit_transform(x, group_a, group_b)
xt = np.zeros_like(x)
mapping = np.zeros(len(x), dtype="int32")
centers = np.array([x[fairlet_center] for fairlet_center in fairlet_centers])
for i, fairlet in enumerate(fairlets):
xt[fairlet] = x[fairlet_centers[i]]
mapping[fairlet] = i
sample_weight[fairlet] = len(fairlet) / len(x)
self._update_estimator_param("sample_weight", sample_weight)
self.sample_weight = sample_weight
self.X = x
self.mapping = mapping
self.centers = centers
return xt
[docs]
def transform(self, X):
"""
Transforms the model by learning a fair cluster.
Parameters
----------
X : matrix-like
input matrix
Returns
-------
matrix
Transformed matrix
"""
fairlets_midxs = pairwise_distances_argmin(X, Y=self.X)
return self.centers[self.mapping[fairlets_midxs]]