Source code for holisticai.robustness.plots._dataset_shift

"""
Module description:
-------------------
- This module contains functions to plot the dataset shift analysis results.

This module provides a collection of plot functions for 2D datasets, focusing
on displaying model predictions, neighborhood analysis, and accuracy
degradation profiles. The functions are designed to help users understand how
models behave under different scenarios, such as changes in data, prediction
performance, and neighborhood-based analysis. Through intuitive scatter plots
and additional visual aids, users can assess model predictions, compare actual
vs. predicted values, and analyze model robustness under varying conditions.

Functions included:
-------------------
- plot_2d: Generates a 2D scatter plot for a dataset, with options to highlight
  or exclusively display a subset of points. It supports both general
  visualizations and focused views of specific data points.

- plot_label_and_prediction: Creates a scatter plot that shows both actual
  labels and predicted labels for a dataset, with a vertical offset to visually
  differentiate between the two. Ideal for visual comparison of prediction
  accuracy.

- plot_neighborhood: Visualizes the neighborhoods around specified points of
  interest by plotting a convex hull, the nearest neighbors, and the accuracy
  within those neighbors. It helps users understand how local groups of data
  points contribute to overall model performance.

- plot_adp_and_adf: Plots the accuracy degradation profile (ADP) by showing the
  percentage of samples above a threshold versus the size factor of the
  dataset. It highlights key points of degradation with color-coding and
  vertical markers, providing insights into the model's robustness as data
  availability decreases. Accuracy degradation factor (ADF) is also showed
  as a circle at the first degradation point.

This module offers a set of functions to explore the behavior and performance
of machine learning models visually, facilitating a better understanding of
model predictions and robustness across different data scenarios.
"""

# Importing required libraries
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


def _validate_and_extract_data(data):
    """
    Validates and extracts data from pandas DataFrame or NumPy arrays.

    Parameters:
    ----------
    data : pd.DataFrame or pd.Series or np.ndarray
        The data to validate and convert to a NumPy array if necessary.
        If it is a DataFrame or Series, it will be converted to a NumPy array.

    Returns:
    -------
    data_vals : np.ndarray
        The data as a NumPy array.

    Raises:
    -------
    ValueError
        If `data` is a pandas DataFrame and does not have exactly two columns.
    TypeError
        If `data` is neither a pandas DataFrame nor a NumPy array.
    """

    # Check if data is a pandas DataFrame
    if isinstance(data, pd.DataFrame):
        if data.shape[1] != 2:
            raise ValueError("Data should have exactly two columns.")
        data_vals = data.values  # Convert DataFrame to NumPy array
    # Check if data is a pandas Series
    elif isinstance(data, pd.Series):
        data_vals = data.to_numpy()
    # Check if data is a NumPy array
    elif isinstance(data, np.ndarray):
        data_vals = data

    # Raise error if the data is not a DataFrame or NumPy array
    else:
        raise TypeError("Data must be either a pandas DataFrame or a NumPy array.")

    return data_vals


