|
|
import os |
|
|
os.system("pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt221/download.html") |
|
|
import shutil |
|
|
import math |
|
|
|
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
os.makedirs("pretrained_models", exist_ok=True) |
|
|
|
|
|
snapshot_download( |
|
|
repo_id="multimodalart/diffposetalk", |
|
|
local_dir="pretrained_models/diffposetalk" |
|
|
) |
|
|
|
|
|
base_dir = "pretrained_models" |
|
|
os.makedirs(base_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
for model in ["FLAME", "mediapipe", "smirk"]: |
|
|
|
|
|
temp_dir = f"{base_dir}/{model}_temp" |
|
|
snapshot_download( |
|
|
repo_id="Skywork/SkyReels-A1", |
|
|
local_dir=temp_dir, |
|
|
allow_patterns=f"extra_models/{model}/**" |
|
|
) |
|
|
|
|
|
|
|
|
src_dir = f"{temp_dir}/extra_models/{model}" |
|
|
dst_dir = f"{base_dir}/{model}" |
|
|
os.makedirs(dst_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
for item in os.listdir(src_dir): |
|
|
shutil.move(f"{src_dir}/{item}", f"{dst_dir}/{item}") |
|
|
|
|
|
|
|
|
shutil.rmtree(temp_dir) |
|
|
|
|
|
|
|
|
snapshot_download( |
|
|
repo_id="Skywork/SkyReels-A1", |
|
|
local_dir=f"{base_dir}/SkyReels-A1-5B", |
|
|
) |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import cv2 |
|
|
import gc |
|
|
import tempfile |
|
|
import moviepy.editor as mp |
|
|
from facexlib.utils.face_restoration_helper import FaceRestoreHelper |
|
|
from diffusers.utils import export_to_video, load_image |
|
|
|
|
|
|
|
|
from skyreels_a1.models.transformer3d import CogVideoXTransformer3DModel |
|
|
from skyreels_a1.skyreels_a1_i2v_pipeline import SkyReelsA1ImagePoseToVideoPipeline |
|
|
from skyreels_a1.pre_process_lmk3d import FaceAnimationProcessor |
|
|
from skyreels_a1.src.media_pipe.mp_utils import LMKExtractor |
|
|
from skyreels_a1.src.media_pipe.draw_util_2d import FaceMeshVisualizer2d |
|
|
|
|
|
from diffusers.models import AutoencoderKLCogVideoX |
|
|
from transformers import SiglipImageProcessor, SiglipVisionModel |
|
|
from diffposetalk.diffposetalk import DiffPoseTalk |
|
|
|
|
|
|
|
|
def parse_video(driving_frames, max_frame_num, fps=25): |
|
|
video_length = len(driving_frames) |
|
|
duration = video_length / fps |
|
|
target_times = np.arange(0, duration, 1/12) |
|
|
frame_indices = (target_times * fps).astype(np.int32) |
|
|
|
|
|
frame_indices = frame_indices[frame_indices < video_length] |
|
|
new_driving_frames = [] |
|
|
for idx in frame_indices: |
|
|
new_driving_frames.append(driving_frames[idx]) |
|
|
if len(new_driving_frames) >= max_frame_num - 1: |
|
|
break |
|
|
|
|
|
video_lenght_add = max_frame_num - len(new_driving_frames) - 1 |
|
|
new_driving_frames = [new_driving_frames[0]]*2 + new_driving_frames[1:len(new_driving_frames)-1] + [new_driving_frames[-1]] * video_lenght_add |
|
|
return new_driving_frames |
|
|
|
|
|
def write_mp4(video_path, samples, fps=12): |
|
|
clip = mp.ImageSequenceClip(samples, fps=fps) |
|
|
clip.write_videofile(video_path, audio_codec="aac", codec="libx264", |
|
|
ffmpeg_params=["-crf", "18", "-preset", "slow"]) |
|
|
|
|
|
def save_video_with_audio(video_path, audio_path, save_path): |
|
|
video_clip = mp.VideoFileClip(video_path) |
|
|
audio_clip = mp.AudioFileClip(audio_path) |
|
|
|
|
|
if audio_clip.duration > video_clip.duration: |
|
|
audio_clip = audio_clip.subclip(0, video_clip.duration) |
|
|
|
|
|
video_with_audio = video_clip.set_audio(audio_clip) |
|
|
video_with_audio.write_videofile(save_path, fps=12, codec="libx264", audio_codec="aac") |
|
|
|
|
|
|
|
|
video_clip.close() |
|
|
audio_clip.close() |
|
|
return save_path |
|
|
|
|
|
def pad_video(driving_frames, fps=25): |
|
|
video_length = len(driving_frames) |
|
|
|
|
|
duration = video_length / fps |
|
|
target_times = np.arange(0, duration, 1/12) |
|
|
frame_indices = (target_times * fps).astype(np.int32) |
|
|
|
|
|
frame_indices = frame_indices[frame_indices < video_length] |
|
|
new_driving_frames = [] |
|
|
for idx in frame_indices: |
|
|
new_driving_frames.append(driving_frames[idx]) |
|
|
|
|
|
pad_length = math.ceil(len(new_driving_frames) / 48) * 48 - len(new_driving_frames) |
|
|
new_driving_frames.extend([new_driving_frames[-1]]*pad_length) |
|
|
return new_driving_frames, pad_length |
|
|
|
|
|
|
|
|
model_name = "pretrained_models/SkyReels-A1-5B/" |
|
|
siglip_name = "pretrained_models/SkyReels-A1-5B/siglip-so400m-patch14-384" |
|
|
weight_dtype = torch.bfloat16 |
|
|
max_frame_num = 49 |
|
|
sample_size = [480, 720] |
|
|
|
|
|
|
|
|
print("Loading models...") |
|
|
|
|
|
|
|
|
lmk_extractor = LMKExtractor() |
|
|
processor = FaceAnimationProcessor(checkpoint='pretrained_models/smirk/SMIRK_em1.pt') |
|
|
vis = FaceMeshVisualizer2d(forehead_edge=False, draw_head=False, draw_iris=False) |
|
|
face_helper = FaceRestoreHelper(upscale_factor=1, face_size=512, crop_ratio=(1, 1), |
|
|
det_model='retinaface_resnet50', save_ext='png', device="cuda") |
|
|
|
|
|
|
|
|
siglip = SiglipVisionModel.from_pretrained(siglip_name) |
|
|
siglip_normalize = SiglipImageProcessor.from_pretrained(siglip_name) |
|
|
|
|
|
|
|
|
diffposetalk = DiffPoseTalk() |
|
|
|
|
|
|
|
|
transformer = CogVideoXTransformer3DModel.from_pretrained( |
|
|
model_name, |
|
|
subfolder="transformer" |
|
|
).to(weight_dtype) |
|
|
|
|
|
vae = AutoencoderKLCogVideoX.from_pretrained( |
|
|
model_name, |
|
|
subfolder="vae" |
|
|
).to(weight_dtype) |
|
|
|
|
|
lmk_encoder = AutoencoderKLCogVideoX.from_pretrained( |
|
|
model_name, |
|
|
subfolder="pose_guider", |
|
|
).to(weight_dtype) |
|
|
|
|
|
|
|
|
pipe = SkyReelsA1ImagePoseToVideoPipeline.from_pretrained( |
|
|
model_name, |
|
|
transformer=transformer, |
|
|
vae=vae, |
|
|
lmk_encoder=lmk_encoder, |
|
|
image_encoder=siglip, |
|
|
feature_extractor=siglip_normalize, |
|
|
torch_dtype=torch.bfloat16 |
|
|
) |
|
|
pipe.to("cuda") |
|
|
pipe.transformer = torch.compile(pipe.transformer) |
|
|
pipe.vae.enable_tiling() |
|
|
|
|
|
pipe.vae = torch.compile(pipe.vae) |
|
|
|
|
|
|
|
|
print("Models loaded successfully!") |
|
|
|
|
|
def process_image_audio(image_path, audio_path, guidance_scale=3.0, steps=10, progress=gr.Progress()): |
|
|
progress(0.1, desc="Processing inputs...") |
|
|
|
|
|
output_dir = "gradio_outputs" |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video_file, \ |
|
|
tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_output_file: |
|
|
temp_video_path = temp_video_file.name |
|
|
final_output_path = temp_output_file.name |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
progress(0.2, desc="Processing image...") |
|
|
|
|
|
image = load_image(image=image_path) |
|
|
image = processor.crop_and_resize(image, sample_size[0], sample_size[1]) |
|
|
|
|
|
|
|
|
ref_image, x1, y1 = processor.face_crop(np.array(image)) |
|
|
face_h, face_w, _ = ref_image.shape |
|
|
source_image = ref_image |
|
|
|
|
|
progress(0.3, desc="Processing facial landmarks...") |
|
|
|
|
|
source_outputs, source_tform, image_original = processor.process_source_image(source_image) |
|
|
|
|
|
progress(0.4, desc="Processing audio...") |
|
|
|
|
|
driving_outputs = diffposetalk.infer_from_file( |
|
|
audio_path, |
|
|
source_outputs["shape_params"].view(-1)[:100].detach().cpu().numpy() |
|
|
) |
|
|
|
|
|
progress(0.5, desc="Processing landmarks from coefficients...") |
|
|
|
|
|
out_frames = processor.preprocess_lmk3d_from_coef( |
|
|
source_outputs, source_tform, image_original.shape, driving_outputs |
|
|
) |
|
|
out_frames, pad_length = pad_video(out_frames) |
|
|
print(len(out_frames), pad_length) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rescale_motions = np.zeros_like(image)[np.newaxis, :].repeat(len(out_frames), axis=0) |
|
|
for ii in range(rescale_motions.shape[0]): |
|
|
rescale_motions[ii][y1:y1+face_h, x1:x1+face_w] = out_frames[ii] |
|
|
|
|
|
ref_image_resized = cv2.resize(ref_image, (512, 512)) |
|
|
ref_lmk = lmk_extractor(ref_image_resized[:, :, ::-1]) |
|
|
|
|
|
ref_img = vis.draw_landmarks_v3( |
|
|
(512, 512), (face_w, face_h), |
|
|
ref_lmk['lmks'].astype(np.float32), normed=True |
|
|
) |
|
|
|
|
|
first_motion = np.zeros_like(np.array(image)) |
|
|
first_motion[y1:y1+face_h, x1:x1+face_w] = ref_img |
|
|
first_motion = first_motion[np.newaxis, :] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
face_helper.clean_all() |
|
|
face_helper.read_image(np.array(image)[:, :, ::-1]) |
|
|
face_helper.get_face_landmarks_5(only_center_face=True) |
|
|
face_helper.align_warp_face() |
|
|
align_face = face_helper.cropped_faces[0] |
|
|
image_face = align_face[:, :, ::-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
progress(0.6, desc="Generating animation (this may take a while)...") |
|
|
|
|
|
out_samples = [] |
|
|
for i in range(0, len(rescale_motions), 48): |
|
|
motions = np.concatenate([first_motion, rescale_motions[i:i+48]]) |
|
|
input_video = motions |
|
|
input_video = torch.from_numpy(np.array(input_video)).permute([3, 0, 1, 2]).unsqueeze(0) |
|
|
input_video = input_video / 255 |
|
|
|
|
|
with torch.no_grad(): |
|
|
sample = pipe( |
|
|
image=image, |
|
|
image_face=image_face, |
|
|
control_video=input_video, |
|
|
prompt="", |
|
|
negative_prompt="", |
|
|
height=480, |
|
|
width=720, |
|
|
num_frames=49, |
|
|
|
|
|
guidance_scale=guidance_scale, |
|
|
num_inference_steps=steps, |
|
|
) |
|
|
if i == 0: |
|
|
out_samples.extend(sample.frames[0]) |
|
|
else: |
|
|
out_samples.extend(sample.frames[0][1:]) |
|
|
|
|
|
|
|
|
|
|
|
if pad_length == 0: |
|
|
out_samples = out_samples[1:] |
|
|
else: |
|
|
out_samples = out_samples[1:-pad_length] |
|
|
|
|
|
progress(0.8, desc="Creating output video...") |
|
|
|
|
|
export_to_video(out_samples, temp_video_path, fps=12) |
|
|
|
|
|
progress(0.9, desc="Adding audio to video...") |
|
|
|
|
|
result_path = save_video_with_audio(temp_video_path, audio_path, final_output_path) |
|
|
|
|
|
|
|
|
target_h, target_w = sample_size[0], sample_size[1] |
|
|
final_images = [] |
|
|
for i in range(len(out_samples)): |
|
|
frame1 = image |
|
|
frame2 = Image.fromarray(np.array(out_samples[i])).convert("RGB") |
|
|
|
|
|
result = Image.new('RGB', (target_w * 2, target_h)) |
|
|
result.paste(frame1, (0, 0)) |
|
|
result.paste(frame2, (target_w, 0)) |
|
|
final_images.append(np.array(result)) |
|
|
|
|
|
comparison_path = os.path.join(output_dir, "comparison.mp4") |
|
|
write_mp4(comparison_path, final_images, fps=12) |
|
|
|
|
|
|
|
|
comparison_with_audio = os.path.join(output_dir, "comparison_with_audio.mp4") |
|
|
comparison_with_audio = save_video_with_audio(comparison_path, audio_path, comparison_with_audio) |
|
|
|
|
|
progress(1.0, desc="Done!") |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
return result_path, comparison_with_audio |
|
|
|
|
|
|
|
|
with gr.Blocks(title="SkyReels A1 Talking Head") as app: |
|
|
gr.Markdown("# SkyReels A1 Talking Head") |
|
|
gr.Markdown('''Upload a portrait image and an audio file to animate the face. 💡 Enjoying this demo? Share your feedback or review, and you might earn exclusive rewards! 🚀✨ |
|
|
📩 [Contact us on Discord](https://discord.com/invite/PwM6NYtccQ) for details. 🔥 [Code](https://github.com/SkyworkAI/SkyReels-A1) [Huggingface](https://huggingface.co/Skywork/SkyReels-A1)''') |
|
|
gr.Markdown('''✨ Try our **AI Office** for more productivity tools! [Visit Skywork AI Agent](https://skywork.ai/?utm_source=skyworkspace) ✨''') |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
with gr.Row(): |
|
|
image_input = gr.Image(type="filepath", label="Portrait Image") |
|
|
audio_input = gr.Audio(type="filepath", label="Driving Audio") |
|
|
|
|
|
with gr.Row(): |
|
|
guidance_scale = gr.Slider(minimum=1.0, maximum=7.0, value=3.0, step=0.1, label="Guidance Scale") |
|
|
inference_steps = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Inference Steps") |
|
|
|
|
|
generate_button = gr.Button("Generate Animation", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_video = gr.Video(label="Animation Result") |
|
|
comparison_video = gr.Video(label="Side-by-Side Comparison") |
|
|
|
|
|
generate_button.click( |
|
|
fn=process_image_audio, |
|
|
inputs=[image_input, audio_input, guidance_scale, inference_steps], |
|
|
outputs=[output_video, comparison_video] |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
## Instructions |
|
|
1. Upload a portrait image (frontal face works best) |
|
|
2. Upload an audio file (wav format recommended) |
|
|
3. Adjust parameters if needed |
|
|
4. Click "Generate Animation" to create the video |
|
|
|
|
|
Note: Processing may take several minutes depending on your hardware. |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.launch(share=True) |