holisticai.robustness.plots.plot_neighborhood#
- holisticai.robustness.plots.plot_neighborhood(X, y, y_pred, n_neighbors, points_of_interest, vertical_offset=0.1, features_to_plot=None, ax=None, indices_show=None)[source]#
Plots a 2D scatter plot of the dataset, highlighting the neighborhood of specific points of interest and calculating accuracy over the selected neighbors.
This function visualizes the neighborhood of selected points in a 2D dataset using the k-nearest neighbors algorithm. The convex hull of the points and their neighbors is plotted, and the accuracy of predictions within this neighborhood is displayed. The plot shows both true labels and predicted labels with a slight vertical offset for clarity.
Parameters
- Xnp.ndarray or pd.DataFrame
The feature matrix where each row represents a sample and each column represents a feature. It can be either a NumPy array or pandas DataFrame.
- ynp.ndarray or pd.Series
The true labels for each sample. It can be either a NumPy array or pandas Series.
- y_prednp.ndarray or pd.Series
The predicted labels for each sample. It can be either a NumPy array or pandas Series.
- n_neighborsint
The number of nearest neighbors to consider when identifying the neighborhood of each point of interest.
- points_of_interestlist or np.ndarray
A list or array of indices corresponding to the points whose neighborhoods will be highlighted.
- vertical_offsetfloat, optional (default=0.1)
The vertical offset applied to the predicted labels on the plot to distinguish them from the true labels.
- features_to_plotlist, optional
A list of feature names (strings) to label the x and y axes of the plot. If not provided, the function will infer the names of X and y from the argument names.
- axmatplotlib.axes.Axes, optional
A matplotlib axes object. If not provided, a new figure and axes will be created.
- indices_showlist or np.ndarray, required
The indices of the points to be shown on the plot. If a point from points_of_interest is not included in indices_show, a ValueError is raised.
Returns
- None
This function does not return any value. It displays the scatter plot with neighborhoods.
Example
>>> import numpy as np >>> import pandas as pd >>> from matplotlib import pyplot as plt >>> X = pd.DataFrame( ... {"Feature1": np.random.rand(100), "Feature2": np.random.rand(100)} ... ) >>> y = pd.Series(np.random.randint(0, 2, size=100)) >>> y_pred = pd.Series(np.random.randint(0, 2, size=100)) >>> points_of_interest = [10, 50] >>> plot_neighborhood( ... X=X, ... y=y, ... y_pred=y_pred, ... n_neighbors=3, ... points_of_interest=points_of_interest, ... indices_show=np.arange(100), ... features_to_plot=["Feature1", "Feature2"], ... )
The plot will display the convex hull around the neighborhoods of the points of interest and annotate the accuracy of predictions over these neighbors.
- Raises:
ValueError – If a point in points_of_interest is not present in indices_show.
Notes
The convex hull of each point’s neighborhood is plotted as a red dashed line.
The accuracy over the nearest neighbors is calculated and annotated next to each point of interest.
The function uses the k-nearest neighbors algorithm to find neighbors and create neighborhoods.