Train_v1 / app.py
Ksjsjjdj's picture
Update app.py
b9816b5 verified
import os
import json
import logging
import threading
import uuid
import time
import sys
import gc
import multiprocessing
import shutil
import math
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
from itertools import chain
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
import transformers
import datasets
from dotenv import load_dotenv
from datasets import load_dataset, get_dataset_config_names, IterableDataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback, AutoConfig, DataCollatorForLanguageModeling
from huggingface_hub import login, whoami, create_repo, upload_folder
import spaces
try:
load_dotenv()
except:
pass
transformers.logging.set_verbosity_error()
datasets.logging.set_verbosity_error()
logging.getLogger("transformers").setLevel(logging.CRITICAL)
logging.getLogger("datasets").setLevel(logging.CRITICAL)
logging.getLogger("torch").setLevel(logging.CRITICAL)
logging.basicConfig(level=logging.CRITICAL, stream=sys.stderr)
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
JOBS = {}
def activation_quant(x):
scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
y = (x * scale).round().clamp_(-128, 127) / scale
return y + x - x.detach()
def weight_quant(w):
scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
u = (w * scale).round().clamp_(-1, 1) / scale
return u + w - w.detach()
class BitLinear(nn.Linear):
def forward(self, x):
target_dtype = x.dtype
w = self.weight.to(target_dtype)
w_quant = weight_quant(w).to(target_dtype)
x_quant = activation_quant(x).to(target_dtype)
if self.bias is not None:
b = self.bias.to(target_dtype)
else:
b = None
return F.linear(x_quant, w_quant, b)
def convert_to_bitnet(model, copy_weights=False):
for name, module in model.named_children():
if isinstance(module, nn.Linear):
bit_linear = BitLinear(module.in_features, module.out_features, module.bias is not None)
if copy_weights:
bit_linear.weight.data = module.weight.data.clone()
if module.bias is not None:
bit_linear.bias.data = module.bias.data.clone()
setattr(model, name, bit_linear)
else:
convert_to_bitnet(module, copy_weights=copy_weights)
class JobStatus:
def __init__(self):
self.id = str(uuid.uuid4())
self.status = "INITIALIZING"
self.progress = 0.0
self.logs = []
self.result = None
self.error = None
self.created_at = datetime.now().strftime("%H:%M:%S")
self.repo_url = None
def add_log(self, message):
timestamp = datetime.now().strftime("%H:%M:%S")
self.logs.append(f"[{timestamp}] {message}")
def set_progress(self, val, msg=None):
self.progress = val
if msg:
self.add_log(msg)
class CustomTrainerCallback(TrainerCallback):
def __init__(self, job_id, hf_token, repo_id):
self.job_id = job_id
self.hf_token = hf_token
self.repo_id = repo_id
def on_step_end(self, args, state, control, **kwargs):
if self.job_id in JOBS:
job = JOBS[self.job_id]
if state.max_steps > 0:
prog = state.global_step / state.max_steps
job.progress = 0.1 + (prog * 0.8)
if state.global_step % 1 == 0:
loss = state.log_history[-1].get('loss', 'N/A') if state.log_history else '...'
job.add_log(f"Training Step {state.global_step}/{state.max_steps} | Loss: {loss}")
return control
def on_save(self, args, state, control, **kwargs):
if self.job_id in JOBS:
job = JOBS[self.job_id]
step = state.global_step
ckpt_name = f"checkpoint-{step}"
ckpt_path = os.path.join(args.output_dir, ckpt_name)
job.add_log(f"System: 100-Step Snapshot saved ({ckpt_name})")
def _upload_bg():
try:
upload_folder(
folder_path=ckpt_path,
path_in_repo=".",
repo_id=self.repo_id,
token=self.hf_token,
commit_message=f"Live Checkpoint Step {step}"
)
job.add_log(f"Cloud: Synced Checkpoint {step} to Root")
except:
pass
threading.Thread(target=_upload_bg, daemon=True).start()
return control
@spaces.GPU(duration=300)
def background_train_task(job_id, hf_token, model_name, new_repo_name,
train_steps, learning_rate, batch_size, datasets_text,
reasoning_mode, c_conf, c_tok, c_gen):
job = JOBS[job_id]
job.status = "RUNNING"
job.add_log("System: initializing BitNet Scratch Protocol...")
try:
if not hf_token.startswith("hf_"):
raise ValueError("Invalid Token")
os.environ["WANDB_DISABLED"] = "true"
os.environ["HF_TOKEN"] = hf_token
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
os.environ["TOKENIZERS_PARALLELISM"] = "true"
login(token=hf_token)
try:
username = whoami()["name"]
full_repo_id = f"{username}/{new_repo_name}"
create_repo(full_repo_id, token=hf_token, exist_ok=True)
job.add_log(f"Auth: Verified {username} -> {full_repo_id}")
except:
raise Exception("Auth Failed")
if not hasattr(torch, 'xla'):
class DummyXLA:
def __getattr__(self, name):
return lambda *args, **kwargs: None
torch.xla = DummyXLA()
raw_items = datasets_text.replace('\n', ',').split(',')
dataset_list = [item.strip() for item in raw_items if item.strip()]
if reasoning_mode:
job.add_log("Config: Reasoning Injection Active")
dataset_list.extend(["gsm8k", "openai/gsm8k"])
def load_single(ds_name, cfg):
try:
ds = load_dataset(ds_name, cfg if cfg else "main", split="train", streaming=True, trust_remote_code=False)
try:
next(iter(ds))
return ds
except:
return None
except:
return None
streams = []
job.set_progress(0.05, "Data: Parallel Stream Connect...")
cpu_count = multiprocessing.cpu_count()
with ThreadPoolExecutor(max_workers=cpu_count * 2) as executor:
futures = []
for ds_name in dataset_list:
futures.append(executor.submit(load_single, ds_name, None))
for future in as_completed(futures):
res = future.result()
if res:
streams.append(res)
if not streams:
raise Exception("No Data Sources")
job.set_progress(0.1, f"Data: {len(streams)} Streams Linked")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="left", add_eos_token=True, add_bos_token=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def process_stream_generator():
iterator = chain.from_iterable(streams)
batch_buffer = []
for item in iterator:
try:
text = str(item.get("text", item.get("content", str(item))))
if len(text) < 5: continue
batch_buffer.append(text)
if len(batch_buffer) >= 100:
encoded_batch = tokenizer(batch_buffer, truncation=True, max_length=2048, padding=False)
for input_ids in encoded_batch["input_ids"]:
yield {"input_ids": input_ids}
batch_buffer = []
except:
continue
job.set_progress(0.15, "Model: Initializing Architecture & Converting to BitNet...")
torch.cuda.empty_cache()
gc.collect()
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
original_model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=True,
)
convert_to_bitnet(original_model, copy_weights=False)
model_size = sum(t.numel() for t in original_model.parameters())
job.add_log(f"Model Size: {model_size/1000**2:.1f}M Parameters (1.58-bit)")
output_dir = f"checkpoints/{job_id}"
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=int(batch_size),
gradient_accumulation_steps=4,
max_steps=int(train_steps),
learning_rate=learning_rate,
optim="adamw_torch_fused" if torch.cuda.is_available() else "adamw_torch",
logging_steps=1,
save_strategy="steps",
save_steps=100,
save_total_limit=1,
report_to="none",
fp16=True if torch.cuda.is_available() else False,
disable_tqdm=True,
dataloader_num_workers=4,
dataloader_pin_memory=True,
gradient_checkpointing=True,
torch_compile=False,
lr_scheduler_type="cosine",
warmup_ratio=0.1
)
dataset_iterable = IterableDataset.from_generator(process_stream_generator)
trainer = Trainer(
model=original_model,
tokenizer=tokenizer,
train_dataset=dataset_iterable,
args=training_args,
data_collator=data_collator,
callbacks=[CustomTrainerCallback(job_id, hf_token, full_repo_id)]
)
job.set_progress(0.2, "Training: BitNet Gradient Descent Initiated...")
trainer.train()
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
job.set_progress(0.9, "Processing: Finalizing Artifacts...")
del original_model
torch.cuda.empty_cache()
gc.collect()
def inject_json(content, fname):
if content and content.strip():
try:
data = json.loads(content)
file_path = os.path.join(output_dir, fname)
if os.path.exists(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
try:
existing_data = json.load(f)
existing_data.update(data)
data = existing_data
except:
pass
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2)
job.add_log(f"Config: Overwritten {fname} with user settings")
except:
pass
inject_json(c_conf, "config.json")
inject_json(c_tok, "tokenizer_config.json")
inject_json(c_gen, "generation_config.json")
job.set_progress(0.95, "Network: Uploading Final BitNet Model...")
upload_folder(
folder_path=output_dir,
path_in_repo=".",
repo_id=full_repo_id,
token=hf_token,
commit_message="BitNet Scratch Trained Model"
)
job.repo_url = f"https://huggingface.co/{full_repo_id}"
job.status = "COMPLETED"
job.set_progress(1.0, "System: Operation Finalized")
except Exception as e:
job.status = "FAILED"
job.error = str(e)
job.add_log(f"FATAL ERROR: {str(e)}")
torch.cuda.empty_cache()
def start_training_wrapper(hf_token, model_name, new_repo_name,
train_steps, learning_rate, batch_size, datasets_text,
reasoning_mode, c_conf, c_tok, c_gen):
if not hf_token or not model_name:
return None, gr.update(selected="launch_tab")
new_job = JobStatus()
JOBS[new_job.id] = new_job
thread = threading.Thread(
target=background_train_task,
args=(new_job.id, hf_token, model_name, new_repo_name,
train_steps, learning_rate, batch_size, datasets_text, reasoning_mode, c_conf, c_tok, c_gen)
)
thread.daemon = True
thread.start()
return new_job.id, gr.update(selected="monitor_tab")
def get_job_update(job_id):
if not job_id:
return "Waiting...", "", 0, "", gr.update(visible=False)
if job_id not in JOBS:
return "Not Found", "", 0, "", gr.update(visible=False)
job = JOBS[job_id]
log_text = "\n".join(job.logs)
result_comp = gr.update(visible=False)
if job.status == "COMPLETED" and job.repo_url:
result_comp = gr.update(visible=True, value=f"✅ Full Model Published: {job.repo_url}")
return job.status, job.created_at, job.progress, log_text, result_comp
def load_from_url(request: gr.Request):
try:
params = request.query_params
job_id = params.get("job_id")
if job_id:
return gr.update(selected="monitor_tab"), job_id
except:
pass
return gr.update(selected="launch_tab"), ""
with gr.Blocks(title="Nucleus Enterprise") as demo:
with gr.Column():
gr.Markdown("# ⚛️ NUCLEUS ENTERPRISE")
gr.Markdown("Autonomous LLM Foundry | V10.0 BitNet Edition")
with gr.Tabs() as main_tabs:
with gr.TabItem("🚀 LAUNCHPAD", id="launch_tab"):
with gr.Row():
with gr.Column(scale=2):
with gr.Row():
hf_token = gr.Textbox(label="HuggingFace Token", type="password", value=os.getenv("HF_TOKEN", ""))
model_name = gr.Textbox(label="Architecture Config Source", value="Qwen/Qwen2.5-0.5B")
repo_name = gr.Textbox(label="Output Repository", value="nucleus-bitnet-v1")
datasets = gr.Textbox(label="Datasets (CSV)", value="Salesforce/fineweb_deduplicated", lines=3)
reasoning = gr.Checkbox(label="Inject Reasoning (CoT/Math)", value=False)
with gr.Column(scale=1):
steps = gr.Number(label="Steps", value=100)
lr = gr.Number(label="Learning Rate", value=1e-4)
batch = gr.Number(label="Batch Size", value=1)
with gr.Accordion("Advanced Config", open=False):
c_conf = gr.Code(label="config.json", language="json")
c_tok = gr.Code(label="tokenizer_config.json", language="json")
c_gen = gr.Code(label="generation_config.json", language="json")
btn_launch = gr.Button("INITIALIZE BITNET TRAINING", variant="primary", size="lg")
with gr.TabItem("📡 TELEMETRY", id="monitor_tab"):
with gr.Row():
job_id_input = gr.Textbox(label="Active Job ID", interactive=True)
btn_refresh = gr.Button("Refresh Stream")
with gr.Row():
status_out = gr.Textbox(label="Status", interactive=False)
time_out = gr.Textbox(label="Start Time", interactive=False)
progress_out = gr.Slider(label="Progress", minimum=0, maximum=1)
final_link = gr.Markdown(visible=False)
logs_out = gr.Code(label="Real-time Kernel Logs", language="shell", interactive=False, lines=15)
timer = gr.Timer(2000, active=False)
demo.load(load_from_url, None, [main_tabs, job_id_input]).then(lambda: gr.Timer(active=True), None, timer)
btn_launch.click(
start_training_wrapper,
inputs=[hf_token, model_name, repo_name, steps, lr, batch, datasets, reasoning, c_conf, c_tok, c_gen],
outputs=[job_id_input, main_tabs]
).then(
None, [job_id_input], None,
js="(id) => { if (id) { const url = new URL(window.location); url.searchParams.set('job_id', id); window.history.pushState({}, '', url); } return id; }"
).then(
lambda: gr.Timer(active=True), None, timer
)
btn_refresh.click(get_job_update, job_id_input, [status_out, time_out, progress_out, logs_out, final_link])
timer.tick(get_job_update, job_id_input, [status_out, time_out, progress_out, logs_out, final_link])
if __name__ == "__main__":
demo.launch(ssr_mode=False)