import os import json import logging import threading import uuid import time import sys import gc import multiprocessing import shutil import math from datetime import datetime from concurrent.futures import ThreadPoolExecutor, as_completed from itertools import chain import torch import torch.nn as nn import torch.nn.functional as F import gradio as gr import transformers import datasets from dotenv import load_dotenv from datasets import load_dataset, get_dataset_config_names, IterableDataset from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback, AutoConfig, DataCollatorForLanguageModeling from huggingface_hub import login, whoami, create_repo, upload_folder import spaces try: load_dotenv() except: pass transformers.logging.set_verbosity_error() datasets.logging.set_verbosity_error() logging.getLogger("transformers").setLevel(logging.CRITICAL) logging.getLogger("datasets").setLevel(logging.CRITICAL) logging.getLogger("torch").setLevel(logging.CRITICAL) logging.basicConfig(level=logging.CRITICAL, stream=sys.stderr) if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True JOBS = {} def activation_quant(x): scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) y = (x * scale).round().clamp_(-128, 127) / scale return y + x - x.detach() def weight_quant(w): scale = 1.0 / w.abs().mean().clamp_(min=1e-5) u = (w * scale).round().clamp_(-1, 1) / scale return u + w - w.detach() class BitLinear(nn.Linear): def forward(self, x): target_dtype = x.dtype w = self.weight.to(target_dtype) w_quant = weight_quant(w).to(target_dtype) x_quant = activation_quant(x).to(target_dtype) if self.bias is not None: b = self.bias.to(target_dtype) else: b = None return F.linear(x_quant, w_quant, b) def convert_to_bitnet(model, copy_weights=False): for name, module in model.named_children(): if isinstance(module, nn.Linear): bit_linear = BitLinear(module.in_features, module.out_features, module.bias is not None) if copy_weights: bit_linear.weight.data = module.weight.data.clone() if module.bias is not None: bit_linear.bias.data = module.bias.data.clone() setattr(model, name, bit_linear) else: convert_to_bitnet(module, copy_weights=copy_weights) class JobStatus: def __init__(self): self.id = str(uuid.uuid4()) self.status = "INITIALIZING" self.progress = 0.0 self.logs = [] self.result = None self.error = None self.created_at = datetime.now().strftime("%H:%M:%S") self.repo_url = None def add_log(self, message): timestamp = datetime.now().strftime("%H:%M:%S") self.logs.append(f"[{timestamp}] {message}") def set_progress(self, val, msg=None): self.progress = val if msg: self.add_log(msg) class CustomTrainerCallback(TrainerCallback): def __init__(self, job_id, hf_token, repo_id): self.job_id = job_id self.hf_token = hf_token self.repo_id = repo_id def on_step_end(self, args, state, control, **kwargs): if self.job_id in JOBS: job = JOBS[self.job_id] if state.max_steps > 0: prog = state.global_step / state.max_steps job.progress = 0.1 + (prog * 0.8) if state.global_step % 1 == 0: loss = state.log_history[-1].get('loss', 'N/A') if state.log_history else '...' job.add_log(f"Training Step {state.global_step}/{state.max_steps} | Loss: {loss}") return control def on_save(self, args, state, control, **kwargs): if self.job_id in JOBS: job = JOBS[self.job_id] step = state.global_step ckpt_name = f"checkpoint-{step}" ckpt_path = os.path.join(args.output_dir, ckpt_name) job.add_log(f"System: 100-Step Snapshot saved ({ckpt_name})") def _upload_bg(): try: upload_folder( folder_path=ckpt_path, path_in_repo=".", repo_id=self.repo_id, token=self.hf_token, commit_message=f"Live Checkpoint Step {step}" ) job.add_log(f"Cloud: Synced Checkpoint {step} to Root") except: pass threading.Thread(target=_upload_bg, daemon=True).start() return control @spaces.GPU(duration=300) def background_train_task(job_id, hf_token, model_name, new_repo_name, train_steps, learning_rate, batch_size, datasets_text, reasoning_mode, c_conf, c_tok, c_gen): job = JOBS[job_id] job.status = "RUNNING" job.add_log("System: initializing BitNet Scratch Protocol...") try: if not hf_token.startswith("hf_"): raise ValueError("Invalid Token") os.environ["WANDB_DISABLED"] = "true" os.environ["HF_TOKEN"] = hf_token os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" os.environ["TOKENIZERS_PARALLELISM"] = "true" login(token=hf_token) try: username = whoami()["name"] full_repo_id = f"{username}/{new_repo_name}" create_repo(full_repo_id, token=hf_token, exist_ok=True) job.add_log(f"Auth: Verified {username} -> {full_repo_id}") except: raise Exception("Auth Failed") if not hasattr(torch, 'xla'): class DummyXLA: def __getattr__(self, name): return lambda *args, **kwargs: None torch.xla = DummyXLA() raw_items = datasets_text.replace('\n', ',').split(',') dataset_list = [item.strip() for item in raw_items if item.strip()] if reasoning_mode: job.add_log("Config: Reasoning Injection Active") dataset_list.extend(["gsm8k", "openai/gsm8k"]) def load_single(ds_name, cfg): try: ds = load_dataset(ds_name, cfg if cfg else "main", split="train", streaming=True, trust_remote_code=False) try: next(iter(ds)) return ds except: return None except: return None streams = [] job.set_progress(0.05, "Data: Parallel Stream Connect...") cpu_count = multiprocessing.cpu_count() with ThreadPoolExecutor(max_workers=cpu_count * 2) as executor: futures = [] for ds_name in dataset_list: futures.append(executor.submit(load_single, ds_name, None)) for future in as_completed(futures): res = future.result() if res: streams.append(res) if not streams: raise Exception("No Data Sources") job.set_progress(0.1, f"Data: {len(streams)} Streams Linked") tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="left", add_eos_token=True, add_bos_token=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token def process_stream_generator(): iterator = chain.from_iterable(streams) batch_buffer = [] for item in iterator: try: text = str(item.get("text", item.get("content", str(item)))) if len(text) < 5: continue batch_buffer.append(text) if len(batch_buffer) >= 100: encoded_batch = tokenizer(batch_buffer, truncation=True, max_length=2048, padding=False) for input_ids in encoded_batch["input_ids"]: yield {"input_ids": input_ids} batch_buffer = [] except: continue job.set_progress(0.15, "Model: Initializing Architecture & Converting to BitNet...") torch.cuda.empty_cache() gc.collect() config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) original_model = AutoModelForCausalLM.from_config( config, trust_remote_code=True, ) convert_to_bitnet(original_model, copy_weights=False) model_size = sum(t.numel() for t in original_model.parameters()) job.add_log(f"Model Size: {model_size/1000**2:.1f}M Parameters (1.58-bit)") output_dir = f"checkpoints/{job_id}" data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) training_args = TrainingArguments( output_dir=output_dir, per_device_train_batch_size=int(batch_size), gradient_accumulation_steps=4, max_steps=int(train_steps), learning_rate=learning_rate, optim="adamw_torch_fused" if torch.cuda.is_available() else "adamw_torch", logging_steps=1, save_strategy="steps", save_steps=100, save_total_limit=1, report_to="none", fp16=True if torch.cuda.is_available() else False, disable_tqdm=True, dataloader_num_workers=4, dataloader_pin_memory=True, gradient_checkpointing=True, torch_compile=False, lr_scheduler_type="cosine", warmup_ratio=0.1 ) dataset_iterable = IterableDataset.from_generator(process_stream_generator) trainer = Trainer( model=original_model, tokenizer=tokenizer, train_dataset=dataset_iterable, args=training_args, data_collator=data_collator, callbacks=[CustomTrainerCallback(job_id, hf_token, full_repo_id)] ) job.set_progress(0.2, "Training: BitNet Gradient Descent Initiated...") trainer.train() trainer.save_model(output_dir) tokenizer.save_pretrained(output_dir) job.set_progress(0.9, "Processing: Finalizing Artifacts...") del original_model torch.cuda.empty_cache() gc.collect() def inject_json(content, fname): if content and content.strip(): try: data = json.loads(content) file_path = os.path.join(output_dir, fname) if os.path.exists(file_path): with open(file_path, 'r', encoding='utf-8') as f: try: existing_data = json.load(f) existing_data.update(data) data = existing_data except: pass with open(file_path, 'w', encoding='utf-8') as f: json.dump(data, f, indent=2) job.add_log(f"Config: Overwritten {fname} with user settings") except: pass inject_json(c_conf, "config.json") inject_json(c_tok, "tokenizer_config.json") inject_json(c_gen, "generation_config.json") job.set_progress(0.95, "Network: Uploading Final BitNet Model...") upload_folder( folder_path=output_dir, path_in_repo=".", repo_id=full_repo_id, token=hf_token, commit_message="BitNet Scratch Trained Model" ) job.repo_url = f"https://huggingface.co/{full_repo_id}" job.status = "COMPLETED" job.set_progress(1.0, "System: Operation Finalized") except Exception as e: job.status = "FAILED" job.error = str(e) job.add_log(f"FATAL ERROR: {str(e)}") torch.cuda.empty_cache() def start_training_wrapper(hf_token, model_name, new_repo_name, train_steps, learning_rate, batch_size, datasets_text, reasoning_mode, c_conf, c_tok, c_gen): if not hf_token or not model_name: return None, gr.update(selected="launch_tab") new_job = JobStatus() JOBS[new_job.id] = new_job thread = threading.Thread( target=background_train_task, args=(new_job.id, hf_token, model_name, new_repo_name, train_steps, learning_rate, batch_size, datasets_text, reasoning_mode, c_conf, c_tok, c_gen) ) thread.daemon = True thread.start() return new_job.id, gr.update(selected="monitor_tab") def get_job_update(job_id): if not job_id: return "Waiting...", "", 0, "", gr.update(visible=False) if job_id not in JOBS: return "Not Found", "", 0, "", gr.update(visible=False) job = JOBS[job_id] log_text = "\n".join(job.logs) result_comp = gr.update(visible=False) if job.status == "COMPLETED" and job.repo_url: result_comp = gr.update(visible=True, value=f"✅ Full Model Published: {job.repo_url}") return job.status, job.created_at, job.progress, log_text, result_comp def load_from_url(request: gr.Request): try: params = request.query_params job_id = params.get("job_id") if job_id: return gr.update(selected="monitor_tab"), job_id except: pass return gr.update(selected="launch_tab"), "" with gr.Blocks(title="Nucleus Enterprise") as demo: with gr.Column(): gr.Markdown("# ⚛️ NUCLEUS ENTERPRISE") gr.Markdown("Autonomous LLM Foundry | V10.0 BitNet Edition") with gr.Tabs() as main_tabs: with gr.TabItem("🚀 LAUNCHPAD", id="launch_tab"): with gr.Row(): with gr.Column(scale=2): with gr.Row(): hf_token = gr.Textbox(label="HuggingFace Token", type="password", value=os.getenv("HF_TOKEN", "")) model_name = gr.Textbox(label="Architecture Config Source", value="Qwen/Qwen2.5-0.5B") repo_name = gr.Textbox(label="Output Repository", value="nucleus-bitnet-v1") datasets = gr.Textbox(label="Datasets (CSV)", value="Salesforce/fineweb_deduplicated", lines=3) reasoning = gr.Checkbox(label="Inject Reasoning (CoT/Math)", value=False) with gr.Column(scale=1): steps = gr.Number(label="Steps", value=100) lr = gr.Number(label="Learning Rate", value=1e-4) batch = gr.Number(label="Batch Size", value=1) with gr.Accordion("Advanced Config", open=False): c_conf = gr.Code(label="config.json", language="json") c_tok = gr.Code(label="tokenizer_config.json", language="json") c_gen = gr.Code(label="generation_config.json", language="json") btn_launch = gr.Button("INITIALIZE BITNET TRAINING", variant="primary", size="lg") with gr.TabItem("📡 TELEMETRY", id="monitor_tab"): with gr.Row(): job_id_input = gr.Textbox(label="Active Job ID", interactive=True) btn_refresh = gr.Button("Refresh Stream") with gr.Row(): status_out = gr.Textbox(label="Status", interactive=False) time_out = gr.Textbox(label="Start Time", interactive=False) progress_out = gr.Slider(label="Progress", minimum=0, maximum=1) final_link = gr.Markdown(visible=False) logs_out = gr.Code(label="Real-time Kernel Logs", language="shell", interactive=False, lines=15) timer = gr.Timer(2000, active=False) demo.load(load_from_url, None, [main_tabs, job_id_input]).then(lambda: gr.Timer(active=True), None, timer) btn_launch.click( start_training_wrapper, inputs=[hf_token, model_name, repo_name, steps, lr, batch, datasets, reasoning, c_conf, c_tok, c_gen], outputs=[job_id_input, main_tabs] ).then( None, [job_id_input], None, js="(id) => { if (id) { const url = new URL(window.location); url.searchParams.set('job_id', id); window.history.pushState({}, '', url); } return id; }" ).then( lambda: gr.Timer(active=True), None, timer ) btn_refresh.click(get_job_update, job_id_input, [status_out, time_out, progress_out, logs_out, final_link]) timer.tick(get_job_update, job_id_input, [status_out, time_out, progress_out, logs_out, final_link]) if __name__ == "__main__": demo.launch(ssr_mode=False)