Dusit-P commited on
Commit
79a9bb6
·
verified ·
1 Parent(s): 3b6a7f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -12
app.py CHANGED
@@ -1,15 +1,16 @@
1
  import os, json, importlib.util, tempfile, traceback, torch, re, math
 
2
  import torch.nn.functional as F
3
  import gradio as gr
4
  import pandas as pd
5
  import plotly.graph_objects as go
6
  from huggingface_hub import hf_hub_download
7
  from safetensors.torch import load_file
8
- from transformers import AutoTokenizer
9
 
10
  # ===== ปรับได้จาก Settings > Variables & secrets ของ Space =====
11
  REPO_ID = os.getenv("REPO_ID", "Dusit-P/thai-sentiment-wcb")
12
- DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "cnn_bilstm") # หรือ "baseline"
13
  HF_TOKEN = os.getenv("HF_TOKEN", None) # ถ้าโมเดลเป็น private ให้เพิ่ม secret ชื่อนี้
14
 
15
  # ---- theme colors (soft modern) ----
@@ -19,7 +20,7 @@ TEMPLATE = "plotly_white"
19
 
20
  CACHE = {}
21
 
22
- # ---------- load architecture & weights from model repo ----------
23
  def _import_models():
24
  if "models_module" in CACHE:
25
  return CACHE["models_module"]
@@ -30,20 +31,94 @@ def _import_models():
30
  CACHE["models_module"] = mod
31
  return mod
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def load_model(model_name: str):
34
  key = f"model:{model_name}"
35
  if key in CACHE:
36
  return CACHE[key]
 
37
  cfg_path = hf_hub_download(REPO_ID, filename=f"{model_name}/config.json", token=HF_TOKEN)
38
  w_path = hf_hub_download(REPO_ID, filename=f"{model_name}/model.safetensors", token=HF_TOKEN)
39
 
40
  with open(cfg_path, "r", encoding="utf-8") as f:
41
  cfg = json.load(f)
42
 
43
- models = _import_models()
44
- tok = AutoTokenizer.from_pretrained(cfg["base_model"])
45
- model = models.create_model_by_name(cfg["arch"])
 
 
 
 
 
 
 
 
 
46
  state = load_file(w_path)
 
47
  model.load_state_dict(state, strict=True)
48
  model.eval()
49
 
@@ -163,7 +238,7 @@ def _shop_summary(out_df: pd.DataFrame, max_shops=15):
163
  g = g.sort_values("total", ascending=False)
164
 
165
  table = g[["total","positive","negative"]].copy()
166
- table["positive_rate(%)"] = (table["positive"] / table["total"] * 100).round(2)
167
  table["negative_rate(%)"] = (table["negative"] / table["total"] * 100).round(2)
168
  table = table.reset_index().rename(columns={"index":"shop"})
169
 
@@ -211,7 +286,7 @@ def predict_one(text: str, model_choice: str):
211
  s = _norm_text(text)
212
  if not _is_substantive_text(s):
213
  return {"negative": 0.0, "positive": 0.0}, "invalid"
214
- model_name = "baseline" if model_choice == "baseline" else "cnn_bilstm"
215
  out = _predict_batch([s], model_name)[0]
