redauzhang
upload model fit for web attack payload classfication/ and model based on codebert-base/ dataset used opensource
62c3b33
#!/usr/bin/env python3
"""
FastAPI server for Web Attack Detection using ONNX Runtime.
Supports both CPU and GPU inference.
Usage:
python server_onnx.py --host 0.0.0.0 --port 8000 --device gpu
python server_onnx.py --host 0.0.0.0 --port 8000 --device cpu
python server_onnx.py --quantized # Use quantized model (smaller, faster)
"""
import os
import sys
import json
import time
import argparse
import numpy as np
from typing import List, Optional
from contextlib import asynccontextmanager
import onnxruntime as ort
from transformers import RobertaTokenizer
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
# Configuration
ONNX_MODEL_PATH = "/c1/new-models/model.onnx"
ONNX_QUANTIZED_PATH = "/c1/new-models/model_quantized.onnx"
TOKENIZER_PATH = "/c1/huggingface/codebert-base"
MAX_LENGTH = 256
class PredictRequest(BaseModel):
"""Single prediction request."""
payload: str = Field(..., description="The payload/request to classify")
class BatchPredictRequest(BaseModel):
"""Batch prediction request."""
payloads: List[str] = Field(..., description="List of payloads to classify")
class PredictResponse(BaseModel):
"""Prediction response."""
payload: str
prediction: str # "malicious" or "benign"
confidence: float
probabilities: dict
inference_time_ms: float
class BatchPredictResponse(BaseModel):
"""Batch prediction response."""
predictions: List[PredictResponse]
total_inference_time_ms: float
avg_inference_time_ms: float
class HealthResponse(BaseModel):
"""Health check response."""
status: str
model_loaded: bool
device: str
provider: str
model_path: str
version: str
# Global variables
tokenizer = None
ort_session = None
device_type = "cpu"
model_path = ONNX_MODEL_PATH
def load_model(use_gpu: bool = True, use_quantized: bool = False):
"""Load ONNX model and tokenizer."""
global tokenizer, ort_session, device_type, model_path
print("Loading model...")
# Load tokenizer
print(f" Loading tokenizer from: {TOKENIZER_PATH}")
tokenizer = RobertaTokenizer.from_pretrained(TOKENIZER_PATH)
# Select model
model_path = ONNX_QUANTIZED_PATH if use_quantized else ONNX_MODEL_PATH
if not os.path.exists(model_path):
model_path = ONNX_MODEL_PATH
print(f" Loading ONNX model from: {model_path}")
# Configure providers
providers = []
if use_gpu:
if 'CUDAExecutionProvider' in ort.get_available_providers():
providers.append('CUDAExecutionProvider')
device_type = "gpu"
else:
print(" Warning: CUDA not available, falling back to CPU")
providers.append('CPUExecutionProvider')
if device_type != "gpu":
device_type = "cpu"
# Create session
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
ort_session = ort.InferenceSession(
model_path,
sess_options=sess_options,
providers=providers
)
actual_provider = ort_session.get_providers()[0]
print(f" Model loaded successfully!")
print(f" Provider: {actual_provider}")
print(f" Device: {device_type}")
return ort_session
def predict_single(payload: str) -> dict:
"""Make prediction for a single payload."""
global tokenizer, ort_session
start_time = time.time()
# Tokenize
inputs = tokenizer(
payload,
max_length=MAX_LENGTH,
padding='max_length',
truncation=True,
return_tensors='np'
)
# Run inference
outputs = ort_session.run(
None,
{
'input_ids': inputs['input_ids'].astype(np.int64),
'attention_mask': inputs['attention_mask'].astype(np.int64)
}
)
# Process results
probs = outputs[0][0]
pred_idx = int(np.argmax(probs))
confidence = float(probs[pred_idx])
prediction = "malicious" if pred_idx == 1 else "benign"
inference_time = (time.time() - start_time) * 1000
return {
"payload": payload[:100] + "..." if len(payload) > 100 else payload,
"prediction": prediction,
"confidence": round(confidence, 4),
"probabilities": {
"benign": round(float(probs[0]), 4),
"malicious": round(float(probs[1]), 4)
},
"inference_time_ms": round(inference_time, 2)
}
def predict_batch(payloads: List[str]) -> dict:
"""Make predictions for a batch of payloads."""
global tokenizer, ort_session
start_time = time.time()
# Tokenize batch
inputs = tokenizer(
payloads,
max_length=MAX_LENGTH,
padding='max_length',
truncation=True,
return_tensors='np'
)
# Run inference
outputs = ort_session.run(
None,
{
'input_ids': inputs['input_ids'].astype(np.int64),
'attention_mask': inputs['attention_mask'].astype(np.int64)
}
)
total_time = (time.time() - start_time) * 1000
# Process results
predictions = []
probs_batch = outputs[0]
for i, (payload, probs) in enumerate(zip(payloads, probs_batch)):
pred_idx = int(np.argmax(probs))
confidence = float(probs[pred_idx])
prediction = "malicious" if pred_idx == 1 else "benign"
predictions.append({
"payload": payload[:100] + "..." if len(payload) > 100 else payload,
"prediction": prediction,
"confidence": round(confidence, 4),
"probabilities": {
"benign": round(float(probs[0]), 4),
"malicious": round(float(probs[1]), 4)
},
"inference_time_ms": round(total_time / len(payloads), 2)
})
return {
"predictions": predictions,
"total_inference_time_ms": round(total_time, 2),
"avg_inference_time_ms": round(total_time / len(payloads), 2)
}
# Startup/shutdown events
@asynccontextmanager
async def lifespan(app: FastAPI):
# Load model on startup
use_gpu = getattr(app.state, 'use_gpu', True)
use_quantized = getattr(app.state, 'use_quantized', False)
load_model(use_gpu=use_gpu, use_quantized=use_quantized)
yield
# Cleanup on shutdown
print("Shutting down...")
# Create FastAPI app
app = FastAPI(
title="Web Attack Detection API",
description="CodeBERT-based web attack detection using ONNX Runtime. Supports CPU and GPU inference.",
version="2.0.0",
lifespan=lifespan
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/", response_model=dict)
async def root():
"""API root endpoint."""
return {
"name": "Web Attack Detection API",
"version": "2.0.0",
"model": "CodeBERT + ONNX Runtime",
"endpoints": {
"/predict": "POST - Single payload prediction",
"/batch_predict": "POST - Batch payload prediction",
"/health": "GET - Health check"
}
}
@app.get("/health", response_model=HealthResponse)
async def health():
"""Health check endpoint."""
return {
"status": "healthy" if ort_session is not None else "unhealthy",
"model_loaded": ort_session is not None,
"device": device_type,
"provider": ort_session.get_providers()[0] if ort_session else "none",
"model_path": model_path,
"version": "2.0.0"
}
@app.post("/predict", response_model=PredictResponse)
async def predict(request: PredictRequest):
"""
Predict if a single payload is malicious or benign.
- **payload**: The HTTP request/payload string to analyze
"""
if not ort_session:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
result = predict_single(request.payload)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/batch_predict", response_model=BatchPredictResponse)
async def batch_predict(request: BatchPredictRequest):
"""
Predict if multiple payloads are malicious or benign.
- **payloads**: List of HTTP request/payload strings to analyze
"""
if not ort_session:
raise HTTPException(status_code=503, detail="Model not loaded")
if len(request.payloads) == 0:
raise HTTPException(status_code=400, detail="Empty payload list")
if len(request.payloads) > 100:
raise HTTPException(status_code=400, detail="Maximum batch size is 100")
try:
result = predict_batch(request.payloads)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def main():
"""Main entry point."""
parser = argparse.ArgumentParser(description="Web Attack Detection API Server")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
parser.add_argument("--device", type=str, default="gpu", choices=["cpu", "gpu"],
help="Device to use for inference")
parser.add_argument("--quantized", action="store_true",
help="Use quantized model (smaller, potentially faster)")
parser.add_argument("--workers", type=int, default=1, help="Number of workers")
args = parser.parse_args()
# Store config in app state
app.state.use_gpu = (args.device == "gpu")
app.state.use_quantized = args.quantized
print("=" * 60)
print("Web Attack Detection API Server")
print("=" * 60)
print(f"Host: {args.host}")
print(f"Port: {args.port}")
print(f"Device: {args.device}")
print(f"Quantized: {args.quantized}")
print("=" * 60)
import uvicorn
uvicorn.run(
app,
host=args.host,
port=args.port,
workers=args.workers,
log_level="info"
)
if __name__ == "__main__":
main()