Now that we've explored the essential patterns and considerations for designing scalable APIs for generative models, let's put this knowledge into practice. This section guides you through building a functional inference API for a diffusion model using FastAPI, a popular Python web framework known for its speed and ease of use. We will implement a basic asynchronous generation endpoint, demonstrating how to handle potentially long-running inference tasks without blocking the main server process.
This practical assumes you have a working Python environment (3.8+ recommended) and are comfortable installing packages using pip
. You should also have a foundational understanding of diffusion models and the diffusers
library from Hugging Face, as well as basic familiarity with REST APIs.
First, install the necessary libraries:
pip install "fastapi[all]" diffusers transformers accelerate torch torchvision torchaudio Pillow
fastapi[all]
: Installs FastAPI and its common dependencies, including Uvicorn (an ASGI server) and Pydantic (for data validation).diffusers
, transformers
, accelerate
: Hugging Face libraries for working with diffusion models.torch
, torchvision
, torchaudio
: PyTorch core libraries. If you have a compatible NVIDIA GPU and CUDA installed, ensure you install the CUDA-enabled version of PyTorch following the instructions on the official PyTorch website. Otherwise, this will install the CPU version.Pillow
: For image handling.You will also need a pre-trained diffusion model. For this example, we'll use a Stable Diffusion model, but you can adapt the code for other diffusion models available in the diffusers
library. Ensure you have sufficient disk space and memory (and ideally a GPU) to download and run the model.
Create a Python file, for example, api_server.py
. Start by importing the necessary modules and initializing the FastAPI app and the diffusion model pipeline.
import io
import base64
import uuid
from fastapi import FastAPI, BackgroundTasks, HTTPException
from pydantic import BaseModel
from diffusers import DiffusionPipeline
import torch
from PIL import Image
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Application Setup ---
app = FastAPI(title="Diffusion Model Inference API")
# --- Model Loading ---
# Ensure you have credentials configured if using private models
# (e.g., huggingface-cli login)
model_id = "stabilityai/stable-diffusion-2-1-base" # Or another diffusion model
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = None
try:
logger.info(f"Loading model: {model_id} onto device: {device}")
# Load the pipeline
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
pipe = pipe.to(device)
logger.info("Model loaded successfully.")
# Optional: Add safety checker if needed/available for the model
# from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
# safety_checker = StableDiffusionSafetyChecker.from_pretrained(...)
# pipe.safety_checker = safety_checker
except Exception as e:
logger.error(f"Failed to load the diffusion model: {e}")
# Decide how to handle failure: exit, run without model, etc.
# For this example, we allow the app to start but generation will fail.
pipe = None
# --- In-memory storage for task results (Replace with persistent storage in production) ---
results_store = {}
In this setup:
FastAPI
application.model_id
and determine the available device
(GPU or CPU).DiffusionPipeline
. Loading the model at application startup is crucial for performance. Reloading the model for each request would introduce significant latency. We use torch.float16
on CUDA for faster inference and reduced memory usage. Error handling is included in case the model fails to load.results_store
is used to temporarily hold generation results. In a production system, you would replace this with a more robust solution like Redis, a database, or cloud storage.FastAPI uses Pydantic models for request and response data validation and serialization. Let's define models for our image generation endpoint.
class GenerationRequest(BaseModel):
prompt: str
negative_prompt: str | None = None
num_inference_steps: int = 50
guidance_scale: float = 7.5
seed: int | None = None # For reproducibility
class TaskResponse(BaseModel):
task_id: str
status: str
message: str | None = None
class ResultResponse(BaseModel):
task_id: str
status: str
prompt: str
image_base64: str | None = None # Base64 encoded image
error_message: str | None = None
GenerationRequest
: Defines the expected input JSON body for a generation request. It includes the prompt
and optional parameters like negative_prompt
, num_inference_steps
, guidance_scale
, and seed
. Default values are provided for convenience.TaskResponse
: Used to immediately reply to the client after accepting a task, providing a task_id
for later status checks.ResultResponse
: Defines the structure for retrieving the result of a generation task, including the status, original prompt, the generated image (encoded as Base64), or an error message.Diffusion model inference can take several seconds or even minutes, especially on CPUs or for high-resolution images. A synchronous API endpoint would block, potentially timing out or preventing the server from handling other requests. We'll use FastAPI's BackgroundTasks
to run the inference process in the background.
First, define the function that performs the actual inference:
def run_diffusion_inference(task_id: str, req: GenerationRequest):
"""Runs the diffusion model inference in the background."""
global results_store
logger.info(f"Starting inference for task_id: {task_id}")
results_store[task_id] = {"status": "processing", "prompt": req.prompt}
if pipe is None:
logger.error(f"Model not loaded. Cannot process task_id: {task_id}")
results_store[task_id].update({"status": "failed", "error_message": "Model not available."})
return
try:
# Set seed for reproducibility if provided
generator = None
if req.seed is not None:
generator = torch.Generator(device=device).manual_seed(req.seed)
# Perform inference
with torch.inference_mode(): # More memory efficient than torch.no_grad()
image: Image.Image = pipe(
prompt=req.prompt,
negative_prompt=req.negative_prompt,
num_inference_steps=req.num_inference_steps,
guidance_scale=req.guidance_scale,
generator=generator
).images[0] # Get the first image from the result list
# Convert image to Base64
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Store result
results_store[task_id].update({"status": "completed", "image_base64": img_str})
logger.info(f"Inference completed successfully for task_id: {task_id}")
except Exception as e:
logger.exception(f"Inference failed for task_id: {task_id}") # Log full traceback
results_store[task_id].update({"status": "failed", "error_message": str(e)})
This function:
task_id
and the GenerationRequest
data.results_store
to mark the task as "processing".pipe
was loaded successfully.torch.inference_mode()
for efficiency.results_store
with the final status ("completed" or "failed") and the result (image or error message).Now, define the FastAPI endpoints: one to submit generation tasks and another to check their status and retrieve results.
@app.post("/generate", response_model=TaskResponse, status_code=202)
async def submit_generation_task(req: GenerationRequest, background_tasks: BackgroundTasks):
"""
Accepts a generation request, adds it to the background queue,
and returns a task ID.
"""
if pipe is None:
raise HTTPException(status_code=503, detail="Model is not available or failed to load.")
task_id = str(uuid.uuid4())
results_store[task_id] = {"status": "pending", "prompt": req.prompt} # Initial status
logger.info(f"Received generation request. Assigning task_id: {task_id}")
background_tasks.add_task(run_diffusion_inference, task_id, req)
return TaskResponse(task_id=task_id, status="pending", message="Task received and queued for processing.")
@app.get("/results/{task_id}", response_model=ResultResponse)
async def get_generation_result(task_id: str):
"""
Retrieves the status and result (if available) for a given task ID.
"""
result = results_store.get(task_id)
if not result:
raise HTTPException(status_code=404, detail="Task ID not found.")
return ResultResponse(
task_id=task_id,
status=result.get("status", "unknown"),
prompt=result.get("prompt", ""),
image_base64=result.get("image_base64"),
error_message=result.get("error_message")
)
@app.get("/health")
async def health_check():
"""Basic health check endpoint."""
# Could add more checks here (e.g., model responsiveness)
model_status = "available" if pipe is not None else "unavailable"
return {"status": "ok", "model_status": model_status}
/generate
(POST):
GenerationRequest
in the request body.BackgroundTasks
provided by FastAPI.task_id
using uuid
.results_store
.run_diffusion_inference
function to the background tasks queue using background_tasks.add_task()
, passing the task_id
and request data.202 Accepted
status code along with the TaskResponse
, indicating the task has been accepted but not completed.503 Service Unavailable
if the model isn't loaded./results/{task_id}
(GET):
task_id
as a path parameter.results_store
.404 Not Found
if the ID doesn't exist.ResultResponse
. The client needs to poll this endpoint until the status is "completed" or "failed"./health
(GET):
Asynchronous inference request flow using FastAPI BackgroundTasks.
Save the code as api_server.py
. Run the API using Uvicorn from your terminal:
uvicorn api_server:app --host 0.0.0.0 --port 8000 --reload
api_server:app
: Tells Uvicorn where to find the FastAPI application instance (app
) inside the api_server.py
file.--host 0.0.0.0
: Makes the server accessible from other machines on your network.--port 8000
: Specifies the port to run on.--reload
: Automatically restarts the server when code changes are detected (useful during development).Testing with curl
:
Submit a generation task:
curl -X POST "http://localhost:8000/generate" \
-H "Content-Type: application/json" \
-d '{
"prompt": "A photo of an astronaut riding a horse on the moon",
"num_inference_steps": 25,
"guidance_scale": 9.0
}'
This should return a JSON response like:
{
"task_id": "some-unique-uuid-string",
"status": "pending",
"message": "Task received and queued for processing."
}
Check the result (replace some-unique-uuid-string
with the actual ID received):
curl -X GET "http://localhost:8000/results/some-unique-uuid-string"
Initially, this might return:
{
"task_id": "some-unique-uuid-string",
"status": "processing",
"prompt": "A photo of an astronaut riding a horse on the moon",
"image_base64": null,
"error_message": null
}
After a short while (depending on your hardware and inference steps), polling again should return the completed result:
{
"task_id": "some-unique-uuid-string",
"status": "completed",
"prompt": "A photo of an astronaut riding a horse on the moon",
"image_base64": "iVBORw0KGgoAAAANSUhEUgAA...",
"error_message": null
}
(The image_base64
string will be very long). You can copy this string and use an online Base64-to-image converter or a simple script to view the generated image.
This practical provides a fundamental structure. Building upon this, you can integrate concepts discussed earlier in the chapter:
BackgroundTasks
with a dedicated message queue system (like Celery with RabbitMQ/Redis or AWS SQS) for better scalability, persistence, and separation of concerns between the API server and the inference workers.slowapi
) to the /generate
endpoint to protect your service from abuse or overload.ResultResponse
. Update the results_store
to use a persistent database or cache like Redis.This hands-on exercise demonstrates the core components of building an asynchronous API for diffusion models. By applying these principles and integrating further scaling techniques, you can create robust and efficient inference services ready for production workloads.
© 2025 ApX Machine Learning