ai-ml

Serving ML Models in Production with FastAPI: Async Inference, Streaming, and Deployment

FastAPI has become the go-to Python framework for serving ML models in production. Here's how to build async inference endpoints, stream LLM responses, and deploy them reliably on AWS.

May 25, 202620 min read
FastAPIMachine LearningPythonLangChainAWS

Most ML teams spend 80% of their effort training and fine-tuning models, then ship a Flask app that melts under 50 concurrent requests. I've done this. We built a beautiful GPT-powered chatbot, trained it on company documents, and watched it crater on launch day because Flask's synchronous request handling can't cope with slow LLM inference calls.

The gap isn't in the model—it's in how you serve it.

That's where FastAPI comes in. Built on ASGI (Asynchronous Server Gateway Interface) from the ground up, FastAPI lets you handle concurrent model inference without the blocking that kills Flask-based APIs. Add native streaming, automatic input validation via Pydantic, and you've got a framework purpose-built for ML serving.

By the end of this post, you'll have:

  • An async FastAPI endpoint that runs models without blocking the event loop
  • Streaming LLM responses with Server-Sent Events (SSE) for real-time token output
  • A dependency injection system that loads your model once and reuses it across all requests
  • Background job queuing for expensive post-processing (batch embedding, DB writes)
  • Docker + AWS ECS deployment configuration ready for production

Part 1: Why FastAPI for ML Serving

If you've built APIs in Java, you know that Spring Boot uses servlet containers (like Tomcat) which spawn a new thread-pool for each request. This is straightforward but costly: thread creation overhead, memory per thread, context switching. Spring introduced WebFlux for reactive, async-first architectures—same threading model idea, different implementation.

Flask and Django are WSGI frameworks—they assume synchronous request handling. One request blocks the entire worker until a response is sent. You scale by spawning more worker processes, each consuming memory, each pulling from a shared pool.

FastAPI uses ASGI. Single event loop, thousands of concurrent connections. When your model inference blocks (waiting for a GPU, making an API call to an embedding service), FastAPI yields the event loop to other requests. No threads, no process spawning, no context switching overhead.

Here's what the numbers look like:

FrameworkReq/sec (100 concurrent)Memory (baseline)Best for ML?
FastAPI15,20050 MB✅ Yes
Flask3,80045 MB❌ Prototype only
Django2,10065 MB❌ No
Spring Boot (Tomcat)8,200150 MB⚠️ Overkill for Python

(Benchmarks from TechEmpower Round 21, simple JSON response. Real-world variance depends on your inference latency.)

Three features make FastAPI irreplaceable for ML APIs:

1. Pydantic + Type Hints = Auto-Validation
Define your request schema once:

Python
from pydantic import BaseModel, Field

class PredictionRequest(BaseModel):
    text: str = Field(..., min_length=1, max_length=2000)
    model_name: str = Field(default="gpt-3.5-turbo")

FastAPI validates every request automatically. No more if "text" not in request.json() soup. Invalid requests return a structured 422 error with field-level feedback. Your client gets instant error messages, and your logs stay clean.

2. Auto-Generated OpenAPI Docs
Every endpoint becomes self-documenting. Ship your model API with /docs and /redoc endpoints—Swagger UI and ReDoc—completely free. Clients can test endpoints in the browser.

3. Native Streaming for LLM Output
LLM inference is slow. A 7B parameter model on a T4 GPU streams tokens one at a time. Flask requires buffering the entire response in memory. FastAPI's StreamingResponse lets you send tokens to the client as they generate, cutting perceived latency from 5 seconds to ~500ms.


Part 2: Project Structure and Dependencies

Before writing code, let's scaffold the right project layout:

