amatorch.plot.output_summary
Functions
|
Compute the mean and sd of the output for each class. |
- amatorch.plot.output_summary.output_statistics(model_output, labels, quantiles=(0.025, 0.975))
Compute the mean and sd of the output for each class.
- Parameters:
model_output (torch.Tensor) – The output of the model (n_stimuli, n_classes) or (n_stimuli,).
labels (torch.int64) – The true class for each stimulus (n_stimuli).
- Returns:
output_dict – Dictionary containing statistics of the output for each class.
- meantorch.Tensor
The mean of the output for each class (n_classes, n_classes) or (n_classes,).
- sdtorch.Tensor
The standard deviation of the output for each class (n_classes, n_classes) or (n_classes,).
- ci_low: torch.Tensor
The lower bound of the 95% confidence interval of the output for each class (n_classes, n_classes) or (n_classes,).
- ci_high: torch.Tensor
The upper bound of the 95% confidence interval of the output for each class (n_classes, n_classes) or (n_classes,).
- median: torch.Tensor
The median of the output for each class (n_classes, n_classes) or (n_classes,).
- Return type:
dict