|
|
import os |
|
|
import json |
|
|
import logging |
|
|
import threading |
|
|
import uuid |
|
|
import time |
|
|
import sys |
|
|
import gc |
|
|
import multiprocessing |
|
|
import shutil |
|
|
import math |
|
|
from datetime import datetime |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
from itertools import chain |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import gradio as gr |
|
|
import transformers |
|
|
import datasets |
|
|
from dotenv import load_dotenv |
|
|
from datasets import load_dataset, get_dataset_config_names, IterableDataset |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback, AutoConfig, DataCollatorForLanguageModeling |
|
|
from huggingface_hub import login, whoami, create_repo, upload_folder |
|
|
import spaces |
|
|
|
|
|
try: |
|
|
load_dotenv() |
|
|
except: |
|
|
pass |
|
|
|
|
|
transformers.logging.set_verbosity_error() |
|
|
datasets.logging.set_verbosity_error() |
|
|
logging.getLogger("transformers").setLevel(logging.CRITICAL) |
|
|
logging.getLogger("datasets").setLevel(logging.CRITICAL) |
|
|
logging.getLogger("torch").setLevel(logging.CRITICAL) |
|
|
logging.basicConfig(level=logging.CRITICAL, stream=sys.stderr) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
JOBS = {} |
|
|
|
|
|
def activation_quant(x): |
|
|
scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) |
|
|
y = (x * scale).round().clamp_(-128, 127) / scale |
|
|
return y + x - x.detach() |
|
|
|
|
|
def weight_quant(w): |
|
|
scale = 1.0 / w.abs().mean().clamp_(min=1e-5) |
|
|
u = (w * scale).round().clamp_(-1, 1) / scale |
|
|
return u + w - w.detach() |
|
|
|
|
|
class BitLinear(nn.Linear): |
|
|
def forward(self, x): |
|
|
target_dtype = x.dtype |
|
|
|
|
|
w = self.weight.to(target_dtype) |
|
|
w_quant = weight_quant(w).to(target_dtype) |
|
|
|
|
|
x_quant = activation_quant(x).to(target_dtype) |
|
|
|
|
|
if self.bias is not None: |
|
|
b = self.bias.to(target_dtype) |
|
|
else: |
|
|
b = None |
|
|
|
|
|
return F.linear(x_quant, w_quant, b) |
|
|
|
|
|
def convert_to_bitnet(model, copy_weights=False): |
|
|
for name, module in model.named_children(): |
|
|
if isinstance(module, nn.Linear): |
|
|
bit_linear = BitLinear(module.in_features, module.out_features, module.bias is not None) |
|
|
if copy_weights: |
|
|
bit_linear.weight.data = module.weight.data.clone() |
|
|
if module.bias is not None: |
|
|
bit_linear.bias.data = module.bias.data.clone() |
|
|
setattr(model, name, bit_linear) |
|
|
else: |
|
|
convert_to_bitnet(module, copy_weights=copy_weights) |
|
|
|
|
|
class JobStatus: |
|
|
def __init__(self): |
|
|
self.id = str(uuid.uuid4()) |
|
|
self.status = "INITIALIZING" |
|
|
self.progress = 0.0 |
|
|
self.logs = [] |
|
|
self.result = None |
|
|
self.error = None |
|
|
self.created_at = datetime.now().strftime("%H:%M:%S") |
|
|
self.repo_url = None |
|
|
|
|
|
def add_log(self, message): |
|
|
timestamp = datetime.now().strftime("%H:%M:%S") |
|
|
self.logs.append(f"[{timestamp}] {message}") |
|
|
|
|
|
def set_progress(self, val, msg=None): |
|
|
self.progress = val |
|
|
if msg: |
|
|
self.add_log(msg) |
|
|
|
|
|
class CustomTrainerCallback(TrainerCallback): |
|
|
def __init__(self, job_id, hf_token, repo_id): |
|
|
self.job_id = job_id |
|
|
self.hf_token = hf_token |
|
|
self.repo_id = repo_id |
|
|
|
|
|
def on_step_end(self, args, state, control, **kwargs): |
|
|
if self.job_id in JOBS: |
|
|
job = JOBS[self.job_id] |
|
|
if state.max_steps > 0: |
|
|
prog = state.global_step / state.max_steps |
|
|
job.progress = 0.1 + (prog * 0.8) |
|
|
if state.global_step % 1 == 0: |
|
|
loss = state.log_history[-1].get('loss', 'N/A') if state.log_history else '...' |
|
|
job.add_log(f"Training Step {state.global_step}/{state.max_steps} | Loss: {loss}") |
|
|
return control |
|
|
|
|
|
def on_save(self, args, state, control, **kwargs): |
|
|
if self.job_id in JOBS: |
|
|
job = JOBS[self.job_id] |
|
|
step = state.global_step |
|
|
ckpt_name = f"checkpoint-{step}" |
|
|
ckpt_path = os.path.join(args.output_dir, ckpt_name) |
|
|
|
|
|
job.add_log(f"System: 100-Step Snapshot saved ({ckpt_name})") |
|
|
|
|
|
def _upload_bg(): |
|
|
try: |
|
|
upload_folder( |
|
|
folder_path=ckpt_path, |
|
|
path_in_repo=".", |
|
|
repo_id=self.repo_id, |
|
|
token=self.hf_token, |
|
|
commit_message=f"Live Checkpoint Step {step}" |
|
|
) |
|
|
job.add_log(f"Cloud: Synced Checkpoint {step} to Root") |
|
|
except: |
|
|
pass |
|
|
|
|
|
threading.Thread(target=_upload_bg, daemon=True).start() |
|
|
return control |
|
|
|
|
|
@spaces.GPU(duration=300) |
|
|
def background_train_task(job_id, hf_token, model_name, new_repo_name, |
|
|
train_steps, learning_rate, batch_size, datasets_text, |
|
|
reasoning_mode, c_conf, c_tok, c_gen): |
|
|
|
|
|
job = JOBS[job_id] |
|
|
job.status = "RUNNING" |
|
|
job.add_log("System: initializing BitNet Scratch Protocol...") |
|
|
|
|
|
try: |
|
|
if not hf_token.startswith("hf_"): |
|
|
raise ValueError("Invalid Token") |
|
|
|
|
|
os.environ["WANDB_DISABLED"] = "true" |
|
|
os.environ["HF_TOKEN"] = hf_token |
|
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "true" |
|
|
|
|
|
login(token=hf_token) |
|
|
try: |
|
|
username = whoami()["name"] |
|
|
full_repo_id = f"{username}/{new_repo_name}" |
|
|
create_repo(full_repo_id, token=hf_token, exist_ok=True) |
|
|
job.add_log(f"Auth: Verified {username} -> {full_repo_id}") |
|
|
except: |
|
|
raise Exception("Auth Failed") |
|
|
|
|
|
if not hasattr(torch, 'xla'): |
|
|
class DummyXLA: |
|
|
def __getattr__(self, name): |
|
|
return lambda *args, **kwargs: None |
|
|
torch.xla = DummyXLA() |
|
|
|
|
|
raw_items = datasets_text.replace('\n', ',').split(',') |
|
|
dataset_list = [item.strip() for item in raw_items if item.strip()] |
|
|
|
|
|
if reasoning_mode: |
|
|
job.add_log("Config: Reasoning Injection Active") |
|
|
dataset_list.extend(["gsm8k", "openai/gsm8k"]) |
|
|
|
|
|
def load_single(ds_name, cfg): |
|
|
try: |
|
|
ds = load_dataset(ds_name, cfg if cfg else "main", split="train", streaming=True, trust_remote_code=False) |
|
|
try: |
|
|
next(iter(ds)) |
|
|
return ds |
|
|
except: |
|
|
return None |
|
|
except: |
|
|
return None |
|
|
|
|
|
streams = [] |
|
|
job.set_progress(0.05, "Data: Parallel Stream Connect...") |
|
|
|
|
|
cpu_count = multiprocessing.cpu_count() |
|
|
with ThreadPoolExecutor(max_workers=cpu_count * 2) as executor: |
|
|
futures = [] |
|
|
for ds_name in dataset_list: |
|
|
futures.append(executor.submit(load_single, ds_name, None)) |
|
|
|
|
|
for future in as_completed(futures): |
|
|
res = future.result() |
|
|
if res: |
|
|
streams.append(res) |
|
|
|
|
|
if not streams: |
|
|
raise Exception("No Data Sources") |
|
|
|
|
|
job.set_progress(0.1, f"Data: {len(streams)} Streams Linked") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="left", add_eos_token=True, add_bos_token=True) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
def process_stream_generator(): |
|
|
iterator = chain.from_iterable(streams) |
|
|
batch_buffer = [] |
|
|
|
|
|
for item in iterator: |
|
|
try: |
|
|
text = str(item.get("text", item.get("content", str(item)))) |
|
|
if len(text) < 5: continue |
|
|
batch_buffer.append(text) |
|
|
|
|
|
if len(batch_buffer) >= 100: |
|
|
encoded_batch = tokenizer(batch_buffer, truncation=True, max_length=2048, padding=False) |
|
|
for input_ids in encoded_batch["input_ids"]: |
|
|
yield {"input_ids": input_ids} |
|
|
batch_buffer = [] |
|
|
except: |
|
|
continue |
|
|
|
|
|
job.set_progress(0.15, "Model: Initializing Architecture & Converting to BitNet...") |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
|
|
original_model = AutoModelForCausalLM.from_config( |
|
|
config, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
convert_to_bitnet(original_model, copy_weights=False) |
|
|
|
|
|
model_size = sum(t.numel() for t in original_model.parameters()) |
|
|
job.add_log(f"Model Size: {model_size/1000**2:.1f}M Parameters (1.58-bit)") |
|
|
|
|
|
output_dir = f"checkpoints/{job_id}" |
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=output_dir, |
|
|
per_device_train_batch_size=int(batch_size), |
|
|
gradient_accumulation_steps=4, |
|
|
max_steps=int(train_steps), |
|
|
learning_rate=learning_rate, |
|
|
optim="adamw_torch_fused" if torch.cuda.is_available() else "adamw_torch", |
|
|
logging_steps=1, |
|
|
save_strategy="steps", |
|
|
save_steps=100, |
|
|
save_total_limit=1, |
|
|
report_to="none", |
|
|
fp16=True if torch.cuda.is_available() else False, |
|
|
disable_tqdm=True, |
|
|
dataloader_num_workers=4, |
|
|
dataloader_pin_memory=True, |
|
|
gradient_checkpointing=True, |
|
|
torch_compile=False, |
|
|
lr_scheduler_type="cosine", |
|
|
warmup_ratio=0.1 |
|
|
) |
|
|
|
|
|
dataset_iterable = IterableDataset.from_generator(process_stream_generator) |
|
|
|
|
|
trainer = Trainer( |
|
|
model=original_model, |
|
|
tokenizer=tokenizer, |
|
|
train_dataset=dataset_iterable, |
|
|
args=training_args, |
|
|
data_collator=data_collator, |
|
|
callbacks=[CustomTrainerCallback(job_id, hf_token, full_repo_id)] |
|
|
) |
|
|
|
|
|
job.set_progress(0.2, "Training: BitNet Gradient Descent Initiated...") |
|
|
trainer.train() |
|
|
trainer.save_model(output_dir) |
|
|
tokenizer.save_pretrained(output_dir) |
|
|
|
|
|
job.set_progress(0.9, "Processing: Finalizing Artifacts...") |
|
|
del original_model |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
def inject_json(content, fname): |
|
|
if content and content.strip(): |
|
|
try: |
|
|
data = json.loads(content) |
|
|
file_path = os.path.join(output_dir, fname) |
|
|
|
|
|
if os.path.exists(file_path): |
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
|
try: |
|
|
existing_data = json.load(f) |
|
|
existing_data.update(data) |
|
|
data = existing_data |
|
|
except: |
|
|
pass |
|
|
|
|
|
with open(file_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(data, f, indent=2) |
|
|
job.add_log(f"Config: Overwritten {fname} with user settings") |
|
|
except: |
|
|
pass |
|
|
|
|
|
inject_json(c_conf, "config.json") |
|
|
inject_json(c_tok, "tokenizer_config.json") |
|
|
inject_json(c_gen, "generation_config.json") |
|
|
|
|
|
job.set_progress(0.95, "Network: Uploading Final BitNet Model...") |
|
|
|
|
|
upload_folder( |
|
|
folder_path=output_dir, |
|
|
path_in_repo=".", |
|
|
repo_id=full_repo_id, |
|
|
token=hf_token, |
|
|
commit_message="BitNet Scratch Trained Model" |
|
|
) |
|
|
|
|
|
job.repo_url = f"https://huggingface.co/{full_repo_id}" |
|
|
job.status = "COMPLETED" |
|
|
job.set_progress(1.0, "System: Operation Finalized") |
|
|
|
|
|
except Exception as e: |
|
|
job.status = "FAILED" |
|
|
job.error = str(e) |
|
|
job.add_log(f"FATAL ERROR: {str(e)}") |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def start_training_wrapper(hf_token, model_name, new_repo_name, |
|
|
train_steps, learning_rate, batch_size, datasets_text, |
|
|
reasoning_mode, c_conf, c_tok, c_gen): |
|
|
|
|
|
if not hf_token or not model_name: |
|
|
return None, gr.update(selected="launch_tab") |
|
|
|
|
|
new_job = JobStatus() |
|
|
JOBS[new_job.id] = new_job |
|
|
|
|
|
thread = threading.Thread( |
|
|
target=background_train_task, |
|
|
args=(new_job.id, hf_token, model_name, new_repo_name, |
|
|
train_steps, learning_rate, batch_size, datasets_text, reasoning_mode, c_conf, c_tok, c_gen) |
|
|
) |
|
|
thread.daemon = True |
|
|
thread.start() |
|
|
|
|
|
return new_job.id, gr.update(selected="monitor_tab") |
|
|
|
|
|
def get_job_update(job_id): |
|
|
if not job_id: |
|
|
return "Waiting...", "", 0, "", gr.update(visible=False) |
|
|
|
|
|
if job_id not in JOBS: |
|
|
return "Not Found", "", 0, "", gr.update(visible=False) |
|
|
|
|
|
job = JOBS[job_id] |
|
|
|
|
|
log_text = "\n".join(job.logs) |
|
|
|
|
|
result_comp = gr.update(visible=False) |
|
|
if job.status == "COMPLETED" and job.repo_url: |
|
|
result_comp = gr.update(visible=True, value=f"✅ Full Model Published: {job.repo_url}") |
|
|
|
|
|
return job.status, job.created_at, job.progress, log_text, result_comp |
|
|
|
|
|
def load_from_url(request: gr.Request): |
|
|
try: |
|
|
params = request.query_params |
|
|
job_id = params.get("job_id") |
|
|
if job_id: |
|
|
return gr.update(selected="monitor_tab"), job_id |
|
|
except: |
|
|
pass |
|
|
return gr.update(selected="launch_tab"), "" |
|
|
|
|
|
with gr.Blocks(title="Nucleus Enterprise") as demo: |
|
|
with gr.Column(): |
|
|
gr.Markdown("# ⚛️ NUCLEUS ENTERPRISE") |
|
|
gr.Markdown("Autonomous LLM Foundry | V10.0 BitNet Edition") |
|
|
|
|
|
with gr.Tabs() as main_tabs: |
|
|
with gr.TabItem("🚀 LAUNCHPAD", id="launch_tab"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
with gr.Row(): |
|
|
hf_token = gr.Textbox(label="HuggingFace Token", type="password", value=os.getenv("HF_TOKEN", "")) |
|
|
model_name = gr.Textbox(label="Architecture Config Source", value="Qwen/Qwen2.5-0.5B") |
|
|
|
|
|
repo_name = gr.Textbox(label="Output Repository", value="nucleus-bitnet-v1") |
|
|
datasets = gr.Textbox(label="Datasets (CSV)", value="Salesforce/fineweb_deduplicated", lines=3) |
|
|
|
|
|
reasoning = gr.Checkbox(label="Inject Reasoning (CoT/Math)", value=False) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
steps = gr.Number(label="Steps", value=100) |
|
|
lr = gr.Number(label="Learning Rate", value=1e-4) |
|
|
batch = gr.Number(label="Batch Size", value=1) |
|
|
|
|
|
with gr.Accordion("Advanced Config", open=False): |
|
|
c_conf = gr.Code(label="config.json", language="json") |
|
|
c_tok = gr.Code(label="tokenizer_config.json", language="json") |
|
|
c_gen = gr.Code(label="generation_config.json", language="json") |
|
|
|
|
|
btn_launch = gr.Button("INITIALIZE BITNET TRAINING", variant="primary", size="lg") |
|
|
|
|
|
with gr.TabItem("📡 TELEMETRY", id="monitor_tab"): |
|
|
with gr.Row(): |
|
|
job_id_input = gr.Textbox(label="Active Job ID", interactive=True) |
|
|
btn_refresh = gr.Button("Refresh Stream") |
|
|
|
|
|
with gr.Row(): |
|
|
status_out = gr.Textbox(label="Status", interactive=False) |
|
|
time_out = gr.Textbox(label="Start Time", interactive=False) |
|
|
progress_out = gr.Slider(label="Progress", minimum=0, maximum=1) |
|
|
|
|
|
final_link = gr.Markdown(visible=False) |
|
|
logs_out = gr.Code(label="Real-time Kernel Logs", language="shell", interactive=False, lines=15) |
|
|
|
|
|
timer = gr.Timer(2000, active=False) |
|
|
|
|
|
demo.load(load_from_url, None, [main_tabs, job_id_input]).then(lambda: gr.Timer(active=True), None, timer) |
|
|
|
|
|
btn_launch.click( |
|
|
start_training_wrapper, |
|
|
inputs=[hf_token, model_name, repo_name, steps, lr, batch, datasets, reasoning, c_conf, c_tok, c_gen], |
|
|
outputs=[job_id_input, main_tabs] |
|
|
).then( |
|
|
None, [job_id_input], None, |
|
|
js="(id) => { if (id) { const url = new URL(window.location); url.searchParams.set('job_id', id); window.history.pushState({}, '', url); } return id; }" |
|
|
).then( |
|
|
lambda: gr.Timer(active=True), None, timer |
|
|
) |
|
|
|
|
|
btn_refresh.click(get_job_update, job_id_input, [status_out, time_out, progress_out, logs_out, final_link]) |
|
|
timer.tick(get_job_update, job_id_input, [status_out, time_out, progress_out, logs_out, final_link]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(ssr_mode=False) |