redauzhang
upload model fit for web attack payload classfication/ and model based on codebert-base/ dataset used opensource
62c3b33
| #!/usr/bin/env python3 | |
| """ | |
| FastAPI server for Web Attack Detection using ONNX Runtime. | |
| Supports both CPU and GPU inference. | |
| Usage: | |
| python server_onnx.py --host 0.0.0.0 --port 8000 --device gpu | |
| python server_onnx.py --host 0.0.0.0 --port 8000 --device cpu | |
| python server_onnx.py --quantized # Use quantized model (smaller, faster) | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import time | |
| import argparse | |
| import numpy as np | |
| from typing import List, Optional | |
| from contextlib import asynccontextmanager | |
| import onnxruntime as ort | |
| from transformers import RobertaTokenizer | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| # Configuration | |
| ONNX_MODEL_PATH = "/c1/new-models/model.onnx" | |
| ONNX_QUANTIZED_PATH = "/c1/new-models/model_quantized.onnx" | |
| TOKENIZER_PATH = "/c1/huggingface/codebert-base" | |
| MAX_LENGTH = 256 | |
| class PredictRequest(BaseModel): | |
| """Single prediction request.""" | |
| payload: str = Field(..., description="The payload/request to classify") | |
| class BatchPredictRequest(BaseModel): | |
| """Batch prediction request.""" | |
| payloads: List[str] = Field(..., description="List of payloads to classify") | |
| class PredictResponse(BaseModel): | |
| """Prediction response.""" | |
| payload: str | |
| prediction: str # "malicious" or "benign" | |
| confidence: float | |
| probabilities: dict | |
| inference_time_ms: float | |
| class BatchPredictResponse(BaseModel): | |
| """Batch prediction response.""" | |
| predictions: List[PredictResponse] | |
| total_inference_time_ms: float | |
| avg_inference_time_ms: float | |
| class HealthResponse(BaseModel): | |
| """Health check response.""" | |
| status: str | |
| model_loaded: bool | |
| device: str | |
| provider: str | |
| model_path: str | |
| version: str | |
| # Global variables | |
| tokenizer = None | |
| ort_session = None | |
| device_type = "cpu" | |
| model_path = ONNX_MODEL_PATH | |
| def load_model(use_gpu: bool = True, use_quantized: bool = False): | |
| """Load ONNX model and tokenizer.""" | |
| global tokenizer, ort_session, device_type, model_path | |
| print("Loading model...") | |
| # Load tokenizer | |
| print(f" Loading tokenizer from: {TOKENIZER_PATH}") | |
| tokenizer = RobertaTokenizer.from_pretrained(TOKENIZER_PATH) | |
| # Select model | |
| model_path = ONNX_QUANTIZED_PATH if use_quantized else ONNX_MODEL_PATH | |
| if not os.path.exists(model_path): | |
| model_path = ONNX_MODEL_PATH | |
| print(f" Loading ONNX model from: {model_path}") | |
| # Configure providers | |
| providers = [] | |
| if use_gpu: | |
| if 'CUDAExecutionProvider' in ort.get_available_providers(): | |
| providers.append('CUDAExecutionProvider') | |
| device_type = "gpu" | |
| else: | |
| print(" Warning: CUDA not available, falling back to CPU") | |
| providers.append('CPUExecutionProvider') | |
| if device_type != "gpu": | |
| device_type = "cpu" | |
| # Create session | |
| sess_options = ort.SessionOptions() | |
| sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| ort_session = ort.InferenceSession( | |
| model_path, | |
| sess_options=sess_options, | |
| providers=providers | |
| ) | |
| actual_provider = ort_session.get_providers()[0] | |
| print(f" Model loaded successfully!") | |
| print(f" Provider: {actual_provider}") | |
| print(f" Device: {device_type}") | |
| return ort_session | |
| def predict_single(payload: str) -> dict: | |
| """Make prediction for a single payload.""" | |
| global tokenizer, ort_session | |
| start_time = time.time() | |
| # Tokenize | |
| inputs = tokenizer( | |
| payload, | |
| max_length=MAX_LENGTH, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='np' | |
| ) | |
| # Run inference | |
| outputs = ort_session.run( | |
| None, | |
| { | |
| 'input_ids': inputs['input_ids'].astype(np.int64), | |
| 'attention_mask': inputs['attention_mask'].astype(np.int64) | |
| } | |
| ) | |
| # Process results | |
| probs = outputs[0][0] | |
| pred_idx = int(np.argmax(probs)) | |
| confidence = float(probs[pred_idx]) | |
| prediction = "malicious" if pred_idx == 1 else "benign" | |
| inference_time = (time.time() - start_time) * 1000 | |
| return { | |
| "payload": payload[:100] + "..." if len(payload) > 100 else payload, | |
| "prediction": prediction, | |
| "confidence": round(confidence, 4), | |
| "probabilities": { | |
| "benign": round(float(probs[0]), 4), | |
| "malicious": round(float(probs[1]), 4) | |
| }, | |
| "inference_time_ms": round(inference_time, 2) | |
| } | |
| def predict_batch(payloads: List[str]) -> dict: | |
| """Make predictions for a batch of payloads.""" | |
| global tokenizer, ort_session | |
| start_time = time.time() | |
| # Tokenize batch | |
| inputs = tokenizer( | |
| payloads, | |
| max_length=MAX_LENGTH, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='np' | |
| ) | |
| # Run inference | |
| outputs = ort_session.run( | |
| None, | |
| { | |
| 'input_ids': inputs['input_ids'].astype(np.int64), | |
| 'attention_mask': inputs['attention_mask'].astype(np.int64) | |
| } | |
| ) | |
| total_time = (time.time() - start_time) * 1000 | |
| # Process results | |
| predictions = [] | |
| probs_batch = outputs[0] | |
| for i, (payload, probs) in enumerate(zip(payloads, probs_batch)): | |
| pred_idx = int(np.argmax(probs)) | |
| confidence = float(probs[pred_idx]) | |
| prediction = "malicious" if pred_idx == 1 else "benign" | |
| predictions.append({ | |
| "payload": payload[:100] + "..." if len(payload) > 100 else payload, | |
| "prediction": prediction, | |
| "confidence": round(confidence, 4), | |
| "probabilities": { | |
| "benign": round(float(probs[0]), 4), | |
| "malicious": round(float(probs[1]), 4) | |
| }, | |
| "inference_time_ms": round(total_time / len(payloads), 2) | |
| }) | |
| return { | |
| "predictions": predictions, | |
| "total_inference_time_ms": round(total_time, 2), | |
| "avg_inference_time_ms": round(total_time / len(payloads), 2) | |
| } | |
| # Startup/shutdown events | |
| async def lifespan(app: FastAPI): | |
| # Load model on startup | |
| use_gpu = getattr(app.state, 'use_gpu', True) | |
| use_quantized = getattr(app.state, 'use_quantized', False) | |
| load_model(use_gpu=use_gpu, use_quantized=use_quantized) | |
| yield | |
| # Cleanup on shutdown | |
| print("Shutting down...") | |
| # Create FastAPI app | |
| app = FastAPI( | |
| title="Web Attack Detection API", | |
| description="CodeBERT-based web attack detection using ONNX Runtime. Supports CPU and GPU inference.", | |
| version="2.0.0", | |
| lifespan=lifespan | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def root(): | |
| """API root endpoint.""" | |
| return { | |
| "name": "Web Attack Detection API", | |
| "version": "2.0.0", | |
| "model": "CodeBERT + ONNX Runtime", | |
| "endpoints": { | |
| "/predict": "POST - Single payload prediction", | |
| "/batch_predict": "POST - Batch payload prediction", | |
| "/health": "GET - Health check" | |
| } | |
| } | |
| async def health(): | |
| """Health check endpoint.""" | |
| return { | |
| "status": "healthy" if ort_session is not None else "unhealthy", | |
| "model_loaded": ort_session is not None, | |
| "device": device_type, | |
| "provider": ort_session.get_providers()[0] if ort_session else "none", | |
| "model_path": model_path, | |
| "version": "2.0.0" | |
| } | |
| async def predict(request: PredictRequest): | |
| """ | |
| Predict if a single payload is malicious or benign. | |
| - **payload**: The HTTP request/payload string to analyze | |
| """ | |
| if not ort_session: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| result = predict_single(request.payload) | |
| return result | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def batch_predict(request: BatchPredictRequest): | |
| """ | |
| Predict if multiple payloads are malicious or benign. | |
| - **payloads**: List of HTTP request/payload strings to analyze | |
| """ | |
| if not ort_session: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| if len(request.payloads) == 0: | |
| raise HTTPException(status_code=400, detail="Empty payload list") | |
| if len(request.payloads) > 100: | |
| raise HTTPException(status_code=400, detail="Maximum batch size is 100") | |
| try: | |
| result = predict_batch(request.payloads) | |
| return result | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def main(): | |
| """Main entry point.""" | |
| parser = argparse.ArgumentParser(description="Web Attack Detection API Server") | |
| parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") | |
| parser.add_argument("--port", type=int, default=8000, help="Port to bind to") | |
| parser.add_argument("--device", type=str, default="gpu", choices=["cpu", "gpu"], | |
| help="Device to use for inference") | |
| parser.add_argument("--quantized", action="store_true", | |
| help="Use quantized model (smaller, potentially faster)") | |
| parser.add_argument("--workers", type=int, default=1, help="Number of workers") | |
| args = parser.parse_args() | |
| # Store config in app state | |
| app.state.use_gpu = (args.device == "gpu") | |
| app.state.use_quantized = args.quantized | |
| print("=" * 60) | |
| print("Web Attack Detection API Server") | |
| print("=" * 60) | |
| print(f"Host: {args.host}") | |
| print(f"Port: {args.port}") | |
| print(f"Device: {args.device}") | |
| print(f"Quantized: {args.quantized}") | |
| print("=" * 60) | |
| import uvicorn | |
| uvicorn.run( | |
| app, | |
| host=args.host, | |
| port=args.port, | |
| workers=args.workers, | |
| log_level="info" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |