Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -175,20 +175,26 @@ def run_task(text, language, task):
|
|
| 175 |
token=HF_TOKEN,
|
| 176 |
trust_remote_code=True
|
| 177 |
)
|
|
|
|
| 178 |
inputs = indictrans_tokenizer(text, return_tensors="pt", src_lang="san", tgt_lang="en")
|
| 179 |
outputs = indictrans_model.generate(**inputs, max_new_tokens=256)
|
| 180 |
-
translated = indictrans_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
return None, None, translated
|
| 182 |
except Exception as e:
|
| 183 |
print("⚠️ IndicTrans2 failed, falling back to MBART:", e)
|
| 184 |
-
indictrans_tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
|
| 185 |
-
indictrans_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
|
| 186 |
|
| 187 |
# Fallback to MBART with correct language codes
|
| 188 |
try:
|
| 189 |
# MBART-50 requires src_lang and forced_bos_token_id
|
| 190 |
translation_tokenizer.src_lang = "sa_IN" # Sanskrit input
|
| 191 |
forced_bos = translation_tokenizer.lang_code_to_id.get("en_XX", None)
|
|
|
|
| 192 |
inputs = translation_tokenizer(text, return_tensors="pt")
|
| 193 |
outputs = translation_model.generate(
|
| 194 |
**inputs,
|
|
@@ -196,6 +202,9 @@ def run_task(text, language, task):
|
|
| 196 |
forced_bos_token_id=forced_bos
|
| 197 |
)
|
| 198 |
translated = translation_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
| 199 |
return None, None, translated
|
| 200 |
except Exception as e2:
|
| 201 |
return None, None, f"Translation error: {e2}"
|
|
|
|
| 175 |
token=HF_TOKEN,
|
| 176 |
trust_remote_code=True
|
| 177 |
)
|
| 178 |
+
|
| 179 |
inputs = indictrans_tokenizer(text, return_tensors="pt", src_lang="san", tgt_lang="en")
|
| 180 |
outputs = indictrans_model.generate(**inputs, max_new_tokens=256)
|
| 181 |
+
translated = indictrans_tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
| 182 |
+
# Detect nonsense outputs (repeated single word)
|
| 183 |
+
if translated and len(set(translated.split())) == 1:
|
| 184 |
+
translated = f"⚠️ Translation returned nonsense (repeated '{translated.split()[0]}')."
|
| 185 |
+
|
| 186 |
return None, None, translated
|
| 187 |
except Exception as e:
|
| 188 |
print("⚠️ IndicTrans2 failed, falling back to MBART:", e)
|
| 189 |
+
#indictrans_tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
|
| 190 |
+
#indictrans_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
|
| 191 |
|
| 192 |
# Fallback to MBART with correct language codes
|
| 193 |
try:
|
| 194 |
# MBART-50 requires src_lang and forced_bos_token_id
|
| 195 |
translation_tokenizer.src_lang = "sa_IN" # Sanskrit input
|
| 196 |
forced_bos = translation_tokenizer.lang_code_to_id.get("en_XX", None)
|
| 197 |
+
|
| 198 |
inputs = translation_tokenizer(text, return_tensors="pt")
|
| 199 |
outputs = translation_model.generate(
|
| 200 |
**inputs,
|
|
|
|
| 202 |
forced_bos_token_id=forced_bos
|
| 203 |
)
|
| 204 |
translated = translation_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 205 |
+
if translated and len(set(translated.split())) == 1:
|
| 206 |
+
translated = f"⚠️ Translation returned nonsense (repeated '{translated.split()[0]}')."
|
| 207 |
+
|
| 208 |
return None, None, translated
|
| 209 |
except Exception as e2:
|
| 210 |
return None, None, f"Translation error: {e2}"
|