Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # pylint: disable=missing-function-docstring | |
| import torch | |
| from lhotse.cut import Cut, MixedCut | |
| from nemo.collections.common.data.lhotse.text_adapters import NeMoSFTExample, SourceTargetTextExample | |
| from nemo.collections.common.data.prompt_fn import registered_prompt_format_fn | |
| from nemo.collections.common.prompts.formatter import BOS_SLOT, EOS_SLOT, Modality, PromptFormatter | |
| class Llama2PromptFormatter(PromptFormatter): | |
| """ | |
| This template has been validated to provide identical tokenized results to the official code | |
| in https://github.com/meta-llama/llama/blob/main/llama/generation.py | |
| """ | |
| NAME = "llama2" | |
| OUTPUT_ROLE = "assistant" | |
| TEMPLATE = { | |
| "system_and_user": { | |
| "template": f"{BOS_SLOT}[INST] <<SYS>>\n|system|\n<</SYS>>\n\n|message| [/INST]", | |
| "slots": { | |
| "system": Modality.Text, | |
| "message": Modality.Text, | |
| }, | |
| }, | |
| "user": { | |
| "template": f"{BOS_SLOT}[INST] |message| [/INST]", | |
| "slots": { | |
| "message": Modality.Text, | |
| }, | |
| }, | |
| OUTPUT_ROLE: { | |
| "template": f"|message| {EOS_SLOT}", | |
| "slots": { | |
| "message": Modality.Text, | |
| }, | |
| }, | |
| } | |
| def llama2(cut: Cut, prompt: Llama2PromptFormatter) -> dict[str, torch.Tensor]: | |
| if isinstance(cut, MixedCut): | |
| cut = cut.first_non_padding_cut | |
| if cut.has_custom("context"): | |
| context = cut.context | |
| elif cut.has_custom("question"): | |
| context = cut.question | |
| else: | |
| context = cut.default_context | |
| turns = [] | |
| if cut.has_custom("system_prompt"): | |
| turns.append({"role": "system_and_user", "slots": {"system": cut.system_prompt, "message": context}}) | |
| else: | |
| turns.append({"role": "user", "slots": {"message": context}}) | |
| if (answer := cut.supervisions[0].text) is not None: | |
| turns.append({"role": "assistant", "slots": {"message": answer}}) | |
| return prompt.encode_dialog(turns) | |
| def llama2_src_tgt_text_example(example: SourceTargetTextExample, prompt: Llama2PromptFormatter): | |
| if example.question is not None: | |
| user_turn = { | |
| "role": "system_and_user", | |
| "slots": {"system": example.question.text, "message": example.source.text}, | |
| } | |
| else: | |
| user_turn = { | |
| "role": "user", | |
| "slots": {"message": example.source.text}, | |
| } | |
| return prompt.encode_dialog( | |
| [ | |
| user_turn, | |
| {"role": prompt.OUTPUT_ROLE, "slots": {"message": example.target.text}}, | |
| ] | |
| ) | |
| def llama2_sft_text_example(example: NeMoSFTExample, prompt: Llama2PromptFormatter): | |
| first_user_turn = example.data["conversations"][0]["value"] | |
| if "system" in example.data and example.data["system"]: | |
| first_turn = { | |
| "role": "system_and_user", | |
| "slots": {"system": example.data["system"], "message": first_user_turn}, | |
| } | |
| else: | |
| first_turn = { | |
| "role": "user", | |
| "slots": {"message": first_user_turn}, | |
| } | |
| return prompt.encode_dialog( | |
| [first_turn] | |
| + [ | |
| {"role": "user" if turn["from"] == "User" else prompt.OUTPUT_ROLE, "slots": {"message": turn["value"]}} | |
| for turn in example.data["conversations"][1:] | |
| ] | |
| ) | |
| LLAMA3_BOS = "<|begin_of_text|>" | |
| LLAMA3_HEADER_BEGIN = "<|start_header_id|>" | |
| LLAMA3_HEADER_END = "<|end_header_id|>" | |
| LLAMA3_END_OF_TURN = "<|eot_id|>" | |
| LLAMA3_NL = "\n\n" | |
| class Llama3PromptFormatter(PromptFormatter): | |
| """ | |
| Implemented following the code at: | |
| https://github.com/meta-llama/llama3/blob/main/llama/test_tokenizer.py#L56 | |
| """ | |
| NAME = "llama3" | |
| OUTPUT_ROLE = "assistant" | |
| INFERENCE_PREFIX = f"{LLAMA3_HEADER_BEGIN}assistant{LLAMA3_HEADER_END}{LLAMA3_NL}" | |
| TEMPLATE = { | |
| "preamble": { | |
| "template": LLAMA3_BOS, | |
| }, | |
| "system": { | |
| "template": f"{LLAMA3_HEADER_BEGIN}system{LLAMA3_HEADER_END}{LLAMA3_NL}|message|{LLAMA3_END_OF_TURN}", | |
| "slots": { | |
| "message": Modality.Text, | |
| }, | |
| }, | |
| "user": { | |
| "template": f"{LLAMA3_HEADER_BEGIN}user{LLAMA3_HEADER_END}{LLAMA3_NL}|message|{LLAMA3_END_OF_TURN}", | |
| "slots": { | |
| "message": Modality.Text, | |
| }, | |
| }, | |
| OUTPUT_ROLE: { | |
| "template": f"{INFERENCE_PREFIX}|message|{LLAMA3_END_OF_TURN}", | |
| "slots": { | |
| "message": Modality.Text, | |
| }, | |
| }, | |
| } | |
| def llama3(cut: Cut, prompt: Llama3PromptFormatter) -> dict[str, torch.Tensor]: | |
| if isinstance(cut, MixedCut): | |
| cut = cut.first_non_padding_cut | |
| if cut.has_custom("context"): | |
| context = cut.context | |
| elif cut.has_custom("question"): | |
| context = cut.question | |
| else: | |
| context = cut.default_context | |
| turns = [] | |
| if cut.has_custom("system_prompt"): | |
| turns.append({"role": "system", "slots": {"message": cut.system_prompt}}) | |
| turns.append({"role": "user", "slots": {"message": context}}) | |
| if (answer := cut.supervisions[0].text) is not None: | |
| turns.append({"role": "assistant", "slots": {"message": answer}}) | |
| return prompt.encode_dialog(turns) | |