@@ -1903,323 +1903,3 @@ def partial_decoding(
19031903 beams = [CTCBeam .from_lm_beam (b ) for b in trimmed_beams ]
19041904
19051905 return beams
1906-
1907-
1908- class TorchAudioCTCPrefixBeamSearcher :
1909- """TorchAudio CTC Prefix Beam Search Decoder.
1910-
1911- This class is a wrapper around the CTC decoder from TorchAudio. It provides a simple interface
1912- where you can either use the CPU or CUDA CTC decoder.
1913-
1914- The CPU decoder is slower but uses less memory. The CUDA decoder is faster but uses more memory.
1915- The CUDA decoder is also only available in the nightly version of torchaudio.
1916-
1917- A lot of features are missing in the CUDA decoder, such as the ability to use a language model,
1918- constraint search, and more. If you want to use those features, you have to use the CPU decoder.
1919-
1920- For more information about the CPU decoder, please refer to the documentation of TorchAudio:
1921- https://pytorch.org/audio/main/generated/torchaudio.models.decoder.ctc_decoder.html
1922-
1923- For more information about the CUDA decoder, please refer to the documentation of TorchAudio:
1924- https://pytorch.org/audio/main/generated/torchaudio.models.decoder.cuda_ctc_decoder.html#torchaudio.models.decoder.cuda_ctc_decoder
1925-
1926- If you want to use the language model, or the lexicon search, please make sure that your
1927- tokenizer/acoustic model uses the same tokens as the language model/lexicon. Otherwise, the decoding will fail.
1928-
1929- The implementation is compatible with SentencePiece Tokens.
1930-
1931- Note: When using CUDA CTC decoder, the blank_index has to be 0. Furthermore, using CUDA CTC decoder
1932- requires the nightly version of torchaudio and a lot of VRAM memory (if you want to use a lot of beams).
1933- Overall, we do recommend to use the CTCBeamSearcher or CTCPrefixBeamSearcher in SpeechBrain if you wants to use
1934- n-gram + beam search decoding. If you wants to have constraint search, please use the CPU version of torchaudio,
1935- and if you want to speedup as much as possible the decoding, please use the CUDA version.
1936-
1937- Arguments
1938- ---------
1939- tokens : list or str
1940- The list of tokens or the path to the tokens file.
1941- If this is a path, then the file should contain one token per line.
1942- lexicon : str, default: None
1943- Lexicon file containing the possible words and corresponding spellings. Each line consists of a word and its space separated spelling.
1944- If None, uses lexicon-free decoding. (default: None)
1945- lm : str, optional
1946- A path containing KenLM language model or None if not using a language model. (default: None)
1947- lm_dict : str, optional
1948- File consisting of the dictionary used for the LM, with a word per line sorted by LM index.
1949- If decoding with a lexicon, entries in lm_dict must also occur in the lexicon file.
1950- If None, dictionary for LM is constructed using the lexicon file. (default: None)
1951- topk : int, optional
1952- Number of top CTCHypothesis to return. (default: 1)
1953- beam_size : int, optional
1954- Numbers of hypotheses to hold after each decode step. (default: 50)
1955- beam_size_token : int, optional
1956- Max number of tokens to consider at each decode step. If None, it is set to the total number of tokens. (default: None)
1957- beam_threshold : float, optional
1958- Threshold for pruning hypothesis. (default: 50)
1959- lm_weight : float, optional
1960- Weight of language model. (default: 2)
1961- word_score : float, optional
1962- Word insertion score. (default: 0)
1963- unk_score : float, optional
1964- Unknown word insertion score. (default: float("-inf"))
1965- sil_score : float, optional
1966- Silence insertion score. (default: 0)
1967- log_add : bool, optional
1968- Whether to use use logadd when merging hypotheses. (default: False)
1969- blank_index : int or str, optional
1970- Index of the blank token. If tokens is a file path, then this should be an str. Otherwise, this should be a int. (default: 0)
1971- sil_index : int or str, optional
1972- Index of the silence token. If tokens is a file path, then this should be an str. Otherwise, this should be a int. (default: 0)
1973- unk_word : str, optional
1974- Unknown word token. (default: "<unk>")
1975- using_cpu_decoder : bool, optional
1976- Whether to use the CPU searcher. If False, then the CUDA decoder is used. (default: True)
1977- blank_skip_threshold : float, optional
1978- Skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding (default: 1.0).
1979- Note: This is only used when using the CUDA decoder, and it might worsen the WER/CER results. Use it at your own risk.
1980-
1981- Example
1982- -------
1983- >>> import torch
1984- >>> from speechbrain.decoders import TorchAudioCTCPrefixBeamSearcher
1985- >>> probs = torch.tensor([[[0.2, 0.0, 0.8], [0.4, 0.0, 0.6]]])
1986- >>> log_probs = torch.log(probs)
1987- >>> lens = torch.tensor([1.0])
1988- >>> blank_index = 2
1989- >>> vocab_list = ["a", "b", "-"]
1990- >>> searcher = TorchAudioCTCPrefixBeamSearcher(
1991- ... tokens=vocab_list, blank_index=blank_index, sil_index=blank_index
1992- ... ) # doctest: +SKIP
1993- >>> hyps = searcher(probs, lens) # doctest: +SKIP
1994- """
1995-
1996- def __init__ (
1997- self ,
1998- tokens : Union [list , str ],
1999- lexicon : Optional [str ] = None ,
2000- lm : Optional [str ] = None ,
2001- lm_dict : Optional [str ] = None ,
2002- topk : int = 1 ,
2003- beam_size : int = 50 ,
2004- beam_size_token : Optional [int ] = None ,
2005- beam_threshold : float = 50 ,
2006- lm_weight : float = 2 ,
2007- word_score : float = 0 ,
2008- unk_score : float = float ("-inf" ),
2009- sil_score : float = 0 ,
2010- log_add : bool = False ,
2011- blank_index : Union [str , int ] = 0 ,
2012- sil_index : Union [str , int ] = 0 ,
2013- unk_word : str = "<unk>" ,
2014- using_cpu_decoder : bool = True ,
2015- blank_skip_threshold : float = 1.0 ,
2016- ):
2017- self .lexicon = lexicon
2018- self .tokens = tokens
2019- self .lm = lm
2020- self .lm_dict = lm_dict
2021- self .topk = topk
2022- self .beam_size = beam_size
2023- self .beam_size_token = beam_size_token
2024- self .beam_threshold = beam_threshold
2025- self .lm_weight = lm_weight
2026- self .word_score = word_score
2027- self .unk_score = unk_score
2028- self .sil_score = sil_score
2029- self .log_add = log_add
2030- self .blank_index = blank_index
2031- self .sil_index = sil_index
2032- self .unk_word = unk_word
2033- self .using_cpu_decoder = using_cpu_decoder
2034- self .blank_skip_threshold = blank_skip_threshold
2035-
2036- if self .using_cpu_decoder :
2037- try :
2038- from torchaudio .models .decoder import ctc_decoder
2039- except ImportError :
2040- raise ImportError (
2041- "ctc_decoder not found. Please install torchaudio and flashlight to use this decoder."
2042- )
2043-
2044- # if this is a path, then torchaudio expect to be an index
2045- # while if its a list then it expects to be a token
2046- if isinstance (self .tokens , str ):
2047- blank_token = self .blank_index
2048- sil_token = self .sil_index
2049- else :
2050- blank_token = self .tokens [self .blank_index ]
2051- sil_token = self .tokens [self .sil_index ]
2052-
2053- self ._ctc_decoder = ctc_decoder (
2054- lexicon = self .lexicon ,
2055- tokens = self .tokens ,
2056- lm = self .lm ,
2057- lm_dict = self .lm_dict ,
2058- nbest = self .topk ,
2059- beam_size = self .beam_size ,
2060- beam_size_token = self .beam_size_token ,
2061- beam_threshold = self .beam_threshold ,
2062- lm_weight = self .lm_weight ,
2063- word_score = self .word_score ,
2064- unk_score = self .unk_score ,
2065- sil_score = self .sil_score ,
2066- log_add = self .log_add ,
2067- blank_token = blank_token ,
2068- sil_token = sil_token ,
2069- unk_word = self .unk_word ,
2070- )
2071- else :
2072- try :
2073- from torchaudio .models .decoder import cuda_ctc_decoder
2074- except ImportError :
2075- raise ImportError (
2076- "cuda_ctc_decoder not found. Please install the latest version of torchaudio to use this decoder."
2077- )
2078- assert self .blank_index == 0 , (
2079- "Index of blank token has to be 0 when using CUDA CTC decoder."
2080- )
2081-
2082- self ._ctc_decoder = cuda_ctc_decoder (
2083- tokens = self .tokens ,
2084- nbest = self .topk ,
2085- beam_size = self .beam_size ,
2086- blank_skip_threshold = self .blank_skip_threshold ,
2087- )
2088-
2089- def decode_beams (
2090- self , log_probs : torch .Tensor , wav_len : Union [torch .Tensor , None ] = None
2091- ) -> List [List [CTCHypothesis ]]:
2092- """Decode log_probs using TorchAudio CTC decoder.
2093-
2094- If `using_cpu_decoder=True` then log_probs and wav_len are moved to CPU before decoding.
2095- When using CUDA CTC decoder, the timestep information is not available. Therefore, the timesteps
2096- in the returned hypotheses are set to None.
2097-
2098- Make sure that the input are in the log domain. The decoder will fail to decode
2099- logits or probabilities. The input should be the log probabilities of the CTC output.
2100-
2101- Arguments
2102- ---------
2103- log_probs : torch.Tensor
2104- The log probabilities of the input audio.
2105- Shape: (batch_size, seq_length, vocab_size)
2106- wav_len : torch.Tensor, default: None
2107- The speechbrain-style relative length. Shape: (batch_size,)
2108- If None, then the length of each audio is assumed to be seq_length.
2109-
2110- Returns
2111- -------
2112- list of list of CTCHypothesis
2113- The decoded hypotheses. The outer list is over the batch dimension, and the inner list is over the topk dimension.
2114- """
2115- if wav_len is not None :
2116- wav_len = log_probs .size (1 ) * wav_len
2117- else :
2118- wav_len = torch .tensor (
2119- [log_probs .size (1 )] * log_probs .size (0 ),
2120- device = log_probs .device ,
2121- dtype = torch .int32 ,
2122- )
2123-
2124- if wav_len .dtype != torch .int32 :
2125- wav_len = wav_len .to (torch .int32 )
2126-
2127- if log_probs .dtype != torch .float32 :
2128- raise ValueError ("log_probs must be float32." )
2129-
2130- # When using CPU decoder, we need to move the log_probs and wav_len to CPU
2131- if self .using_cpu_decoder and log_probs .is_cuda :
2132- log_probs = log_probs .cpu ()
2133-
2134- if self .using_cpu_decoder and wav_len .is_cuda :
2135- wav_len = wav_len .cpu ()
2136-
2137- if not log_probs .is_contiguous ():
2138- raise RuntimeError ("log_probs must be contiguous." )
2139-
2140- results = self ._ctc_decoder (log_probs , wav_len )
2141-
2142- tokens_preds = []
2143- words_preds = []
2144- scores_preds = []
2145- timesteps_preds = []
2146-
2147- # over batch dim
2148- for i in range (len (results )):
2149- if self .using_cpu_decoder :
2150- preds = [
2151- results [i ][j ].tokens .tolist ()
2152- for j in range (len (results [i ]))
2153- ]
2154- preds = [
2155- [self .tokens [token ] for token in tokens ] for tokens in preds
2156- ]
2157- tokens_preds .append (preds )
2158-
2159- timesteps = [
2160- results [i ][j ].timesteps .tolist ()
2161- for j in range (len (results [i ]))
2162- ]
2163- timesteps_preds .append (timesteps )
2164-
2165- else :
2166- # no timesteps is available for CUDA CTC decoder
2167- timesteps = [None for _ in range (len (results [i ]))]
2168- timesteps_preds .append (timesteps )
2169-
2170- preds = [results [i ][j ].tokens for j in range (len (results [i ]))]
2171- preds = [
2172- [self .tokens [token ] for token in tokens ] for tokens in preds
2173- ]
2174- tokens_preds .append (preds )
2175-
2176- words = [results [i ][j ].words for j in range (len (results [i ]))]
2177- words_preds .append (words )
2178-
2179- scores = [results [i ][j ].score for j in range (len (results [i ]))]
2180- scores_preds .append (scores )
2181-
2182- hyps = []
2183- for (
2184- batch_index ,
2185- (batch_text , batch_score , batch_timesteps ),
2186- ) in enumerate (zip (tokens_preds , scores_preds , timesteps_preds )):
2187- hyps .append ([])
2188- for text , score , timestep in zip (
2189- batch_text , batch_score , batch_timesteps
2190- ):
2191- hyps [batch_index ].append (
2192- CTCHypothesis (
2193- text = "" .join (text ),
2194- last_lm_state = None ,
2195- score = score ,
2196- lm_score = score ,
2197- text_frames = timestep ,
2198- )
2199- )
2200- return hyps
2201-
2202- def __call__ (
2203- self , log_probs : torch .Tensor , wav_len : Union [torch .Tensor , None ] = None
2204- ) -> List [List [CTCHypothesis ]]:
2205- """Decode log_probs using TorchAudio CTC decoder.
2206-
2207- If `using_cpu_decoder=True` then log_probs and wav_len are moved to CPU before decoding.
2208- When using CUDA CTC decoder, the timestep information is not available. Therefore, the timesteps
2209- in the returned hypotheses are set to None.
2210-
2211- Arguments
2212- ---------
2213- log_probs : torch.Tensor
2214- The log probabilities of the input audio.
2215- Shape: (batch_size, seq_length, vocab_size)
2216- wav_len : torch.Tensor, default: None
2217- The speechbrain-style relative length. Shape: (batch_size,)
2218- If None, then the length of each audio is assumed to be seq_length.
2219-
2220- Returns
2221- -------
2222- list of list of CTCHypothesis
2223- The decoded hypotheses. The outer list is over the batch dimension, and the inner list is over the topk dimension.
2224- """
2225- return self .decode_beams (log_probs , wav_len )
0 commit comments