Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,15 +1,16 @@
|
|
| 1 |
import os, json, importlib.util, tempfile, traceback, torch, re, math
|
|
|
|
| 2 |
import torch.nn.functional as F
|
| 3 |
import gradio as gr
|
| 4 |
import pandas as pd
|
| 5 |
import plotly.graph_objects as go
|
| 6 |
from huggingface_hub import hf_hub_download
|
| 7 |
from safetensors.torch import load_file
|
| 8 |
-
from transformers import AutoTokenizer
|
| 9 |
|
| 10 |
# ===== ปรับได้จาก Settings > Variables & secrets ของ Space =====
|
| 11 |
REPO_ID = os.getenv("REPO_ID", "Dusit-P/thai-sentiment-wcb")
|
| 12 |
-
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "cnn_bilstm") #
|
| 13 |
HF_TOKEN = os.getenv("HF_TOKEN", None) # ถ้าโมเดลเป็น private ให้เพิ่ม secret ชื่อนี้
|
| 14 |
|
| 15 |
# ---- theme colors (soft modern) ----
|
|
@@ -19,7 +20,7 @@ TEMPLATE = "plotly_white"
|
|
| 19 |
|
| 20 |
CACHE = {}
|
| 21 |
|
| 22 |
-
# ----------
|
| 23 |
def _import_models():
|
| 24 |
if "models_module" in CACHE:
|
| 25 |
return CACHE["models_module"]
|
|
@@ -30,20 +31,94 @@ def _import_models():
|
|
| 30 |
CACHE["models_module"] = mod
|
| 31 |
return mod
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def load_model(model_name: str):
|
| 34 |
key = f"model:{model_name}"
|
| 35 |
if key in CACHE:
|
| 36 |
return CACHE[key]
|
|
|
|
| 37 |
cfg_path = hf_hub_download(REPO_ID, filename=f"{model_name}/config.json", token=HF_TOKEN)
|
| 38 |
w_path = hf_hub_download(REPO_ID, filename=f"{model_name}/model.safetensors", token=HF_TOKEN)
|
| 39 |
|
| 40 |
with open(cfg_path, "r", encoding="utf-8") as f:
|
| 41 |
cfg = json.load(f)
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
state = load_file(w_path)
|
|
|
|
| 47 |
model.load_state_dict(state, strict=True)
|
| 48 |
model.eval()
|
| 49 |
|
|
@@ -163,7 +238,7 @@ def _shop_summary(out_df: pd.DataFrame, max_shops=15):
|
|
| 163 |
g = g.sort_values("total", ascending=False)
|
| 164 |
|
| 165 |
table = g[["total","positive","negative"]].copy()
|
| 166 |
-
table["positive_rate(%)"] = (table["positive"] / table["total"] * 100).round(2)
|
| 167 |
table["negative_rate(%)"] = (table["negative"] / table["total"] * 100).round(2)
|
| 168 |
table = table.reset_index().rename(columns={"index":"shop"})
|
| 169 |
|
|
@@ -211,7 +286,7 @@ def predict_one(text: str, model_choice: str):
|
|
| 211 |
s = _norm_text(text)
|
| 212 |
if not _is_substantive_text(s):
|
| 213 |
return {"negative": 0.0, "positive": 0.0}, "invalid"
|
| 214 |
-
model_name =
|
| 215 |
out = _predict_batch([s], model_name)[0]
|
| 216 |
probs = {
|
| 217 |
"negative": float(out["negative(%)"].rstrip("%"))/100.0,
|
|
@@ -225,7 +300,7 @@ def predict_one(text: str, model_choice: str):
|
|
| 225 |
|
| 226 |
def predict_many(text_block: str, model_choice: str):
|
| 227 |
try:
|
| 228 |
-
model_name =
|
| 229 |
raw_lines = (text_block or "").splitlines()
|
| 230 |
trimmed = [_norm_text(ln) for ln in raw_lines if _norm_text(ln)]
|
| 231 |
cleaned, skipped = _clean_texts(trimmed)
|
|
@@ -257,7 +332,7 @@ def predict_csv(file_obj, model_choice: str, review_col_override: str = "", shop
|
|
| 257 |
if file_obj is None:
|
| 258 |
return pd.DataFrame(), None, go.Figure(), go.Figure(), go.Figure(), pd.DataFrame(), "กรุณาอัปโหลดไฟล์ CSV"
|
| 259 |
|
| 260 |
-
model_name =
|
| 261 |
df = pd.read_csv(file_obj.name)
|
| 262 |
|
| 263 |
auto_rev, auto_shop = _detect_cols(df)
|
|
@@ -333,10 +408,14 @@ def predict_csv(file_obj, model_choice: str, review_col_override: str = "", shop
|
|
| 333 |
raise
|
| 334 |
|
| 335 |
# ---------- Gradio UI ----------
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
with gr.Blocks(title="Thai Sentiment API (Dusit-P)") as demo:
|
| 337 |
-
gr.Markdown("### Thai Sentiment (WangchanBERTa + LSTM/CNN Heads)")
|
| 338 |
|
| 339 |
-
model_radio = gr.Radio(choices=
|
| 340 |
|
| 341 |
with gr.Tab("Single"):
|
| 342 |
t1 = gr.Textbox(lines=3, label="ข้อความรีวิว (1 ข้อความ)")
|
|
|
|
| 1 |
import os, json, importlib.util, tempfile, traceback, torch, re, math
|
| 2 |
+
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
import gradio as gr
|
| 5 |
import pandas as pd
|
| 6 |
import plotly.graph_objects as go
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
from safetensors.torch import load_file
|
| 9 |
+
from transformers import AutoTokenizer, AutoModel
|
| 10 |
|
| 11 |
# ===== ปรับได้จาก Settings > Variables & secrets ของ Space =====
|
| 12 |
REPO_ID = os.getenv("REPO_ID", "Dusit-P/thai-sentiment-wcb")
|
| 13 |
+
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "cnn_bilstm") # "cnn_bilstm" | "baseline" | "last4weighted_pure"
|
| 14 |
HF_TOKEN = os.getenv("HF_TOKEN", None) # ถ้าโมเดลเป็น private ให้เพิ่ม secret ชื่อนี้
|
| 15 |
|
| 16 |
# ---- theme colors (soft modern) ----
|
|
|
|
| 20 |
|
| 21 |
CACHE = {}
|
| 22 |
|
| 23 |
+
# ---------- โหลดสถาปัตยกรรมจาก repo (common/models.py) ----------
|
| 24 |
def _import_models():
|
| 25 |
if "models_module" in CACHE:
|
| 26 |
return CACHE["models_module"]
|
|
|
|
| 31 |
CACHE["models_module"] = mod
|
| 32 |
return mod
|
| 33 |
|
| 34 |
+
# ---------- Fallback เผื่อ common/models.py ยังไม่รู้จัก Model3 ----------
|
| 35 |
+
class _BaseHead(nn.Module):
|
| 36 |
+
def __init__(self, hidden_in, hidden_lstm=128, classes=2, dropout=0.3, pooling='masked_mean'):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.lstm = nn.LSTM(hidden_in, hidden_lstm, bidirectional=True, batch_first=True)
|
| 39 |
+
self.dropout = nn.Dropout(dropout)
|
| 40 |
+
self.fc = nn.Linear(hidden_lstm*2, classes)
|
| 41 |
+
assert pooling in ['cls','masked_mean','masked_max']
|
| 42 |
+
self.pooling = pooling
|
| 43 |
+
def _pool(self, x, mask):
|
| 44 |
+
if self.pooling=='cls': return x[:,0,:]
|
| 45 |
+
mask = mask.unsqueeze(-1)
|
| 46 |
+
if self.pooling=='masked_mean':
|
| 47 |
+
s=(x*mask).sum(1); d=mask.sum(1).clamp(min=1e-6); return s/d
|
| 48 |
+
x=x.masked_fill(mask==0,-1e9); return x.max(1).values
|
| 49 |
+
def forward_after_bert(self, seq, mask):
|
| 50 |
+
x,_ = self.lstm(seq)
|
| 51 |
+
x = self._pool(x, mask)
|
| 52 |
+
return self.fc(self.dropout(x))
|
| 53 |
+
|
| 54 |
+
class _Model3PureLast4(nn.Module):
|
| 55 |
+
"""Last-4 weighted (Pure): LSTM รับ 768 จาก BERT"""
|
| 56 |
+
def __init__(self, base_model, hidden=128, classes=2, dropout=0.3, pooling='masked_mean'):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.bert = AutoModel.from_pretrained(base_model)
|
| 59 |
+
self.w = nn.Parameter(torch.ones(4))
|
| 60 |
+
H = self.bert.config.hidden_size
|
| 61 |
+
self.head = _BaseHead(H, hidden, classes, dropout, pooling)
|
| 62 |
+
def forward(self, ids, mask):
|
| 63 |
+
out = self.bert(input_ids=ids, attention_mask=mask, output_hidden_states=True)
|
| 64 |
+
last4 = out.hidden_states[-4:]
|
| 65 |
+
w = F.softmax(self.w, dim=0)
|
| 66 |
+
seq = sum(w[i]*last4[i] for i in range(4)) # [B,T,768]
|
| 67 |
+
return self.head.forward_after_bert(seq, mask)
|
| 68 |
+
|
| 69 |
+
class _Model3ConvLast4(nn.Module):
|
| 70 |
+
"""Last-4 weighted + Conv1d(→128): LSTM รับ 128"""
|
| 71 |
+
def __init__(self, base_model, hidden=128, classes=2, dropout=0.3, pooling='masked_mean'):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.bert = AutoModel.from_pretrained(base_model)
|
| 74 |
+
self.w = nn.Parameter(torch.ones(4))
|
| 75 |
+
H = self.bert.config.hidden_size
|
| 76 |
+
self.c1 = nn.Conv1d(H,128,3,padding=1)
|
| 77 |
+
self.c2 = nn.Conv1d(128,128,5,padding=2)
|
| 78 |
+
self.head = _BaseHead(128, hidden, classes, dropout, pooling)
|
| 79 |
+
def forward(self, ids, mask):
|
| 80 |
+
out = self.bert(input_ids=ids, attention_mask=mask, output_hidden_states=True)
|
| 81 |
+
last4 = out.hidden_states[-4:]
|
| 82 |
+
w = F.softmax(self.w, dim=0)
|
| 83 |
+
seq = sum(w[i]*last4[i] for i in range(4)) # [B,T,768]
|
| 84 |
+
x = F.relu(self.c1(seq.transpose(1,2)))
|
| 85 |
+
x = F.relu(self.c2(x)).transpose(1,2) # [B,T,128]
|
| 86 |
+
return self.head.forward_after_bert(x, mask)
|
| 87 |
+
|
| 88 |
+
def _create_model_fallback(arch: str, base_model: str):
|
| 89 |
+
"""เลือกสถาปัตยกรรม fallback จากชื่อ arch ใน config.json"""
|
| 90 |
+
if arch in ("Model3_Pure_Last4Weighted", "last4weighted_pure", "last4_pure"):
|
| 91 |
+
return _Model3PureLast4(base_model)
|
| 92 |
+
if arch in ("Model3_MLP_Last4Weighted", "last4weighted"):
|
| 93 |
+
return _Model3ConvLast4(base_model)
|
| 94 |
+
raise ValueError(f"No fallback available for arch={arch}")
|
| 95 |
+
|
| 96 |
+
# ---------- โหลดโมเดลจากโฟลเดอร์ใน repo (เช่น cnn_bilstm/, baseline/, last4weighted_pure/) ----------
|
| 97 |
def load_model(model_name: str):
|
| 98 |
key = f"model:{model_name}"
|
| 99 |
if key in CACHE:
|
| 100 |
return CACHE[key]
|
| 101 |
+
|
| 102 |
cfg_path = hf_hub_download(REPO_ID, filename=f"{model_name}/config.json", token=HF_TOKEN)
|
| 103 |
w_path = hf_hub_download(REPO_ID, filename=f"{model_name}/model.safetensors", token=HF_TOKEN)
|
| 104 |
|
| 105 |
with open(cfg_path, "r", encoding="utf-8") as f:
|
| 106 |
cfg = json.load(f)
|
| 107 |
|
| 108 |
+
base_model = cfg.get("base_model", "airesearch/wangchanberta-base-att-spm-uncased")
|
| 109 |
+
arch_name = cfg.get("arch", "")
|
| 110 |
+
tok = AutoTokenizer.from_pretrained(base_model)
|
| 111 |
+
|
| 112 |
+
# พยายามสร้างจาก common/models.py ก่อน ถ้าไม่สำเร็จค่อย fallback
|
| 113 |
+
try:
|
| 114 |
+
models = _import_models()
|
| 115 |
+
model = models.create_model_by_name(arch_name)
|
| 116 |
+
except Exception as e:
|
| 117 |
+
print(f"[INFO] Using fallback for arch={arch_name} ({e})")
|
| 118 |
+
model = _create_model_fallback(arch_name, base_model)
|
| 119 |
+
|
| 120 |
state = load_file(w_path)
|
| 121 |
+
# ใช้ strict=True ถ้า key ตรง; ถ้าอยากกัน edge-case สามารถปรับเป็น strict=False ได้
|
| 122 |
model.load_state_dict(state, strict=True)
|
| 123 |
model.eval()
|
| 124 |
|
|
|
|
| 238 |
g = g.sort_values("total", ascending=False)
|
| 239 |
|
| 240 |
table = g[["total","positive","negative"]].copy()
|
| 241 |
+
table["positive_rate(%))"] = (table["positive"] / table["total"] * 100).round(2)
|
| 242 |
table["negative_rate(%)"] = (table["negative"] / table["total"] * 100).round(2)
|
| 243 |
table = table.reset_index().rename(columns={"index":"shop"})
|
| 244 |
|
|
|
|
| 286 |
s = _norm_text(text)
|
| 287 |
if not _is_substantive_text(s):
|
| 288 |
return {"negative": 0.0, "positive": 0.0}, "invalid"
|
| 289 |
+
model_name = model_choice # ใช้ชื่อโฟลเดอร์โดยตรง
|
| 290 |
out = _predict_batch([s], model_name)[0]
|
| 291 |
probs = {
|
| 292 |
"negative": float(out["negative(%)"].rstrip("%"))/100.0,
|
|
|
|
| 300 |
|
| 301 |
def predict_many(text_block: str, model_choice: str):
|
| 302 |
try:
|
| 303 |
+
model_name = model_choice # ใช้ชื่อโฟลเดอร์โดยตรง
|
| 304 |
raw_lines = (text_block or "").splitlines()
|
| 305 |
trimmed = [_norm_text(ln) for ln in raw_lines if _norm_text(ln)]
|
| 306 |
cleaned, skipped = _clean_texts(trimmed)
|
|
|
|
| 332 |
if file_obj is None:
|
| 333 |
return pd.DataFrame(), None, go.Figure(), go.Figure(), go.Figure(), pd.DataFrame(), "กรุณาอัปโหลดไฟล์ CSV"
|
| 334 |
|
| 335 |
+
model_name = model_choice # ใช้ชื่อโฟลเดอร์โดยตรง
|
| 336 |
df = pd.read_csv(file_obj.name)
|
| 337 |
|
| 338 |
auto_rev, auto_shop = _detect_cols(df)
|
|
|
|
| 408 |
raise
|
| 409 |
|
| 410 |
# ---------- Gradio UI ----------
|
| 411 |
+
AVAILABLE_CHOICES = ["cnn_bilstm", "baseline", "last4weighted_pure"] # เพิ่มชื่อโฟลเดอร์โมเดลใหม่ที่คุณอัปจริง
|
| 412 |
+
if DEFAULT_MODEL not in AVAILABLE_CHOICES:
|
| 413 |
+
DEFAULT_MODEL = "cnn_bilstm"
|
| 414 |
+
|
| 415 |
with gr.Blocks(title="Thai Sentiment API (Dusit-P)") as demo:
|
| 416 |
+
gr.Markdown("### Thai Sentiment (WangchanBERTa + LSTM/CNN/Last4 Heads)")
|
| 417 |
|
| 418 |
+
model_radio = gr.Radio(choices=AVAILABLE_CHOICES, value=DEFAULT_MODEL, label="เลือกโมเดล")
|
| 419 |
|
| 420 |
with gr.Tab("Single"):
|
| 421 |
t1 = gr.Textbox(lines=3, label="ข้อความรีวิว (1 ข้อความ)")
|