Once you have trained and saved a machine learning model, the next step towards making it useful is to create an interface so that other applications or users can send data to it and receive predictions. A common and effective way to achieve this is by wrapping the model in a web Application Programming Interface (API), specifically a REST (Representational State Transfer) API.
A REST API uses standard HTTP methods (like GET, POST, PUT, DELETE) to allow communication between a client (which requests a prediction) and a server (which hosts the model and computes the prediction). For model prediction, the client typically sends new input data via an HTTP POST request to a specific URL endpoint on the server. The server processes this data using the loaded model and sends the prediction back in the HTTP response, often formatted as JSON.
This approach offers several advantages:
While you could build a web server from scratch using Python's built-in http.server
, using a web framework significantly simplifies development. Frameworks handle the complexities of HTTP parsing, request routing, and response generation. For Python, two popular choices for building model APIs are:
For building robust model serving APIs, FastAPI is often preferred due to its built-in data validation and automatic documentation features, which help ensure correct usage and simplify integration. We will use FastAPI in our examples.
Before building the API, ensure you have the necessary libraries installed. You'll typically need the framework itself, an ASGI server like Uvicorn to run FastAPI, the library used for your model (e.g., scikit-learn
), and the library used for saving/loading (e.g., joblib
).
pip install fastapi uvicorn scikit-learn joblib pydantic pandas
In your API application code, you first need to load the serialized model you saved in the previous step. It's important to load the model once when the API server starts up, rather than reloading it for every incoming prediction request. Loading the model can be time-consuming, and doing it repeatedly would introduce significant latency.
# main.py
import joblib
import pandas as pd
from fastapi import FastAPI
from pydantic import BaseModel
import os
# Define the path to the saved model file
MODEL_DIR = os.environ.get("MODEL_DIR", ".") # Use environment variable or current dir
MODEL_PATH = os.path.join(MODEL_DIR, "model.joblib")
# Load the model during application startup
try:
model = joblib.load(MODEL_PATH)
print(f"Model loaded successfully from {MODEL_PATH}")
except FileNotFoundError:
print(f"Error: Model file not found at {MODEL_PATH}")
# Handle the error appropriately - maybe exit or use a default/dummy model
model = None # Or raise an exception
except Exception as e:
print(f"Error loading model: {e}")
model = None
# Initialize the FastAPI app
app = FastAPI(title="Model Prediction API", version="1.0")
The code attempts to load a model named
model.joblib
from a directory specified by theMODEL_DIR
environment variable or the current directory. Error handling is included for cases where the file doesn't exist or another loading error occurs.
A significant advantage of FastAPI is its use of Pydantic for data validation. By defining Pydantic models (classes inheriting from BaseModel
), you specify the expected structure and data types for incoming requests and outgoing responses. FastAPI uses these models to automatically parse and validate JSON request bodies and serialize Python objects back to JSON responses.
Let's assume our model predicts a class label based on four numerical features. We can define schemas like this:
# main.py (continued)
class InputFeatures(BaseModel):
"""Defines the structure for input data"""
feature1: float
feature2: float
feature3: float
feature4: float
# Example to enforce constraints if needed
# @validator('feature1')
# def feature1_must_be_positive(cls, v):
# if v <= 0:
# raise ValueError('feature1 must be positive')
# return v
class PredictionOut(BaseModel):
"""Defines the structure for the prediction output"""
predicted_class: int # Or float for regression, str for class names
probability: float | None = None # Optional: include prediction probability
InputFeatures
expects a JSON object with four keys (feature1
tofeature4
), each having a floating-point value.PredictionOut
defines the structure of the JSON response, containing the predicted class.
Now, we define the API endpoint that will receive data and return predictions. We use FastAPI's decorator syntax (@app.post(...)
) to associate a URL path (e.g., /predict
) and an HTTP method (POST) with a Python function.
# main.py (continued)
@app.get("/")
async def read_root():
"""Root endpoint providing basic API information."""
return {"message": "Welcome to the Model Prediction API. Use the /predict endpoint."}
@app.post("/predict", response_model=PredictionOut)
async def predict(features: InputFeatures):
"""
Receives input features via POST request, makes a prediction,
and returns the predicted class.
"""
if model is None:
# Handle case where model failed to load
raise HTTPException(status_code=503, detail="Model is not available")
# Convert input data into the format expected by the model
# Example: scikit-learn models often expect a 2D array-like structure
# The order of columns MUST match the order used during training!
input_df = pd.DataFrame([features.model_dump()]) # Pydantic v2 uses model_dump()
# Ensure column order if your model is sensitive to it
# feature_order = ['feature1', 'feature2', 'feature3', 'feature4'] # Define expected order
# input_df = input_df[feature_order]
try:
# Make prediction
prediction_result = model.predict(input_df)
predicted_class = int(prediction_result[0]) # Assuming predict returns an array
# Optional: Get prediction probability if model supports it (e.g., logistic regression, trees)
prediction_proba = None
if hasattr(model, "predict_proba"):
probabilities = model.predict_proba(input_df)
# Assuming binary classification, get probability of the predicted class
prediction_proba = float(probabilities[0, predicted_class])
# Return the prediction formatted according to PredictionOut schema
return PredictionOut(predicted_class=predicted_class, probability=prediction_proba)
except Exception as e:
# Handle potential errors during prediction
# Log the error for debugging
print(f"Error during prediction: {e}")
raise HTTPException(status_code=500, detail="Prediction failed")
This code defines a
/predict
endpoint that accepts POST requests. The functionpredict
takes an argumentfeatures
type-hinted with ourInputFeatures
Pydantic model. FastAPI automatically validates the incoming JSON body against this schema. Inside the function, the input data is converted to a Pandas DataFrame (a common format for scikit-learn models), the model'spredict
method is called, and the result is returned, formatted according to thePredictionOut
schema. Basic error handling for model loading and prediction execution is included.
To run this FastAPI application, you use an ASGI server like Uvicorn from your terminal:
uvicorn main:app --reload --host 0.0.0.0 --port 8000
main
: Refers to the Python file main.py
.app
: Refers to the FastAPI instance created inside main.py
(e.g., app = FastAPI()
).--reload
: Enables auto-reloading when code changes (useful for development).--host 0.0.0.0
: Makes the server accessible from other machines on the network (use 127.0.0.1
for local access only).--port 8000
: Specifies the port number the server listens on.Once the server is running, you can access the automatically generated interactive documentation by navigating your web browser to http://localhost:8000/docs
(Swagger UI) or http://localhost:8000/redoc
.
You can test your running API endpoint using various tools:
FastAPI Docs: The interactive /docs
interface allows you to send test requests directly from your browser.
curl: A command-line tool for making HTTP requests.
curl -X POST "http://localhost:8000/predict" \
-H "Content-Type: application/json" \
-d '{
"feature1": 5.1,
"feature2": 3.5,
"feature3": 1.4,
"feature4": 0.2
}'
Python requests
library:
import requests
import json
api_url = "http://localhost:8000/predict"
input_data = {
"feature1": 5.1,
"feature2": 3.5,
"feature3": 1.4,
"feature4": 0.2
}
response = requests.post(api_url, json=input_data)
if response.status_code == 200:
prediction = response.json()
print(f"Prediction response: {prediction}")
else:
print(f"Error: {response.status_code}")
print(response.text)
The following diagram illustrates the typical flow when a client requests a prediction from the model API:
A client sends a POST request with input features to the API server. The server validates the input, formats it, uses the loaded model to make a prediction, formats the output, and sends the prediction back to the client in the response.
Building a REST API is a fundamental step in operationalizing machine learning models. Frameworks like FastAPI streamline this process, allowing you to create robust, well-documented, and performant prediction services. This API layer serves as the bridge between your trained model and the applications that need its predictive power. The next logical step is often to package this API service for easier deployment, which leads us to containerization.
© 2025 ApX Machine Learning