@@ -638,7 +638,20 @@ def distance_diff_loss(
638638 reduction : str
639639 Options are 'mean', 'batch', 'batchmean', 'sum'.
640640 See pytorch for 'mean', 'sum'. The 'batch' option returns
641- one loss per item in the batch, 'batchmean' returns sum / batch size.
641+ one loss per item in the batch, 'batchmean' returns sum / batch size
642+
643+ Example
644+ -------
645+ >>> predictions = torch.tensor(
646+ ... [[0.25, 0.5, 0.25, 0.0],
647+ ... [0.05, 0.05, 0.9, 0.0],
648+ ... [8.0, 0.10, 0.05, 0.05]]
649+ ... )
650+ >>> targets = torch.tensor([2., 3., 1.])
651+ >>> length = torch.tensor([.75, .75, 1.])
652+ >>> loss = distance_diff_loss(predictions, targets, length)
653+ >>> loss
654+ tensor(0.2967)
642655 """
643656 return compute_masked_loss (
644657 functools .partial (
@@ -768,7 +781,73 @@ def compute_length_mask(data, length=None, len_dim=1):
768781 data: torch.tensor
769782 the data shape
770783 len_dim: int
771- the length dimension (defaults to 1)"""
784+ the length dimension (defaults to 1)
785+
786+ Returns
787+ -------
788+ mask: torch.Tensor
789+ the mask
790+
791+ Example
792+ -------
793+ >>> data = torch.arange(5)[None, :, None].repeat(3, 1, 2)
794+ >>> data += torch.arange(1, 4)[:, None, None]
795+ >>> data *= torch.arange(1, 3)[None, None, :]
796+ >>> data
797+ tensor([[[ 1, 2],
798+ [ 2, 4],
799+ [ 3, 6],
800+ [ 4, 8],
801+ [ 5, 10]],
802+ <BLANKLINE>
803+ [[ 2, 4],
804+ [ 3, 6],
805+ [ 4, 8],
806+ [ 5, 10],
807+ [ 6, 12]],
808+ <BLANKLINE>
809+ [[ 3, 6],
810+ [ 4, 8],
811+ [ 5, 10],
812+ [ 6, 12],
813+ [ 7, 14]]])
814+ >>> compute_length_mask(data, torch.tensor([1., .4, .8]))
815+ tensor([[[1, 1],
816+ [1, 1],
817+ [1, 1],
818+ [1, 1],
819+ [1, 1]],
820+ <BLANKLINE>
821+ [[1, 1],
822+ [1, 1],
823+ [0, 0],
824+ [0, 0],
825+ [0, 0]],
826+ <BLANKLINE>
827+ [[1, 1],
828+ [1, 1],
829+ [1, 1],
830+ [1, 1],
831+ [0, 0]]])
832+ >>> compute_length_mask(data, torch.tensor([.5, 1., .5]), len_dim=2)
833+ tensor([[[1, 0],
834+ [1, 0],
835+ [1, 0],
836+ [1, 0],
837+ [1, 0]],
838+ <BLANKLINE>
839+ [[1, 1],
840+ [1, 1],
841+ [1, 1],
842+ [1, 1],
843+ [1, 1]],
844+ <BLANKLINE>
845+ [[1, 0],
846+ [1, 0],
847+ [1, 0],
848+ [1, 0],
849+ [1, 0]]])
850+ """
772851 mask = torch .ones_like (data )
773852 if length is not None :
774853 length_mask = length_to_mask (
@@ -778,7 +857,7 @@ def compute_length_mask(data, length=None, len_dim=1):
778857 # Handle any dimensionality of input
779858 while len (length_mask .shape ) < len (mask .shape ):
780859 length_mask = length_mask .unsqueeze (- 1 )
781- length_mask = length_mask .type (mask .dtype )
860+ length_mask = length_mask .type (mask .dtype ). transpose ( 1 , len_dim )
782861 mask *= length_mask
783862 return mask
784863
@@ -1415,8 +1494,51 @@ class VariationalAutoencoderLoss(nn.Module):
14151494 rec_loss: callable
14161495 a function or module to compute the reconstruction loss
14171496
1497+ len_dim: int
1498+ the dimension to be used for the length, if encoding sequences
1499+ of variable length
1500+
14181501 dist_loss_weight: float
14191502 the relative weight of the distribution loss (K-L divergence)
1503+
1504+ Example
1505+ -------
1506+ >>> from speechbrain.nnet.autoencoder import VariationalAutoencoderOutput
1507+ >>> vae_loss = VariationalAutoencoderLoss(dist_loss_weight=0.5)
1508+ >>> predictions = VariationalAutoencoderOutput(
1509+ ... rec=torch.tensor(
1510+ ... [[0.8, 1.0],
1511+ ... [1.2, 0.6],
1512+ ... [0.4, 1.4]]
1513+ ... ),
1514+ ... mean=torch.tensor(
1515+ ... [[0.5, 1.0],
1516+ ... [1.5, 1.0],
1517+ ... [1.0, 1.4]],
1518+ ... ),
1519+ ... log_var=torch.tensor(
1520+ ... [[0.0, -0.2],
1521+ ... [2.0, -2.0],
1522+ ... [0.2, 0.4]],
1523+ ... ),
1524+ ... latent=torch.randn(3, 1),
1525+ ... latent_sample=torch.randn(3, 1),
1526+ ... latent_length=torch.tensor([1., 1., 1.]),
1527+ ... )
1528+ >>> targets = torch.tensor(
1529+ ... [[0.9, 1.1],
1530+ ... [1.4, 0.6],
1531+ ... [0.2, 1.4]]
1532+ ... )
1533+ >>> loss = vae_loss(predictions, targets)
1534+ >>> loss
1535+ tensor(1.1264)
1536+ >>> details = vae_loss.details(predictions, targets)
1537+ >>> details #doctest: +NORMALIZE_WHITESPACE
1538+ VariationalAutoencoderLossDetails(loss=tensor(1.1264),
1539+ rec_loss=tensor(0.0333),
1540+ dist_loss=tensor(2.1861),
1541+ weighted_dist_loss=tensor(1.0930))
14201542 """
14211543
14221544 def __init__ (self , rec_loss = None , len_dim = 1 , dist_loss_weight = 0.001 ):
@@ -1433,7 +1555,7 @@ def forward(self, predictions, targets, length=None, reduction="batchmean"):
14331555 Arguments
14341556 ---------
14351557 predictions: speechbrain.nnet.autoencoder.VariationalAutoencoderOutput
1436- the variational autoencoder output (or a tuple of rec, mean, log_var)
1558+ the variational autoencoder output
14371559 targets: torch.Tensor
14381560 the reconstruction targets
14391561 length : torch.Tensor
@@ -1516,26 +1638,54 @@ class AutoencoderLoss(nn.Module):
15161638 rec_loss: callable
15171639 the callable to compute the reconstruction loss
15181640 len_dim: torch.Tensor
1519- the dimension index to be used for length"""
1641+ the dimension index to be used for length
1642+
1643+
1644+ Example
1645+ -------
1646+ >>> from speechbrain.nnet.autoencoder import AutoencoderOutput
1647+ >>> ae_loss = AutoencoderLoss()
1648+ >>> rec = torch.tensor(
1649+ ... [[0.8, 1.0],
1650+ ... [1.2, 0.6],
1651+ ... [0.4, 1.4]]
1652+ ... )
1653+ >>> predictions = AutoencoderOutput(
1654+ ... rec=rec,
1655+ ... latent=torch.randn(3, 1),
1656+ ... latent_length=torch.tensor([1., 1.])
1657+ ... )
1658+ >>> targets = torch.tensor(
1659+ ... [[0.9, 1.1],
1660+ ... [1.4, 0.6],
1661+ ... [0.2, 1.4]]
1662+ ... )
1663+ >>> ae_loss(predictions, targets)
1664+ tensor(0.0333)
1665+ >>> ae_loss.details(predictions, targets)
1666+ AutoencoderLossDetails(loss=tensor(0.0333), rec_loss=tensor(0.0333))
1667+ """
15201668
15211669 def __init__ (self , rec_loss = None , len_dim = 1 ):
15221670 super ().__init__ ()
1671+ if rec_loss is None :
1672+ rec_loss = mse_loss
15231673 self .rec_loss = rec_loss
15241674 self .len_dim = len_dim
15251675
15261676 def forward (self , predictions , targets , length = None , reduction = "batchmean" ):
15271677 """Computes the autoencoder loss
1678+
15281679 Arguments
15291680 ---------
1530-
15311681 predictions: speechbrain.nnet.autoencoder.AutoencoderOutput
15321682 the autoencoder output
15331683
15341684 targets: torch.Tensor
15351685 targets for the reconstruction loss
15361686
1537- length : torch.Tensor
1538- Length of each sample for computing true error with a mask.
1687+ length: torch.Tensor
1688+ Length of each sample for computing true error with a mask
15391689
15401690 """
15411691 rec_loss = self ._align_length_axis (
@@ -1552,8 +1702,8 @@ def details(self, predictions, targets, length=None, reduction="batchmean"):
15521702
15531703 Arguments
15541704 ---------
1555- predictions: speechbrain.nnet.autoencoder.VariationalAutoencoderOutput
1556- the variational autoencoder output (or a tuple of rec, mean, log_var)
1705+ predictions: speechbrain.nnet.autoencoder.AutoencoderOutput
1706+ the autoencoder output
15571707
15581708 targets: torch.Tensor
15591709 targets for the reconstruction loss
@@ -1567,7 +1717,7 @@ def details(self, predictions, targets, length=None, reduction="batchmean"):
15671717
15681718 Results
15691719 -------
1570- details: VAELossDetails
1720+ details: AutoencoderLossDetails
15711721 a namedtuple with the following parameters
15721722 loss: torch.Tensor
15731723 the combined loss
@@ -1583,8 +1733,11 @@ def _align_length_axis(self, tensor):
15831733
15841734def _reduce_autoencoder_loss (loss , length , reduction ):
15851735 max_len = loss .size (1 )
1586- mask = length_to_mask (length * max_len , max_len )
1587- mask = unsqueeze_as (mask , loss ).expand_as (loss )
1736+ if length is not None :
1737+ mask = length_to_mask (length * max_len , max_len )
1738+ mask = unsqueeze_as (mask , loss ).expand_as (loss )
1739+ else :
1740+ mask = torch .ones_like (loss )
15881741 reduced_loss = reduce_loss (loss * mask , mask , reduction = reduction )
15891742 return reduced_loss
15901743
@@ -1608,6 +1761,27 @@ class Laplacian(nn.Module):
16081761 the size of the Laplacian kernel
16091762 dtype: torch.dtype
16101763 the data type (optional)
1764+
1765+ Example
1766+ -------
1767+ >>> lap = Laplacian(3)
1768+ >>> lap.get_kernel()
1769+ tensor([[[[-1., -1., -1.],
1770+ [-1., 8., -1.],
1771+ [-1., -1., -1.]]]])
1772+ >>> data = torch.eye(6) + torch.eye(6).flip(0)
1773+ >>> data
1774+ tensor([[1., 0., 0., 0., 0., 1.],
1775+ [0., 1., 0., 0., 1., 0.],
1776+ [0., 0., 1., 1., 0., 0.],
1777+ [0., 0., 1., 1., 0., 0.],
1778+ [0., 1., 0., 0., 1., 0.],
1779+ [1., 0., 0., 0., 0., 1.]])
1780+ >>> lap(data.unsqueeze(0))
1781+ tensor([[[ 6., -3., -3., 6.],
1782+ [-3., 4., 4., -3.],
1783+ [-3., 4., 4., -3.],
1784+ [ 6., -3., -3., 6.]]])
16111785 """
16121786
16131787 def __init__ (self , kernel_size , dtype = torch .float32 ):
@@ -1641,7 +1815,10 @@ def forward(self, data):
16411815
16421816class LaplacianVarianceLoss (nn .Module ):
16431817 """The Laplacian variance loss - used to penalize blurriness in image-like
1644- data, such as spectrograms
1818+ data, such as spectrograms.
1819+
1820+ The loss value will be the negative variance because the
1821+ higher the variance, the sharper the image.
16451822
16461823 Arguments
16471824 ---------
@@ -1650,6 +1827,32 @@ class LaplacianVarianceLoss(nn.Module):
16501827
16511828 len_dim: int
16521829 the dimension to be used as the length
1830+
1831+ Example
1832+ -------
1833+ >>> lap_loss = LaplacianVarianceLoss(3)
1834+ >>> data = torch.ones(6, 6).unsqueeze(0)
1835+ >>> data
1836+ tensor([[[1., 1., 1., 1., 1., 1.],
1837+ [1., 1., 1., 1., 1., 1.],
1838+ [1., 1., 1., 1., 1., 1.],
1839+ [1., 1., 1., 1., 1., 1.],
1840+ [1., 1., 1., 1., 1., 1.],
1841+ [1., 1., 1., 1., 1., 1.]]])
1842+ >>> lap_loss(data)
1843+ tensor(-0.)
1844+ >>> data = (
1845+ ... torch.eye(6) + torch.eye(6).flip(0)
1846+ ... ).unsqueeze(0)
1847+ >>> data
1848+ tensor([[[1., 0., 0., 0., 0., 1.],
1849+ [0., 1., 0., 0., 1., 0.],
1850+ [0., 0., 1., 1., 0., 0.],
1851+ [0., 0., 1., 1., 0., 0.],
1852+ [0., 1., 0., 0., 1., 0.],
1853+ [1., 0., 0., 0., 0., 1.]]])
1854+ >>> lap_loss(data)
1855+ tensor(-17.6000)
16531856 """
16541857
16551858 def __init__ (self , kernel_size = 3 , len_dim = 1 ):
0 commit comments