[docs] def plot_2d(X, y, highlight_group=None, show_just_group=None, features_to_plot=None): """ Plots a 2D scatter plot of a dataset, with options to highlight or exclusively show a subset of points. This function generates a 2D scatter plot from the given dataset, where each point represents a sample. Users can highlight specific points or display only a subset of the data. Axis labels can be customized using feature names from the `features_to_plot` argument. Parameters ---------- X : np.ndarray or pd.DataFrame The feature matrix, where each row is a sample and each column represents a feature. This can be either a NumPy array or a pandas DataFrame with two columns. y : np.ndarray or pd.Series The labels for each sample, represented as a one-dimensional NumPy array or pandas Series. highlight_group : list or np.ndarray, optional The indices of the points to be highlighted (outlined in red) or exclusively plotted. If None, no points are highlighted. show_just_group : bool, optional If True, only the points specified in `highlight_group` are plotted, and all other points are hidden. If False, all points are plotted, but the highlighted group is outlined. features_to_plot : list, optional A list of feature names (strings) to label the x and y axes of the plot. The list should contain exactly two elements. If not provided, the function will infer the names of `X` and `y` from the argument names. Returns ------- None This function does not return any values. It displays the scatter plot. Example ------- >>> import numpy as np >>> import pandas as pd >>> from matplotlib import pyplot as plt >>> >>> X = pd.DataFrame( ... {"Feature1": np.random.rand(100), "Feature2": np.random.rand(100)} ... ) >>> y = pd.Series(np.random.randint(0, 2, size=100)) >>> highlight_group = [10, 20, 30, 40] >>> plot_2d( ... X, ... y, ... highlight_group=highlight_group, ... show_just_group=True, ... features_to_plot=["Feature1", "Feature2"], ... ) Scatter Plot of a 2D dataset: .. image:: /_static/images/plot_2d_pure.png :alt: Scatter Plot of a 2D dataset Scatter Plot of a 2D dataset with a highlighted group: .. image:: /_static/images/plot_2d_highlight_group.png :alt: Scatter Plot of a 2D dataset with a highlighted group Scatter Plot of a 2D dataset with a highlighted group and it's labels: .. image:: /_static/images/plot_2d_show_just_group.png :alt: Scatter Plot of a 2D dataset with a highlighted group and it's labels Scatter Plot of a 2D dataset with y_test and y_pred together in the same graph while caltulating the accuracy over the point and its' selected neighbors. .. image:: /_static/images/plot_2d_neighborhood.png :alt: Scatter Plot of a 2D dataset with y_test and y_pred together with neighborhood accuracy calculation """ import inspect frame = inspect.currentframe() arg_info = inspect.getargvalues(frame) # Get the name of the variables if features_to_plot is not None: X_name = features_to_plot[0] y_name = features_to_plot[1] else: X_name = next(name for name, value in arg_info.locals.items() if value is X) y_name = next(name for name, value in arg_info.locals.items() if value is y) # Validate and extract data X_vals = _validate_and_extract_data(X) y_vals = _validate_and_extract_data(y) plt.figure(figsize=(8, 6)) if show_just_group and highlight_group is not None: # Plot only the samples at highlight_group plt.scatter( X_vals[highlight_group, 0], X_vals[highlight_group, 1], c=y_vals[highlight_group], cmap="viridis", s=50, edgecolor="k", ) # Annotate the points with their indices for idx in highlight_group: plt.text(X_vals[idx, 0], X_vals[idx, 1], str(idx), color="grey", fontsize=10, ha="right") else: # Plot all samples plt.scatter(X_vals[:, 0], X_vals[:, 1], c=y_vals, cmap="viridis", s=50, edgecolor="k") # If highlight_group is provided, outline the selected points if highlight_group is not None: if not isinstance(highlight_group, (np.ndarray, list)): raise TypeError("highlight_group must be either a list or a NumPy array.") plt.scatter( X_vals[highlight_group, 0], X_vals[highlight_group, 1], facecolors="none", edgecolors="red", linewidths=2, s=150, ) plt.xlabel(X_name, fontweight="bold") plt.ylabel(y_name, fontweight="bold") plt.title("2D Dataset Scatter Plot")
[docs] def plot_label_and_prediction(X, y, y_pred, vertical_offset=0.1, features_to_plot=None): """ Plots a 2D scatter plot of a dataset, displaying both the true labels (`y`) and the predicted labels (`y_pred`) with a slight vertical offset for distinction. This function generates a scatter plot where each point represents a sample from the dataset. The true labels are shown in a darker shade, while the predicted labels are displayed with a vertical offset to distinguish them. The axes can be labeled using feature names provided by the user or inferred from the data. Parameters ---------- X : np.ndarray or pd.DataFrame The feature matrix where each row represents a sample and each column represents a feature. This can be either a NumPy array or a pandas DataFrame. y : np.ndarray or pd.Series The true labels for each sample. This can be a one-dimensional NumPy array or a pandas Series. y_pred : np.ndarray or pd.Series The predicted labels for each sample. This can be a one-dimensional NumPy array or a pandas Series. vertical_offset : float, optional (default=0.1) The vertical offset applied to the predicted labels on the plot to distinguish them from the true labels. features_to_plot : list, optional A list of feature names (strings) to label the x and y axes of the plot. If not provided, the function will infer the names of `X` and `y` from the argument names. Returns ------- None This function does not return any value. It displays the scatter plot with true labels and predicted labels. Example ------- >>> import numpy as np >>> import pandas as pd >>> from matplotlib import pyplot as plt >>> >>> # Example dataset >>> X = pd.DataFrame( ... {"Feature1": np.random.rand(100), "Feature2": np.random.rand(100)} ... ) >>> y = pd.Series(np.random.randint(0, 2, size=100)) >>> y_pred = pd.Series(np.random.randint(0, 2, size=100)) >>> >>> # Plot with labels and predictions >>> plot_label_and_prediction( ... X, y, y_pred, vertical_offset=0.1, features_to_plot=["Feature1", "Feature2"] ... ) This will display a 2D scatter plot with both the true labels and the predicted labels, where the predicted labels are slightly offset for clarity. Scatter Plot of a 2D dataset with y_test and y_pred together in the same graph. The predicted values (`y_pred`, shaded circles) are shifted vertically by a small amount to allow better visualization. The plot highlights areas where the classifier incorrectly predicted the true labels, evident by differing colors between `y_test` and `y_pred`. .. image:: /_static/images/plot_2d_label_and_prediction.png :alt: Scatter Plot of a 2D dataset with y_test and y_pred together in the same graph """ import inspect frame = inspect.currentframe() arg_info = inspect.getargvalues(frame) # Get the name of the variables if features_to_plot is not None: X_name = features_to_plot[0] y_name = features_to_plot[1] else: X_name = next(name for name, value in arg_info.locals.items() if value is X) y_name = next(name for name, value in arg_info.locals.items() if value is y) # Validate and extract data X_vals = _validate_and_extract_data(X) y_vals = _validate_and_extract_data(y) y_pred_vals = _validate_and_extract_data(y_pred) # Plotting the data plt.figure(figsize=(8, 6)) plt.scatter( X_vals[:, 0], X_vals[:, 1], c=y_vals, cmap="viridis", s=50, edgecolor="k", label="label (darker)", ) # Plot y_pred with a vertical offset plt.scatter( X_vals[:, 0], X_vals[:, 1] - vertical_offset, c=y_pred_vals, cmap="viridis", s=50, edgecolor="k", label="prediction (lighter)", alpha=0.5, ) plt.xlabel(X_name, fontweight="bold") plt.ylabel(y_name, fontweight="bold") plt.title("2D Dataset: label and prediction") plt.legend()
[docs] def plot_neighborhood( X, y, y_pred, n_neighbors, points_of_interest, vertical_offset=0.1, features_to_plot=None, ax=None, indices_show=None, ): """ Plots a 2D scatter plot of the dataset, highlighting the neighborhood of specific points of interest and calculating accuracy over the selected neighbors. This function visualizes the neighborhood of selected points in a 2D dataset using the k-nearest neighbors algorithm. The convex hull of the points and their neighbors is plotted, and the accuracy of predictions within this neighborhood is displayed. The plot shows both true labels and predicted labels with a slight vertical offset for clarity. Parameters ---------- X : np.ndarray or pd.DataFrame The feature matrix where each row represents a sample and each column represents a feature. It can be either a NumPy array or pandas DataFrame. y : np.ndarray or pd.Series The true labels for each sample. It can be either a NumPy array or pandas Series. y_pred : np.ndarray or pd.Series The predicted labels for each sample. It can be either a NumPy array or pandas Series. n_neighbors : int The number of nearest neighbors to consider when identifying the neighborhood of each point of interest. points_of_interest : list or np.ndarray A list or array of indices corresponding to the points whose neighborhoods will be highlighted. vertical_offset : float, optional (default=0.1) The vertical offset applied to the predicted labels on the plot to distinguish them from the true labels. features_to_plot : list, optional A list of feature names (strings) to label the x and y axes of the plot. If not provided, the function will infer the names of `X` and `y` from the argument names. ax : matplotlib.axes.Axes, optional A matplotlib axes object. If not provided, a new figure and axes will be created. indices_show : list or np.ndarray, required The indices of the points to be shown on the plot. If a point from `points_of_interest` is not included in `indices_show`, a ValueError is raised. Returns ------- None This function does not return any value. It displays the scatter plot with neighborhoods. Example ------- >>> import numpy as np >>> import pandas as pd >>> from matplotlib import pyplot as plt >>> X = pd.DataFrame( ... {"Feature1": np.random.rand(100), "Feature2": np.random.rand(100)} ... ) >>> y = pd.Series(np.random.randint(0, 2, size=100)) >>> y_pred = pd.Series(np.random.randint(0, 2, size=100)) >>> points_of_interest = [10, 50] >>> plot_neighborhood( ... X=X, ... y=y, ... y_pred=y_pred, ... n_neighbors=3, ... points_of_interest=points_of_interest, ... indices_show=np.arange(100), ... features_to_plot=["Feature1", "Feature2"], ... ) The plot will display the convex hull around the neighborhoods of the points of interest and annotate the accuracy of predictions over these neighbors. Raises ------ ValueError If a point in `points_of_interest` is not present in `indices_show`. Notes ----- - The convex hull of each point's neighborhood is plotted as a red dashed line. - The accuracy over the nearest neighbors is calculated and annotated next to each point of interest. - The function uses the k-nearest neighbors algorithm to find neighbors and create neighborhoods. """ import inspect import matplotlib.pyplot as plt from scipy.spatial import ConvexHull from sklearn.metrics import accuracy_score from sklearn.neighbors import NearestNeighbors frame = inspect.currentframe() arg_info = inspect.getargvalues(frame) # Get the name of the variables if features_to_plot is not None: X_name = features_to_plot[0] y_name = features_to_plot[1] else: X_name = next(name for name, value in arg_info.locals.items() if value is X) y_name = next(name for name, value in arg_info.locals.items() if value is y) # Neighborhood on X knn = NearestNeighbors(n_neighbors=n_neighbors + 1) knn.fit(X) # Validate and extract data X = _validate_and_extract_data(X) y = _validate_and_extract_data(y) y_pred = _validate_and_extract_data(y_pred) if ax is None: fig, ax = plt.subplots(1, 1, figsize=(8, 6)) # Plotting the data ax.scatter(X[:, 0], X[:, 1], c=y, cmap="viridis", s=50, edgecolor="k") # Plot y_pred with a vertical offset ax.scatter( X[:, 0], X[:, 1] - vertical_offset, c=y_pred, cmap="viridis", s=50, edgecolor="k", label="y_pred", alpha=0.5 ) for sample_index in points_of_interest: if sample_index not in indices_show: raise ValueError(f"The point {sample_index} is not a point in 'indices_show'.") # Find the position of sample_index in indices_show position = np.where(indices_show == sample_index)[0] # Find the n nearest neighbors of the sample _, indices = knn.kneighbors([X[position[0]]]) # Extract the selected points (sample + neighbors) selected_points = X[indices[0]] # Create a convex hull around the selected points hull = ConvexHull(selected_points) # Plot the convex hull as an outline for simplex in hull.simplices: ax.plot(selected_points[simplex, 0], selected_points[simplex, 1], "r--", linewidth=1) # Annotate all points with their indices for i, (x_plot, y_plot) in enumerate(X): ax.text(x_plot, y_plot, str(indices_show[i]), color="grey", fontsize=10, ha="right") # Accuracy over the neighbors acc = accuracy_score(y[indices][0], y_pred[indices][0]) # Add text near sample_index plt.text( X[position[0]][0], X[position[0]][1], f"Acc = {acc*100:.1f}%", ha="left", va="bottom", fontsize=12, color="blue", ) # Plot labels and title ax.set_xlabel(X_name, fontweight="bold") ax.set_ylabel(y_name, fontweight="bold") plt.title( f"Convex Hull of Samples {', '.join(map(str, points_of_interest))} and its {n_neighbors} Nearest Neighbors." )
[docs] def plot_adp_and_adf(results_df): """ Plots the Accuracy Degradation Profile (ADP) in a 2D plot, showing the percentage of samples above the threshold (ADP) on the vertical axis and dataset size (size_factor) on the horizontal axis, with the x-axis reversed. Points are colored green if the model's performance is acceptable ("OK") and red if there is significant accuracy degradation ("acc degrad!"). The first point where performance degradation occurs is highlighted and circled (Accuracy Degradation Factor, ADF), and a vertical dotted line is drawn at this point to mark the corresponding dataset size. Parameters ---------- results_df : pd.DataFrame A DataFrame containing the following columns: - 'size_factor' (float): Fraction of the dataset used in the evaluation. - 'ADP' (float): The percentage of samples above the threshold for acceptable accuracy. - 'decision' (str): A string indicating whether the model's performance is acceptable ('OK') or shows significant accuracy degradation ('acc degrad!'). - 'average_accuracy' (float): The average accuracy across the samples. - 'variance_accuracy' (float): The variance of the accuracy across the samples. Returns ------- None This function does not return any values. It generates and displays a scatter plot with labeled points and a vertical line at the first accuracy degradation point. Example ------- >>> import pandas as pd >>> data = { ... "size_factor": [0.95, 0.9, 0.85, 0.8, 0.75], ... "ADP": [0.98, 0.97, 0.94, 0.87, 0.76], ... "decision": ["OK", "OK", "OK", "acc degrad!", "acc degrad!"], ... "average_accuracy": [0.97, 0.96, 0.93, 0.85, 0.74], ... "variance_accuracy": [0.02, 0.03, 0.04, 0.05, 0.06], ... } >>> results_df = pd.DataFrame(data) >>> plot_adp_and_adf(results_df) This will display a scatter plot with 'size_factor' on the x-axis (in reverse order) and 'ADP' on the y-axis. The first point where performance degrades ('acc degrad!') will be circled, and a vertical dotted line will be added to indicate the corresponding size factor. Notes ----- - The blue line represents the average accuracy with shaded areas indicating the variance. Green points indicate acceptable performance, while red points mark instances of degradation. - The first 'acc degrad!' point is highlighted with a red circle and a vertical dotted line to emphasize the Accuracy Degradation Factor (ADF). - The x-axis is inverted to show the dataset size decreasing from left to right. """ # Extract relevant columns x = results_df["size_factor"] y = results_df["ADP"] decision = results_df["decision"] average_accuracy = results_df["average_accuracy"] variance_accuracy = results_df["variance_accuracy"] # Create figure plt.figure(figsize=(10, 6)) plt.plot(x, average_accuracy, "-o", color="blue", label="average_accuracy") plt.fill_between( x, average_accuracy - 0.95 * variance_accuracy, average_accuracy + 0.95 * variance_accuracy, color="blue", alpha=0.2, ) # Plot OK points (green) plt.scatter(x[decision == "OK"], y[decision == "OK"], color="green", label="ADP - OK", s=100, edgecolor="k") # Plot acc degrad! points (red) plt.scatter( x[decision == "acc degrad!"], y[decision == "acc degrad!"], color="red", label="ADP - acc degrad!", s=100, edgecolor="k", ) # Find the first 'acc degrad!' point first_degradation = results_df[results_df["decision"] == "acc degrad!"].iloc[0] # Highlight the first 'acc degrad!' point plt.scatter( first_degradation["size_factor"], first_degradation["ADP"], s=300, facecolors="none", edgecolors="red", label="ADF", linewidth=2, ) # Add vertical dotted line at the first degradation point plt.axvline(x=first_degradation["size_factor"], color="red", linestyle="--") # Add the size factor text directly on the x-axis plt.text( first_degradation["size_factor"], plt.gca().get_ylim()[0] - 0.04, # Just below the x-axis f'{first_degradation["size_factor"]:.2f}', color="red", ha="center", ) # Label axes and title plt.xlabel("Size Factor", fontweight="bold") plt.ylabel("Percentual", fontweight="bold") plt.title("Accuracy Degradation Profile (ADP) and Accuracy Degradation Factor (ADF)") # Reverse the x-axis (from 0.95 to 0.05) plt.gca().invert_xaxis() # Show legend plt.legend() plt.grid()