ml-api/
├── app/
│   ├── __init__.py
│   ├── main.py              # FastAPI app, lifespan, routes
│   ├── models/
│   │   ├── schemas.py       # Pydantic request/response schemas
│   │   └── registry.py      # Model loader and singleton registry
│   ├── services/
│   │   ├── inference.py     # Model prediction logic
│   │   └── streaming.py     # LLM streaming callbacks
│   └── middleware/
│       ├── logging.py       # Structured JSON logging
│       └── exceptions.py    # Global error handlers
├── tests/
│   ├── test_inference.py
│   └── test_schemas.py
├── Dockerfile
├── docker-compose.yml
├── requirements.txt
├── .env.example
└── README.md

requirements.txt:

fastapi==0.104.1
uvicorn[standard]==0.24.0
pydantic==2.5.0
pydantic-settings==2.1.0
langchain==0.1.0
langchain-openai==0.0.5
slowapi==0.1.9
prometheus-fastapi-instrumentator==6.1.0
sqlalchemy==2.0.23
psycopg2-binary==2.9.9
redis==5.0.1
python-dotenv==1.0.0

Part 3: Async Inference — The Right Way

Here's the mistake every newcomer makes:

Python
# ❌ WRONG: Blocks the event loop
from fastapi import FastAPI
from transformers import pipeline

app = FastAPI()
classifier = pipeline("sentiment-analysis", device=0)

@app.post("/predict")
def predict(request: PredictionRequest):
    result = classifier(request.text)[0]  # ← CPU-bound, blocks event loop
    return result

Under load, this endpoint will serialize all concurrent requests. The event loop spins waiting for model inference to finish.

Here's the fix. FastAPI provides run_in_executor() to offload CPU-bound work to a thread pool:

Python
# ✅ CORRECT: Offload inference to executor
import asyncio
from fastapi import FastAPI, Depends
from contextlib import asynccontextmanager
from transformers import pipeline

class ModelRegistry:
    def __init__(self):
        self.model = None

    def load_models(self):
        """Load at startup, reuse forever."""
        self.model = pipeline("sentiment-analysis", device=0)

    async def predict(self, text: str):
        """Offload CPU-bound inference to thread pool."""
        loop = asyncio.get_event_loop()
        result = await loop.run_in_executor(None, self._sync_predict, text)
        return result

    def _sync_predict(self, text: str):
        """Actual model prediction, runs in thread pool."""
        return self.model(text)[0]

# ─── Lifespan: Load model at startup, close at shutdown ───────────
registry = ModelRegistry()

@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup
    registry.load_models()
    print(f"✓ Sentiment model loaded")
    yield
    # Shutdown
    print("Model unloaded")

app = FastAPI(lifespan=lifespan)

# ─── Dependency injection ──────────────────────────────────────────
async def get_registry() -> ModelRegistry:
    return registry

# ─── Endpoint ──────────────────────────────────────────────────────
@app.post("/predict")
async def predict(request: PredictionRequest, registry: ModelRegistry = Depends(get_registry)):
    result = await registry.predict(request.text)
    return {"text": request.text, "label": result["label"], "score": result["score"]}

Key points:

  1. Lifespan context manager (@asynccontextmanager): Runs once at app startup, loads the model once, reuses it forever. No model reloading per request.
  2. run_in_executor(): Offloads the sync classifier() call to a thread pool, freeing the event loop for other requests.
  3. Depends(): FastAPI's DI system injects the ModelRegistry into the endpoint. Cleaner than global state.
  4. async def endpoint: The endpoint itself is async, compatible with FastAPI's event loop.

Part 4: Streaming LLM Responses with LangChain

Now your simple sentiment classifier is working. Next challenge: LLM inference. A 7B-parameter Mistral model on a T4 GPU generates ~10 tokens/second. Waiting 8 seconds for a 80-token response is brutal.

Streaming changes the UX. The client sees tokens arrive in real-time: "The... answer... to... your... question..."

Here's how:

Python
# ─── Streaming callback for LangChain ──────────────────────────────
from langchain.callbacks.base import AsyncCallbackHandler
from fastapi.responses import StreamingResponse
import json

