Source code for holisticai.bias.plots._multiclass

# Base Imports
import numpy as np
import pandas as pd
import seaborn as sns

# Import metrics
from holisticai.bias.metrics import frequency_matrix
from holisticai.utils import get_colors

# utils
from holisticai.utils._validation import _multiclass_checks
from matplotlib import pyplot as plt


[docs] def frequency_plot(p_attr, y_pred, ax=None, size=None, title=None): """ Frequency plot. Description ---------- This function plots how frequently members of each group fall into each class. Parameters ---------- p_attr : array-like Protected attribute vector y_pred : array-like Returns ------- matplotlib ax (or None) """ # check and coerce inputs p_attr, y_pred, _, _, _ = _multiclass_checks( p_attr=p_attr, y_pred=y_pred, y_true=None, groups=None, classes=None, ) # get success rates sr_list = frequency_matrix(p_attr, y_pred, normalize="group") # sort by success rate sr_tot = sr_list.sum(axis=0) / sr_list.sum(axis=0).sum() sr_tot.name = "total" sr_list = pd.concat([sr_list, pd.DataFrame(sr_tot).transpose()], axis=0) name_classes = sr_list.columns n_classes = len(name_classes) # charting sns.set_theme() colors = get_colors(sr_list.shape[0]) hai_palette = sns.color_palette(colors) num_classes_threshold = 2 if n_classes > num_classes_threshold: for i in range(n_classes): fig, ax = plt.subplots() fig.suptitle("Class " + str(name_classes[i])) sns.barplot( x=sr_list.index.to_list(), y=sr_list[name_classes[i]], palette=hai_palette, ax=ax, ) ax.set_xlabel("Group") ax.set_ylabel("Frequency ") _, labels = plt.xticks() plt.setp(labels, rotation=45) return None if ax is None: fig, ax = plt.subplots(figsize=size) sns.barplot( x=sr_list.index.to_list(), y=sr_list[name_classes[1]], palette=hai_palette, ax=ax, ) ax.set_xlabel("Group") ax.set_ylabel("Frequency") _, labels = plt.xticks() plt.setp(labels, rotation=45) if title is not None: ax.set_title(title) else: ax.set_title(f"Frequency Plot (Class {name_classes[1]})") return ax
[docs] def statistical_parity_plot(p_attr, y_pred, pos_label=1, compare_to=None, ax=None, size=None, title=None): """ Statistical Parity Plot (Binary Classification). Description ----------- This function plots the statistical parity for each group along with acceptable bounds. We take the group with maximum success rate as the comparison group. Parameters ---------- p_attr : array-like Protected attribute vector y_pred : array-like Prediction vector (binary) pos_label (optional) : label, default=1 The positive label compare_to (optional) : str or int The group we are comparing to ax (optional) : matplotlib axes Pre-existing axes for the plot size (optional) : (int, int) Size of the figure title (optional) : str Title of the figure Returns ------- matplotlib ax """ # check and coerce inputs p_attr, y_pred, _, groups, _ = _multiclass_checks( p_attr=p_attr, y_pred=y_pred, y_true=None, groups=None, classes=None, ) group_dict = dict(zip(groups, range(len(groups)))) # get success rates. sr_list = frequency_matrix(p_attr, y_pred * 1, groups=groups, normalize="group")[pos_label].to_numpy() # sort by success rate. sr_list_sorted, groups_sorted = zip(*sorted(zip(sr_list, groups), reverse=True)) sr_list_sorted, groups_sorted = list(sr_list_sorted), list(groups_sorted) # statistical parity list if compare_to is not None: sp_list = sr_list_sorted - sr_list[group_dict[compare_to]] else: sp_list = sr_list_sorted - np.max(sr_list) # setup sns.set_theme() if ax is None: fig, ax = plt.subplots(figsize=size) # charting colors = get_colors(len(groups)) hai_palette = sns.color_palette(colors) ax.set_xlabel("Group") ax.set_ylabel("Statistical Parity") sns.barplot(x=groups_sorted, y=sp_list, palette=hai_palette, ax=ax) # horizontal lines ax.axhline(y=-0.1, color="grey", linestyle="--", label="lower bound") ax.axhline(y=0.1, color="grey", linestyle="--", label="upper bound") ax.axhspan(-0.1, 0.1, alpha=0.3, color="grey", zorder=0, label="fair area") # tilt labels _, labels = plt.xticks() plt.setp(labels, rotation=45) ax.legend() if title is not None: ax.set_title(title) else: ax.set_title("Statistical Parity plot") return ax
[docs] def disparate_impact_plot(p_attr, y_pred, pos_label=1, compare_to=None, ax=None, size=None, title=None): """ Disparate Impact Plot (Binary Classification). Description ----------- This function plots the disparate impact for each group along with acceptable bounds. We take the group with maximum success rate as the 'majority group'. Parameters ---------- p_attr : array-like Protected attribute vector y_pred : array-like Prediction vector pos_label : label, default=1 The positive label compare_to (optional) : str or int The group we are comparing to ax (optional) : matplotlib axes Pre-existing axes for the plot size (optional) : (int, int) Size of the figure title (optional) : str Title of the figure Returns ------- matplotlib ax """ # check and coerce inputs p_attr, y_pred, _, groups, _ = _multiclass_checks( p_attr=p_attr, y_pred=y_pred, y_true=None, groups=None, classes=None, ) group_dict = dict(zip(groups, range(len(groups)))) # get success rates sr_list = frequency_matrix(p_attr, 1 * y_pred, groups)[pos_label].to_numpy() # sort by success rate. sr_list_sorted, groups_sorted = zip(*sorted(zip(sr_list, groups), reverse=True)) sr_list_sorted, groups_sorted = list(sr_list_sorted), list(groups_sorted) # disparate impact list if compare_to is not None: di_list = sr_list_sorted / sr_list[group_dict[compare_to]] else: di_list = sr_list_sorted / sr_list_sorted[0] # setup sns.set_theme() if ax is None: fig, ax = plt.subplots(figsize=size) # charting colors = get_colors(len(groups)) hai_palette = sns.color_palette(colors) ax.set_xlabel("Group") ax.set_ylabel("Disparate Impact") sns.barplot(x=groups_sorted, y=di_list, palette=hai_palette, ax=ax) # horizontal lines ax.axhspan(0.8, 1.2, alpha=0.3, color="grey", label="fair area") ax.axhline(y=1.2, color="grey", linestyle="--", label="upper bound") ax.axhline(y=0.8, color="grey", linestyle="--", label="lower bound") # tilt labels _, labels = plt.xticks() plt.setp(labels, rotation=45) # legend ax.legend() if title is not None: ax.set_title(title) else: ax.set_title("Disparate Impact plot") # return return ax
[docs] def frequency_matrix_plot( p_attr, y_pred, groups=None, classes=None, normalize=None, reverse_colors=False, ax=None, size=None, title=None, ): """ Frequency Matrix Plot. Description ----------- This function plots the matrix of occurence rate (count) for each group, class pair. We include the option to normalise over groups or classes. Parameters ---------- p_attr : array-like Protected attribute vector y_pred : array-like Prediction vector (categorical) groups (optional) : array or list The groups in order classes (optional) : array or list The classes in order normalize (optional): None, 'group' or 'class' According to which of group or class we normalize reverse_colors (optional): bool, default=False Option to reverse the color palette ax (optional) : matplotlib axes Pre-existing axes for the plot size (optional) : (int, int) Size of the figure title (optional) : str Title of the figure Returns ------- matplotlib ax """ # check and coerce inputs p_attr, y_pred, _, groups, classes = _multiclass_checks( p_attr=p_attr, y_pred=y_pred, y_true=None, groups=groups, classes=classes, ) # compute frequency matrix colors = get_colors(10, extended_colors=True, reverse=reverse_colors) hai_palette = sns.color_palette(colors) sr_mat = frequency_matrix(p_attr, y_pred, groups=groups, classes=classes, normalize=normalize) # setup sns.set_theme() if ax is None: fig, ax = plt.subplots(figsize=size) # charting if normalize is None: sns.heatmap(sr_mat, annot=True, cmap=hai_palette, ax=ax) else: sns.heatmap(sr_mat, annot=True, fmt=".2%", cmap=hai_palette, ax=ax) ax.set_xlabel("Class") ax.set_ylabel("Group") if title is not None: ax.set_title(title) else: ax.set_title("Frequency matrix plot") # return return ax
[docs] def accuracy_bar_plot(p_attr, y_pred, y_true, ax=None, size=None, title=None): """ Accuracy Bar Plot. Description ----------- This function plots the accuracy of the predictions for each group. Parameters ---------- p_attr : array-like Protected attribute vector y_pred : array-like Prediction vector y_true : array-like Target vector ax (optional) : matplotlib axes Pre-existing axes for the plot size (optional) : (int, int) Size of the figure title (optional) : str Title of the figure Returns ------- matplotlib ax """ # check and coerce inputs p_attr, y_pred, y_true, groups, _ = _multiclass_checks( p_attr=p_attr, y_pred=y_pred, y_true=y_true, groups=None, classes=None, ) # loop over groups acc_list = [] for c in groups: members = p_attr == c pred_c = y_pred[members] true_c = y_true[members] truepred = (pred_c == true_c).sum() acc = truepred / len(pred_c) acc_list.append(acc) acc_list = list(acc_list) groups = list(groups) # tot acc_tot = (y_pred == y_true).sum() / len(y_pred) acc_list.append(acc_tot) groups.append("Total") # setup sns.set_theme() if ax is None: fig, ax = plt.subplots(figsize=size) # charting colors = get_colors(len(groups)) hai_palette = sns.color_palette(colors) ax.set_xlabel("Group") ax.set_ylabel("Accuracy") sns.barplot(x=groups, y=acc_list, palette=hai_palette, ax=ax) if title is not None: ax.set_title(title) else: ax.set_title("Accuracy Bar Plot") return ax