import os, json, importlib.util, tempfile, traceback, torch, re, math import torch.nn as nn import torch.nn.functional as F import gradio as gr import pandas as pd import plotly.graph_objects as go from huggingface_hub import hf_hub_download from safetensors.torch import load_file from transformers import AutoTokenizer, AutoModel # ===== ปรับได้จาก Settings > Variables & secrets ของ Space ===== REPO_ID = os.getenv("REPO_ID", "Dusit-P/thai-sentiment-wcb") DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "cnn_bilstm") # "cnn_bilstm" | "baseline" | "last4weighted_pure" HF_TOKEN = os.getenv("HF_TOKEN", None) # ถ้าโมเดลเป็น private ให้เพิ่ม secret ชื่อนี้ # ---- theme colors (soft modern) ---- NEG_COLOR = os.getenv("NEG_COLOR", "#F87171") # red-400 (นุ่ม) POS_COLOR = os.getenv("POS_COLOR", "#34D399") # emerald-400 (นุ่ม) TEMPLATE = "plotly_white" CACHE = {} # ---------- โหลดสถาปัตยกรรมจาก repo (common/models.py) ---------- def _import_models(): if "models_module" in CACHE: return CACHE["models_module"] models_py = hf_hub_download(REPO_ID, filename="common/models.py", token=HF_TOKEN) spec = importlib.util.spec_from_file_location("models", models_py) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) CACHE["models_module"] = mod return mod # ---------- Fallback เผื่อ common/models.py ยังไม่รู้จัก Model3 ---------- class _BaseHead(nn.Module): def __init__(self, hidden_in, hidden_lstm=128, classes=2, dropout=0.3, pooling='masked_mean'): super().__init__() self.lstm = nn.LSTM(hidden_in, hidden_lstm, bidirectional=True, batch_first=True) self.dropout = nn.Dropout(dropout) self.fc = nn.Linear(hidden_lstm*2, classes) assert pooling in ['cls','masked_mean','masked_max'] self.pooling = pooling def _pool(self, x, mask): if self.pooling=='cls': return x[:,0,:] mask = mask.unsqueeze(-1) if self.pooling=='masked_mean': s=(x*mask).sum(1); d=mask.sum(1).clamp(min=1e-6); return s/d x=x.masked_fill(mask==0,-1e9); return x.max(1).values def forward_after_bert(self, seq, mask): x,_ = self.lstm(seq) x = self._pool(x, mask) return self.fc(self.dropout(x)) class _Model3PureLast4(nn.Module): """Last-4 weighted (Pure): LSTM รับ 768 จาก BERT""" def __init__(self, base_model, hidden=128, classes=2, dropout=0.3, pooling='masked_mean'): super().__init__() self.bert = AutoModel.from_pretrained(base_model) self.w = nn.Parameter(torch.ones(4)) H = self.bert.config.hidden_size self.head = _BaseHead(H, hidden, classes, dropout, pooling) def forward(self, ids, mask): out = self.bert(input_ids=ids, attention_mask=mask, output_hidden_states=True) last4 = out.hidden_states[-4:] w = F.softmax(self.w, dim=0) seq = sum(w[i]*last4[i] for i in range(4)) # [B,T,768] return self.head.forward_after_bert(seq, mask) class _Model3ConvLast4(nn.Module): """Last-4 weighted + Conv1d(→128): LSTM รับ 128""" def __init__(self, base_model, hidden=128, classes=2, dropout=0.3, pooling='masked_mean'): super().__init__() self.bert = AutoModel.from_pretrained(base_model) self.w = nn.Parameter(torch.ones(4)) H = self.bert.config.hidden_size self.c1 = nn.Conv1d(H,128,3,padding=1) self.c2 = nn.Conv1d(128,128,5,padding=2) self.head = _BaseHead(128, hidden, classes, dropout, pooling) def forward(self, ids, mask): out = self.bert(input_ids=ids, attention_mask=mask, output_hidden_states=True) last4 = out.hidden_states[-4:] w = F.softmax(self.w, dim=0) seq = sum(w[i]*last4[i] for i in range(4)) # [B,T,768] x = F.relu(self.c1(seq.transpose(1,2))) x = F.relu(self.c2(x)).transpose(1,2) # [B,T,128] return self.head.forward_after_bert(x, mask) def _create_model_fallback(arch: str, base_model: str): """เลือกสถาปัตยกรรม fallback จากชื่อ arch ใน config.json""" if arch in ("Model3_Pure_Last4Weighted", "last4weighted_pure", "last4_pure"): return _Model3PureLast4(base_model) if arch in ("Model3_MLP_Last4Weighted", "last4weighted"): return _Model3ConvLast4(base_model) raise ValueError(f"No fallback available for arch={arch}") # ---------- โหลดโมเดลจากโฟลเดอร์ใน repo (เช่น cnn_bilstm/, baseline/, last4weighted_pure/) ---------- def load_model(model_name: str): key = f"model:{model_name}" if key in CACHE: return CACHE[key] cfg_path = hf_hub_download(REPO_ID, filename=f"{model_name}/config.json", token=HF_TOKEN) w_path = hf_hub_download(REPO_ID, filename=f"{model_name}/model.safetensors", token=HF_TOKEN) with open(cfg_path, "r", encoding="utf-8") as f: cfg = json.load(f) base_model = cfg.get("base_model", "airesearch/wangchanberta-base-att-spm-uncased") arch_name = cfg.get("arch", "") tok = AutoTokenizer.from_pretrained(base_model) # พยายามสร้างจาก common/models.py ก่อน ถ้าไม่สำเร็จค่อย fallback try: models = _import_models() model = models.create_model_by_name(arch_name) except Exception as e: print(f"[INFO] Using fallback for arch={arch_name} ({e})") model = _create_model_fallback(arch_name, base_model) state = load_file(w_path) # ใช้ strict=True ถ้า key ตรง; ถ้าอยากกัน edge-case สามารถปรับเป็น strict=False ได้ model.load_state_dict(state, strict=True) model.eval() CACHE[key] = (model, tok, cfg) return CACHE[key] # ---------- helpers ---------- def _format_pct(x: float) -> str: return f"{x*100:.2f}%" # ====== ฟิลเตอร์ข้อความที่ไม่ใช่รีวิว / ค่าว่าง / สัญลักษณ์ ====== _INVALID_STRINGS = {"-", "--", "—", "n/a", "na", "null", "none", "nan", ".", "…", ""} # lower-case _RE_HAS_LETTER = re.compile(r"[ก-๙A-Za-z]") # ต้องมีอย่างน้อย 1 ตัวอักษรไทยหรืออังกฤษ def _norm_text(v) -> str: """แปลงค่าให้เป็นสตริงพร้อม trim และกัน NaN/None""" if v is None: return "" if isinstance(v, float) and math.isnan(v): return "" s = str(v).strip() return s def _is_substantive_text(s: str, min_chars: int = 2) -> bool: """เงื่อนไขว่าเป็นข้อความที่พอจะวิเคราะห์ได้""" if not s: return False s_lower = s.lower() if s_lower in _INVALID_STRINGS: return False if not _RE_HAS_LETTER.search(s): return False if len(s.replace(" ", "")) < min_chars: return False return True def _clean_texts(texts): """รับ list ใด ๆ → คืน (รายการที่ใช้ได้, จำนวนที่ถูกข้าม)""" all_norm = [_norm_text(t) for t in texts] cleaned = [t for t in all_norm if _is_substantive_text(t)] skipped = len(all_norm) - len(cleaned) return cleaned, skipped def _detect_cols(df: pd.DataFrame): """เดาชื่อคอลัมน์รีวิว/ร้านอัตโนมัติ ถ้าไม่พบรีวิว เลือกคอลัมน์ object ตัวแรก""" rev_cands = ["review", "text", "comment", "content", "message", "ข้อความ", "รีวิว"] shop_cands = ["shop", "shop_name", "store", "restaurant", "brand", "merchant", "ชื่อร้าน"] review_col = next((c for c in rev_cands if c in df.columns), None) shop_col = next((c for c in shop_cands if c in df.columns), None) if review_col is None: obj_cols = [c for c in df.columns if df[c].dtype == object] if obj_cols: review_col = obj_cols[0] return review_col, shop_col def _summarize_df(df: pd.DataFrame): """สรุปภาพรวม + ตัวเลขเฉลี่ยความมั่นใจ""" total = len(df) neg = int((df["label"] == "negative").sum()) pos = int((df["label"] == "positive").sum()) neg_avg = pd.to_numeric(df["negative(%)"].str.rstrip("%"), errors="coerce").mean() pos_avg = pd.to_numeric(df["positive(%)"].str.rstrip("%"), errors="coerce").mean() info = ( f"**Summary** \n" f"- Total: {total} \n" f"- Negative: {neg} \n" f"- Positive: {pos} \n" f"- Avg negative: {neg_avg:.2f}% \n" f"- Avg positive: {pos_avg:.2f}%" ) return {"total": total, "neg": neg, "pos": pos, "neg_avg": neg_avg, "pos_avg": pos_avg, "md": info} def _make_figures(df: pd.DataFrame): s = _summarize_df(df) # --- BAR: 2 trace, สีคงที่ --- fig_bar = go.Figure() fig_bar.add_bar(name="negative", x=["negative"], y=[s["neg"]], marker_color=NEG_COLOR) fig_bar.add_bar(name="positive", x=["positive"], y=[s["pos"]], marker_color=POS_COLOR) fig_bar.update_layout( barmode="group", title="Label counts", xaxis_title="label", yaxis_title="count", template=TEMPLATE, legend_title="label", ) # --- PIE: สีสอดคล้องกับ bar --- fig_pie = go.Figure( go.Pie( labels=["negative", "positive"], values=[s["neg"], s["pos"]], hole=0.35, sort=False, marker=dict(colors=[NEG_COLOR, POS_COLOR]), ) ) fig_pie.update_layout(title="Label share", template=TEMPLATE) return fig_bar, fig_pie, s["md"] def _shop_summary(out_df: pd.DataFrame, max_shops=15): """สรุปต่อร้าน: ตาราง + stacked bar (pos/neg) — ใช้สีคงที่""" if "shop" not in out_df.columns: empty_tbl = pd.DataFrame(columns=["shop","total","positive","negative","positive_rate(%)","negative_rate(%)"]) return go.Figure(), empty_tbl g = out_df.groupby("shop")["label"].value_counts().unstack(fill_value=0) for col in ["positive","negative"]: if col not in g.columns: g[col] = 0 g["total"] = g["positive"] + g["negative"] g = g.sort_values("total", ascending=False) table = g[["total","positive","negative"]].copy() table["positive_rate(%))"] = (table["positive"] / table["total"] * 100).round(2) table["negative_rate(%)"] = (table["negative"] / table["total"] * 100).round(2) table = table.reset_index().rename(columns={"index":"shop"}) # กราฟโชว์ top N ร้าน top = table.head(max_shops) fig = go.Figure() fig.add_bar(name="positive", x=top["shop"], y=top["positive"], marker_color=POS_COLOR) fig.add_bar(name="negative", x=top["shop"], y=top["negative"], marker_color=NEG_COLOR) fig.update_layout( barmode="stack", title=f"Per-shop counts (top {len(top)})", xaxis_title="shop", yaxis_title="count", legend_title="label", template=TEMPLATE, xaxis=dict(tickangle=-30), ) return fig, table # ---------- core prediction ---------- def _predict_batch(texts, model_name, batch_size=64): """รับ list[str] (ผ่านการกรองแล้ว) → คืน list[dict]""" model, tok, cfg = load_model(model_name) results = [] for i in range(0, len(texts), batch_size): chunk = texts[i:i+batch_size] enc = tok(chunk, padding=True, truncation=True, max_length=cfg["max_len"], return_tensors="pt") with torch.no_grad(): logits = model(enc["input_ids"], enc["attention_mask"]) probs = F.softmax(logits, dim=1).cpu().numpy() for txt, p in zip(chunk, probs): neg, pos = float(p[0]), float(p[1]) label = "positive" if pos >= neg else "negative" results.append({ "review": txt, "negative(%)": _format_pct(neg), "positive(%)": _format_pct(pos), "label": label, }) return results # ---------- API wrappers ---------- def predict_one(text: str, model_choice: str): try: s = _norm_text(text) if not _is_substantive_text(s): return {"negative": 0.0, "positive": 0.0}, "invalid" model_name = model_choice # ใช้ชื่อโฟลเดอร์โดยตรง out = _predict_batch([s], model_name)[0] probs = { "negative": float(out["negative(%)"].rstrip("%"))/100.0, "positive": float(out["positive(%)"].rstrip("%"))/100.0, } return probs, out["label"] except Exception as e: print("ERROR in predict_one:", repr(e)) traceback.print_exc() raise def predict_many(text_block: str, model_choice: str): try: model_name = model_choice # ใช้ชื่อโฟลเดอร์โดยตรง raw_lines = (text_block or "").splitlines() trimmed = [_norm_text(ln) for ln in raw_lines if _norm_text(ln)] cleaned, skipped = _clean_texts(trimmed) if len(cleaned) == 0: empty = pd.DataFrame(columns=["review","negative(%)","positive(%)","label"]) return empty, go.Figure(), go.Figure(), "No valid text" results = _predict_batch(cleaned, model_name) df = pd.DataFrame(results, columns=["review","negative(%)","positive(%)","label"]) fig_bar, fig_pie, info_md = _make_figures(df) info_md = f"{info_md} \n- Skipped (empty/non-text): {skipped}" return df, fig_bar, fig_pie, info_md except Exception as e: print("ERROR in predict_many:", repr(e)) traceback.print_exc() raise def predict_csv(file_obj, model_choice: str, review_col_override: str = "", shop_col_override: str = ""): """ พฤติกรรม: - ไม่ตัดแถวทิ้ง: แถว invalid ยังอยู่ เรียงตามไฟล์เดิม - review ของแถว invalid = NA, ไม่คำนวณผลลัพธ์ - shop คงค่าจากไฟล์เดิม ไม่แปลงเป็นสตริง - กราฟ/สรุป คำนวณจากเฉพาะแถว valid """ try: if file_obj is None: return pd.DataFrame(), None, go.Figure(), go.Figure(), go.Figure(), pd.DataFrame(), "กรุณาอัปโหลดไฟล์ CSV" model_name = model_choice # ใช้ชื่อโฟลเดอร์โดยตรง df = pd.read_csv(file_obj.name) auto_rev, auto_shop = _detect_cols(df) rev_col = (review_col_override or "").strip() or auto_rev shop_col = (shop_col_override or "").strip() or auto_shop if rev_col not in df.columns: raise ValueError(f"ไม่พบคอลัมน์รีวิว '{rev_col}' ใน CSV (columns = {list(df.columns)})") # === เตรียมรีวิวและมาสก์แถวที่ 'มีเนื้อหา' เท่านั้น === reviews_norm = df[rev_col].apply(_norm_text) mask_valid = reviews_norm.apply(_is_substantive_text) idx_valid = df.index[mask_valid].tolist() skipped = int((~mask_valid).sum()) # === พยากรณ์เฉพาะแถวที่ valid === results = [] if len(idx_valid) > 0: texts_valid = reviews_norm.loc[idx_valid].tolist() results = _predict_batch(texts_valid, model_name) # list[dict] ตามลำดับ idx_valid # === สร้าง DataFrame ผลลัพธ์ "ครบทุกแถว" ตามลำดับเดิม === out = pd.DataFrame(index=df.index, columns=["review","negative(%)","positive(%)","label"]) # review: valid → normalized text, invalid → NA out.loc[idx_valid, "review"] = reviews_norm.loc[idx_valid].values out.loc[~mask_valid, "review"] = pd.NA # เติมผลพยากรณ์กลับตาม index เดิมสำหรับแถว valid for i, idx in enumerate(idx_valid): p = results[i] out.at[idx, "negative(%)"] = p["negative(%)"] out.at[idx, "positive(%)"] = p["positive(%)"] out.at[idx, "label"] = p["label"] # แทรกคอลัมน์ shop ด้านหน้า (คงค่าตามต้นฉบับโดยไม่ .astype(str)) if shop_col and shop_col in df.columns: out.insert(0, "shop", df[shop_col]) else: out.insert(0, "shop", pd.Series([pd.NA]*len(out), index=out.index)) # === เตรียมข้อมูล "เฉพาะแถวที่ valid" ไว้ทำกราฟ/สรุป === out_valid = out.loc[idx_valid].copy() # ไฟล์ผลลัพธ์สำหรับดาวน์โหลด → ครบทุกแถว tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") out.to_csv(tmp.name, index=False, encoding="utf-8-sig") if out_valid.empty: empty_fig = go.Figure() info_md = "ไม่พบรีวิวที่เป็นข้อความ\n- Skipped (empty/non-text): {}".format(skipped) empty_tbl = pd.DataFrame(columns=["shop","total","positive","negative","positive_rate(%)","negative_rate(%)"]) return out, tmp.name, empty_fig, empty_fig, empty_fig, empty_tbl, info_md # กราฟ/สรุปรวม (จากแถวที่ valid เท่านั้น) fig_bar, fig_pie, info_md = _make_figures(out_valid) # กราฟ/ตารางต่อร้าน (ใช้เฉพาะ valid) fig_shop, tbl_shop = _shop_summary(out_valid) # แนบข้อความบอกคอลัมน์ที่ใช้ + จำนวนแถวที่ถูกข้าม info_md = ( f"{info_md} \n" f"ใช้คอลัมน์รีวิว: {rev_col}" + (f" | คอลัมน์ร้าน: {shop_col}" if shop_col and (shop_col in df.columns) else " | ไม่มีคอลัมน์ร้าน") + f" \n- Skipped (empty/non-text): {skipped}" ) return out, tmp.name, fig_bar, fig_pie, fig_shop, tbl_shop, info_md except Exception as e: print("ERROR in predict_csv:", repr(e)) traceback.print_exc() raise # ---------- Gradio UI ---------- AVAILABLE_CHOICES = ["cnn_bilstm", "baseline", "last4weighted_bilstm"] # เพิ่มชื่อโฟลเดอร์โมเดลใหม่ที่คุณอัปจริง if DEFAULT_MODEL not in AVAILABLE_CHOICES: DEFAULT_MODEL = "cnn_bilstm" with gr.Blocks(title="Thai Sentiment API (Dusit-P)") as demo: gr.Markdown("### Thai Sentiment (WangchanBERTa + LSTM/CNN/Last4 Heads)") model_radio = gr.Radio(choices=AVAILABLE_CHOICES, value=DEFAULT_MODEL, label="เลือกโมเดล") with gr.Tab("Single"): t1 = gr.Textbox(lines=3, label="ข้อความรีวิว (1 ข้อความ)") probs = gr.Label(label="Probabilities") pred = gr.Textbox(label="Prediction", interactive=False) gr.Button("Predict").click(predict_one, [t1, model_radio], [probs, pred]) with gr.Tab("Batch (หลายข้อความ)"): t2 = gr.Textbox(lines=8, label="พิมพ์หลายรีวิว (บรรทัดละ 1 รีวิว)") df2 = gr.Dataframe(label="ผลลัพธ์", interactive=False) bar2 = gr.Plot(label="Label counts (bar)") pie2 = gr.Plot(label="Label share (pie)") sum2 = gr.Markdown() gr.Button("Run Batch").click(predict_many, [t2, model_radio], [df2, bar2, pie2, sum2]) with gr.Tab("CSV (auto-detect columns)"): f = gr.File(label="อัปโหลด CSV", file_types=[".csv"]) review_col_inp = gr.Textbox(label="ชื่อคอลัมน์รีวิว (เว้นว่างให้เดาได้)") shop_col_inp = gr.Textbox(label="ชื่อคอลัมน์ร้าน (เว้นว่างได้)") df3 = gr.Dataframe(label="ผลลัพธ์ CSV", interactive=False) download = gr.File(label="ดาวน์โหลดผลลัพธ์") bar3 = gr.Plot(label="Label counts (bar)") pie3 = gr.Plot(label="Label share (pie)") shop_bar = gr.Plot(label="Per-shop stacked bar") shop_tbl = gr.Dataframe(label="Per-shop summary", interactive=False) info = gr.Markdown() gr.Button("Run CSV").click( predict_csv, inputs=[f, model_radio, review_col_inp, shop_col_inp], outputs=[df3, download, bar3, pie3, shop_bar, shop_tbl, info] ) if __name__ == "__main__": demo.launch()