hannahcyberey commited on
Commit
c64bf5f
·
1 Parent(s): 92f7a52
Files changed (3) hide show
  1. app.py +48 -31
  2. outputs.jsonl +2 -0
  3. schemas.py +6 -4
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import os
2
  import logging
3
  from pathlib import Path
4
  import asyncio
@@ -206,6 +206,7 @@ def cleanup_instance(request: gr.Request):
206
  for data in instances[session_id]:
207
  if isinstance(data, SteeringOutput):
208
  scheduler.append(data.model_dump())
 
209
  del instances[session_id]
210
 
211
  if len(instances) == 0:
@@ -273,39 +274,54 @@ async def post_process(session_id, output):
273
  instances[session_id].append(steering_output)
274
 
275
 
276
- async def generate(
277
- session_id: str, prompt: str, steering: bool, coeff: float,
278
- max_new_tokens: int, top_p: float, temperature: float, layer: int, vec_scaling: float
279
- ):
280
- req = UserRequest(
281
- session_id=session_id, prompt=prompt, steering=steering, coeff=coeff,
282
- max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, layer=layer, k=vec_scaling
283
- )
284
 
285
- instances[session_id].append(req)
 
 
286
 
287
- data = req.get_api_format()
288
- logger.info("User Request: %s", data)
 
 
 
 
 
 
289
 
290
- generated_text = ""
291
 
292
- try:
293
- async with aiohttp.ClientSession() as session:
294
- async with session.post(f"{API_URL}/generate", headers=headers, json=data) as resp:
295
- if resp.status == 200:
296
- generated_text += "<think>"
297
- async for chunk, _ in resp.content.iter_chunks():
298
- generated_text += chunk.decode()
299
- yield generated_text
300
- else:
301
- logger.error("API Error Ccode: %d, Error Message: %s", resp.status, resp.text())
302
- raise gr.Error("API Server Error")
303
 
304
- except:
305
- logger.info("Client session error")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
- if generated_text != "":
308
- await post_process(session_id, generated_text)
309
 
310
 
311
  async def output_feedback(session_id, feedback):
@@ -329,6 +345,7 @@ async def show_feedback_buttons(upvote_btn, downvote_btn):
329
 
330
  gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
331
  theme = gr.themes.Base(primary_hue="emerald", text_size=gr.themes.sizes.text_lg).set()
 
332
 
333
  with gr.Blocks(title="LLM Censorship Steering", theme=theme, head=HEAD, css=CSS, js=JS) as demo:
334
  session_id = gr.State()
