import os #os.system("pip install spaces-0.1.0-py3-none-any.whl") import torch import logging import multiprocessing import threading from itertools import chain from concurrent.futures import ThreadPoolExecutor, as_completed from datasets import load_dataset, get_dataset_config_names, IterableDataset from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback from peft import LoraConfig, get_peft_model, PeftModel from huggingface_hub import login, whoami, create_repo, upload_folder from IPython.display import clear_output import gradio as gr from dotenv import load_dotenv import spaces try: load_dotenv() except: pass @spaces.GPU class GradioProgressCallback(TrainerCallback): def __init__(self, progress_bar): self.progress_bar = progress_bar def on_step_end(self, args, state, control, **kwargs): if state.global_step > 0: self.progress_bar(state.global_step / state.max_steps, desc=f"Paso {state.global_step}/{state.max_steps}") return control @spaces.GPU() def run_training(hf_token, model_name, new_repo_name, lora_r, lora_alpha, lora_dropout, train_steps, learning_rate, batch_size, datasets_text, progress=gr.Progress()): os.environ["WANDB_DISABLED"] = "true" os.environ["HF_TOKEN"] = hf_token try: login(token=hf_token) username = whoami()["name"] except Exception as e: return f"Error de autenticación: {str(e)}" # device = "cuda" if torch.cuda.is_available() else "cpu" num_workers = multiprocessing.cpu_count() if not hasattr(torch, 'xla'): class DummyXLA: def __getattr__(self, name): return lambda *args, **kwargs: None torch.xla = DummyXLA() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) raw_items = datasets_text.replace('\n', ',').split(',') dataset_list = [item.strip() for item in raw_items if item.strip()] def get_sample_text(ds): try: sample = next(iter(ds)) if isinstance(sample, dict): return sample.get("text", str(sample)) return str(sample) except: return None def load_single(ds_name, cfg): try: ds = load_dataset(ds_name, cfg, streaming=True) if isinstance(ds, dict): ds = next(iter(ds.values())) if get_sample_text(ds): return ds return None except: return None def load_all_datasets(): streams = [] tasks = [] progress(0.1, desc="Analizando configuraciones...") for ds_name in dataset_list: try: configs = get_dataset_config_names(ds_name) except: configs = [] if not configs: tasks.append((ds_name, None)) else: for c in configs: tasks.append((ds_name, c)) progress(0.2, desc=f"Cargando {len(tasks)} fuentes...") with ThreadPoolExecutor(max_workers=num_workers) as executor: future_to_task = {executor.submit(load_single, d, c): (d, c) for d, c in tasks} for future in as_completed(future_to_task): try: ds = future.result() if ds: streams.append(ds) except: pass return streams loaded_streams = load_all_datasets() if not loaded_streams: return "Error: No se pudo cargar ningún dataset válido." def all_samples(): return chain.from_iterable(loaded_streams) progress(0.3, desc="Cargando Tokenizer...") try: tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", add_eos_token=True, add_bos_token=True) tokenizer.pad_token = tokenizer.eos_token except Exception as e: return f"Error cargando tokenizer: {str(e)}" def create_text_lines(sample): if isinstance(sample, dict): text = sample.get("text", "\n".join(str(v) for v in sample.values() if isinstance(v, str))) else: text = str(sample) return [line.strip() for line in text.splitlines() if line.strip()] def process_sample(sample): lines = create_text_lines(sample) results = [] for line in lines: tok = tokenizer(line, truncation=False) tok["labels"] = tok["input_ids"].copy() results.append(tok) return results def processed_samples_generator(): batch = [] for sample in all_samples(): batch.append(sample) if len(batch) >= 100: with ThreadPoolExecutor(max_workers=num_workers) as executor: futures = [executor.submit(process_sample, s) for s in batch] for future in as_completed(futures): try: res = future.result() for tok in res: yield tok except: pass batch.clear() if batch: with ThreadPoolExecutor(max_workers=num_workers) as executor: futures = [executor.submit(process_sample, s) for s in batch] for future in as_completed(futures): try: res = future.result() for tok in res: yield tok except: pass progress(0.4, desc="Cargando Modelo...") try: original_model = AutoModelForCausalLM.from_pretrained(model_name) except Exception as e: return f"Error cargando modelo: {str(e)}" peft_config = LoraConfig( r=int(lora_r), lora_alpha=int(lora_alpha), target_modules=["q_proj", "k_proj", "v_proj", "dense"], bias="none", lora_dropout=lora_dropout, task_type="CAUSAL_LM" ) peft_model = get_peft_model(original_model, peft_config) peft_model.config.use_cache = False output_dir = "/content/final-checkpoint" max_steps_val = int(train_steps) save_steps_val = max_steps_val // 2 if max_steps_val > 10 else 1 training_args = TrainingArguments( output_dir=output_dir, per_device_train_batch_size=int(batch_size), gradient_accumulation_steps=1, max_steps=max_steps_val, learning_rate=learning_rate, optim="adamw_torch", logging_steps=5, save_strategy="steps", save_steps=save_steps_val, report_to="none" ) processed_dataset = IterableDataset.from_generator(processed_samples_generator) trainer = Trainer( model=peft_model, train_dataset=processed_dataset, args=training_args, callbacks=[GradioProgressCallback(progress)] ) progress(0.5, desc="Entrenando...") trainer.train() progress(0.8, desc="Guardando...") trainer.save_model(output_dir) progress(0.9, desc="Fusionando...") ft = PeftModel.from_pretrained(original_model, output_dir, torch_dtype=torch.float32, is_trainable=False).merge_and_unload() final_path = "/content/merged_model" ft.save_pretrained(final_path, safe_serialization=True) tokenizer.save_pretrained(final_path) progress(0.95, desc="Subiendo...") full_repo = f"{username}/{new_repo_name}" create_repo(full_repo, token=hf_token, exist_ok=True) upload_folder(folder_path=final_path, repo_id=full_repo, token=hf_token) return f"Completado: https://huggingface.co/{full_repo}" custom_css = """ body {background-color: #0b0f19; color: #e0e6ed;} .gradio-container {max-width: 1200px !important; margin: 0 auto;} h1 {text-align: center; color: #00e5ff; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; text-transform: uppercase; letter-spacing: 2px;} .primary-btn {background: linear-gradient(135deg, #00C9FF 0%, #92FE9D 100%); border: none; color: #000; font-weight: 800; font-size: 16px; padding: 12px; transition: transform 0.2s;} .primary-btn:hover {transform: scale(1.02); filter: brightness(1.1);} .input-box textarea {font-family: 'Consolas', 'Monaco', monospace; font-size: 13px; background-color: #1a202c; color: #a0aec0; border: 1px solid #2d3748;} .gr-box {border-radius: 8px; background-color: #1a202c; border: 1px solid #2d3748;} label {color: #00e5ff !important; font-weight: bold;} """ with gr.Blocks(title="Entrenador LLM Ultimate") as demo: gr.HTML(f"") gr.HTML("""
Entrenamiento Multi-Dataset con Fusión Automática y Subida a Hub