Spaces:
Runtime error
Runtime error
| import torch | |
| from ldm_patched.modules.conds import CONDRegular, CONDCrossAttn | |
| from ldm_patched.modules.samplers import sampling_function | |
| from ldm_patched.modules import model_management | |
| from ldm_patched.modules.ops import cleanup_cache | |
| def cond_from_a1111_to_patched_ldm(cond): | |
| if isinstance(cond, torch.Tensor): | |
| result = dict( | |
| cross_attn=cond, | |
| model_conds=dict( | |
| c_crossattn=CONDCrossAttn(cond), | |
| ) | |
| ) | |
| return [result, ] | |
| cross_attn = cond['crossattn'] | |
| pooled_output = cond['vector'] | |
| result = dict( | |
| cross_attn=cross_attn, | |
| pooled_output=pooled_output, | |
| model_conds=dict( | |
| c_crossattn=CONDCrossAttn(cross_attn), | |
| y=CONDRegular(pooled_output) | |
| ) | |
| ) | |
| return [result, ] | |
| def cond_from_a1111_to_patched_ldm_weighted(cond, weights): | |
| transposed = list(map(list, zip(*weights))) | |
| results = [] | |
| for cond_pre in transposed: | |
| current_indices = [] | |
| current_weight = 0 | |
| for i, w in cond_pre: | |
| current_indices.append(i) | |
| current_weight = w | |
| if hasattr(cond, 'advanced_indexing'): | |
| feed = cond.advanced_indexing(current_indices) | |
| else: | |
| feed = cond[current_indices] | |
| h = cond_from_a1111_to_patched_ldm(feed) | |
| h[0]['strength'] = current_weight | |
| results += h | |
| return results | |
| def forge_sample(self, denoiser_params, cond_scale, cond_composition): | |
| model = self.inner_model.inner_model.forge_objects.unet.model | |
| control = self.inner_model.inner_model.forge_objects.unet.controlnet_linked_list | |
| extra_concat_condition = self.inner_model.inner_model.forge_objects.unet.extra_concat_condition | |
| x = denoiser_params.x | |
| timestep = denoiser_params.sigma | |
| uncond = cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond) | |
| cond = cond_from_a1111_to_patched_ldm_weighted(denoiser_params.text_cond, cond_composition) | |
| model_options = self.inner_model.inner_model.forge_objects.unet.model_options | |
| seed = self.p.seeds[0] | |
| if extra_concat_condition is not None: | |
| image_cond_in = extra_concat_condition | |
| else: | |
| image_cond_in = denoiser_params.image_cond | |
| if isinstance(image_cond_in, torch.Tensor): | |
| if image_cond_in.shape[0] == x.shape[0] \ | |
| and image_cond_in.shape[2] == x.shape[2] \ | |
| and image_cond_in.shape[3] == x.shape[3]: | |
| for i in range(len(uncond)): | |
| uncond[i]['model_conds']['c_concat'] = CONDRegular(image_cond_in) | |
| for i in range(len(cond)): | |
| cond[i]['model_conds']['c_concat'] = CONDRegular(image_cond_in) | |
| if control is not None: | |
| for h in cond + uncond: | |
| h['control'] = control | |
| for modifier in model_options.get('conditioning_modifiers', []): | |
| model, x, timestep, uncond, cond, cond_scale, model_options, seed = modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed) | |
| denoised = sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options, seed) | |
| return denoised | |
| def sampling_prepare(unet, x): | |
| B, C, H, W = x.shape | |
| memory_estimation_function = unet.model_options.get('memory_peak_estimation_modifier', unet.memory_required) | |
| unet_inference_memory = memory_estimation_function([B * 2, C, H, W]) | |
| additional_inference_memory = unet.extra_preserved_memory_during_sampling | |
| additional_model_patchers = unet.extra_model_patchers_during_sampling | |
| if unet.controlnet_linked_list is not None: | |
| additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype()) | |
| additional_model_patchers += unet.controlnet_linked_list.get_models() | |
| model_management.load_models_gpu( | |
| models=[unet] + additional_model_patchers, | |
| memory_required=unet_inference_memory + additional_inference_memory) | |
| real_model = unet.model | |
| percent_to_timestep_function = lambda p: real_model.model_sampling.percent_to_sigma(p) | |
| for cnet in unet.list_controlnets(): | |
| cnet.pre_run(real_model, percent_to_timestep_function) | |
| return | |
| def sampling_cleanup(unet): | |
| for cnet in unet.list_controlnets(): | |
| cnet.cleanup() | |
| cleanup_cache() | |
| return | |