|
32 | 32 |
|
33 | 33 | Authors |
34 | 34 | * Mirco Ravanelli 2020 |
| 35 | + * Rogier van Dalen 2025 |
35 | 36 | """ |
36 | 37 |
|
37 | 38 | import math |
| 39 | +from typing import Tuple, Union |
38 | 40 |
|
39 | 41 | import torch |
| 42 | +from torch.distributed import ReduceOp |
40 | 43 |
|
41 | 44 | from speechbrain.dataio.dataio import length_to_mask |
42 | 45 | from speechbrain.utils.checkpoints import ( |
@@ -993,6 +996,165 @@ def forward(self, x): |
993 | 996 | return cw_x |
994 | 997 |
|
995 | 998 |
|
| 999 | +def gaussian_statistics(x: torch.Tensor, dim: Union[int, tuple, None] = None): |
| 1000 | + """ |
| 1001 | + Compute first- and second-order moments of data, and return them as the |
| 1002 | + count, mean, and variance of a vector over one or more dimensions. |
| 1003 | +
|
| 1004 | + Arguments |
| 1005 | + --------- |
| 1006 | + x: torch.Tensor |
| 1007 | + The tensor to compute the statistics over |
| 1008 | + dim: int | tuple | None |
| 1009 | + The dimension or dimensions that the statistics should be computed over. |
| 1010 | + The other dimensions are retained in the output. |
| 1011 | + If None, then scalar-valued statistics will be returned. |
| 1012 | +
|
| 1013 | + Returns |
| 1014 | + ------- |
| 1015 | + count |
| 1016 | + The number of sub-vectors or sub-tensors that the statistics were |
| 1017 | + computed over. |
| 1018 | + mean |
| 1019 | + The mean. |
| 1020 | + variance |
| 1021 | + The variance. |
| 1022 | + """ |
| 1023 | + |
| 1024 | + if dim is None: |
| 1025 | + number = math.prod(x.shape) |
| 1026 | + elif isinstance(dim, int): |
| 1027 | + number = x.shape[dim] |
| 1028 | + else: |
| 1029 | + assert isinstance(dim, tuple) |
| 1030 | + if dim == (): |
| 1031 | + return 1, x, torch.zeros_like(x) |
| 1032 | + number = 1 |
| 1033 | + for d in dim: |
| 1034 | + number *= x.shape[d] |
| 1035 | + |
| 1036 | + # First keep the dimensions so that broadcasting works. |
| 1037 | + mean_with_dims = torch.mean(x, dim=dim, keepdim=True) |
| 1038 | + mean = ( |
| 1039 | + torch.squeeze(mean_with_dims) |
| 1040 | + if dim is None |
| 1041 | + else torch.squeeze(mean_with_dims, dim=dim) |
| 1042 | + ) |
| 1043 | + variance = torch.mean(torch.square(x - mean_with_dims), dim=dim) |
| 1044 | + |
| 1045 | + return (number, mean, variance) |
| 1046 | + |
| 1047 | + |
| 1048 | +def combine_gaussian_statistics( |
| 1049 | + left_statistics: Tuple[int, torch.Tensor, torch.Tensor], |
| 1050 | + right_statistics: Tuple[int, torch.Tensor, torch.Tensor], |
| 1051 | +): |
| 1052 | + """ |
| 1053 | + Combine the first- and second-order moments from two pieces of data. |
| 1054 | + The data and the result is in the form (count, mean, variance). |
| 1055 | + The result is the mean and variance as if they have been computed on the |
| 1056 | + concatenation of the data for left_statistics and the data for |
| 1057 | + right_statistics. |
| 1058 | +
|
| 1059 | + Arguments |
| 1060 | + --------- |
| 1061 | + left_statistics: Tuple[int, torch.Tensor, torch.Tensor] |
| 1062 | + One set of statistics. |
| 1063 | + right_statistics: Tuple[int, torch.Tensor, torch.Tensor] |
| 1064 | + Another set of statistics. |
| 1065 | +
|
| 1066 | + Returns |
| 1067 | + ------- |
| 1068 | + count |
| 1069 | + The total number of elements in the data. |
| 1070 | + mean |
| 1071 | + The combined mean. |
| 1072 | + variance |
| 1073 | + The combined variance, relative to the new mean. |
| 1074 | + """ |
| 1075 | + left_count, left_mean, left_variance = left_statistics |
| 1076 | + right_count, right_mean, right_variance = right_statistics |
| 1077 | + assert left_mean.shape == left_variance.shape |
| 1078 | + assert left_mean.shape == right_mean.shape |
| 1079 | + assert left_variance.shape == right_variance.shape |
| 1080 | + |
| 1081 | + count = left_count + right_count |
| 1082 | + |
| 1083 | + left_weight = left_count / count |
| 1084 | + right_weight = right_count / count |
| 1085 | + |
| 1086 | + mean = left_weight * left_mean + right_weight * right_mean |
| 1087 | + |
| 1088 | + # Reconstruct the left and right variances relative to "mean". |
| 1089 | + compensated_left_variance = left_variance + torch.square(mean - left_mean) |
| 1090 | + compensated_right_variance = right_variance + torch.square( |
| 1091 | + mean - right_mean |
| 1092 | + ) |
| 1093 | + |
| 1094 | + variance = ( |
| 1095 | + left_weight * compensated_left_variance |
| 1096 | + + right_weight * compensated_right_variance |
| 1097 | + ) |
| 1098 | + |
| 1099 | + return count, mean, variance |
| 1100 | + |
| 1101 | + |
| 1102 | +def combine_gaussian_statistics_distributed( |
| 1103 | + statistics: Tuple[int, torch.Tensor, torch.Tensor], |
| 1104 | +): |
| 1105 | + """ |
| 1106 | + Combine the first- and second-order moments from multiple pieces of data |
| 1107 | + using torch.distributed. |
| 1108 | + The data and the result is in the form (count, mean, variance). |
| 1109 | + The result is the mean and variance as if they have been computed on the |
| 1110 | + concatenation of the data for statistics for all parallel processes. |
| 1111 | +
|
| 1112 | + Arguments |
| 1113 | + --------- |
| 1114 | + statistics: Tuple[int, torch.Tensor, torch.Tensor] |
| 1115 | + The new statistics for this process, to be combined with the current |
| 1116 | + statistics and the new statistics for all other processes. |
| 1117 | +
|
| 1118 | + Returns |
| 1119 | + ------- |
| 1120 | + count |
| 1121 | + The total number of elements in the data across processes. |
| 1122 | + mean |
| 1123 | + The combined mean. |
| 1124 | + variance |
| 1125 | + The combined variance, relative to the new mean. |
| 1126 | + """ |
| 1127 | + # This is the DDP version of combine_gaussian_statistics above. |
| 1128 | + local_count, local_mean, local_variance = statistics |
| 1129 | + global_count = ddp_all_reduce(torch.tensor(local_count), ReduceOp.SUM) |
| 1130 | + |
| 1131 | + local_weight = local_count / global_count |
| 1132 | + global_mean = ddp_all_reduce(local_weight * local_mean, ReduceOp.SUM) |
| 1133 | + |
| 1134 | + compensated_local_variance = local_variance + torch.square( |
| 1135 | + local_mean - global_mean |
| 1136 | + ) |
| 1137 | + global_variance = ddp_all_reduce( |
| 1138 | + local_weight * compensated_local_variance, ReduceOp.SUM |
| 1139 | + ) |
| 1140 | + |
| 1141 | + return (global_count, global_mean, global_variance) |
| 1142 | + |
| 1143 | + |
| 1144 | +def mean_std_update(x, mask, dim, run_count, run_mean, run_std=None): |
| 1145 | + assert torch.all(mask), "Not implemented yet" |
| 1146 | + |
| 1147 | + # TODO implement run_std is None |
| 1148 | + current_statistics = (run_count, run_mean, torch.square(run_std)) |
| 1149 | + new_statistics = combine_gaussian_statistics_distributed( |
| 1150 | + gaussian_statistics(x, dim=dim) |
| 1151 | + ) |
| 1152 | + (count, mean, variance) = combine_gaussian_statistics( |
| 1153 | + current_statistics, new_statistics |
| 1154 | + ) |
| 1155 | + return count, mean, torch.sqrt(variance) |
| 1156 | + |
| 1157 | + |
996 | 1158 | @register_checkpoint_hooks |
997 | 1159 | class InputNormalization(torch.nn.Module): |
998 | 1160 | """Performs mean and variance normalization of the input tensor. |
|
0 commit comments