holisticai.robustness.plots.plot_2d#

holisticai.robustness.plots.plot_2d(X, y, highlight_group=None, show_just_group=None, features_to_plot=None)[source]#

Plots a 2D scatter plot of a dataset, with options to highlight or exclusively show a subset of points.

This function generates a 2D scatter plot from the given dataset, where each point represents a sample. Users can highlight specific points or display only a subset of the data. Axis labels can be customized using feature names from the features_to_plot argument.

Parameters

Xnp.ndarray or pd.DataFrame

The feature matrix, where each row is a sample and each column represents a feature. This can be either a NumPy array or a pandas DataFrame with two columns.

ynp.ndarray or pd.Series

The labels for each sample, represented as a one-dimensional NumPy array or pandas Series.

highlight_grouplist or np.ndarray, optional

The indices of the points to be highlighted (outlined in red) or exclusively plotted. If None, no points are highlighted.

show_just_groupbool, optional

If True, only the points specified in highlight_group are plotted, and all other points are hidden. If False, all points are plotted, but the highlighted group is outlined.

features_to_plotlist, optional

A list of feature names (strings) to label the x and y axes of the plot. The list should contain exactly two elements. If not provided, the function will infer the names of X and y from the argument names.

Returns

None

This function does not return any values. It displays the scatter plot.

Example

>>> import numpy as np
>>> import pandas as pd
>>> from matplotlib import pyplot as plt
>>>
>>> X = pd.DataFrame(
...     {"Feature1": np.random.rand(100), "Feature2": np.random.rand(100)}
... )
>>> y = pd.Series(np.random.randint(0, 2, size=100))
>>> highlight_group = [10, 20, 30, 40]
>>> plot_2d(
...     X,
...     y,
...     highlight_group=highlight_group,
...     show_just_group=True,
...     features_to_plot=["Feature1", "Feature2"],
... )

Scatter Plot of a 2D dataset:

Scatter Plot of a 2D dataset

Scatter Plot of a 2D dataset with a highlighted group:

Scatter Plot of a 2D dataset with a highlighted group

Scatter Plot of a 2D dataset with a highlighted group and it’s labels:

Scatter Plot of a 2D dataset with a highlighted group and it's labels

Scatter Plot of a 2D dataset with y_test and y_pred together in the same graph while caltulating the accuracy over the point and its’ selected neighbors.

Scatter Plot of a 2D dataset with y_test and y_pred together with neighborhood accuracy calculation