Skip to content

Commit 5d4115d

Browse files
committed
Diffusion: Many cosmetic changes, some minor fixes
1 parent a72abbe commit 5d4115d

4 files changed

Lines changed: 309 additions & 20 deletions

File tree

speechbrain/dataio/dataio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,7 +1189,7 @@ def clean_padding_(tensor, length, len_dim=1, mask_value=0.0):
11891189
>>> x
11901190
tensor([[[ 0, 1, 10, 10, 10],
11911191
[ 0, 2, 10, 10, 10]],
1192-
<BLANKLINE>
1192+
<BLANKLINE>
11931193
[[ 1, 2, 3, 4, 5],
11941194
[ 2, 4, 6, 8, 10]],
11951195
<BLANKLINE>
@@ -1254,7 +1254,7 @@ def clean_padding(tensor, length, len_dim=1, mask_value=0.0):
12541254
>>> x_p
12551255
tensor([[[ 0, 1, 10, 10, 10],
12561256
[ 0, 2, 10, 10, 10]],
1257-
<BLANKLINE>
1257+
<BLANKLINE>
12581258
[[ 1, 2, 3, 4, 5],
12591259
[ 2, 4, 6, 8, 10]],
12601260
<BLANKLINE>

speechbrain/nnet/losses.py

Lines changed: 217 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,20 @@ def distance_diff_loss(
638638
reduction : str
639639
Options are 'mean', 'batch', 'batchmean', 'sum'.
640640
See pytorch for 'mean', 'sum'. The 'batch' option returns
641-
one loss per item in the batch, 'batchmean' returns sum / batch size.
641+
one loss per item in the batch, 'batchmean' returns sum / batch size
642+
643+
Example
644+
-------
645+
>>> predictions = torch.tensor(
646+
... [[0.25, 0.5, 0.25, 0.0],
647+
... [0.05, 0.05, 0.9, 0.0],
648+
... [8.0, 0.10, 0.05, 0.05]]
649+
... )
650+
>>> targets = torch.tensor([2., 3., 1.])
651+
>>> length = torch.tensor([.75, .75, 1.])
652+
>>> loss = distance_diff_loss(predictions, targets, length)
653+
>>> loss
654+
tensor(0.2967)
642655
"""
643656
return compute_masked_loss(
644657
functools.partial(
@@ -768,7 +781,73 @@ def compute_length_mask(data, length=None, len_dim=1):
768781
data: torch.tensor
769782
the data shape
770783
len_dim: int
771-
the length dimension (defaults to 1)"""
784+
the length dimension (defaults to 1)
785+
786+
Returns
787+
-------
788+
mask: torch.Tensor
789+
the mask
790+
791+
Example
792+
-------
793+
>>> data = torch.arange(5)[None, :, None].repeat(3, 1, 2)
794+
>>> data += torch.arange(1, 4)[:, None, None]
795+
>>> data *= torch.arange(1, 3)[None, None, :]
796+
>>> data
797+
tensor([[[ 1, 2],
798+
[ 2, 4],
799+
[ 3, 6],
800+
[ 4, 8],
801+
[ 5, 10]],
802+
<BLANKLINE>
803+
[[ 2, 4],
804+
[ 3, 6],
805+
[ 4, 8],
806+
[ 5, 10],
807+
[ 6, 12]],
808+
<BLANKLINE>
809+
[[ 3, 6],
810+
[ 4, 8],
811+
[ 5, 10],
812+
[ 6, 12],
813+
[ 7, 14]]])
814+
>>> compute_length_mask(data, torch.tensor([1., .4, .8]))
815+
tensor([[[1, 1],
816+
[1, 1],
817+
[1, 1],
818+
[1, 1],
819+
[1, 1]],
820+
<BLANKLINE>
821+
[[1, 1],
822+
[1, 1],
823+
[0, 0],
824+
[0, 0],
825+
[0, 0]],
826+
<BLANKLINE>
827+
[[1, 1],
828+
[1, 1],
829+
[1, 1],
830+
[1, 1],
831+
[0, 0]]])
832+
>>> compute_length_mask(data, torch.tensor([.5, 1., .5]), len_dim=2)
833+
tensor([[[1, 0],
834+
[1, 0],
835+
[1, 0],
836+
[1, 0],
837+
[1, 0]],
838+
<BLANKLINE>
839+
[[1, 1],
840+
[1, 1],
841+
[1, 1],
842+
[1, 1],
843+
[1, 1]],
844+
<BLANKLINE>
845+
[[1, 0],
846+
[1, 0],
847+
[1, 0],
848+
[1, 0],
849+
[1, 0]]])
850+
"""
772851
mask = torch.ones_like(data)
773852
if length is not None:
774853
length_mask = length_to_mask(
@@ -778,7 +857,7 @@ def compute_length_mask(data, length=None, len_dim=1):
778857
# Handle any dimensionality of input
779858
while len(length_mask.shape) < len(mask.shape):
780859
length_mask = length_mask.unsqueeze(-1)
781-
length_mask = length_mask.type(mask.dtype)
860+
length_mask = length_mask.type(mask.dtype).transpose(1, len_dim)
782861
mask *= length_mask
783862
return mask
784863

@@ -1415,8 +1494,51 @@ class VariationalAutoencoderLoss(nn.Module):
14151494
rec_loss: callable
14161495
a function or module to compute the reconstruction loss
14171496
1497+
len_dim: int
1498+
the dimension to be used for the length, if encoding sequences
1499+
of variable length
1500+
14181501
dist_loss_weight: float
14191502
the relative weight of the distribution loss (K-L divergence)
1503+
1504+
Example
1505+
-------
1506+
>>> from speechbrain.nnet.autoencoder import VariationalAutoencoderOutput
1507+
>>> vae_loss = VariationalAutoencoderLoss(dist_loss_weight=0.5)
1508+
>>> predictions = VariationalAutoencoderOutput(
1509+
... rec=torch.tensor(
1510+
... [[0.8, 1.0],
1511+
... [1.2, 0.6],
1512+
... [0.4, 1.4]]
1513+
... ),
1514+
... mean=torch.tensor(
1515+
... [[0.5, 1.0],
1516+
... [1.5, 1.0],
1517+
... [1.0, 1.4]],
1518+
... ),
1519+
... log_var=torch.tensor(
1520+
... [[0.0, -0.2],
1521+
... [2.0, -2.0],
1522+
... [0.2, 0.4]],
1523+
... ),
1524+
... latent=torch.randn(3, 1),
1525+
... latent_sample=torch.randn(3, 1),
1526+
... latent_length=torch.tensor([1., 1., 1.]),
1527+
... )
1528+
>>> targets = torch.tensor(
1529+
... [[0.9, 1.1],
1530+
... [1.4, 0.6],
1531+
... [0.2, 1.4]]
1532+
... )
1533+
>>> loss = vae_loss(predictions, targets)
1534+
>>> loss
1535+
tensor(1.1264)
1536+
>>> details = vae_loss.details(predictions, targets)
1537+
>>> details #doctest: +NORMALIZE_WHITESPACE
1538+
VariationalAutoencoderLossDetails(loss=tensor(1.1264),
1539+
rec_loss=tensor(0.0333),
1540+
dist_loss=tensor(2.1861),
1541+
weighted_dist_loss=tensor(1.0930))
14201542
"""
14211543

14221544
def __init__(self, rec_loss=None, len_dim=1, dist_loss_weight=0.001):
@@ -1433,7 +1555,7 @@ def forward(self, predictions, targets, length=None, reduction="batchmean"):
14331555
Arguments
14341556
---------
14351557
predictions: speechbrain.nnet.autoencoder.VariationalAutoencoderOutput
1436-
the variational autoencoder output (or a tuple of rec, mean, log_var)
1558+
the variational autoencoder output
14371559
targets: torch.Tensor
14381560
the reconstruction targets
14391561
length : torch.Tensor
@@ -1516,26 +1638,54 @@ class AutoencoderLoss(nn.Module):
15161638
rec_loss: callable
15171639
the callable to compute the reconstruction loss
15181640
len_dim: torch.Tensor
1519-
the dimension index to be used for length"""
1641+
the dimension index to be used for length
1642+
1643+
1644+
Example
1645+
-------
1646+
>>> from speechbrain.nnet.autoencoder import AutoencoderOutput
1647+
>>> ae_loss = AutoencoderLoss()
1648+
>>> rec = torch.tensor(
1649+
... [[0.8, 1.0],
1650+
... [1.2, 0.6],
1651+
... [0.4, 1.4]]
1652+
... )
1653+
>>> predictions = AutoencoderOutput(
1654+
... rec=rec,
1655+
... latent=torch.randn(3, 1),
1656+
... latent_length=torch.tensor([1., 1.])
1657+
... )
1658+
>>> targets = torch.tensor(
1659+
... [[0.9, 1.1],
1660+
... [1.4, 0.6],
1661+
... [0.2, 1.4]]
1662+
... )
1663+
>>> ae_loss(predictions, targets)
1664+
tensor(0.0333)
1665+
>>> ae_loss.details(predictions, targets)
1666+
AutoencoderLossDetails(loss=tensor(0.0333), rec_loss=tensor(0.0333))
1667+
"""
15201668

15211669
def __init__(self, rec_loss=None, len_dim=1):
15221670
super().__init__()
1671+
if rec_loss is None:
1672+
rec_loss = mse_loss
15231673
self.rec_loss = rec_loss
15241674
self.len_dim = len_dim
15251675

15261676
def forward(self, predictions, targets, length=None, reduction="batchmean"):
15271677
"""Computes the autoencoder loss
1678+
15281679
Arguments
15291680
---------
1530-
15311681
predictions: speechbrain.nnet.autoencoder.AutoencoderOutput
15321682
the autoencoder output
15331683
15341684
targets: torch.Tensor
15351685
targets for the reconstruction loss
15361686
1537-
length : torch.Tensor
1538-
Length of each sample for computing true error with a mask.
1687+
length: torch.Tensor
1688+
Length of each sample for computing true error with a mask
15391689
15401690
"""
15411691
rec_loss = self._align_length_axis(
@@ -1552,8 +1702,8 @@ def details(self, predictions, targets, length=None, reduction="batchmean"):
15521702
15531703
Arguments
15541704
---------
1555-
predictions: speechbrain.nnet.autoencoder.VariationalAutoencoderOutput
1556-
the variational autoencoder output (or a tuple of rec, mean, log_var)
1705+
predictions: speechbrain.nnet.autoencoder.AutoencoderOutput
1706+
the autoencoder output
15571707
15581708
targets: torch.Tensor
15591709
targets for the reconstruction loss
@@ -1567,7 +1717,7 @@ def details(self, predictions, targets, length=None, reduction="batchmean"):
15671717
15681718
Results
15691719
-------
1570-
details: VAELossDetails
1720+
details: AutoencoderLossDetails
15711721
a namedtuple with the following parameters
15721722
loss: torch.Tensor
15731723
the combined loss
@@ -1583,8 +1733,11 @@ def _align_length_axis(self, tensor):
15831733

15841734
def _reduce_autoencoder_loss(loss, length, reduction):
15851735
max_len = loss.size(1)
1586-
mask = length_to_mask(length * max_len, max_len)
1587-
mask = unsqueeze_as(mask, loss).expand_as(loss)
1736+
if length is not None:
1737+
mask = length_to_mask(length * max_len, max_len)
1738+
mask = unsqueeze_as(mask, loss).expand_as(loss)
1739+
else:
1740+
mask = torch.ones_like(loss)
15881741
reduced_loss = reduce_loss(loss * mask, mask, reduction=reduction)
15891742
return reduced_loss
15901743

@@ -1608,6 +1761,27 @@ class Laplacian(nn.Module):
16081761
the size of the Laplacian kernel
16091762
dtype: torch.dtype
16101763
the data type (optional)
1764+
1765+
Example
1766+
-------
1767+
>>> lap = Laplacian(3)
1768+
>>> lap.get_kernel()
1769+
tensor([[[[-1., -1., -1.],
1770+
[-1., 8., -1.],
1771+
[-1., -1., -1.]]]])
1772+
>>> data = torch.eye(6) + torch.eye(6).flip(0)
1773+
>>> data
1774+
tensor([[1., 0., 0., 0., 0., 1.],
1775+
[0., 1., 0., 0., 1., 0.],
1776+
[0., 0., 1., 1., 0., 0.],
1777+
[0., 0., 1., 1., 0., 0.],
1778+
[0., 1., 0., 0., 1., 0.],
1779+
[1., 0., 0., 0., 0., 1.]])
1780+
>>> lap(data.unsqueeze(0))
1781+
tensor([[[ 6., -3., -3., 6.],
1782+
[-3., 4., 4., -3.],
1783+
[-3., 4., 4., -3.],
1784+
[ 6., -3., -3., 6.]]])
16111785
"""
16121786

16131787
def __init__(self, kernel_size, dtype=torch.float32):
@@ -1641,7 +1815,10 @@ def forward(self, data):
16411815

16421816
class LaplacianVarianceLoss(nn.Module):
16431817
"""The Laplacian variance loss - used to penalize blurriness in image-like
1644-
data, such as spectrograms
1818+
data, such as spectrograms.
1819+
1820+
The loss value will be the negative variance because the
1821+
higher the variance, the sharper the image.
16451822
16461823
Arguments
16471824
---------
@@ -1650,6 +1827,32 @@ class LaplacianVarianceLoss(nn.Module):
16501827
16511828
len_dim: int
16521829
the dimension to be used as the length
1830+
1831+
Example
1832+
-------
1833+
>>> lap_loss = LaplacianVarianceLoss(3)
1834+
>>> data = torch.ones(6, 6).unsqueeze(0)
1835+
>>> data
1836+
tensor([[[1., 1., 1., 1., 1., 1.],
1837+
[1., 1., 1., 1., 1., 1.],
1838+
[1., 1., 1., 1., 1., 1.],
1839+
[1., 1., 1., 1., 1., 1.],
1840+
[1., 1., 1., 1., 1., 1.],
1841+
[1., 1., 1., 1., 1., 1.]]])
1842+
>>> lap_loss(data)
1843+
tensor(-0.)
1844+
>>> data = (
1845+
... torch.eye(6) + torch.eye(6).flip(0)
1846+
... ).unsqueeze(0)
1847+
>>> data
1848+
tensor([[[1., 0., 0., 0., 0., 1.],
1849+
[0., 1., 0., 0., 1., 0.],
1850+
[0., 0., 1., 1., 0., 0.],
1851+
[0., 0., 1., 1., 0., 0.],
1852+
[0., 1., 0., 0., 1., 0.],
1853+
[1., 0., 0., 0., 0., 1.]]])
1854+
>>> lap_loss(data)
1855+
tensor(-17.6000)
16531856
"""
16541857

16551858
def __init__(self, kernel_size=3, len_dim=1):

0 commit comments

Comments
 (0)