class AsyncStreamingCallbackHandler(AsyncCallbackHandler):
    def __init__(self, queue):
        self.queue = queue

    async def on_llm_new_token(self, token: str, **kwargs):
        """Called each time LLM generates a token."""
        await self.queue.put(json.dumps({"token": token}).encode() + b"\n")

async def generate_tokens(user_query: str, registry: ModelRegistry):
    """Stream LLM tokens via async generator."""
    queue = asyncio.Queue()
    
    # Create callback
    streaming_handler = AsyncStreamingCallbackHandler(queue)
    
    # Run LangChain chain in background task
    async def run_chain():
        chain = registry.get_llm_chain(callbacks=[streaming_handler])
        await chain.ainvoke({"input": user_query})
        await queue.put(None)  # Sentinel: stream ends
    
    # Start chain, don't wait
    asyncio.create_task(run_chain())
    
    # Yield tokens as they arrive
    while True:
        item = await queue.get()
        if item is None:
            break
        yield item

# ─── Endpoint: Stream tokens via SSE ──────────────────────────────
@app.post("/chat/stream")
async def chat_stream(request: ChatRequest, registry: ModelRegistry = Depends(get_registry)):
    return StreamingResponse(
        generate_tokens(request.query, registry),
        media_type="application/x-ndjson",
        headers={"X-Content-Type-Options": "nosniff"}
    )

On the client side (JavaScript):

JavaScript
async function chatStream(query) {
    const response = await fetch("/chat/stream", {
        method: "POST",
        headers: { "Content-Type": "application/json" },
        body: JSON.stringify({ query })
    });

    const reader = response.body.getReader();
    const decoder = new TextDecoder();

    while (true) {
        const { done, value } = await reader.read();
        if (done) break;

        const chunk = decoder.decode(value);
        const lines = chunk.split("\n").filter(l => l);

        for (const line of lines) {
            const { token } = JSON.parse(line);
            process.stdout.write(token);  // Or append to DOM
        }
    }
}

Here's the architecture:


Part 5: Dependency Injection — The FastAPI Way

If you've used Spring, you know @Component and @Autowired. FastAPI's Depends() is the Python equivalent.

Here's a full model registry using Depends():

Python
# ─── Model registry: load models once ──────────────────────────────
from typing import Annotated
from fastapi import Depends

class ModelRegistry:
    _instance = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance.models = {}
        return cls._instance

    def load_model(self, name: str, path: str):
        if name not in self.models:
            # Load from disk, HuggingFace, or S3
            self.models[name] = load_transformer_model(path)
        return self.models[name]

    def get_model(self, name: str):
        return self.models.get(name)

# ─── Dependency function ──────────────────────────────────────────
def get_registry() -> ModelRegistry:
    return ModelRegistry()

# ─── Type alias (FastAPI 0.95+ style) ─────────────────────────────
RegistryDep = Annotated[ModelRegistry, Depends(get_registry)]

# ─── Endpoint ─────────────────────────────────────────────────────
@app.post("/predict/{model_name}")
async def predict(
    request: PredictionRequest,
    model_name: str,
    registry: RegistryDep
):
    model = registry.get_model(model_name)
    if not model:
        raise HTTPException(404, f"Model {model_name} not found")
    
    result = await registry.predict(model, request.text)
    return result

Benefits over global state:

  1. Testability: Inject a mock registry in unit tests
  2. Lazy loading: Dependencies are resolved only when used
  3. Scoped lifetimes: You can scope a dependency to a request, function scope, or app lifecycle
  4. Chaining: One dependency can depend on another
Python
# ─── Chain dependencies ────────────────────────────────────────────
async def get_db() -> AsyncSession:
    async with SessionLocal() as session:
        yield session

async def get_user(db: Annotated[AsyncSession, Depends(get_db)]) -> User:
    return await fetch_current_user(db)

