Skip to content

Commit 341e8d3

Browse files
author
Rogier van Dalen
committed
Add functions to compute and combine Gaussian statistics
1 parent d9fb58f commit 341e8d3

2 files changed

Lines changed: 446 additions & 0 deletions

File tree

speechbrain/processing/features.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,14 @@
3232
3333
Authors
3434
* Mirco Ravanelli 2020
35+
* Rogier van Dalen 2025
3536
"""
3637

3738
import math
39+
from typing import Tuple, Union
3840

3941
import torch
42+
from torch.distributed import ReduceOp
4043

4144
from speechbrain.dataio.dataio import length_to_mask
4245
from speechbrain.utils.checkpoints import (
@@ -993,6 +996,165 @@ def forward(self, x):
993996
return cw_x
994997

995998

999+
def gaussian_statistics(x: torch.Tensor, dim: Union[int, tuple, None] = None):
1000+
"""
1001+
Compute first- and second-order moments of data, and return them as the
1002+
count, mean, and variance of a vector over one or more dimensions.
1003+
1004+
Arguments
1005+
---------
1006+
x: torch.Tensor
1007+
The tensor to compute the statistics over
1008+
dim: int | tuple | None
1009+
The dimension or dimensions that the statistics should be computed over.
1010+
The other dimensions are retained in the output.
1011+
If None, then scalar-valued statistics will be returned.
1012+
1013+
Returns
1014+
-------
1015+
count
1016+
The number of sub-vectors or sub-tensors that the statistics were
1017+
computed over.
1018+
mean
1019+
The mean.
1020+
variance
1021+
The variance.
1022+
"""
1023+
1024+
if dim is None:
1025+
number = math.prod(x.shape)
1026+
elif isinstance(dim, int):
1027+
number = x.shape[dim]
1028+
else:
1029+
assert isinstance(dim, tuple)
1030+
if dim == ():
1031+
return 1, x, torch.zeros_like(x)
1032+
number = 1
1033+
for d in dim:
1034+
number *= x.shape[d]
1035+
1036+
# First keep the dimensions so that broadcasting works.
1037+
mean_with_dims = torch.mean(x, dim=dim, keepdim=True)
1038+
mean = (
1039+
torch.squeeze(mean_with_dims)
1040+
if dim is None
1041+
else torch.squeeze(mean_with_dims, dim=dim)
1042+
)
1043+
variance = torch.mean(torch.square(x - mean_with_dims), dim=dim)
1044+
1045+
return (number, mean, variance)
1046+
1047+
1048+
def combine_gaussian_statistics(
1049+
left_statistics: Tuple[int, torch.Tensor, torch.Tensor],
1050+
right_statistics: Tuple[int, torch.Tensor, torch.Tensor],
1051+
):
1052+
"""
1053+
Combine the first- and second-order moments from two pieces of data.
1054+
The data and the result is in the form (count, mean, variance).
1055+
The result is the mean and variance as if they have been computed on the
1056+
concatenation of the data for left_statistics and the data for
1057+
right_statistics.
1058+
1059+
Arguments
1060+
---------
1061+
left_statistics: Tuple[int, torch.Tensor, torch.Tensor]
1062+
One set of statistics.
1063+
right_statistics: Tuple[int, torch.Tensor, torch.Tensor]
1064+
Another set of statistics.
1065+
1066+
Returns
1067+
-------
1068+
count
1069+
The total number of elements in the data.
1070+
mean
1071+
The combined mean.
1072+
variance
1073+
The combined variance, relative to the new mean.
1074+
"""
1075+
left_count, left_mean, left_variance = left_statistics
1076+
right_count, right_mean, right_variance = right_statistics
1077+
assert left_mean.shape == left_variance.shape
1078+
assert left_mean.shape == right_mean.shape
1079+
assert left_variance.shape == right_variance.shape
1080+
1081+
count = left_count + right_count
1082+
1083+
left_weight = left_count / count
1084+
right_weight = right_count / count
1085+
1086+
mean = left_weight * left_mean + right_weight * right_mean
1087+
1088+
# Reconstruct the left and right variances relative to "mean".
1089+
compensated_left_variance = left_variance + torch.square(mean - left_mean)
1090+
compensated_right_variance = right_variance + torch.square(
1091+
mean - right_mean
1092+
)
1093+
1094+
variance = (
1095+
left_weight * compensated_left_variance
1096+
+ right_weight * compensated_right_variance
1097+
)
1098+
1099+
return count, mean, variance
1100+
1101+
1102+
def combine_gaussian_statistics_distributed(
1103+
statistics: Tuple[int, torch.Tensor, torch.Tensor],
1104+
):
1105+
"""
1106+
Combine the first- and second-order moments from multiple pieces of data
1107+
using torch.distributed.
1108+
The data and the result is in the form (count, mean, variance).
1109+
The result is the mean and variance as if they have been computed on the
1110+
concatenation of the data for statistics for all parallel processes.
1111+
1112+
Arguments
1113+
---------
1114+
statistics: Tuple[int, torch.Tensor, torch.Tensor]
1115+
The new statistics for this process, to be combined with the current
1116+
statistics and the new statistics for all other processes.
1117+
1118+
Returns
1119+
-------
1120+
count
1121+
The total number of elements in the data across processes.
1122+
mean
1123+
The combined mean.
1124+
variance
1125+
The combined variance, relative to the new mean.
1126+
"""
1127+
# This is the DDP version of combine_gaussian_statistics above.
1128+
local_count, local_mean, local_variance = statistics
1129+
global_count = ddp_all_reduce(torch.tensor(local_count), ReduceOp.SUM)
1130+
1131+
local_weight = local_count / global_count
1132+
global_mean = ddp_all_reduce(local_weight * local_mean, ReduceOp.SUM)
1133+
1134+
compensated_local_variance = local_variance + torch.square(
1135+
local_mean - global_mean
1136+
)
1137+
global_variance = ddp_all_reduce(
1138+
local_weight * compensated_local_variance, ReduceOp.SUM
1139+
)
1140+
1141+
return (global_count, global_mean, global_variance)
1142+
1143+
1144+
def mean_std_update(x, mask, dim, run_count, run_mean, run_std=None):
1145+
assert torch.all(mask), "Not implemented yet"
1146+
1147+
# TODO implement run_std is None
1148+
current_statistics = (run_count, run_mean, torch.square(run_std))
1149+
new_statistics = combine_gaussian_statistics_distributed(
1150+
gaussian_statistics(x, dim=dim)
1151+
)
1152+
(count, mean, variance) = combine_gaussian_statistics(
1153+
current_statistics, new_statistics
1154+
)
1155+
return count, mean, torch.sqrt(variance)
1156+
1157+
9961158
@register_checkpoint_hooks
9971159
class InputNormalization(torch.nn.Module):
9981160
"""Performs mean and variance normalization of the input tensor.

0 commit comments

Comments
 (0)