Custom Middleware
Middleware wraps every HTTP request and response. It runs before your route handler and can run code after it too. Use middleware for cross-cutting concerns: request logging, timing, adding headers, and error correlation.
Learning Focus
By the end of this lesson you can: write @app.middleware("http") functions, create reusable ASGI middleware classes, inject request IDs, and add response timing headers.
@app.middleware("http") Pattern
app/middleware/logging.py
import time
import logging
from fastapi import Request, Response
logger = logging.getLogger("api.access")
async def logging_middleware(request: Request, call_next) -> Response:
start = time.perf_counter()
response = await call_next(request)
elapsed = time.perf_counter() - start
logger.info(
"%s %s %d %.3fs",
request.method,
request.url.path,
response.status_code,
elapsed,
)
response.headers["X-Response-Time"] = f"{elapsed:.4f}s"
return response
app/main.py
from fastapi import FastAPI
from app.middleware.logging import logging_middleware
app = FastAPI()
app.middleware("http")(logging_middleware)
Request ID Middleware
app/middleware/request_id.py
import uuid
from fastapi import Request, Response
async def request_id_middleware(request: Request, call_next) -> Response:
request_id = request.headers.get("X-Request-ID", str(uuid.uuid4()))
# Attach to request state so handlers can access it
request.state.request_id = request_id
response = await call_next(request)
response.headers["X-Request-ID"] = request_id
return response
ASGI Middleware Class
For more control, write a proper ASGI middleware class:
app/middleware/timing.py
import time
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
class TimingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next) -> Response:
start = time.perf_counter()
response = await call_next(request)
elapsed_ms = (time.perf_counter() - start) * 1000
response.headers["X-Process-Time-Ms"] = f"{elapsed_ms:.2f}"
return response
app/main.py
from app.middleware.timing import TimingMiddleware
app.add_middleware(TimingMiddleware)
Middleware Execution Order
Middleware added last runs outermost (first to see the request, last to see the response):
app/main.py
app.add_middleware(TimingMiddleware) # Runs outermost (added last)
app.add_middleware(RequestIDMiddleware) # Runs second
app.middleware("http")(logging_middleware) # Runs innermost (added first)
Conditional Middleware
Skip middleware for specific paths (e.g., health check):
app/middleware/request_id.py
SKIP_PATHS = {"/health", "/metrics", "/favicon.ico"}
async def request_id_middleware(request: Request, call_next) -> Response:
if request.url.path in SKIP_PATHS:
return await call_next(request)
request_id = request.headers.get("X-Request-ID", str(uuid.uuid4()))
request.state.request_id = request_id
response = await call_next(request)
response.headers["X-Request-ID"] = request_id
return response
Accessing Middleware Data in Routes
app/routers/items.py
from fastapi import Request
@router.get("/")
async def list_items(request: Request) -> dict:
request_id = getattr(request.state, "request_id", "unknown")
return {"request_id": request_id, "items": []}
Common Pitfalls
| Pitfall | Cause / Symptom | Fix |
|---|---|---|
| Middleware catches errors before exception handlers | call_next raises instead of returning 4xx/5xx | Wrap call_next in try/except inside middleware |
| Request body consumed in middleware | Body is a stream — read once | Only read body in middleware if necessary; use request.body() carefully |
| Middleware order wrong | Earlier middleware can't access headers set by later one | Remember: last added = outermost |
| Performance overhead | Too many middleware layers | Profile with time.perf_counter(), minimize layers |
| State not thread-safe | Using module-level variables in middleware | Use request.state for per-request data |
Hands-On Practice
app/middleware/security_headers.py
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
SECURITY_HEADERS = {
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Referrer-Policy": "strict-origin-when-cross-origin",
"Permissions-Policy": "geolocation=(), microphone=(), camera=()",
}
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next) -> Response:
response = await call_next(request)
for header, value in SECURITY_HEADERS.items():
response.headers[header] = value
return response
app/main.py
app.add_middleware(SecurityHeadersMiddleware)
test-middleware.sh
uvicorn app.main:app --reload
curl -v http://localhost:8000/ 2>&1 | grep -E "X-|Response-Time"
# → X-Request-ID: <uuid>
# → X-Response-Time: 0.0012s
# → X-Content-Type-Options: nosniff