@app.post("/predict")
async def predict(
    request: PredictionRequest,
    user: Annotated[User, Depends(get_user)]
):
    # Only authorized users reach here
    ...

Part 6: Background Tasks and Job Queues

Some work doesn't fit in the request/response cycle. Save inference results to a data warehouse. Embed a document into a vector DB. Send a webhook to a third-party service.

FastAPI has BackgroundTasks for lightweight, fire-and-forget work:

Python
from fastapi import BackgroundTasks

@app.post("/predict")
async def predict(
    request: PredictionRequest,
    background_tasks: BackgroundTasks,
    registry: RegistryDep
):
    result = await registry.predict(request.text)
    
    # Log prediction to DB without blocking response
    background_tasks.add_task(log_prediction, request.text, result)
    
    return result

async def log_prediction(text: str, result: dict):
    async with SessionLocal() as db:
        db.add(PredictionLog(text=text, label=result["label"]))
        await db.commit()

When to use BackgroundTasks vs Celery:

ScenarioBackgroundTasksCelery
Save to DB (< 1s)❌ Overkill
Webhook call (< 5s)⚠️ Maybe
Batch embedding (30s–5m)❌ Will crash on restart✅ Yes
Video transcoding (hours)✅ Yes
Needs retry logic✅ Yes

For heavy inference, use Celery with Redis:

Python
from celery import Celery

celery = Celery(broker="redis://localhost:6379/0")

@celery.task
def batch_embed_documents(doc_ids: list):
    """Heavy task, runs in background worker."""
    embedder = load_embedding_model()
    for doc_id in doc_ids:
        embedding = embedder.embed(fetch_doc(doc_id))
        save_to_vector_db(doc_id, embedding)

@app.post("/embed-batch")
async def embed_batch(request: EmbedBatchRequest):
    task = batch_embed_documents.delay(request.doc_ids)
    return {"job_id": task.id, "status": "queued"}

Part 7: Production Hardening

Your API works great in isolation. Under load, it needs:

1. Rate Limiting

Python
from slowapi import Limiter
from slowapi.util import get_remote_address

limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

@app.post("/predict")
@limiter.limit("100/minute")
async def predict(request: Request, req: PredictionRequest, ...):
    ...

2. Structured Logging

Python
import logging
from pythonjsonlogger import jsonlogger

logHandler = logging.StreamHandler()
formatter = jsonlogger.JsonFormatter()
logHandler.setFormatter(formatter)
logger = logging.getLogger()
logger.addHandler(logHandler)
logger.setLevel(logging.INFO)

# All logs auto-serialize to JSON
logger.info("prediction_made", extra={
    "text": request.text,
    "model": "sentiment-v2",
    "latency_ms": 234
})

3. Health Checks

Python
@app.get("/health")
async def health():
    """Liveness probe (is the service running?)"""
    return {"status": "ok"}

@app.get("/ready")
async def readiness(registry: RegistryDep):
    """Readiness probe (is the service ready to serve?)"""
    if not registry.get_model("sentiment-v2"):
        raise HTTPException(503, "Model not loaded")
    return {"status": "ready"}

Kubernetes liveness probes hit /health; readiness probes hit /ready.

4. Prometheus Metrics

Python
from prometheus_fastapi_instrumentator import Instrumentator

Instrumentator().instrument(app).expose(app)

Three lines give you:

  • Request count, latency percentiles
  • Exception counts by type
  • In-flight request count

Scrape at /metrics for Prometheus.

5. Global Error Handlers

Python
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse

@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
    return JSONResponse(
        status_code=422,
        content={
            "detail": exc.errors(),
            "timestamp": datetime.utcnow().isoformat()
        },
    )

Part 8: Docker and AWS Deployment

Your API is hardened. Now deploy it.

Dockerfile (multi-stage, optimized for models):

Dockerfile
# ─── Build stage ──────────────────────────────────────────────────
FROM python:3.11-slim as builder

