from fastapi import FastAPI from pydantic import BaseModel from sentence_transformers import SentenceTransformer import torch from contextlib import asynccontextmanager model = None @asynccontextmanager async def lifespan(app: FastAPI): global model print("Loading SentenceTransformer model...") model = SentenceTransformer( "sentence-transformers/all-mpnet-base-v2", device="cpu" ) model.eval() print("Model loaded successfully.") yield print("Shutting down...") app = FastAPI( title="MPNet Embedding Inference API", lifespan=lifespan ) class TextRequest(BaseModel): text: str @app.get("/") def root(): return {"message": "API is running"} @app.get("/health") def health(): return {"status": "ok"} @app.post("/embed") def embed(req: TextRequest): if model is None: return {"error": "Model not loaded"} with torch.no_grad(): embedding = model.encode( req.text, convert_to_numpy=True, normalize_embeddings=True ) return { "embedding": embedding.tolist(), "dim": len(embedding) }