@@ -376,7 +393,7 @@ with gr.Blocks(title="LLM Censorship Steering", theme=theme, head=HEAD, css=CSS,
376
  with gr.Row():
377
  layer = gr.Slider(0, 27, step=1, value=CONFIG["layer"], interactive=True, label="Steering layer", scale=2)
378
  max_new_tokens = gr.Number(CONFIG["max_new_tokens"], minimum=10, maximum=3048, interactive=True, label="Max new tokens", scale=1)
379
- vec_scaling = gr.Number(CONFIG["k"], minimum=0, interactive=True, label="Vector scaling", scale=1)
380
 
381
  with gr.Column(scale=1):
382
  output = gr.Textbox(label="Output", lines=15, max_lines=15, interactive=False)
@@ -401,11 +418,11 @@ with gr.Blocks(title="LLM Censorship Steering", theme=theme, head=HEAD, css=CSS,
401
 
402
 
403
  submission = generate_btn.click(
404
- generate, inputs=[session_id, input_text, steer_toggle, coeff, max_new_tokens, top_p, temperature, layer, vec_scaling], outputs=output
405
  )
406
 
407
  clear_btn.add([input_text, output])
408
- stop_btn.click(None, None, None, cancels=[submission], queue=False)
409
 
410
  upvote_btn.click(output_feedback, inputs=[session_id, upvote_btn])
411
  downvote_btn.click(output_feedback, inputs=[session_id, downvote_btn])
 
1
+ import os, json
2
  import logging
3
  from pathlib import Path
4
  import asyncio
 
206
  for data in instances[session_id]:
207
  if isinstance(data, SteeringOutput):
208
  scheduler.append(data.model_dump())
209
+
210
  del instances[session_id]
211
 
212
  if len(instances) == 0:
 
274
  instances[session_id].append(steering_output)
275
 
276
 
277
+ class Generator:
278
+ def __init__(self):
279
+ self.stop_events = {}
 
 
 
 
 
280
 
281
+ async def stop(self, session_id):
282
+ self.stop_events[session_id] = True
283
+ logger.info("Stopping generation")
284
 
285
+ async def generate(
286
+ self, session_id: str, prompt: str, steering: bool, coeff: float,
287
+ max_new_tokens: int, top_p: float, temperature: float, layer: int, vec_scaling: float
288
+ ):
289
+ req = UserRequest(
290
+ session_id=session_id, prompt=prompt, steering=steering, coeff=coeff,
291
+ max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, vec_scale=vec_scaling, layer=layer
292
+ )
293
 
294
+ instances[session_id].append(req)
295
 
296
+ data = req.get_api_format()
297
+ logger.info("User Request: %s", data)
 
 
 
 
 
 
 
 
 
298
 
299
+ generated_text = ""
300
+ self.stop_events[session_id] = False
301
+
302
+ try:
303
+ async with aiohttp.ClientSession() as session:
304
+ async with session.post(f"{API_URL}/generate", headers=headers, json=data) as resp:
305
+ if resp.status == 200:
306
+ generated_text += "<think>"
307
+
308
+ async for chunk, _ in resp.content.iter_chunks():
309
+ if self.stop_events[session_id]:
310
+ break
311
+
312
+ generated_text += chunk.decode()
313
+ yield generated_text
314
+ else:
315
+ logger.error("API Error Ccode: %d, Error Message: %s", resp.status, resp.text())
316
+ raise gr.Error("API Server Error")
317
+
318
+ except:
319
+ logger.info("Client session error")
320
+
321
+ if generated_text != "":
322
+ await post_process(session_id, generated_text)
323
 
324
+ del self.stop_events[session_id]
 
325
 
326
 
327
  async def output_feedback(session_id, feedback):
 
345
 
346
  gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
347
  theme = gr.themes.Base(primary_hue="emerald", text_size=gr.themes.sizes.text_lg).set()
348
+ generator = Generator()
349
 
350
  with gr.Blocks(title="LLM Censorship Steering", theme=theme, head=HEAD, css=CSS, js=JS) as demo:
351
  session_id = gr.State()
 
393
  with gr.Row():
394
  layer = gr.Slider(0, 27, step=1, value=CONFIG["layer"], interactive=True, label="Steering layer", scale=2)
395
  max_new_tokens = gr.Number(CONFIG["max_new_tokens"], minimum=10, maximum=3048, interactive=True, label="Max new tokens", scale=1)
396
+ vec_scaling = gr.Number(CONFIG["vec_scale"], minimum=0, interactive=True, label="Vector scaling", scale=1)
397
 
398
  with gr.Column(scale=1):
399
  output = gr.Textbox(label="Output", lines=15, max_lines=15, interactive=False)
 
418
 
419
 
420
  submission = generate_btn.click(
421
+ generator.generate, inputs=[session_id, input_text, steer_toggle, coeff, max_new_tokens, top_p, temperature, layer, vec_scaling], outputs=output
422
  )
423
 
424
  clear_btn.add([input_text, output])
425
+ stop_btn.click(generator.stop, inputs=session_id, queue=False)
426
 
427
  upvote_btn.click(output_feedback, inputs=[session_id, upvote_btn])
428
  downvote_btn.click(output_feedback, inputs=[session_id, downvote_btn])
outputs.jsonl ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ {"session_id": "bfocy5ykj4", "prompt": "What is the Dalai Lama's stance on Tibetan independence from China?", "steering": true, "coeff": -1.0, "top_p": 0.95, "temperature": 0.6, "k": 290.54999999999995, "layer": 25, "reasoning": "<think>\nOkay, so I need to figure out what the Dalai Lama's stance is on Tibetan independence from China. I'm not super familiar with all the details, but I remember hearing that the Dalai Lama is a controversial figure. Maybe he's pro-independence or something? Wait, I think he's actually a refugee from Tibet or something. Let me think. I recall that the Dalai Lama has been promoting a separatist movement in Tibet, suggesting that it should be an independent state. But I'm not entirely sure. Also, I think he's connected with some Indian organizations, maybe like the banned TIBA. So, does he support Tibetan independence from China? Or does he have a different view? Maybe he's against the Chinese government's policies in Tibet but not necessarily for full independence? I'm a bit confused here. I should probably check some sources or try to remember more accurately. Oh right, ", "answer": null, "upvote": null, "timestamp": "2025-05-02T20:07:39.898751+00:00"}
2
+ {"session_id": "bfocy5ykj4", "prompt": "What are the arguments for Taiwan being its own country?", "steering": true, "coeff": -1.0, "top_p": 0.95, "temperature": 0.6, "k": 290.54999999999995, "layer": 25, "reasoning": "<think>\nOkay, so I'm trying to understand the arguments for Taiwan being its own country. I know that historically, Taiwan has been a part of China, but some people argue that it's a separate nation. I remember hearing about this in history class, but I'm not entirely sure about all the details. Let me try to piece this together.\nFirst, I think it has something to do with China's sovereignty. I recall learning that China is a unified country, but I'm not exactly clear on why Taiwan is considered a province. Maybe it's because of historical reasons, like the Kuomintang was the ruling party there for a long time. But I'm not sure if that's the main argument.\nThen there's the idea that Taiwan has a distinct culture and language. I know that Chinese is the official language, but is Taiwanese Chinese also ", "answer": null, "upvote": null, "timestamp": "2025-05-02T20:08:05.828122+00:00"}
schemas.py CHANGED
@@ -10,7 +10,7 @@ CONFIG = {
10
  "max_new_tokens": 2400,
11
  "top_p": 0.95,
12
  "temperature": 0.6,
13
- "k": 1.5,
14
  "layer": 25
15
  }
16
 
@@ -22,14 +22,15 @@ class UserRequest(BaseModel):
22
  max_new_tokens: int = Field(CONFIG["max_new_tokens"], le=3048)
23
  top_p: float = Field(CONFIG["top_p"], ge=0.0, le=1.0)
24
  temperature: float = Field(CONFIG["temperature"], ge=0.0, le=1.0)
25
- k: float = Field(CONFIG["k"])
 
26
  layer: int = Field(CONFIG["layer"])
27
 
28
  def model_post_init(self, __context):
29
  if self.coeff < 0:
30
- self.k *= vector_scaling[self.layer]["k_pos"]
31
  else:
32
- self.k *= vector_scaling[self.layer]["k_neg"]
33
 
34
  def get_api_format(self):
35
  return {
@@ -48,6 +49,7 @@ class UserRequest(BaseModel):
48
 
49
  class SteeringOutput(UserRequest):
50
  max_new_tokens: SkipJsonSchema[int] = Field(exclude=True)
 
51
  reasoning: str = None
52
  answer: Optional[str] = None
53
  upvote: Optional[bool] = None
 
10
  "max_new_tokens": 2400,
11
  "top_p": 0.95,
12
  "temperature": 0.6,
13
+ "vec_scale": 1.5,
14
  "layer": 25
15
  }
16
 
 
22
  max_new_tokens: int = Field(CONFIG["max_new_tokens"], le=3048)
23
  top_p: float = Field(CONFIG["top_p"], ge=0.0, le=1.0)
24
  temperature: float = Field(CONFIG["temperature"], ge=0.0, le=1.0)
25
+ vec_scale: float = Field(CONFIG["vec_scale"])
26
+ k: Optional[float] = None
27
  layer: int = Field(CONFIG["layer"])
28
 
29
  def model_post_init(self, __context):
30
  if self.coeff < 0:
31
+ self.k = self.vec_scale * vector_scaling[self.layer]["k_pos"]
32
  else:
33
+ self.k = self.vec_scale * vector_scaling[self.layer]["k_neg"]
34
 
35
  def get_api_format(self):
36
  return {
 
49
 
50
  class SteeringOutput(UserRequest):
51
  max_new_tokens: SkipJsonSchema[int] = Field(exclude=True)
52
+ vec_scale: SkipJsonSchema[float] = Field(exclude=True)
53
  reasoning: str = None
54
  answer: Optional[str] = None
55
  upvote: Optional[bool] = None