Source code for holisticai.explainability.plots._partial_dependencies

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from holisticai.explainability.metrics.global_feature_importance._fluctuation_ratio import fluctuation_ratio
from holisticai.explainability.metrics.global_feature_importance._xai_ease_score import XAIEaseAnnotator
from holisticai.utils import Importances, PartialDependence
from matplotlib import cm
from scipy.interpolate import interp1d


[docs] def plot_partial_dependence( partial_dependence: PartialDependence, ranked_feature_importance: Importances, subplots=(1, 1), figsize=None, class_idx=0, ): """ Plots the partial dependence of features on the predicted target. Parameters ---------- partial_dependence: PartialDependence The partial dependence object containing the computed partial dependence values. ranked_feature_importance: RankedFeatureImportance The ranked feature importance object containing the feature names and their importance scores. subplots: (tuple, optional) The shape of the subplots grid. Defaults to (1, 1). figsize: (tuple, optional) The size of the figure. Defaults to None. Returns ------- fig: The matplotlib figure object containing the plot. Example ------- >>> partial_dependence = PartialDependence(values=[...]) >>> ranked_feature_importance = Importances(values=[...], feature_names=[...]) >>> plot_partial_dependence(partial_dependence, ranked_feature_importance) The plot should look like this: .. image:: /_static/images/xai_plot_partial_dependence.png :alt: Plot Partial Dependence """ partial_dependence_values = partial_dependence.values[class_idx] _, axs = plt.subplots(*subplots, figsize=figsize) axs = [axs] if isinstance(axs, plt.Axes) else axs.flatten() n_plots = min(len(axs), len(partial_dependence_values)) annotator = XAIEaseAnnotator() for feature_index in range(n_plots): ax = axs[feature_index] individuals = partial_dependence_values[feature_index]["individual"][0] average = partial_dependence_values[feature_index]["average"][0] x = partial_dependence_values[feature_index]["grid_values"][0] level = annotator.compute_xai_ease_score_data(partial_dependence_values, ranked_feature_importance).set_index( "feature" )["scores"] feature_name = ranked_feature_importance.feature_names[feature_index] feature_value = ranked_feature_importance[feature_index] curve_media = np.mean(individuals, axis=0) curve_std = np.std(individuals, axis=0) curve_lower = curve_media - curve_std curve_upper = curve_media + curve_std ax.plot(x, average, color="blue", label=level.loc[feature_name]) # for curve in individuals: # ax.plot(x, curve, alpha=0.05, color="skyblue") ax.fill_between(x, curve_lower, curve_upper, color="skyblue", alpha=0.2) ymin = individuals.min() ymax = individuals.max() ax.set_ylim(ymin, ymax) xmin = x.min() xmax = x.max() ax.set_xlim(xmin, xmax) ax.set_xlabel("Feature Value") ax.set_ylabel("Predicted Target") ax.set_title(f"{feature_name} ({feature_value:.3f})") ax.grid(True) ax.legend() plt.tight_layout()
def get_oscillations_from_individuals(grid_values, individuals): X = grid_values indice_oscilacion_normalizados = [] for Y in individuals: interpolacion = interp1d(X, Y, kind="linear") X_nuevo = np.linspace(min(X), max(X), 20) Y_nuevo = interpolacion(X_nuevo) derivada = np.diff(Y_nuevo) cambios_signo = np.sum(np.diff(np.sign(derivada)) != 0) indice_oscilacion_normalizados.append(cambios_signo / len(Y_nuevo)) return indice_oscilacion_normalizados # Function to plot the explainable partial dependence with oscillations and importance visualization def plot_explainable_partial_dependence( partial_dependencies, importances, figsize, feature_names=None, model_name=None, label=0, top_n=10 ): if feature_names is None: feature_names = partial_dependencies.feature_names[:top_n] feature_names = feature_names[:top_n] fr_df = fluctuation_ratio(partial_dependencies, importances, top_n=top_n, aggregated=False) df = importances.as_dataframe().set_index("Variable") df = ( pd.concat([df, fr_df], axis=1) .dropna() .sort_values("Importance", ascending=False) .reset_index() .rename({"index": "Variable"}, axis=1) ) # Set default figure size if not provided if figsize is None: figsize = (20, 5) ncols = 5 # Number of columns for subplots if top_n is None: top_n = min(ncols, len(feature_names)) # Use top N features or the maximum columns fig = plt.figure(figsize=figsize) # Create subplots color_map = cm.get_cmap("Blues") # Color map for individual curves if model_name is not None: fig.suptitle(model_name) # Set title if model name is provided df = df.set_index("Variable") # Set index to 'Variable' for easy access # Plot the top N features for i, feature_name in enumerate(feature_names): # Get individual values and grid values for each feature individuals = partial_dependencies.get_value(feature_name=feature_name, label=label, data_type="individual") grid_values = partial_dependencies.get_value(feature_name=feature_name, label=label, data_type="grid_values") average = partial_dependencies.get_value(feature_name=feature_name, label=label, data_type="average") # Convert individual curves to a matrix for standard deviation calculation individuals_matrix = np.array(individuals) oscillations = get_oscillations_from_individuals(grid_values, individuals) # Calculate oscillations indexes = np.argsort(oscillations)[::-1][:15] # Select top oscillating curves num_curves = len(indexes) # Plot individual curves with varying transparency ax = fig.add_subplot(1, top_n, i + 1) # Crear un subplot en una cuadrícula 2x2 for idx, index in enumerate(indexes): ax.plot(grid_values, individuals[index], color=color_map(idx / num_curves), alpha=0.2) # Calculate standard deviation at each grid point std_dev = np.std(individuals_matrix, axis=0) # Plot the average curve imp = df.loc[feature_name, "Importance"] # Get importance of the feature ax.set_facecolor((1, 0, 0, imp)) # Set the background color based on importance ax.plot(grid_values, average, color="blue", linewidth=2, label="Average") # Plot the confidence band (average ± standard deviation) ax.fill_between(grid_values, average - std_dev, average + std_dev, color="blue", alpha=0.1, label="Std Dev") # Adjust labels and titles short_feature_name = feature_name.split("_")[-1] # Use short version of the feature name ax.set_xlabel("Grid Values") ax.set_ylabel("Partial Dependence") ax.set_title(f"[feature={short_feature_name}, F={df.loc[feature_name,'Fluctuation Ratio']:.3f}]", fontsize=12) ax.legend() ax.xaxis.set_major_locator(plt.MaxNLocator(5)) # Set the number of x-axis ticks ax.grid(True) # Show grid plt.tight_layout() # Adjust layout to avoid overlap