Spaces:
Runtime error
Runtime error
Commit
·
eafe433
1
Parent(s):
6d21cff
Update app.py
Browse files
app.py
CHANGED
|
@@ -155,16 +155,21 @@ def infer(prompt, negative_prompt, image, model_type="Standard"):
|
|
| 155 |
prompts = num_samples * [prompt]
|
| 156 |
if model_type=="Standard":
|
| 157 |
prompt_ids = std_pipeline.prepare_text_inputs(prompts)
|
| 158 |
-
|
| 159 |
prompt_ids = enc_pipeline.prepare_text_inputs(prompts)
|
|
|
|
|
|
|
| 160 |
prompt_ids = shard(prompt_ids)
|
| 161 |
|
| 162 |
if model_type=="Standard":
|
| 163 |
annotated_image = generate_annotation(image, overlap=False, hand_encoding=False)
|
| 164 |
overlap_image = generate_annotation(image, overlap=True, hand_encoding=False)
|
| 165 |
-
|
| 166 |
annotated_image = generate_annotation(image, overlap=False, hand_encoding=True)
|
| 167 |
overlap_image = generate_annotation(image, overlap=True, hand_encoding=True)
|
|
|
|
|
|
|
|
|
|
| 168 |
validation_image = Image.fromarray(annotated_image).convert("RGB")
|
| 169 |
|
| 170 |
if model_type=="Standard":
|
|
@@ -183,7 +188,7 @@ def infer(prompt, negative_prompt, image, model_type="Standard"):
|
|
| 183 |
neg_prompt_ids=negative_prompt_ids,
|
| 184 |
jit=True,
|
| 185 |
).images
|
| 186 |
-
|
| 187 |
processed_image = enc_pipeline.prepare_image_inputs(num_samples * [validation_image])
|
| 188 |
processed_image = shard(processed_image)
|
| 189 |
|
|
@@ -200,7 +205,8 @@ def infer(prompt, negative_prompt, image, model_type="Standard"):
|
|
| 200 |
jit=True,
|
| 201 |
).images
|
| 202 |
|
| 203 |
-
|
|
|
|
| 204 |
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
|
| 205 |
|
| 206 |
results = [i for i in images]
|
|
|
|
| 155 |
prompts = num_samples * [prompt]
|
| 156 |
if model_type=="Standard":
|
| 157 |
prompt_ids = std_pipeline.prepare_text_inputs(prompts)
|
| 158 |
+
elif model_type=="Hand Encoding":
|
| 159 |
prompt_ids = enc_pipeline.prepare_text_inputs(prompts)
|
| 160 |
+
else:
|
| 161 |
+
pass
|
| 162 |
prompt_ids = shard(prompt_ids)
|
| 163 |
|
| 164 |
if model_type=="Standard":
|
| 165 |
annotated_image = generate_annotation(image, overlap=False, hand_encoding=False)
|
| 166 |
overlap_image = generate_annotation(image, overlap=True, hand_encoding=False)
|
| 167 |
+
elif model_type=="Hand Encoding":
|
| 168 |
annotated_image = generate_annotation(image, overlap=False, hand_encoding=True)
|
| 169 |
overlap_image = generate_annotation(image, overlap=True, hand_encoding=True)
|
| 170 |
+
|
| 171 |
+
else:
|
| 172 |
+
pass
|
| 173 |
validation_image = Image.fromarray(annotated_image).convert("RGB")
|
| 174 |
|
| 175 |
if model_type=="Standard":
|
|
|
|
| 188 |
neg_prompt_ids=negative_prompt_ids,
|
| 189 |
jit=True,
|
| 190 |
).images
|
| 191 |
+
elif model_type=="Hand Encoding":
|
| 192 |
processed_image = enc_pipeline.prepare_image_inputs(num_samples * [validation_image])
|
| 193 |
processed_image = shard(processed_image)
|
| 194 |
|
|
|
|
| 205 |
jit=True,
|
| 206 |
).images
|
| 207 |
|
| 208 |
+
else:
|
| 209 |
+
pass
|
| 210 |
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
|
| 211 |
|
| 212 |
results = [i for i in images]
|