ma7583 commited on
Commit
7e2e79c
·
verified ·
1 Parent(s): 5a6719d

Upload 2 files

Browse files
Files changed (2) hide show
  1. sft_fastapi.py +162 -0
  2. sft_fastapi.sh +15 -0
sft_fastapi.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Query
2
+ from fastapi.responses import JSONResponse
3
+ from pydantic import BaseModel
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ import random
8
+ import os
9
+ from typing import List, Optional
10
+ from sft_dataset import load_config
11
+
12
+ app = FastAPI(title="PVS Step Recommender API", version="1.0.0")
13
+
14
+ # ------------------------------
15
+ # Global state (loaded once)
16
+ # ------------------------------
17
+ TOKENIZER = None
18
+ MODEL = None
19
+ TEST_DATASET = None
20
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
21
+
22
+ def load_model_and_tokenizer(path: str):
23
+ global TOKENIZER, MODEL
24
+ if TOKENIZER is None or MODEL is None:
25
+ TOKENIZER = AutoTokenizer.from_pretrained(path, use_fast=True)
26
+ # device_map="auto" lets HF place layers; dtype="auto" for mixed precision when available
27
+ MODEL = AutoModelForCausalLM.from_pretrained(path, dtype="auto", device_map="auto")
28
+ # Some models have no pad token id; fall back to eos
29
+ if TOKENIZER.pad_token_id is None and TOKENIZER.eos_token_id is not None:
30
+ TOKENIZER.pad_token = TOKENIZER.eos_token
31
+ print("model and tokenizer loaded")
32
+ return TOKENIZER, MODEL
33
+
34
+
35
+ def recommend_top_k_steps(model, tokenizer, prompt: str, top_k: int = 3):
36
+ inputs = tokenizer(prompt, max_length=2048, truncation=True, return_tensors='pt')
37
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
38
+
39
+ stop_ids = {tokenizer.eos_token_id}
40
+ for token in ["END"]:
41
+ tok_id = tokenizer.convert_tokens_to_ids(token)
42
+ if tok_id is not None and tok_id != tokenizer.unk_token_id:
43
+ stop_ids.add(tok_id)
44
+
45
+ model.eval()
46
+ with torch.no_grad():
47
+ gen = model.generate(
48
+ **inputs,
49
+ do_sample=True,
50
+ num_return_sequences=top_k,
51
+ top_k=50,
52
+ top_p=0.9,
53
+ temperature=0.7,
54
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
55
+ eos_token_id=list(stop_ids),
56
+ output_scores=True,
57
+ return_dict_in_generate=True,
58
+ max_new_tokens=128,
59
+ )
60
+
61
+ sequences = gen.sequences
62
+ scores = gen.scores
63
+ prompt_len = inputs["input_ids"].shape[1]
64
+
65
+ suggestions_with_logprob = []
66
+ for i in range(sequences.size(0)):
67
+ gen_ids = sequences[i, prompt_len:]
68
+ # Decode for display; keep raw text and also split first line as the command
69
+ gen_text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
70
+
71
+ total_logprob, token_count = 0.0, 0
72
+ for t in range(min(len(scores), gen_ids.numel())):
73
+ token_id = int(gen_ids[t].item())
74
+ if token_id in stop_ids:
75
+ break
76
+ step_logits = scores[t][i]
77
+ step_logprobs = F.log_softmax(step_logits, dim=-1)
78
+ total_logprob += float(step_logprobs[token_id].item())
79
+ token_count += 1
80
+
81
+ length_norm_logprob = total_logprob / max(token_count, 1)
82
+ suggestions_with_logprob.append({
83
+ "log_prob": length_norm_logprob,
84
+ "command": gen_text.split("\n")[0]
85
+ })
86
+
87
+ suggestions_with_logprob.sort(key=lambda x: x["log_prob"], reverse=True)
88
+ return suggestions_with_logprob
89
+
90
+
91
+ # ------------------------------
92
+ # Pydantic models
93
+ # ------------------------------
94
+ class RecommendResponse(BaseModel):
95
+ prompt: str
96
+ top_k: int
97
+ suggestions: List[dict]
98
+
99
+
100
+ class RecommendRequest(BaseModel):
101
+ sequent: str
102
+ prev_commands: List[str]
103
+ top_k: Optional[int] = 3
104
+
105
+ # ------------------------------
106
+ # Startup: load config, model, and dataset
107
+ # ------------------------------
108
+ @app.on_event("startup")
109
+ def startup_event():
110
+ # Allow overriding via env vars, else use YAML
111
+ config_path = os.environ.get("PVS_API_CONFIG", "pvs_v5.yaml")
112
+ config = load_config(config_path)
113
+
114
+ save_path = os.environ.get("PVS_MODEL_PATH", getattr(config, 'save_path', None))
115
+ if not save_path:
116
+ raise RuntimeError("Model path not provided. Set PVS_MODEL_PATH or include save_path in config YAML.")
117
+
118
+ load_model_and_tokenizer(save_path)
119
+
120
+
121
+ # ------------------------------
122
+ # Routes
123
+ # ------------------------------
124
+ @app.get("/health")
125
+ def health():
126
+ return {"status": "ok", "device": DEVICE}
127
+
128
+
129
+ @app.get("/info")
130
+ def info():
131
+ return {
132
+ "model_name": getattr(MODEL.config, 'name_or_path', None),
133
+ "vocab_size": getattr(MODEL.config, 'vocab_size', None),
134
+ "eos_token_id": TOKENIZER.eos_token_id,
135
+ "pad_token_id": TOKENIZER.pad_token_id,
136
+ "device": str(MODEL.device),
137
+ }
138
+
139
+ @app.post("/recommend", response_model=RecommendResponse)
140
+ def recommend(req: RecommendRequest):
141
+ sequent = req.sequent.strip()
142
+ prev_cmds = req.prev_commands or []
143
+ prompt_lines = [f"Current Sequent:\n{sequent}\n"]
144
+ for i, cmd in enumerate(prev_cmds):
145
+ prompt_lines.append(f"Prev Command {i+1}: {cmd if cmd else 'None'}")
146
+ prompt = "\n".join(prompt_lines) + "\nNext Command:\n"
147
+ suggestions = recommend_top_k_steps(MODEL, TOKENIZER, prompt, top_k=req.top_k)
148
+ return RecommendResponse(prompt=prompt, top_k=req.top_k, suggestions=suggestions)
149
+
150
+ # if not prompt.strip():
151
+ # return JSONResponse(status_code=400, content={"error": "prompt must be a non-empty string"})
152
+
153
+ # suggestions = recommend_top_k_steps(MODEL, TOKENIZER, prompt, top_k=top_k)
154
+ # return RecommendResponse(prompt=prompt, top_k=top_k, suggestions=suggestions)
155
+
156
+
157
+ # ------------------------------
158
+ # Entrypoint for running with `python pvs_step_recommender_api.py`
159
+ # ------------------------------
160
+ if __name__ == "__main__":
161
+ import uvicorn
162
+ uvicorn.run("pvs_step_recommender_api:app", host="0.0.0.0", port=int(os.environ.get("PORT", 8000)), reload=False)
sft_fastapi.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # curl http://localhost:8000/health
2
+ # python -m uvicorn sft_fastapi:app --host 0.0.0.0 --port 8000
3
+
4
+ # curl -G "http://localhost:8000/recommend" \
5
+ # --data-urlencode "prompt=$(cat prompt.txt)" \
6
+ # --data-urlencode "top_k=3"
7
+
8
+
9
+ curl -X POST http://localhost:8000/recommend \
10
+ -H "Content-Type: application/json" \
11
+ -d '{
12
+ "sequent": "{1} FORALL (A, B: simple_polygon_2d, j: below(A`num_vertices), i: nat): LET IV = injected_vertices(A, B, A`num_vertices), s = edges_of_polygon(A)(j), L = injected_vertices(A, B, j)`length, Q = injected_edge_seq(s, injected_edge(s, B)) IN i < IV`length AND i >= L AND i < Q`length + L IMPLIES IV`seq(i) = Q`seq(i - L)",
13
+ "prev_commands": ["None", "None", "None"],
14
+ "top_k": 3
15
+ }'