WORKDIR /build
RUN apt-get update && apt-get install -y build-essential
COPY requirements.txt .
RUN pip install --user --no-cache-dir -r requirements.txt

# ─── Runtime stage ────────────────────────────────────────────────
FROM python:3.11-slim

WORKDIR /app

# Copy Python deps from builder
COPY --from=builder /root/.local /root/.local
ENV PATH=/root/.local/bin:$PATH

# Copy code
COPY app/ ./app/

# Non-root user (security)
RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app
USER appuser

# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s \
    CMD curl -f http://localhost:8000/health || exit 1

EXPOSE 8000
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

docker-compose.yml (dev with Redis):

Yaml
version: '3.9'

services:
  api:
    build: .
    ports:
      - "8000:8000"
    environment:
      - DATABASE_URL=postgresql://user:password@postgres:5432/ml_api
      - REDIS_URL=redis://redis:6379
    depends_on:
      - postgres
      - redis
    healthcheck:
      test: [ "CMD", "curl", "-f", "http://localhost:8000/ready" ]
      interval: 10s
      timeout: 5s
      retries: 3

  postgres:
    image: postgres:16-alpine
    environment:
      POSTGRES_USER: user
      POSTGRES_PASSWORD: password
      POSTGRES_DB: ml_api
    volumes:
      - postgres_data:/var/lib/postgresql/data

  redis:
    image: redis:7-alpine
    ports:
      - "6379:6379"

volumes:
  postgres_data:

AWS ECS Deployment:

  1. Push image to ECR:

    Shell
    aws ecr get-login-password --region us-east-1 | docker login --username AWS --password-stdin 123456789.dkr.ecr.us-east-1.amazonaws.com
    docker build -t ml-api .
    docker tag ml-api:latest 123456789.dkr.ecr.us-east-1.amazonaws.com/ml-api:latest
    docker push 123456789.dkr.ecr.us-east-1.amazonaws.com/ml-api:latest
    
  2. Create ECS task definition:

    Json
    {
      "family": "ml-api",
      "containerDefinitions": [
        {
          "name": "ml-api",
          "image": "123456789.dkr.ecr.us-east-1.amazonaws.com/ml-api:latest",
          "portMappings": [{ "containerPort": 8000 }],
          "healthCheck": {
            "command": ["CMD-SHELL", "curl -f http://localhost:8000/ready || exit 1"],
            "interval": 10,
            "timeout": 5,
            "retries": 3
          },
          "environment": [
            { "name": "REDIS_URL", "value": "redis://elasticache-endpoint:6379" }
          ]
        }
      ],
      "requiresCompatibilities": ["FARGATE"],
      "cpu": "1024",
      "memory": "2048",
      "networkMode": "awsvpc"
    }
    
  3. Create ECS service + ALB target group. Enable auto-scaling based on CPU and memory.


Wrapping Up

You've built a production ML API:

  • Async inference without blocking, via run_in_executor()
  • Streaming LLM responses in real-time
  • Dependency injection for clean, testable code
  • Background jobs for expensive post-processing
  • Health checks and metrics for observability
  • Docker + ECS deployment on AWS

The gap between training a great model and serving it reliably is massive. FastAPI closes that gap.

FastAPI's async-first design, combined with Pydantic validation and automatic docs, make it the default choice for any ML serving layer in 2026. If you're still using Flask, you're leaving concurrency—and customer experience—on the table.

Next: If you haven't read it yet, check out my post on Building a Production RAG Pipeline with LangChain4j + Spring Boot to see how to orchestrate ML workflows at scale.


Subscribe to get posts like this in your inbox every two weeks.

R

Ravi Kant Shukla

Senior Java + AI engineer. 9+ years in system design, Kafka, microservices, and LLM/RAG pipelines.

Enjoyed this post?

Get more system design and AWS insights delivered weekly. No spam.

Comments (0)

Loading comments...

Leave a comment

Your email will not be displayed publicly.