Spaces:
Runtime error
Runtime error
| import functools as ft | |
| import gradio as gr | |
| import torch | |
| import transformers | |
| from transformers import T5ForConditionalGeneration, T5Tokenizer | |
| tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained( | |
| "roborovski/superprompt-v1" | |
| ) | |
| model: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained( | |
| "roborovski/superprompt-v1" | |
| ) | |
| def super_prompt(text: str, seed: int, max_new_tokens: int, prompt: str) -> str: | |
| transformers.set_seed(seed) | |
| if max_new_tokens <= 0: | |
| max_new_tokens = 150 | |
| with torch.inference_mode(): | |
| if prompt: | |
| input_text = f"{prompt} {text}" | |
| else: | |
| input_text = f"Expand the following prompt to add more detail: {text}" | |
| input_ids = tokenizer(input_text, return_tensors="pt").input_ids | |
| outputs = model.generate( | |
| input_ids, | |
| max_length=max_new_tokens, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_k=50, | |
| top_p=0.95, | |
| ) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| demo = gr.Interface( | |
| fn=super_prompt, | |
| inputs=[ | |
| gr.Textbox(label="input text"), | |
| gr.Slider(label="seed", minimum=0, maximum=2**32-1, step=1), | |
| gr.Slider(label="max_new_tokens", minimum=0, maximum=375, step=1), | |
| gr.Textbox(label="custom prompt", placeholder="leave empty to use default")], | |
| outputs=[gr.Textbox(label="output", lines=6)], | |
| ) | |
| demo.launch() |