docker run
docker-compose.yml
Once your machine learning model is trained and saved, it needs an interface to receive input data and return predictions. Simply running a Python script on demand isn't practical for most deployment scenarios. Instead, we typically wrap the model loading and prediction logic within a web Application Programming Interface (API). This API acts as a standardized entry point, allowing other services or applications to request predictions over a network, usually via HTTP requests.
For Python-based ML models, two popular microframeworks for building these APIs are Flask and FastAPI. They provide the tools to handle incoming web requests, process data, interact with your model, and send back responses, all with relatively little boilerplate code.
Creating an API standardizes interaction with your model. Instead of relying on custom scripts or specific execution environments, anyone (or any application) that can send an HTTP request can potentially use your model. This offers several advantages:
Flask is a lightweight WSGI (Web Server Gateway Interface) web application framework. It's known for its simplicity and minimal core, making it easy to get started. You add functionality through extensions as needed.
A minimal Flask application for inference might look like this:
.joblib
or .pkl
file) when the application starts. Be mindful of where this model file resides, especially when running inside a container (often copied in during the build or mounted via a volume)./predict
) that listens for specific HTTP methods (typically POST
for sending data).predict()
or predict_proba()
method.Example: Flask API for a Scikit-learn Model
Let's assume you have a trained Scikit-learn model saved as model.joblib
.
app.py
(Flask):import joblib
from flask import Flask, request, jsonify
import pandas as pd
# Initialize Flask app
app = Flask(__name__)
# Load the trained model
# Ensure 'model.joblib' is accessible in the container's working directory
# or mounted via a volume.
try:
model = joblib.load('model.joblib')
print("Model loaded successfully.")
except FileNotFoundError:
print("Error: model.joblib not found. Ensure it's in the correct path.")
model = None # Handle gracefully or exit
except Exception as e:
print(f"Error loading model: {e}")
model = None
# Define the prediction endpoint
@app.route('/predict', methods=['POST'])
def predict():
if model is None:
return jsonify({"error": "Model not loaded"}), 500
try:
# Get JSON data from the request
data = request.get_json(force=True)
# Basic input validation (example: check for 'features' key)
if 'features' not in data:
return jsonify({"error": "Missing 'features' key in JSON payload"}), 400
# Convert input data to DataFrame (adjust columns as per your model)
# This assumes input is like: {"features": [[val1, val2, ...], [val1, val2, ...]]}
# Or: {"features": [{"col1": val1, "col2": val2}, {"col1": val1, "col2": val2}]}
input_df = pd.DataFrame(data['features'])
# Make prediction
predictions = model.predict(input_df)
# Convert predictions to list for JSON serialization
output = predictions.tolist()
# Return predictions as JSON
return jsonify({"predictions": output})
except Exception as e:
# Basic error handling
print(f"Prediction error: {e}")
return jsonify({"error": "Prediction failed", "details": str(e)}), 500
# Run the app (for local development)
# In production, use a WSGI server like Gunicorn:
# gunicorn --bind 0.0.0.0:5000 app:app
if __name__ == '__main__':
# Use port 5000 by default, or configure as needed
app.run(host='0.0.0.0', port=5000, debug=False) # Set debug=False for production
requirements.txt
:flask
pandas
scikit-learn
joblib
# Add gunicorn if using it as the server
gunicorn
To run this locally: python app.py
. To run it with a production server: gunicorn --bind 0.0.0.0:5000 app:app
. Inside a Docker container, the CMD
or ENTRYPOINT
would typically invoke gunicorn
.
FastAPI is a modern, high-performance web framework built on Starlette (for web parts) and Pydantic (for data validation). It leverages ASGI (Asynchronous Server Gateway Interface), enabling support for asynchronous code (async
/await
), which can significantly improve throughput for I/O-bound tasks often found in web services.
Key advantages of FastAPI for inference APIs include:
Example: FastAPI API for a Scikit-learn Model
Using the same model.joblib
:
app.py
(FastAPI):import joblib
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import List, Union, Dict, Any
import pandas as pd
import uvicorn # Required for running locally
# Define input data schema using Pydantic
# Option 1: List of lists for features
# class PredictionInput(BaseModel):
# features: List[List[Union[float, int, str]]]
# Option 2: List of dictionaries for features (more explicit)
class FeatureRow(BaseModel):
# Define expected feature names and types explicitly
# Example:
feature1: float
feature2: int
feature3: str
# Add all features your model expects...
class PredictionInput(BaseModel):
features: List[FeatureRow]
# Define output data schema
class PredictionOutput(BaseModel):
predictions: List[Any] # Adjust 'Any' to the specific type if known (int, float, str)
# Initialize FastAPI app
app = FastAPI(title="ML Model Inference API", version="1.0")
# Load the trained model
try:
model = joblib.load('model.joblib')
print("Model loaded successfully.")
except FileNotFoundError:
print("Error: model.joblib not found. Ensure it's in the correct path.")
model = None
except Exception as e:
print(f"Error loading model: {e}")
model = None
@app.on_event("startup")
async def startup_event():
# You can also load the model here if preferred
# Global 'model' variable needs to be accessible
if model is None:
print("Model could not be loaded on startup.")
# Optionally raise an error or handle differently
@app.get("/")
def read_root():
return {"message": "Welcome to the ML Inference API"}
# Define the prediction endpoint
@app.post("/predict", response_model=PredictionOutput)
async def predict(input_data: PredictionInput):
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded")
try:
# Pydantic automatically validated the input based on PredictionInput schema
# Convert Pydantic model to list of dicts, then to DataFrame
input_dict_list = [row.dict() for row in input_data.features]
input_df = pd.DataFrame(input_dict_list)
# Ensure column order matches model's training if necessary
# input_df = input_df[expected_column_order]
# Make prediction
predictions = model.predict(input_df)
# Convert predictions to list
output = predictions.tolist()
# Return predictions (FastAPI validates against PredictionOutput)
return PredictionOutput(predictions=output)
except Exception as e:
print(f"Prediction error: {e}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {e}")
# Run the app (for local development)
# In production, use Uvicorn directly or via Gunicorn with Uvicorn workers:
# uvicorn app:app --host 0.0.0.0 --port 8000
# gunicorn -k uvicorn.workers.UvicornWorker -b 0.0.0.0:8000 app:app
if __name__ == '__main__':
# Use port 8000 by default for FastAPI
uvicorn.run("app:app", host='0.0.0.0', port=8000, reload=False) # reload=True for dev
requirements.txt
:fastapi
uvicorn[standard] # Includes performance extras
pandas
scikit-learn
joblib
pydantic
# Add gunicorn if using it as the server
gunicorn
To run this locally: uvicorn app:app --reload --port 8000
. For production: uvicorn app:app --host 0.0.0.0 --port 8000
. Inside Docker, the CMD
would typically invoke uvicorn
or gunicorn
with Uvicorn workers. Accessing http://localhost:8000/docs
in your browser will show the automatically generated interactive documentation.
Regardless of whether you choose Flask or FastAPI, the process of containerizing the API is similar:
Dockerfile
starting from a suitable base image (e.g., python:3.9-slim
).requirements.txt
file and run pip install -r requirements.txt
.app.py
and any other necessary scripts or modules.model.joblib
(or other model file) into the image. Alternatively, plan to mount it using a volume if the model is large or updated frequently.EXPOSE
instruction to document the port your application listens on (e.g., EXPOSE 5000
for Flask, EXPOSE 8000
for FastAPI). Note that EXPOSE
is documentation; you still need to publish the port with -p
when running the container (docker run -p 8080:8000 ...
).CMD
or ENTRYPOINT
to specify the command that starts your web server (e.g., CMD ["gunicorn", "--bind", "0.0.0.0:5000", "app:app"]
or CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
). Using 0.0.0.0
as the host ensures the server listens on all available network interfaces inside the container, making it accessible from the outside when the port is published.Building an API with Flask or FastAPI provides a robust and standard way to serve your ML models. The choice between them often comes down to project requirements: Flask's simplicity is great for smaller projects or when performance isn't the primary concern, while FastAPI's performance, built-in validation, and async capabilities make it attractive for more demanding applications. Both integrate well with Docker for creating portable and scalable inference services.
© 2025 ApX Machine Learning