ivxivx commited on
Commit
2866fc1
·
unverified ·
1 Parent(s): 0d717b7
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +263 -0
  3. requirements.txt +4 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: HF Story Generator
3
  emoji: 😻
4
  colorFrom: indigo
5
  colorTo: red
 
1
  ---
2
+ title: Story Generator
3
  emoji: 😻
4
  colorFrom: indigo
5
  colorTo: red
app.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %pip install gradio diffusers
2
+
3
+ import os
4
+ from huggingface_hub import login
5
+ login(token=os.getenv("HUGGINGFACEHUB_API_KEY"))
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import random
10
+
11
+ # import spaces #[uncomment to use ZeroGPU]
12
+ from diffusers import DiffusionPipeline
13
+ import torch
14
+
15
+ def get_device_type(idx):
16
+ if torch.cuda.is_available():
17
+ return f"cuda:{idx}" if torch.cuda.device_count() >= idx else "cuda", torch.float16
18
+ elif torch.backends.mps.is_available():
19
+ return "mps", torch.float16
20
+ else:
21
+ return "cpu", torch.float32
22
+
23
+ device0, torch_dtype = get_device_type(0)
24
+ device1, torch_dtype = get_device_type(1)
25
+
26
+ text_generation_model_name="meta-llama/Llama-3.2-3B-Instruct"
27
+
28
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
29
+ tokenizer = AutoTokenizer.from_pretrained(text_generation_model_name)
30
+ model = AutoModelForCausalLM.from_pretrained(text_generation_model_name).to(device0, dtype=torch.float16)
31
+
32
+ text_to_speech_model_name = "suno/bark"
33
+ # Only PyTorch models and some Diffusers pipelines have .to().
34
+ text_to_speech_pipeline = pipeline("text-to-speech", model=text_to_speech_model_name, device=device0)
35
+
36
+ text_to_image_model_name = "stabilityai/sdxl-turbo"
37
+ text_to_image_pipeline = DiffusionPipeline.from_pretrained(text_to_image_model_name, torch_dtype=torch_dtype).to(device1)
38
+
39
+ # from diffusers import StableDiffusionPipeline
40
+ # text_to_image_model_name = "sd-legacy/stable-diffusion-v1-5"
41
+ # text_to_image_pipeline = StableDiffusionPipeline.from_pretrained(text_to_image_model_name, torch_dtype=torch_dtype).to(device1)
42
+
43
+ MAX_SEED = np.iinfo(np.int32).max
44
+ MAX_IMAGE_SIZE = 1024
45
+
46
+ system_prompt = (
47
+ "You are a helpful AI assistant that generates a story for kids based on the input provided. "
48
+ "The story should be engaging and creative. "
49
+ "Here is the input: {input} "
50
+ "Please respond with the story."
51
+ )
52
+
53
+ history = []
54
+
55
+ def generate_text(message):
56
+ global history
57
+ sys_prompt = system_prompt.replace("{input}", message)
58
+ if not history or history[0].get("role") != "system":
59
+ history = [{"role": "system", "content": sys_prompt}] + history
60
+ else:
61
+ history[0]["content"] = sys_prompt
62
+
63
+ history.append({"role": "user", "content": message})
64
+
65
+ # 1. Build prompt from history using chat template
66
+ prompt = tokenizer.apply_chat_template(history, tokenize=False)
67
+ inputs = tokenizer(prompt, return_tensors="pt").to(device0)
68
+ outputs = model.generate(**inputs, max_new_tokens=128)
69
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)
70
+
71
+ if "<|start_header_id|>assistant<|end_header_id|>" in decoded:
72
+ analysis_response = decoded.split("<|start_header_id|>assistant<|end_header_id|>")[-1]
73
+ analysis_response = analysis_response.replace("<|eot_id|>", "").strip()
74
+ elif "<|im_start|>assistant" in decoded:
75
+ # This works for most chat templates that append the assistant's reply at the end
76
+ analysis_response = decoded.split("<|im_start|>assistant")[-1]
77
+ analysis_response = analysis_response.replace("<|im_end|>", "").strip()
78
+ else:
79
+ # Fallback: just return the decoded output
80
+ analysis_response = decoded.strip()
81
+
82
+ return analysis_response
83
+
84
+ def generate_audio(text):
85
+ tts_result = text_to_speech_pipeline(text)
86
+
87
+ # example: [[ 0.00073422 0.00038968 0.00035801 ... -0.01280548 -0.0147996 -0.01798675]]
88
+ audio = tts_result["audio"] # This is a numpy array (wav)
89
+
90
+ audio_array = np.array(audio, dtype=np.float32).flatten()
91
+
92
+ sample_rate = 22050 # or your actual sample rate
93
+
94
+ # gr.Gradio expects tuple[int, np.ndarray]
95
+ return (sample_rate, audio_array)
96
+
97
+
98
+ # @spaces.GPU #[uncomment to use ZeroGPU]
99
+ def generate_image(
100
+ prompt,
101
+ negative_prompt,
102
+ guidance_scale,
103
+ num_inference_steps,
104
+ width,
105
+ height,
106
+ seed,
107
+ randomize_seed,
108
+ progress=gr.Progress(track_tqdm=True),
109
+ ):
110
+ if randomize_seed:
111
+ seed = random.randint(0, MAX_SEED)
112
+
113
+ generator = torch.Generator().manual_seed(seed)
114
+
115
+ image = text_to_image_pipeline(
116
+ prompt=prompt,
117
+ negative_prompt=negative_prompt,
118
+ guidance_scale=guidance_scale,
119
+ num_inference_steps=num_inference_steps,
120
+ width=width,
121
+ height=height,
122
+ generator=generator,
123
+ ).images[0]
124
+
125
+ return image, seed
126
+
127
+ def generate_all(
128
+ prompt,
129
+ negative_prompt,
130
+ guidance_scale,
131
+ num_inference_steps,
132
+ width,
133
+ height,
134
+ seed,
135
+ randomize_seed,
136
+ progress=gr.Progress(track_tqdm=True),
137
+ ):
138
+ # Generate text from the prompt
139
+ story = generate_text(prompt)
140
+
141
+ # Generate audio from the text
142
+ audio = generate_audio(story)
143
+
144
+ # Generate image from the text
145
+ image, seed = generate_image(
146
+ story,
147
+ negative_prompt,
148
+ guidance_scale,
149
+ num_inference_steps,
150
+ width,
151
+ height,
152
+ seed,
153
+ randomize_seed,
154
+ progress=progress,
155
+ )
156
+
157
+ return story, audio, image, seed
158
+
159
+
160
+ examples = [
161
+ "sky",
162
+ "sea",
163
+ ]
164
+
165
+ css = """
166
+ #col-container {
167
+ margin: 0 auto;
168
+ max-width: 640px;
169
+ }
170
+ """
171
+
172
+ with gr.Blocks(css=css) as demo:
173
+ with gr.Column(elem_id="col-container"):
174
+ gr.Markdown("# Story Generator (text & audio on cuda0, image on cuda1)")
175
+
176
+ with gr.Row():
177
+ prompt = gr.Text(
178
+ label="Prompt",
179
+ show_label=False,
180
+ max_lines=1,
181
+ placeholder="Enter your prompt",
182
+ container=False,
183
+ )
184
+
185
+ run_button = gr.Button("Run", scale=0, variant="primary")
186
+
187
+ story = gr.Text(label="Story", show_label=False)
188
+
189
+ audio = gr.Audio(label="Audio", show_label=False)
190
+
191
+ image = gr.Image(label="Image", show_label=False)
192
+
193
+ with gr.Accordion("Advanced Settings", open=False):
194
+ negative_prompt = gr.Text(
195
+ label="Negative prompt",
196
+ max_lines=1,
197
+ placeholder="Enter a negative prompt",
198
+ visible=False,
199
+ )
200
+
201
+ seed = gr.Slider(
202
+ label="Seed",
203
+ minimum=0,
204
+ maximum=MAX_SEED,
205
+ step=1,
206
+ value=0,
207
+ )
208
+
209
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
210
+
211
+ with gr.Row():
212
+ width = gr.Slider(
213
+ label="Width",
214
+ minimum=256,
215
+ maximum=MAX_IMAGE_SIZE,
216
+ step=32,
217
+ value=1024, # Replace with defaults that work for your model
218
+ )
219
+
220
+ height = gr.Slider(
221
+ label="Height",
222
+ minimum=256,
223
+ maximum=MAX_IMAGE_SIZE,
224
+ step=32,
225
+ value=1024, # Replace with defaults that work for your model
226
+ )
227
+
228
+ with gr.Row():
229
+ guidance_scale = gr.Slider(
230
+ label="Guidance scale",
231
+ minimum=0.0,
232
+ maximum=10.0,
233
+ step=0.1,
234
+ value=0.0, # Replace with defaults that work for your model
235
+ )
236
+
237
+ num_inference_steps = gr.Slider(
238
+ label="Number of inference steps",
239
+ minimum=1,
240
+ maximum=50,
241
+ step=1,
242
+ value=2, # Replace with defaults that work for your model
243
+ )
244
+
245
+ gr.Examples(examples=examples, inputs=[prompt])
246
+ gr.on(
247
+ triggers=[run_button.click, prompt.submit],
248
+ fn=generate_all,
249
+ inputs=[
250
+ prompt,
251
+ negative_prompt,
252
+ guidance_scale,
253
+ num_inference_steps,
254
+ width,
255
+ height,
256
+ seed,
257
+ randomize_seed,
258
+ ],
259
+ outputs=[story, audio, image, seed],
260
+ )
261
+
262
+ if __name__ == "__main__":
263
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ accelerate
2
+ diffusers
3
+ torch
4
+ transformers