Spaces:
Sleeping
Sleeping
| import os | |
| import av | |
| import torch | |
| from pathlib import Path | |
| from typing import Dict | |
| from tqdm import tqdm | |
| from downloader import download_youtube_video | |
| from model_api import initialize_model, initialize_processor, prepare_model_inputs, generate_description, process_model_output, clear_gpu_cache, create_prompt_message, get_device_and_dtype | |
| def extract_frames_with_timestamps( | |
| video_path: str, | |
| output_dir: str, | |
| time_step: float = 1.0, | |
| quality: int = 95, | |
| frame_prefix: str = "frame", | |
| use_hw_accel: bool = True, | |
| hw_device: str = "cuda" | |
| ) -> Dict[str, str]: | |
| """ | |
| Extracts frames from video with NVIDIA hardware acceleration (NVDEC/CUDA). | |
| Args: | |
| video_path: Path to the video file | |
| output_dir: Directory to save frames | |
| time_step: Interval between frames (seconds) | |
| quality: JPEG quality (1-100) | |
| frame_prefix: Prefix for saved frames | |
| use_hw_accel: Enable NVIDIA hardware decoding | |
| hw_device: GPU device (e.g., 'cuda:0') | |
| Returns: | |
| Dict of {timestamp: frame_path} | |
| """ | |
| result = {} | |
| try: | |
| video_path = Path(video_path).absolute() | |
| output_dir = Path(output_dir).absolute() | |
| if not video_path.exists(): | |
| raise ValueError(f"Video file not found: {video_path}") | |
| frames_dir = output_dir / "frames" | |
| frames_dir.mkdir(parents=True, exist_ok=True) | |
| # Configure hardware acceleration | |
| options = {} | |
| if use_hw_accel: | |
| options.update({ | |
| 'hwaccel': 'cuda', | |
| 'hwaccel_device': hw_device, | |
| 'hwaccel_output_format': 'cuda' # Keep frames in GPU memory | |
| }) | |
| # Open video with hardware acceleration | |
| container = av.open(str(video_path), options=options) | |
| video_stream = next(s for s in container.streams if s.type == 'video') | |
| fps = float(video_stream.average_rate) | |
| if fps <= 0: | |
| raise RuntimeError("Invalid frame rate") | |
| frame_interval = max(1, int(round(fps * time_step))) | |
| frame_count = 0 | |
| for frame in container.decode(video_stream): | |
| if frame_count % frame_interval == 0: | |
| current_time = float(frame.pts * video_stream.time_base) | |
| hh = int(current_time // 3600) | |
| mm = int((current_time % 3600) // 60) | |
| ss = current_time % 60 | |
| timestamp = f"{hh:02d}:{mm:02d}:{ss:06.3f}" | |
| safe_timestamp = timestamp.replace(':', '_').replace('.', '_') | |
| frame_path = frames_dir / f"{frame_prefix}_{safe_timestamp}.jpg" | |
| # Convert GPU frame to CPU if needed | |
| if hasattr(frame, 'to_ndarray'): # CUDA frame | |
| img = frame.to_ndarray(format='rgb24') | |
| img = av.VideoFrame.from_ndarray(img, format='rgb24') | |
| else: | |
| img = frame | |
| img.to_image().save(str(frame_path), quality=quality) | |
| result[timestamp] = str(frame_path) | |
| frame_count += 1 | |
| return result | |
| except Exception as e: | |
| for path in result.values(): | |
| try: os.remove(path) | |
| except: pass | |
| raise RuntimeError(f"Frame extraction failed: {str(e)}") | |
| def generate_frame_descriptions(frames_dict: Dict, custom_prompt: str = None, device: str = "cuda", torch_dtype: torch.dtype = torch.float16): | |
| """ | |
| Generate descriptions for video frames with progress tracking | |
| Args: | |
| frames_dict (dict): Dictionary of {timestamp: image_path} pairs | |
| custom_prompt (str, optional): Custom prompt to use for all frames. | |
| Can include {timestamp} placeholder. | |
| Returns: | |
| dict: Dictionary of {timestamp: description} pairs | |
| """ | |
| # Instantiating model components | |
| model = initialize_model(device, torch_dtype) | |
| processor = initialize_processor() | |
| descriptions = {} | |
| with tqdm( | |
| frames_dict.items(), | |
| total=len(frames_dict), | |
| desc="Processing frames", | |
| unit="frame" | |
| ) as progress_bar: | |
| for timestamp, image_path in progress_bar: | |
| try: | |
| progress_bar.set_postfix({"current": timestamp}) | |
| # Prepare model input with custom prompt | |
| messages = create_prompt_message(image_path, timestamp, custom_prompt) | |
| inputs = prepare_model_inputs(processor, messages, device) | |
| # Generate and process output | |
| generated_ids = generate_description(model, inputs) | |
| output_text = process_model_output(processor, inputs, generated_ids) | |
| descriptions[timestamp] = output_text | |
| # Memory cleanup | |
| del inputs, generated_ids | |
| clear_gpu_cache() | |
| except Exception as e: | |
| print(f"\nError processing frame {timestamp}: {str(e)}") | |
| descriptions[timestamp] = f"Description generation error: {str(e)}" | |
| clear_gpu_cache() | |
| # Final cleanup | |
| del model, processor | |
| clear_gpu_cache() | |
| return descriptions | |
| if __name__ == "__main__": | |
| video_url = "https://www.youtube.com/watch?v=L1vXCYZAYYM" | |
| video_data = download_youtube_video(video_url) | |
| frames = extract_frames_with_timestamps(video_path=video_data['video_path'], output_dir=video_data['data_path'], time_step=10) | |
| print(frames) |