Chatterbox tts#1976
Conversation
Signed-off-by: Ssofja <sofiakostandian@gmail.com>
Signed-off-by: Ssofja <sofiakostandian@gmail.com>
Greptile SummaryThis PR adds
Confidence Score: 2/5Not ready to merge — four correctness bugs in the core generation path need to be fixed before the stage can be trusted in production. The implementation has four independent logic bugs, all in the main file. The language code is validated lowercase but stored and forwarded to the model with its original casing, which would break multilingual inference for any caller who passes an uppercase code. The nemo_curator/stages/audio/tts/chatterbox_tts.py requires the most attention; the test file should also be updated to assert the correctness (not just consistency) of reference_voice values and to cover the truncated-ID collision scenario. Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant Stage as ChatterboxTTSStage
participant Ref as Reference Resolver
participant TmpDir as Temp Dir
participant Model as ChatterboxTTS/MTL
participant FS as Output FS
Caller->>Stage: process_batch(tasks)
Stage->>Stage: _ensure_ready()
Stage->>Model: from_pretrained(device)
Stage->>Stage: _load_reference_audio_files()
loop per AudioTask
Stage->>Ref: _assign_reference(speaker, conv_id)
Ref->>TmpDir: write RTTM-stripped WAV or MLS concat WAV
Ref-->>Stage: ref_path
Stage->>Stage: _output_filename(conv_id, speaker, text)
alt file already exists
Stage->>FS: sf.read(audio_path)
else
Stage->>Model: "generate(text, audio_prompt_path=ref_path, ...)"
Model-->>Stage: wav tensor
Stage->>Stage: _normalize_audio(wav)
Stage->>FS: sf.write(audio_path, audio_data, sr)
end
Stage-->>Caller: AudioTask with audio_filepath, duration, reference_voice
end
Reviews (1): Last reviewed commit: "added chatterbox tts stage" | Re-trigger Greptile |
| if language is not None and language.lower() not in SUPPORTED_LANGUAGES: | ||
| raise ValueError( | ||
| f"Unsupported language '{language}'. " | ||
| f"Supported: {', '.join(sorted(SUPPORTED_LANGUAGES))}" | ||
| ) |
There was a problem hiding this comment.
The language code is validated with
.lower() but stored as-is and later passed directly to the model as language_id. If a caller passes "RU" or "FR", it clears the SUPPORTED_LANGUAGES check (because "ru" is in the set), but the raw uppercase string is forwarded to ChatterboxMultilingualTTS.generate. The Chatterbox API expects lowercase ISO 639-1 codes, so inference would either fail or silently produce wrong-language output.
| if language is not None and language.lower() not in SUPPORTED_LANGUAGES: | |
| raise ValueError( | |
| f"Unsupported language '{language}'. " | |
| f"Supported: {', '.join(sorted(SUPPORTED_LANGUAGES))}" | |
| ) | |
| if language is not None and language.lower() not in SUPPORTED_LANGUAGES: | |
| raise ValueError( | |
| f"Unsupported language '{language}'. " | |
| f"Supported: {', '.join(sorted(SUPPORTED_LANGUAGES))}" | |
| ) | |
| if language is not None: | |
| language = language.lower() |
| out_path = os.path.join( | ||
| self.temp_dir, os.path.basename(audio_filepath) | ||
| ) |
There was a problem hiding this comment.
Both
_get_reference_audio_wavs and _get_reference_audio_mls write their RTTM-processed / concatenated output into self.temp_dir using os.path.basename(audio_filepath). When different dialogs contain a speaker file with the same name (e.g. dialog001/Alice.wav and dialog002/Alice.wav), the second write silently overwrites the first temp file. Any speaker already assigned temp_dir/Alice.wav then synthesises audio with the wrong voice without any warning.
| out_path = os.path.join( | |
| self.temp_dir, os.path.basename(audio_filepath) | |
| ) | |
| unique_name = hashlib.md5(audio_filepath.encode()).hexdigest()[:8] + "_" + os.path.basename(audio_filepath) | |
| out_path = os.path.join(self.temp_dir, unique_name) |
| @staticmethod | ||
| def _output_filename(conversation_id: str, speaker: str, text: str) -> str: | ||
| """Deterministic filename: ``{conv_id_short}_{speaker}_{text_hash}.wav``.""" | ||
| conv_short = conversation_id[:12] if len(conversation_id) > 12 else conversation_id | ||
| text_hash = hashlib.md5(text.encode("utf-8")).hexdigest()[:10] | ||
| return f"{conv_short}_{speaker}_{text_hash}.wav" |
There was a problem hiding this comment.
Truncating the conversation ID to 12 characters means two conversations whose IDs share a 12-character prefix (common with structured IDs such as
session1_conv001 / session1_conv002) generate the same filename for the same speaker and text. On a subsequent run the cached file from the first conversation is reused for the second even though a different reference voice may have been assigned, producing a silent audio/metadata mismatch.
| @staticmethod | |
| def _output_filename(conversation_id: str, speaker: str, text: str) -> str: | |
| """Deterministic filename: ``{conv_id_short}_{speaker}_{text_hash}.wav``.""" | |
| conv_short = conversation_id[:12] if len(conversation_id) > 12 else conversation_id | |
| text_hash = hashlib.md5(text.encode("utf-8")).hexdigest()[:10] | |
| return f"{conv_short}_{speaker}_{text_hash}.wav" | |
| @staticmethod | |
| def _output_filename(conversation_id: str, speaker: str, text: str) -> str: | |
| """Deterministic filename: ``{conv_id_hash}_{speaker}_{text_hash}.wav``.""" | |
| conv_hash = hashlib.md5(conversation_id.encode("utf-8")).hexdigest()[:12] | |
| text_hash = hashlib.md5(text.encode("utf-8")).hexdigest()[:10] | |
| return f"{conv_hash}_{speaker}_{text_hash}.wav" |
| reference_wav = self._assign_reference(speaker, conversation_id) | ||
|
|
||
| filename = self._output_filename(conversation_id, speaker, text) | ||
| audio_path = os.path.join(self.output_audio_dir, filename) | ||
|
|
||
| if os.path.exists(audio_path): | ||
| audio_data, _ = sf.read(audio_path) | ||
| else: | ||
| audio_data = self._generate_turn_audio( | ||
| text, reference_wav, conversation_id | ||
| ) | ||
| sf.write(audio_path, audio_data, self.sample_rate) | ||
|
|
||
| duration = len(audio_data) / self.sample_rate | ||
|
|
||
| out_data = dict(data) | ||
| out_data["audio_filepath"] = audio_path | ||
| out_data["duration"] = duration | ||
| out_data["reference_voice"] = Path(reference_wav).parent.name |
There was a problem hiding this comment.
Path(reference_wav).parent.name returns the temp-directory name (e.g. chatterbox_ref_abc123) whenever the reference has been RTTM-processed or comes from the MLS layout, because both code paths write to self.temp_dir/<filename>. Only the raw wavs path (no RTTM) has a meaningful parent (the dialog ID). The emitted reference_voice value should be the MLS speaker ID or the dialog/speaker tag, not an ephemeral temp-dir name.
| reference_wav = self._assign_reference(speaker, conversation_id) | |
| filename = self._output_filename(conversation_id, speaker, text) | |
| audio_path = os.path.join(self.output_audio_dir, filename) | |
| if os.path.exists(audio_path): | |
| audio_data, _ = sf.read(audio_path) | |
| else: | |
| audio_data = self._generate_turn_audio( | |
| text, reference_wav, conversation_id | |
| ) | |
| sf.write(audio_path, audio_data, self.sample_rate) | |
| duration = len(audio_data) / self.sample_rate | |
| out_data = dict(data) | |
| out_data["audio_filepath"] = audio_path | |
| out_data["duration"] = duration | |
| out_data["reference_voice"] = Path(reference_wav).parent.name | |
| reference_wav, ref_id = self._assign_reference(speaker, conversation_id) | |
| filename = self._output_filename(conversation_id, speaker, text) | |
| audio_path = os.path.join(self.output_audio_dir, filename) | |
| if os.path.exists(audio_path): | |
| audio_data, _ = sf.read(audio_path) | |
| else: | |
| audio_data = self._generate_turn_audio( | |
| text, reference_wav, conversation_id | |
| ) | |
| sf.write(audio_path, audio_data, self.sample_rate) | |
| duration = len(audio_data) / self.sample_rate | |
| out_data = dict(data) | |
| out_data["audio_filepath"] = audio_path | |
| out_data["duration"] = duration | |
| out_data["reference_voice"] = ref_id |
| top_p: float = 1.0, | ||
| normalize_audio: bool = True, | ||
| normalize_level: float = -20.0, | ||
| **kwargs, |
| if self.language: | ||
| os.environ["TRANSFORMERS_ATTN_IMPLEMENTATION"] = "eager" | ||
| try: | ||
| import chatterbox.models.t3.llama_configs as _llama_cfgs | ||
| for _cfg_dict in _llama_cfgs.LLAMA_CONFIGS.values(): | ||
| _cfg_dict["attn_implementation"] = "eager" | ||
| except (ImportError, AttributeError): | ||
| pass | ||
|
|
||
| from chatterbox.mtl_tts import ChatterboxMultilingualTTS | ||
| self.model = ChatterboxMultilingualTTS.from_pretrained(device=self.device) | ||
| logger.info(f"Loaded ChatterboxMultilingualTTS (language={self.language})") | ||
| else: | ||
| from chatterbox.tts import ChatterboxTTS | ||
| self.model = ChatterboxTTS.from_pretrained(device=self.device) | ||
| logger.info("Loaded ChatterboxTTS (English)") |
There was a problem hiding this comment.
Can this imports be at the top of the script? Same with the environment variable?
| text_hash = hashlib.md5(text.encode("utf-8")).hexdigest()[:10] | ||
| return f"{conv_short}_{speaker}_{text_hash}.wav" | ||
|
|
||
| def _ensure_ready(self) -> None: |
| if not tasks: | ||
| return [] | ||
|
|
||
| self._ensure_ready() |
There was a problem hiding this comment.
Remove. We should never call setup in process_batch/process.
| return [] | ||
|
|
||
| self._ensure_ready() | ||
| os.makedirs(self.output_audio_dir, exist_ok=True) |
There was a problem hiding this comment.
Should this be in setup?
Description
Add a ChatterboxTTS-based speech synthesis stage (
ChatterboxTTSStage) to the NeMo Curator audio pipeline for generating multi-speaker conversation audio from text.New stage:
ChatterboxTTSStage— Synthesises conversation-turn audio using Chatterbox TTS. Supports both the English-only model (ChatterboxTTS) and the multilingual model (ChatterboxMultilingualTTS, 23 languages). Speaker voices are automatically assigned from a reference audio dataset and stay consistent within each conversation.Key features:
wavs/<dialog>/<speaker>.wav(with optional RTTM silence stripping) and MLS<spk>/<book>/<seg>.flac(auto-concatenated to target duration).Resources(gpus=1)).New files:
nemo_curator/stages/audio/tts/__init__.pynemo_curator/stages/audio/tts/chatterbox_tts.pytests/stages/audio/tts/__init__.pytests/stages/audio/tts/test_chatterbox_tts.py(55 tests)Usage
Supported languages (multilingual mode):
ar,da,de,el,en,es,fi,fr,he,hi,it,ja,ko,ms,nl,no,pl,pt,ru,sv,sw,tr,zhChecklist