from __future__ import annotations
from typing import Optional
import pandas as pd
from holisticai.bias.mitigation.postprocessing.fair_topk.algorithm_utils.fail_prob import (
RecursiveNumericFailProbabilityCalculator,
)
from holisticai.bias.mitigation.postprocessing.fair_topk.algorithm_utils.valitation_utils import (
check_ranking,
validate_basic_parameters,
)
from holisticai.utils.transformers.bias import BMPostprocessing as BMPost
[docs]
class FairTopK(BMPost):
"""
Fair Top K bias mitigation [1]_ can be used for Recommender Systems.\
The strategy extends group fairness definition using the standard notion of protected groups\
and is based on ensuring that the proportion of protected candidates in every prefix of the top-k\
ranking.
Parameters
----------
top_n : int
The total number of elements.
p : float
The proportion of protected candidates in the top-k ranking.
alpha : float
The significance level.
query_col : str
The name of the column in data that contains query ids.
doc_col : str
The name of the column in data that contains document ids.
group_col : str
The name of the column in data that contains protected attribute.
score_col : str
The name of the column in data that contains judgment values.
Examples
--------
>>> from holisticai.bias.mitigation import FairTopK
>>> mitigator = FairTopK(**params)
>>> new_rankings = mitigator.transform(rankings)
References
---------
.. [1] Zehlike, Meike, et al. "Fa* ir: A fair top-k ranking algorithm." Proceedings of the 2017 ACM on\
Conference on Information and Knowledge Management. 2017.
"""
def __init__(
self,
top_n: Optional[int],
p: Optional[float],
alpha: Optional[float],
query_col: Optional[str] = "query_id",
doc_col: Optional[str] = "doc_id",
group_col: Optional[str] = "group_id",
score_col: Optional[str] = "score",
):
# check the parameters first
validate_basic_parameters(top_n, p, alpha)
self.query_col = query_col
self.doc_col = doc_col
self.group_col = group_col
self.score_col = score_col
# assign the parameters
self.top_n = top_n # the total number of elements
self.p = p # the proportion of protected candidates in the top-k ranking
self.alpha = alpha # the significance level
self._cache = {} # stores generated mtables in memory
def _create_adjusted_mtable(self):
"""
Description
-----------
Creates an adjusted mtable by using the alpha value.
Return
------
list
mtable as list of int elements
"""
if (self.top_n, self.p, self.alpha) not in self._cache:
# create the mtable
fail_prob_pair = RecursiveNumericFailProbabilityCalculator(self.top_n, self.p, self.alpha).adjust_alpha()
mtable = [int(i) for i in fail_prob_pair.mtable.m.tolist()]
# store as list
self._cache[(self.top_n, self.p, self.alpha)] = mtable
# return from cache
return self._cache[(self.top_n, self.p, self.alpha)]
[docs]
def is_fair(self, ranking):
"""
Checks if the ranking is fair for the given parameters
Parameters
----------
ranking: list
The ranking to be checked (list of Resultinfo)
Returns
------
bool
True if the ranking is fair, False otherwise.
"""
return check_ranking(ranking[self.group_col], self._create_adjusted_mtable())
def _fair_top_k(self, protected_candidates, non_protected_candidates, mtable):
"""
Reorganize the results info ensuring true the mtable condition (#protected[:i] >= mtable[i]).
Parameters
----------
protected_candidates: pd.DataFrame
ranking dataframe filtered with only protected candidates
non_protected_candidates: pd.DataFrame
ranking dataframe filtered with only non protected candidates
mtable: list
adjusted mtable
Returns
------
list
List of re-ranked results.
"""
result = []
countProtected = 0
idxProtected = 0
idxNonProtected = 0
for i in range(self.top_n):
if idxProtected >= len(protected_candidates) and idxNonProtected >= len(non_protected_candidates):
# no more candidates available, return list shorter than k
return result
if idxProtected >= len(protected_candidates):
# no more protected candidates available, take non-protected instead
result.append(non_protected_candidates.iloc[idxNonProtected])
idxNonProtected += 1
elif idxNonProtected >= len(non_protected_candidates):
# no more non-protected candidates available, take protected instead
result.append(protected_candidates.iloc[idxProtected])
idxProtected += 1
countProtected += 1
elif countProtected < mtable[i]:
# add a protected candidate
result.append(protected_candidates.iloc[idxProtected])
idxProtected += 1
countProtected += 1
elif (
protected_candidates.iloc[idxProtected][self.score_col]
>= non_protected_candidates.iloc[idxNonProtected][self.score_col]
):
# the best is a protected one
result.append(protected_candidates.iloc[idxProtected])
idxProtected += 1
countProtected += 1
else:
# the best is a non-protected one
result.append(non_protected_candidates.iloc[idxNonProtected])
idxNonProtected += 1
return result