216
  probs = {
217
  "negative": float(out["negative(%)"].rstrip("%"))/100.0,
@@ -225,7 +300,7 @@ def predict_one(text: str, model_choice: str):
225
 
226
  def predict_many(text_block: str, model_choice: str):
227
  try:
228
- model_name = "baseline" if model_choice == "baseline" else "cnn_bilstm"
229
  raw_lines = (text_block or "").splitlines()
230
  trimmed = [_norm_text(ln) for ln in raw_lines if _norm_text(ln)]
231
  cleaned, skipped = _clean_texts(trimmed)
@@ -257,7 +332,7 @@ def predict_csv(file_obj, model_choice: str, review_col_override: str = "", shop
257
  if file_obj is None:
258
  return pd.DataFrame(), None, go.Figure(), go.Figure(), go.Figure(), pd.DataFrame(), "กรุณาอัปโหลดไฟล์ CSV"
259
 
260
- model_name = "baseline" if model_choice == "baseline" else "cnn_bilstm"
261
  df = pd.read_csv(file_obj.name)
262
 
263
  auto_rev, auto_shop = _detect_cols(df)
@@ -333,10 +408,14 @@ def predict_csv(file_obj, model_choice: str, review_col_override: str = "", shop
333
  raise
334
 
335
  # ---------- Gradio UI ----------
 
 
 
 
336
  with gr.Blocks(title="Thai Sentiment API (Dusit-P)") as demo:
337
- gr.Markdown("### Thai Sentiment (WangchanBERTa + LSTM/CNN Heads)")
338
 
339
- model_radio = gr.Radio(choices=["cnn_bilstm","baseline"], value=DEFAULT_MODEL, label="เลือกโมเดล")
340
 
341
  with gr.Tab("Single"):
342
  t1 = gr.Textbox(lines=3, label="ข้อความรีวิว (1 ข้อความ)")
 
1
  import os, json, importlib.util, tempfile, traceback, torch, re, math
2
+ import torch.nn as nn
3
  import torch.nn.functional as F
4
  import gradio as gr
5
  import pandas as pd
6
  import plotly.graph_objects as go
7
  from huggingface_hub import hf_hub_download
8
  from safetensors.torch import load_file
9
+ from transformers import AutoTokenizer, AutoModel
10
 
11
  # ===== ปรับได้จาก Settings > Variables & secrets ของ Space =====
12
  REPO_ID = os.getenv("REPO_ID", "Dusit-P/thai-sentiment-wcb")
13
+ DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "cnn_bilstm") # "cnn_bilstm" | "baseline" | "last4weighted_pure"
14
  HF_TOKEN = os.getenv("HF_TOKEN", None) # ถ้าโมเดลเป็น private ให้เพิ่ม secret ชื่อนี้
15
 
16
  # ---- theme colors (soft modern) ----
 
20
 
21
  CACHE = {}
22
 
23
+ # ---------- โหลดสถาปัตยกรรมจาก repo (common/models.py) ----------
24
  def _import_models():
25
  if "models_module" in CACHE:
26
  return CACHE["models_module"]
 
31
  CACHE["models_module"] = mod
32
  return mod
33
 
34
+ # ---------- Fallback เผื่อ common/models.py ยังไม่รู้จัก Model3 ----------
35
+ class _BaseHead(nn.Module):
36
+ def __init__(self, hidden_in, hidden_lstm=128, classes=2, dropout=0.3, pooling='masked_mean'):
37
+ super().__init__()
38
+ self.lstm = nn.LSTM(hidden_in, hidden_lstm, bidirectional=True, batch_first=True)
39
+ self.dropout = nn.Dropout(dropout)
40
+ self.fc = nn.Linear(hidden_lstm*2, classes)
41
+ assert pooling in ['cls','masked_mean','masked_max']
42
+ self.pooling = pooling
43
+ def _pool(self, x, mask):
44
+ if self.pooling=='cls': return x[:,0,:]
45
+ mask = mask.unsqueeze(-1)
46
+ if self.pooling=='masked_mean':
47
+ s=(x*mask).sum(1); d=mask.sum(1).clamp(min=1e-6); return s/d
48
+ x=x.masked_fill(mask==0,-1e9); return x.max(1).values
49
+ def forward_after_bert(self, seq, mask):
50
+ x,_ = self.lstm(seq)
51
+ x = self._pool(x, mask)
52
+ return self.fc(self.dropout(x))
53
+
54
+ class _Model3PureLast4(nn.Module):
55
+ """Last-4 weighted (Pure): LSTM รับ 768 จาก BERT"""
56
+ def __init__(self, base_model, hidden=128, classes=2, dropout=0.3, pooling='masked_mean'):
57
+ super().__init__()
58
+ self.bert = AutoModel.from_pretrained(base_model)
59
+ self.w = nn.Parameter(torch.ones(4))
60
+ H = self.bert.config.hidden_size
61
+ self.head = _BaseHead(H, hidden, classes, dropout, pooling)
62
+ def forward(self, ids, mask):
63
+ out = self.bert(input_ids=ids, attention_mask=mask, output_hidden_states=True)
64
+ last4 = out.hidden_states[-4:]
65
+ w = F.softmax(self.w, dim=0)
66
+ seq = sum(w[i]*last4[i] for i in range(4)) # [B,T,768]
67
+ return self.head.forward_after_bert(seq, mask)
68
+
69
+ class _Model3ConvLast4(nn.Module):
70
+ """Last-4 weighted + Conv1d(→128): LSTM รับ 128"""
71
+ def __init__(self, base_model, hidden=128, classes=2, dropout=0.3, pooling='masked_mean'):
72
+ super().__init__()
73
+ self.bert = AutoModel.from_pretrained(base_model)
74
+ self.w = nn.Parameter(torch.ones(4))
75
+ H = self.bert.config.hidden_size
76
+ self.c1 = nn.Conv1d(H,128,3,padding=1)
77
+ self.c2 = nn.Conv1d(128,128,5,padding=2)
78
+ self.head = _BaseHead(128, hidden, classes, dropout, pooling)
79
+ def forward(self, ids, mask):
80
+ out = self.bert(input_ids=ids, attention_mask=mask, output_hidden_states=True)
81
+ last4 = out.hidden_states[-4:]
82
+ w = F.softmax(self.w, dim=0)
83
+ seq = sum(w[i]*last4[i] for i in range(4)) # [B,T,768]
84
+ x = F.relu(self.c1(seq.transpose(1,2)))
85
+ x = F.relu(self.c2(x)).transpose(1,2) # [B,T,128]
86
+ return self.head.forward_after_bert(x, mask)
87
+
88
+ def _create_model_fallback(arch: str, base_model: str):
89
+ """เลือกสถาปัตยกรรม fallback จากชื่อ arch ใน config.json"""
90
+ if arch in ("Model3_Pure_Last4Weighted", "last4weighted_pure", "last4_pure"):
91
+ return _Model3PureLast4(base_model)
92
+ if arch in ("Model3_MLP_Last4Weighted", "last4weighted"):
93
+ return _Model3ConvLast4(base_model)
94
+ raise ValueError(f"No fallback available for arch={arch}")
95
+
96
+ # ---------- โหลดโมเดลจากโฟลเดอร์ใน repo (เช่น cnn_bilstm/, baseline/, last4weighted_pure/) ----------
97
  def load_model(model_name: str):
98
  key = f"model:{model_name}"
99
  if key in CACHE:
100
  return CACHE[key]
101
+
102
  cfg_path = hf_hub_download(REPO_ID, filename=f"{model_name}/config.json", token=HF_TOKEN)
103
  w_path = hf_hub_download(REPO_ID, filename=f"{model_name}/model.safetensors", token=HF_TOKEN)
104
 
105
  with open(cfg_path, "r", encoding="utf-8") as f:
106
  cfg = json.load(f)
107
 
108
+ base_model = cfg.get("base_model", "airesearch/wangchanberta-base-att-spm-uncased")
109
+ arch_name = cfg.get("arch", "")
110
+ tok = AutoTokenizer.from_pretrained(base_model)
111
+
112
+ # พยายามสร้างจาก common/models.py ก่อน ถ้าไม่สำเร็จค่อย fallback
113
+ try:
114
+ models = _import_models()
115
+ model = models.create_model_by_name(arch_name)
116
+ except Exception as e:
117
+ print(f"[INFO] Using fallback for arch={arch_name} ({e})")
118
+ model = _create_model_fallback(arch_name, base_model)
119
+
120
  state = load_file(w_path)
121
+ # ใช้ strict=True ถ้า key ตรง; ถ้าอยากกัน edge-case สามารถปรับเป็น strict=False ได้
122
  model.load_state_dict(state, strict=True)
123
  model.eval()
124
 
 
238
  g = g.sort_values("total", ascending=False)
239
 
240
  table = g[["total","positive","negative"]].copy()
241
+ table["positive_rate(%))"] = (table["positive"] / table["total"] * 100).round(2)
242
  table["negative_rate(%)"] = (table["negative"] / table["total"] * 100).round(2)
243
  table = table.reset_index().rename(columns={"index":"shop"})
244
 
 
286
  s = _norm_text(text)
287
  if not _is_substantive_text(s):
288
  return {"negative": 0.0, "positive": 0.0}, "invalid"
289
+ model_name = model_choice # ใช้ชื่อโฟลเดอร์โดยตรง
290
  out = _predict_batch([s], model_name)[0]
291
  probs = {
292
  "negative": float(out["negative(%)"].rstrip("%"))/100.0,
 
300
 
301
  def predict_many(text_block: str, model_choice: str):
302
  try:
303
+ model_name = model_choice # ใช้ชื่อโฟลเดอร์โดยตรง
304
  raw_lines = (text_block or "").splitlines()
305
  trimmed = [_norm_text(ln) for ln in raw_lines if _norm_text(ln)]
306
  cleaned, skipped = _clean_texts(trimmed)
 
332
  if file_obj is None:
333
  return pd.DataFrame(), None, go.Figure(), go.Figure(), go.Figure(), pd.DataFrame(), "กรุณาอัปโหลดไฟล์ CSV"
334
 
335
+ model_name = model_choice # ใช้ชื่อโฟลเดอร์โดยตรง
336
  df = pd.read_csv(file_obj.name)
337
 
338
  auto_rev, auto_shop = _detect_cols(df)
 
408
  raise
409
 
410
  # ---------- Gradio UI ----------
411
+ AVAILABLE_CHOICES = ["cnn_bilstm", "baseline", "last4weighted_pure"] # เพิ่มชื่อโฟลเดอร์โมเดลใหม่ที่คุณอัปจริง
412
+ if DEFAULT_MODEL not in AVAILABLE_CHOICES:
413
+ DEFAULT_MODEL = "cnn_bilstm"
414
+
415
  with gr.Blocks(title="Thai Sentiment API (Dusit-P)") as demo:
416
+ gr.Markdown("### Thai Sentiment (WangchanBERTa + LSTM/CNN/Last4 Heads)")
417
 
418
+ model_radio = gr.Radio(choices=AVAILABLE_CHOICES, value=DEFAULT_MODEL, label="เลือกโมเดล")
419
 
420
  with gr.Tab("Single"):
421
  t1 = gr.Textbox(lines=3, label="ข้อความรีวิว (1 ข้อความ)")