import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification device = "cuda" if torch.cuda.is_available() else "cpu" # ====================== # Load MODELS # ====================== ROBERTA_MODEL = "roberta-base-openai-detector" DISTIL_MODEL = "distilroberta-base" # RoBERTa (AI Detector) roberta_tokenizer = AutoTokenizer.from_pretrained(ROBERTA_MODEL) roberta_model = AutoModelForSequenceClassification.from_pretrained(ROBERTA_MODEL).to(device) roberta_model.eval() # DistilRoBERTa (Auxiliary signal) distil_tokenizer = AutoTokenizer.from_pretrained(DISTIL_MODEL) distil_model = AutoModelForSequenceClassification.from_pretrained( DISTIL_MODEL, num_labels=2 ).to(device) distil_model.eval() # ====================== # Prediction function # ====================== def get_probs(tokenizer, model, text): inputs = tokenizer( text, return_tensors="pt", truncation=True, padding=True, max_length=512 ).to(device) with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=1)[0] return probs.cpu() def detect_text(text): if not text.strip(): return "Please enter some text.", None # Individual model predictions roberta_probs = get_probs(roberta_tokenizer, roberta_model, text) distil_probs = get_probs(distil_tokenizer, distil_model, text) # 🔥 ENSEMBLE (Soft Voting) ensemble_probs = (roberta_probs + distil_probs) / 2 human_prob = ensemble_probs[0].item() ai_prob = ensemble_probs[1].item() if ai_prob > human_prob: label = "🤖 **AI Generated**" confidence = ai_prob else: label = "🧑 **Human Written**" confidence = human_prob message = f"{label}\n\nConfidence: **{confidence*100:.2f}%**" return message, { "Human": round(human_prob, 4), "AI": round(ai_prob, 4) } # ====================== # Gradio UI # ====================== demo = gr.Interface( fn=detect_text, inputs=gr.Textbox(lines=8, placeholder="Paste your text here..."), outputs=[ gr.Markdown(label="Result"), gr.Label(label="Probabilities") ], title="AI Text Detector (Ensemble Model)", description="Ensemble of RoBERTa + DistilRoBERTa using soft voting." ) demo.launch()