趋近智
构建一个简单的 FastAPI 服务,用于加载预训练的机器学习模型并公开预测接口。它集成了 Pydantic 数据验证,并应用了模型集成技术。
首先,你需要一个已保存到文件的训练好的机器学习模型。在这个例子中,我们假设你有一个使用 scikit-learn 在 Iris 数据集上训练并使用 joblib 保存的简单分类器。
如果你手头没有现成的,你可以像这样创建一个基础模型:
# train_save_model.py
import joblib
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
# 加载 Iris 数据集
iris = load_iris()
X, y = iris.data, iris.target
# 训练一个简单的逻辑回归模型
model = LogisticRegression(max_iter=200)
model.fit(X, y)
# 保存训练好的模型
model_filename = 'iris_classifier.joblib'
joblib.dump(model, model_filename)
print(f"模型已训练并保存到 {model_filename}")
# 预期输出映射:{0: 'setosa', 1: 'versicolor', 2: 'virginica'}
print(f"目标名称: {list(iris.target_names)}")
运行此脚本(python train_save_model.py),会在你的项目目录中生成 iris_classifier.joblib 文件。这个模型需要四个输入特征:萼片长度、萼片宽度、花瓣长度和花瓣宽度(所有单位均为厘米)。它能预测三种鸢尾花物种之一:setosa、versicolor 或 virginica。
我们来组织一下服务。创建一个包含以下结构的小型项目目录:
fastapi_ml_service/
├── iris_classifier.joblib # 你的已保存模型文件
├── models.py # 用于请求/响应的 Pydantic 模型
└── main.py # 你的 FastAPI 应用程序代码
models.py)我们需要 Pydantic 模型来定义输入数据和预测响应的结构。创建 models.py 文件:
# models.py
from pydantic import BaseModel, Field
from typing import List
class IrisFeatures(BaseModel):
"""鸢尾花预测的输入特征。"""
sepal_length: float = Field(..., gt=0, description="萼片长度,单位厘米")
sepal_width: float = Field(..., gt=0, description="萼片宽度,单位厘米")
petal_length: float = Field(..., gt=0, description="花瓣长度,单位厘米")
petal_width: float = Field(..., gt=0, description="花瓣宽度,单位厘米")
class Config:
# 用于 FastAPI 文档的示例
schema_extra = {
"example": {
"sepal_length": 5.1,
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
}
}
class PredictionOut(BaseModel):
"""预测输出的模式。"""
predicted_class_id: int = Field(..., description="预测类别索引(0、1 或 2)")
predicted_class_name: str = Field(..., description="预测类别名称('setosa'、'versicolor'、'virginica')")
probabilities: List[float] = Field(..., description="每个类别的概率列表 [setosa, versicolor, virginica]")
此处:
IrisFeatures 规定了模型所需的四个输入特征。我们使用 Field 来添加验证(必须大于 0)和描述,FastAPI 会使用它们自动生成文档。我们还包含了一个示例载荷。PredictionOut 规定了响应的结构,包括预测类别索引、对应的名称以及每个类别的概率。main.py)现在,我们来编写 main.py 中的核心应用程序逻辑。我们将在应用启动时加载模型并创建一个 /predict 接口。
# main.py
import joblib
import numpy as np
from fastapi import FastAPI, HTTPException
from models import IrisFeatures, PredictionOut # 导入 Pydantic 模型
# --- 应用程序设置 ---
app = FastAPI(
title="鸢尾花预测服务",
description="一个使用预训练模型预测鸢尾花物种的简单 API。",
version="0.1.0",
)
# --- 模型加载 ---
# 在应用程序启动时加载模型。
# 对于大型应用程序,请考虑依赖注入。
model_path = "iris_classifier.joblib"
try:
model = joblib.load(model_path)
print(f"模型从 {model_path} 加载成功")
# 根据 Iris 数据集标准顺序定义类别名称
class_names = ['setosa', 'versicolor', 'virginica']
except FileNotFoundError:
print(f"错误:模型文件在 {model_path} 未找到")
model = None # 如果加载失败,将模型设为 None
except Exception as e:
print(f"加载模型时出错:{e}")
model = None
# --- API 接口 ---
@app.get("/")
def read_root():
"""根接口,提供基本的 API 信息。"""
return {"message": "欢迎来到鸢尾花预测 API!"}
@app.post("/predict", response_model=PredictionOut)
async def predict_iris(features: IrisFeatures):
"""
根据输入特征预测鸢尾花物种。
接受萼片长度、萼片宽度、花瓣长度和花瓣宽度,
返回预测类别 ID、类别名称和概率。
"""
if model is None:
raise HTTPException(status_code=503, detail="模型未加载或不可用。")
# 1. 将输入数据转换为模型预期的格式
# (scikit-learn 模型通常需要一个 2D NumPy 数组)
input_data = np.array([[
features.sepal_length,
features.sepal_width,
features.petal_length,
features.petal_width
]])
# 2. 进行预测
try:
prediction_id = model.predict(input_data)
probabilities = model.predict_proba(input_data)
except Exception as e:
# 处理预测过程中可能出现的错误
raise HTTPException(status_code=500, detail=f"预测错误: {e}")
# 3. 格式化响应
predicted_class_index = int(prediction_id[0]) # 获取第一个元素
if predicted_class_index < 0 or predicted_class_index >= len(class_names):
raise HTTPException(status_code=500, detail="预测索引超出范围。")
predicted_class_name = class_names[predicted_class_index]
prediction_probabilities = probabilities[0].tolist() # 获取第一个(也是唯一一个)输入的概率
return PredictionOut(
predicted_class_id=predicted_class_index,
predicted_class_name=predicted_class_name,
probabilities=prediction_probabilities
)
# --- 运行应用程序(可选,用于直接执行) ---
# 通常,你会从命令行使用 Uvicorn 运行它。
# if __name__ == "__main__":
# import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=8000)
main.py 中的步骤:
FastAPI、joblib、numpy、HTTPException 和我们的 Pydantic 模型。joblib.load() 加载 iris_classifier.joblib 文件。其中包含了基本的错误处理。我们还定义了与模型输出对应的 class_names。/predict 接口:
IrisFeatures Pydantic 模型相符的请求体。FastAPI 会自动解析传入的 JSON 并根据此模型进行验证。如果验证失败,FastAPI 会自动返回 422 Unprocessable Entity 错误。response_model=PredictionOut。FastAPI 使用此参数来验证传出的响应、筛选数据(只返回 PredictionOut 中定义的字段)并生成文档。IrisFeatures 对象转换为 2D NumPy 数组,以符合 scikit-learn 的 predict 和 predict_proba 方法的要求。model.predict() 获取类别 ID,并调用 model.predict_proba() 获取类别概率。PredictionOut 模型的一个实例。FastAPI 会自动将这个 Pydantic 对象转换为 JSON 响应。启动服务器: 在 fastapi_ml_service 目录中打开终端并运行:
uvicorn main:app --reload --host 0.0.0.0 --port 8000
main:app:告诉 Uvicorn 在 main.py 文件中找到 app 对象。--reload:当代码更改时自动重启服务器(在开发过程中很有用)。--host 0.0.0.0:使服务器可以从你的网络中的其他机器(或后续的 Docker 容器)访问。--port 8000:指定运行端口。访问文档: 打开你的网页浏览器并访问 http://localhost:8000/docs。你应该会看到 FastAPI 生成的交互式 Swagger UI 文档,其中显示了你的 / 和 /predict 接口,包括 Pydantic 定义的模式。
使用 curl 测试: 打开另一个终端,向 /predict 接口发送一个 POST 请求:
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"sepal_length": 5.1,
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
}'
你应该会收到一个类似于这样的 JSON 响应(概率可能略有不同):
{
"predicted_class_id": 0,
"predicted_class_name": "setosa",
"probabilities": [0.97..., 0.02..., 0.00...]
}
测试验证: 尝试发送无效数据(例如,缺少字段或提供非数字数据):
curl -X 'POST' \
'http://localhost:8000/predict' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"sepal_length": 5.1,
"sepal_width": "not-a-number",
"petal_length": 1.4,
"petal_width": 0.2
}'
FastAPI(得益于 Pydantic)将自动返回一个 422 Unprocessable Entity 错误,其中包含验证失败的详细信息:
{
"detail": [
{
"loc": [
"body",
"sepal_width"
],
"msg": "value is not a valid float",
"type": "type_error.float"
}
]
}
你现在已经成功构建了一个可工作的机器学习预测服务,使用了 FastAPI!它加载模型,通过 Pydantic 为输入和输出规定了清晰的数据约定,并通过一个明确的 API 接口提供预测。这为部署更复杂的模型建立了坚固的根基。在后续章节中,我们将讨论组织大型应用程序、测试、处理异步操作和容器化。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造