amatorch.plot.output_summary

Functions

output_statistics(model_output, labels[, ...])

Compute the mean and sd of the output for each class.

amatorch.plot.output_summary.output_statistics(model_output, labels, quantiles=(0.025, 0.975))

Compute the mean and sd of the output for each class.

Parameters:
  • model_output (torch.Tensor) – The output of the model (n_stimuli, n_classes) or (n_stimuli,).

  • labels (torch.int64) – The true class for each stimulus (n_stimuli).

Returns:

output_dict – Dictionary containing statistics of the output for each class.

meantorch.Tensor

The mean of the output for each class (n_classes, n_classes) or (n_classes,).

sdtorch.Tensor

The standard deviation of the output for each class (n_classes, n_classes) or (n_classes,).

ci_low: torch.Tensor

The lower bound of the 95% confidence interval of the output for each class (n_classes, n_classes) or (n_classes,).

ci_high: torch.Tensor

The upper bound of the 95% confidence interval of the output for each class (n_classes, n_classes) or (n_classes,).

median: torch.Tensor

The median of the output for each class (n_classes, n_classes) or (n_classes,).

Return type:

dict