amatorch.plot

Utilities for plotting model parameters (filters and statistics) and model outputs (responses, posteriors and estimates).

amatorch.plot.draw_color_bar(colormap, limits, fig, title=None)

Draw a color bar for the given colormap and limits.

Parameters:
  • colormap (str or matplotlib.colors.Colormap) – The colormap to use.

  • limits (list) – The minimum and maximum values for the color scale.

  • fig (matplotlib.figure.Figure) – The figure to draw the color bar on.

Returns:

color_bar – The created color bar.

Return type:

ColorbarBase

amatorch.plot.plot_estimates_statistics(estimates, true_values, ax=None, ci_bars=False, quantiles=(0.025, 0.975))

Plot the mean estimates by true value.

Parameters:
  • estimates (torch.Tensor) – The estimated value for each stimulus (n_stimuli).

  • labels (torch.int64) – The label for each stimulus (n_stimuli).

  • ax (matplotlib.axes.Axes) – The axes to plot on.

  • ci_bars (bool) – Whether to plot confidence intervals.

amatorch.plot.plot_filters(model, n_cols=2, n_filters=10)

Plot the filters of an AMA model.

Parameters:
  • model (AMA model) – The model to plot the filters for.

  • n_cols (int, optional) – Number of columns in the grid layout. Default is 3.

  • n_filters (int, optional) – Number of filters to plot. Default is 10.

amatorch.plot.scatter_estimates(estimates, true_values, jitter=0.0, ax=None, points_per_class=1000)

Plot the model estimates.

Parameters:
  • estimates (torch.Tensor) – The estimated value for each stimulus (n_stimuli).

  • true_values (torch.Tensor) – The true value for each stimulus (n_stimuli).

  • jitter (float) – Amount of noise to add to the true values for plotting.

amatorch.plot.scatter_responses(responses, labels, ax=None, values=None, filter_pair=(0, 1), n_points=1000, classes_plot=None, legend_type='none', **kwargs)

Plot scatter of the responses to different categories.

Parameters:
  • responses (torch.Tensor) – Responses to the stimuli. Shape (n_stimuli, n_filters).

  • labels (torch.int64) – Class labels of each point with shape (n_points).

  • ax (matplotlib.axes.Axes, optional) – Axes to plot the scatter. If None, a new figure is created. The default is None.

  • values (torch.Tensor, optional) – Values to color the classes. The default is linearly spaced values between -1 and 1.

  • filter_pair (tuple, optional) – Pair of filters to plot. The default is (0, 1).

  • n_points (int, optional) – Number of points per class to plot. The default is 1000.

  • classes_plot (list, optional) – List of classes to plot. The default is all classes.

  • legend_type (str, optional) – Type of legend to add: ‘none’, ‘continuous’, ‘discrete’.

Returns:

ax – Axes with the scatter plot.

Return type:

matplotlib.axes.Axes

amatorch.plot.statistics_ellipses(means, covariances, filter_pair=(0, 1), ax=None, values=None, classes_plot=None, color_map='viridis', legend_type='none', **kwargs)

Plot the ellipses of the filter response statistics across classes.

Parameters:
  • means (torch.Tensor) – Means of the filter responses. Shape (n_classes, n_filters).

  • covariances (torch.Tensor) – Covariances of the filter responses. Shape (n_classes, n_filters, n_filters).

  • filter_pair (tuple of int, optional) – Pair of filters to plot. The default is [0, 1].

  • ax (matplotlib.axes.Axes, optional) – Axes to plot the ellipses. If None, a new figure is created. The default is None.

  • values (torch.Tensor, optional) – Values to color code the ellipses. Each value corresponds to a class. The default is linearly spaced values between -1 and 1.

  • classes_plot (list, optional) – List of classes to plot. The default is all classes.

  • color_map (str or matplotlib.colors.Colormap, optional) – Color map to use for the ellipses. The default is ‘viridis’.

  • legend_type (str, optional) – Type of legend to add: ‘none’, ‘continuous’, ‘discrete’.

Returns:

ax – Axes with the scatter plot.

Return type:

matplotlib.axes.Axes

Modules

colors

filters

inference

output_summary

responses

statistics