subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from time import perf_counter
from typing import Optional
import lhotse.dataset
import torch
from lhotse import CutSet
from lhotse.serialization import SequentialJsonlWriter
from omegaconf import OmegaConf
from transformers import GenerationConfig
from whisper_normalizer.basic import BasicTextNormalizer
from whisper_normalizer.english import EnglishTextNormalizer
from nemo.collections.asr.metrics.wer import word_error_rate_detail
from nemo.collections.common.data.lhotse.cutset import guess_parse_cutset
from nemo.collections.speechlm2 import SALM
from nemo.core.config import hydra_runner
from nemo.utils import logging
class ToAudio(torch.utils.data.Dataset):
def __getitem__(self, cuts: CutSet):
audios, audio_lens = cuts.load_audio(collate=True)
return {"cuts": cuts, "audios": audios, "audio_lens": audio_lens}
@dataclass
class SalmEvalConfig:
pretrained_name: str
inputs: str
batch_size: int = 64
max_new_tokens: int = 128
output_manifest: Optional[str] = "generations.jsonl"
verbose: bool = True
use_normalizer: Optional[str] = "english" # "english", "basic", or "none" / "None"
device: str = "cuda"
dtype: str = "bfloat16"
extra_eos_tokens: Optional[list[str]] = None
system_prompt: Optional[str] = None
user_prompt: Optional[str] = None
@hydra_runner(config_name="SalmEvalConfig", schema=SalmEvalConfig)
def main(cfg: SalmEvalConfig):
logging.info(f'Hydra config:\n{OmegaConf.to_yaml(cfg)}')
model = SALM.from_pretrained(cfg.pretrained_name).eval().to(getattr(torch, cfg.dtype)).to(cfg.device)
cuts = guess_parse_cutset(cfg.inputs).sort_by_duration()
dloader = torch.utils.data.DataLoader(
dataset=ToAudio(),
sampler=lhotse.dataset.DynamicCutSampler(cuts, max_cuts=cfg.batch_size),
num_workers=1,
batch_size=None,
)
normalizer = {"english": EnglishTextNormalizer(), "basic": BasicTextNormalizer()}.get(
cfg.use_normalizer, lambda x: x
)
eos_tokens = [model.text_eos_id]
if cfg.extra_eos_tokens is not None:
for t in cfg.extra_eos_tokens:
tid = model.tokenizer.token_to_id(t)
assert tid is not None, f"Token '{t}' is not in the model's vocabulary."
eos_tokens.append(tid)
# Construct the prompt from ASR data of the form.
# Optional system prompt goes first.
prompt = []
if cfg.system_prompt is not None:
prompt.append({"role": "system", "content": cfg.system_prompt})
# If no user prompt is provided, just use the audio placeholder.
content = model.audio_locator_tag
# Otherwise:
# * if user prompt already has audio placeholder, add it as-is,
# * if not, append audio placeholder at the end of user prompt
if cfg.user_prompt is not None:
content = cfg.user_prompt
if model.audio_locator_tag not in content:
content = f"{content} {model.audio_locator_tag}"
prompt.append({"role": "user", "content": content})
refs = []
hyps = []
input_durations = []
infer_durations = []
for batch_idx, batch in enumerate(dloader):
ts = perf_counter()
answer_ids = model.generate(
prompts=[prompt] * len(batch["cuts"]), # identical prompt for each example
audios=batch["audios"].to(model.device, non_blocking=True),
audio_lens=batch["audio_lens"].to(model.device, non_blocking=True),
generation_config=GenerationConfig(
max_new_tokens=cfg.max_new_tokens,
bos_token_id=model.text_bos_id,
eos_token_id=eos_tokens,
pad_token_id=model.text_pad_id,
),
)
answer_ids = answer_ids.cpu()
batch_infer_duration = perf_counter() - ts
batch_duration = sum(c.duration for c in batch["cuts"])
batch_refs = [normalizer(cut.supervisions[0].text) for cut in batch["cuts"]]
batch_hyps = [
normalizer(model.tokenizer.ids_to_text(parse_hyp(ans, eos_tokens)).strip()) for ans in answer_ids
]
if cfg.verbose:
batch_wer, _, nins, ndel, nsub = word_error_rate_detail(batch_hyps, batch_refs)
batch_rtfx = batch_duration / batch_infer_duration
logging.info(
f"Batch {batch_idx}: WER={batch_wer:.2%} [ins={nins:.2%} del={ndel:.2%} sub={nsub:.2%}] RTFx={batch_rtfx:.1f}"
)
refs.extend(batch_refs)
hyps.extend(batch_hyps)
input_durations.append(batch_duration)
infer_durations.append(batch_infer_duration)
wer, _, nins, ndel, nsub = word_error_rate_detail(hypotheses=hyps, references=refs, use_cer=False)
rtfx = sum(input_durations) / sum(infer_durations)
logging.info(f"WER: {wer:.2%} [ins={nins:.2%} del={ndel:.2%} sub={nsub:.2%}]")
logging.info(f"RTFx: {rtfx:.1f}")
if cfg.output_manifest is not None:
with SequentialJsonlWriter(cfg.output_manifest) as writer:
for cut, ref, hyp in zip(cuts, refs, hyps):
writer.write({"id": cut.id, "duration": cut.duration, "text": ref, "pred_text": hyp})
def parse_hyp(answer: torch.Tensor, eos_tokens: list[int]):
end = torch.isin(answer, torch.tensor(eos_tokens)).nonzero(as_tuple=True)[0]
if end.numel() == 0:
return answer
end = end[0]
return answer[:end]
if __name__ == '__main__':
main()