|
|
import math |
|
|
from typing import List, Optional, Union |
|
|
|
|
|
import torch |
|
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline |
|
|
from diffusers.utils import BaseOutput |
|
|
from diffusers.utils import ( |
|
|
logging, |
|
|
) |
|
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
from dataclasses import dataclass |
|
|
from model.dit import DiT_Llama |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PiTPipelineOutput(BaseOutput): |
|
|
image_embeds: torch.Tensor |
|
|
|
|
|
|
|
|
class PiTPipeline(DiffusionPipeline): |
|
|
|
|
|
def __init__(self, prior: DiT_Llama): |
|
|
super().__init__() |
|
|
|
|
|
self.register_modules( |
|
|
prior=prior, |
|
|
) |
|
|
|
|
|
def prepare_latents(self, shape, dtype, device, generator, latents): |
|
|
if latents is None: |
|
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
|
else: |
|
|
if latents.shape != shape: |
|
|
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") |
|
|
latents = latents.to(device) |
|
|
|
|
|
return latents |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, |
|
|
cond_sequence: torch.FloatTensor, |
|
|
negative_cond_sequence: torch.FloatTensor, |
|
|
num_images_per_prompt: int = 1, |
|
|
num_inference_steps: int = 25, |
|
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
|
latents: Optional[torch.FloatTensor] = None, |
|
|
init_latents: Optional[torch.FloatTensor] = None, |
|
|
strength: Optional[float] = None, |
|
|
guidance_scale: float = 1.0, |
|
|
output_type: Optional[str] = "pt", |
|
|
return_dict: bool = True, |
|
|
): |
|
|
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
|
|
device = self._execution_device |
|
|
|
|
|
batch_size = cond_sequence.shape[0] |
|
|
batch_size = batch_size * num_images_per_prompt |
|
|
|
|
|
embedding_dim = self.prior.config.embedding_dim |
|
|
|
|
|
latents = self.prepare_latents( |
|
|
(batch_size, 16, embedding_dim), |
|
|
self.prior.dtype, |
|
|
device, |
|
|
generator, |
|
|
latents, |
|
|
) |
|
|
|
|
|
if init_latents is not None: |
|
|
init_latents = init_latents.to(latents.device) |
|
|
latents = (strength) * latents + (1 - strength) * init_latents |
|
|
|
|
|
|
|
|
dt = 1.0 / num_inference_steps |
|
|
dt = torch.tensor([dt] * batch_size).to(latents.device).view([batch_size, *([1] * len(latents.shape[1:]))]) |
|
|
start_inference_step = ( |
|
|
math.ceil(num_inference_steps * (strength)) if strength is not None else num_inference_steps |
|
|
) |
|
|
for i in range(start_inference_step, 0, -1): |
|
|
t = i / num_inference_steps |
|
|
t = torch.tensor([t] * batch_size).to(latents.device) |
|
|
|
|
|
vc = self.prior(latents, t, cond_sequence) |
|
|
if do_classifier_free_guidance: |
|
|
vu = self.prior(latents, t, negative_cond_sequence) |
|
|
vc = vu + guidance_scale * (vc - vu) |
|
|
|
|
|
latents = latents - dt * vc |
|
|
|
|
|
image_embeddings = latents |
|
|
|
|
|
if output_type not in ["pt", "np"]: |
|
|
raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") |
|
|
|
|
|
if output_type == "np": |
|
|
image_embeddings = image_embeddings.cpu().numpy() |
|
|
|
|
|
if not return_dict: |
|
|
return image_embeddings |
|
|
|
|
|
return PiTPipelineOutput(image_embeds=image_embeddings) |
|
|
|