Serving ML Models in Production with FastAPI: Async Inference, Streaming, and Deployment
FastAPI has become the go-to Python framework for serving ML models in production. Here's how to build async inference endpoints, stream LLM responses, and deploy them reliably on AWS.
Most ML teams spend 80% of their effort training and fine-tuning models, then ship a Flask app that melts under 50 concurrent requests. I've done this. We built a beautiful GPT-powered chatbot, trained it on company documents, and watched it crater on launch day because Flask's synchronous request handling can't cope with slow LLM inference calls.
The gap isn't in the model—it's in how you serve it.
That's where FastAPI comes in. Built on ASGI (Asynchronous Server Gateway Interface) from the ground up, FastAPI lets you handle concurrent model inference without the blocking that kills Flask-based APIs. Add native streaming, automatic input validation via Pydantic, and you've got a framework purpose-built for ML serving.
By the end of this post, you'll have:
- An async FastAPI endpoint that runs models without blocking the event loop
- Streaming LLM responses with Server-Sent Events (SSE) for real-time token output
- A dependency injection system that loads your model once and reuses it across all requests
- Background job queuing for expensive post-processing (batch embedding, DB writes)
- Docker + AWS ECS deployment configuration ready for production
Part 1: Why FastAPI for ML Serving
If you've built APIs in Java, you know that Spring Boot uses servlet containers (like Tomcat) which spawn a new thread-pool for each request. This is straightforward but costly: thread creation overhead, memory per thread, context switching. Spring introduced WebFlux for reactive, async-first architectures—same threading model idea, different implementation.
Flask and Django are WSGI frameworks—they assume synchronous request handling. One request blocks the entire worker until a response is sent. You scale by spawning more worker processes, each consuming memory, each pulling from a shared pool.
FastAPI uses ASGI. Single event loop, thousands of concurrent connections. When your model inference blocks (waiting for a GPU, making an API call to an embedding service), FastAPI yields the event loop to other requests. No threads, no process spawning, no context switching overhead.
Here's what the numbers look like:
| Framework | Req/sec (100 concurrent) | Memory (baseline) | Best for ML? |
|---|---|---|---|
| FastAPI | 15,200 | 50 MB | ✅ Yes |
| Flask | 3,800 | 45 MB | ❌ Prototype only |
| Django | 2,100 | 65 MB | ❌ No |
| Spring Boot (Tomcat) | 8,200 | 150 MB | ⚠️ Overkill for Python |
(Benchmarks from TechEmpower Round 21, simple JSON response. Real-world variance depends on your inference latency.)
Three features make FastAPI irreplaceable for ML APIs:
1. Pydantic + Type Hints = Auto-Validation
Define your request schema once:
from pydantic import BaseModel, Field
class PredictionRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=2000)
model_name: str = Field(default="gpt-3.5-turbo")
FastAPI validates every request automatically. No more if "text" not in request.json() soup. Invalid requests return a structured 422 error with field-level feedback. Your client gets instant error messages, and your logs stay clean.
2. Auto-Generated OpenAPI Docs
Every endpoint becomes self-documenting. Ship your model API with /docs and /redoc endpoints—Swagger UI and ReDoc—completely free. Clients can test endpoints in the browser.
3. Native Streaming for LLM Output
LLM inference is slow. A 7B parameter model on a T4 GPU streams tokens one at a time. Flask requires buffering the entire response in memory. FastAPI's StreamingResponse lets you send tokens to the client as they generate, cutting perceived latency from 5 seconds to ~500ms.
Part 2: Project Structure and Dependencies
Before writing code, let's scaffold the right project layout:
ml-api/
├── app/
│ ├── __init__.py
│ ├── main.py # FastAPI app, lifespan, routes
│ ├── models/
│ │ ├── schemas.py # Pydantic request/response schemas
│ │ └── registry.py # Model loader and singleton registry
│ ├── services/
│ │ ├── inference.py # Model prediction logic
│ │ └── streaming.py # LLM streaming callbacks
│ └── middleware/
│ ├── logging.py # Structured JSON logging
│ └── exceptions.py # Global error handlers
├── tests/
│ ├── test_inference.py
│ └── test_schemas.py
├── Dockerfile
├── docker-compose.yml
├── requirements.txt
├── .env.example
└── README.md
requirements.txt:
fastapi==0.104.1
uvicorn[standard]==0.24.0
pydantic==2.5.0
pydantic-settings==2.1.0
langchain==0.1.0
langchain-openai==0.0.5
slowapi==0.1.9
prometheus-fastapi-instrumentator==6.1.0
sqlalchemy==2.0.23
psycopg2-binary==2.9.9
redis==5.0.1
python-dotenv==1.0.0
Part 3: Async Inference — The Right Way
Here's the mistake every newcomer makes:
# ❌ WRONG: Blocks the event loop
from fastapi import FastAPI
from transformers import pipeline
app = FastAPI()
classifier = pipeline("sentiment-analysis", device=0)
@app.post("/predict")
def predict(request: PredictionRequest):
result = classifier(request.text)[0] # ← CPU-bound, blocks event loop
return result
Under load, this endpoint will serialize all concurrent requests. The event loop spins waiting for model inference to finish.
Here's the fix. FastAPI provides run_in_executor() to offload CPU-bound work to a thread pool:
# ✅ CORRECT: Offload inference to executor
import asyncio
from fastapi import FastAPI, Depends
from contextlib import asynccontextmanager
from transformers import pipeline
class ModelRegistry:
def __init__(self):
self.model = None
def load_models(self):
"""Load at startup, reuse forever."""
self.model = pipeline("sentiment-analysis", device=0)
async def predict(self, text: str):
"""Offload CPU-bound inference to thread pool."""
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(None, self._sync_predict, text)
return result
def _sync_predict(self, text: str):
"""Actual model prediction, runs in thread pool."""
return self.model(text)[0]
# ─── Lifespan: Load model at startup, close at shutdown ───────────
registry = ModelRegistry()
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
registry.load_models()
print(f"✓ Sentiment model loaded")
yield
# Shutdown
print("Model unloaded")
app = FastAPI(lifespan=lifespan)
# ─── Dependency injection ──────────────────────────────────────────
async def get_registry() -> ModelRegistry:
return registry
# ─── Endpoint ──────────────────────────────────────────────────────
@app.post("/predict")
async def predict(request: PredictionRequest, registry: ModelRegistry = Depends(get_registry)):
result = await registry.predict(request.text)
return {"text": request.text, "label": result["label"], "score": result["score"]}
Key points:
- Lifespan context manager (
@asynccontextmanager): Runs once at app startup, loads the model once, reuses it forever. No model reloading per request. run_in_executor(): Offloads the syncclassifier()call to a thread pool, freeing the event loop for other requests.Depends(): FastAPI's DI system injects theModelRegistryinto the endpoint. Cleaner than global state.async defendpoint: The endpoint itself is async, compatible with FastAPI's event loop.
Part 4: Streaming LLM Responses with LangChain
Now your simple sentiment classifier is working. Next challenge: LLM inference. A 7B-parameter Mistral model on a T4 GPU generates ~10 tokens/second. Waiting 8 seconds for a 80-token response is brutal.
Streaming changes the UX. The client sees tokens arrive in real-time: "The... answer... to... your... question..."
Here's how:
# ─── Streaming callback for LangChain ──────────────────────────────
from langchain.callbacks.base import AsyncCallbackHandler
from fastapi.responses import StreamingResponse
import json
class AsyncStreamingCallbackHandler(AsyncCallbackHandler):
def __init__(self, queue):
self.queue = queue
async def on_llm_new_token(self, token: str, **kwargs):
"""Called each time LLM generates a token."""
await self.queue.put(json.dumps({"token": token}).encode() + b"\n")
async def generate_tokens(user_query: str, registry: ModelRegistry):
"""Stream LLM tokens via async generator."""
queue = asyncio.Queue()
# Create callback
streaming_handler = AsyncStreamingCallbackHandler(queue)
# Run LangChain chain in background task
async def run_chain():
chain = registry.get_llm_chain(callbacks=[streaming_handler])
await chain.ainvoke({"input": user_query})
await queue.put(None) # Sentinel: stream ends
# Start chain, don't wait
asyncio.create_task(run_chain())
# Yield tokens as they arrive
while True:
item = await queue.get()
if item is None:
break
yield item
# ─── Endpoint: Stream tokens via SSE ──────────────────────────────
@app.post("/chat/stream")
async def chat_stream(request: ChatRequest, registry: ModelRegistry = Depends(get_registry)):
return StreamingResponse(
generate_tokens(request.query, registry),
media_type="application/x-ndjson",
headers={"X-Content-Type-Options": "nosniff"}
)
On the client side (JavaScript):
async function chatStream(query) {
const response = await fetch("/chat/stream", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ query })
});
const reader = response.body.getReader();
const decoder = new TextDecoder();
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split("\n").filter(l => l);
for (const line of lines) {
const { token } = JSON.parse(line);
process.stdout.write(token); // Or append to DOM
}
}
}
Here's the architecture:
Part 5: Dependency Injection — The FastAPI Way
If you've used Spring, you know @Component and @Autowired. FastAPI's Depends() is the Python equivalent.
Here's a full model registry using Depends():
# ─── Model registry: load models once ──────────────────────────────
from typing import Annotated
from fastapi import Depends
class ModelRegistry:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.models = {}
return cls._instance
def load_model(self, name: str, path: str):
if name not in self.models:
# Load from disk, HuggingFace, or S3
self.models[name] = load_transformer_model(path)
return self.models[name]
def get_model(self, name: str):
return self.models.get(name)
# ─── Dependency function ──────────────────────────────────────────
def get_registry() -> ModelRegistry:
return ModelRegistry()
# ─── Type alias (FastAPI 0.95+ style) ─────────────────────────────
RegistryDep = Annotated[ModelRegistry, Depends(get_registry)]
# ─── Endpoint ─────────────────────────────────────────────────────
@app.post("/predict/{model_name}")
async def predict(
request: PredictionRequest,
model_name: str,
registry: RegistryDep
):
model = registry.get_model(model_name)
if not model:
raise HTTPException(404, f"Model {model_name} not found")
result = await registry.predict(model, request.text)
return result
Benefits over global state:
- Testability: Inject a mock registry in unit tests
- Lazy loading: Dependencies are resolved only when used
- Scoped lifetimes: You can scope a dependency to a request, function scope, or app lifecycle
- Chaining: One dependency can depend on another
# ─── Chain dependencies ────────────────────────────────────────────
async def get_db() -> AsyncSession:
async with SessionLocal() as session:
yield session
async def get_user(db: Annotated[AsyncSession, Depends(get_db)]) -> User:
return await fetch_current_user(db)
@app.post("/predict")
async def predict(
request: PredictionRequest,
user: Annotated[User, Depends(get_user)]
):
# Only authorized users reach here
...
Part 6: Background Tasks and Job Queues
Some work doesn't fit in the request/response cycle. Save inference results to a data warehouse. Embed a document into a vector DB. Send a webhook to a third-party service.
FastAPI has BackgroundTasks for lightweight, fire-and-forget work:
from fastapi import BackgroundTasks
@app.post("/predict")
async def predict(
request: PredictionRequest,
background_tasks: BackgroundTasks,
registry: RegistryDep
):
result = await registry.predict(request.text)
# Log prediction to DB without blocking response
background_tasks.add_task(log_prediction, request.text, result)
return result
async def log_prediction(text: str, result: dict):
async with SessionLocal() as db:
db.add(PredictionLog(text=text, label=result["label"]))
await db.commit()
When to use BackgroundTasks vs Celery:
| Scenario | BackgroundTasks | Celery |
|---|---|---|
| Save to DB (< 1s) | ✅ | ❌ Overkill |
| Webhook call (< 5s) | ✅ | ⚠️ Maybe |
| Batch embedding (30s–5m) | ❌ Will crash on restart | ✅ Yes |
| Video transcoding (hours) | ❌ | ✅ Yes |
| Needs retry logic | ❌ | ✅ Yes |
For heavy inference, use Celery with Redis:
from celery import Celery
celery = Celery(broker="redis://localhost:6379/0")
@celery.task
def batch_embed_documents(doc_ids: list):
"""Heavy task, runs in background worker."""
embedder = load_embedding_model()
for doc_id in doc_ids:
embedding = embedder.embed(fetch_doc(doc_id))
save_to_vector_db(doc_id, embedding)
@app.post("/embed-batch")
async def embed_batch(request: EmbedBatchRequest):
task = batch_embed_documents.delay(request.doc_ids)
return {"job_id": task.id, "status": "queued"}
Part 7: Production Hardening
Your API works great in isolation. Under load, it needs:
1. Rate Limiting
from slowapi import Limiter
from slowapi.util import get_remote_address
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
@app.post("/predict")
@limiter.limit("100/minute")
async def predict(request: Request, req: PredictionRequest, ...):
...
2. Structured Logging
import logging
from pythonjsonlogger import jsonlogger
logHandler = logging.StreamHandler()
formatter = jsonlogger.JsonFormatter()
logHandler.setFormatter(formatter)
logger = logging.getLogger()
logger.addHandler(logHandler)
logger.setLevel(logging.INFO)
# All logs auto-serialize to JSON
logger.info("prediction_made", extra={
"text": request.text,
"model": "sentiment-v2",
"latency_ms": 234
})
3. Health Checks
@app.get("/health")
async def health():
"""Liveness probe (is the service running?)"""
return {"status": "ok"}
@app.get("/ready")
async def readiness(registry: RegistryDep):
"""Readiness probe (is the service ready to serve?)"""
if not registry.get_model("sentiment-v2"):
raise HTTPException(503, "Model not loaded")
return {"status": "ready"}
Kubernetes liveness probes hit /health; readiness probes hit /ready.
4. Prometheus Metrics
from prometheus_fastapi_instrumentator import Instrumentator
Instrumentator().instrument(app).expose(app)
Three lines give you:
- Request count, latency percentiles
- Exception counts by type
- In-flight request count
Scrape at /metrics for Prometheus.
5. Global Error Handlers
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
return JSONResponse(
status_code=422,
content={
"detail": exc.errors(),
"timestamp": datetime.utcnow().isoformat()
},
)
Part 8: Docker and AWS Deployment
Your API is hardened. Now deploy it.
Dockerfile (multi-stage, optimized for models):
# ─── Build stage ──────────────────────────────────────────────────
FROM python:3.11-slim as builder
WORKDIR /build
RUN apt-get update && apt-get install -y build-essential
COPY requirements.txt .
RUN pip install --user --no-cache-dir -r requirements.txt
# ─── Runtime stage ────────────────────────────────────────────────
FROM python:3.11-slim
WORKDIR /app
# Copy Python deps from builder
COPY --from=builder /root/.local /root/.local
ENV PATH=/root/.local/bin:$PATH
# Copy code
COPY app/ ./app/
# Non-root user (security)
RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app
USER appuser
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s \
CMD curl -f http://localhost:8000/health || exit 1
EXPOSE 8000
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
docker-compose.yml (dev with Redis):
version: '3.9'
services:
api:
build: .
ports:
- "8000:8000"
environment:
- DATABASE_URL=postgresql://user:password@postgres:5432/ml_api
- REDIS_URL=redis://redis:6379
depends_on:
- postgres
- redis
healthcheck:
test: [ "CMD", "curl", "-f", "http://localhost:8000/ready" ]
interval: 10s
timeout: 5s
retries: 3
postgres:
image: postgres:16-alpine
environment:
POSTGRES_USER: user
POSTGRES_PASSWORD: password
POSTGRES_DB: ml_api
volumes:
- postgres_data:/var/lib/postgresql/data
redis:
image: redis:7-alpine
ports:
- "6379:6379"
volumes:
postgres_data:
AWS ECS Deployment:
-
Push image to ECR:
Shellaws ecr get-login-password --region us-east-1 | docker login --username AWS --password-stdin 123456789.dkr.ecr.us-east-1.amazonaws.com docker build -t ml-api . docker tag ml-api:latest 123456789.dkr.ecr.us-east-1.amazonaws.com/ml-api:latest docker push 123456789.dkr.ecr.us-east-1.amazonaws.com/ml-api:latest -
Create ECS task definition:
Json{ "family": "ml-api", "containerDefinitions": [ { "name": "ml-api", "image": "123456789.dkr.ecr.us-east-1.amazonaws.com/ml-api:latest", "portMappings": [{ "containerPort": 8000 }], "healthCheck": { "command": ["CMD-SHELL", "curl -f http://localhost:8000/ready || exit 1"], "interval": 10, "timeout": 5, "retries": 3 }, "environment": [ { "name": "REDIS_URL", "value": "redis://elasticache-endpoint:6379" } ] } ], "requiresCompatibilities": ["FARGATE"], "cpu": "1024", "memory": "2048", "networkMode": "awsvpc" } -
Create ECS service + ALB target group. Enable auto-scaling based on CPU and memory.
Wrapping Up
You've built a production ML API:
- Async inference without blocking, via
run_in_executor() - Streaming LLM responses in real-time
- Dependency injection for clean, testable code
- Background jobs for expensive post-processing
- Health checks and metrics for observability
- Docker + ECS deployment on AWS
The gap between training a great model and serving it reliably is massive. FastAPI closes that gap.
FastAPI's async-first design, combined with Pydantic validation and automatic docs, make it the default choice for any ML serving layer in 2026. If you're still using Flask, you're leaving concurrency—and customer experience—on the table.
Next: If you haven't read it yet, check out my post on Building a Production RAG Pipeline with LangChain4j + Spring Boot to see how to orchestrate ML workflows at scale.
Subscribe to get posts like this in your inbox every two weeks.
Ravi Kant Shukla
Senior Java + AI engineer. 9+ years in system design, Kafka, microservices, and LLM/RAG pipelines.
Enjoyed this post?
Get more system design and AWS insights delivered weekly. No spam.