趋近智
一个机器学习预测服务将进行重构,以便使用APIRouter进行更好的组织管理。随后,将使用TestClient编写测试,以验证其功能和稳定性。
假设我们最初的预测服务大致如下所示(很可能在一个单独的main.py文件中):
# main_before_refactor.py (示例)
from fastapi import FastAPI
from pydantic import BaseModel
import joblib # 或者你喜欢的模型加载库
# 假设模型已预训练并保存为 'model.joblib'
# 假设必要的预处理步骤已在其他地方定义或很简单
# --- 数据模型 (来自第二章) ---
class InputFeatures(BaseModel):
feature1: float
feature2: float
# ... 其他特征
class PredictionOutput(BaseModel):
prediction: float # 或合适的类型
# --- 应用设置 ---
app = FastAPI(title="Simple ML Prediction Service")
# --- 模型加载 (来自第三章) ---
# 在实际应用中,处理潜在的加载错误
model = joblib.load("model.joblib")
# --- 预测接口 (来自第三章) ---
@app.post("/predict", response_model=PredictionOutput)
async def make_prediction(input_data: InputFeatures):
"""
接收输入特征并返回预测结果。
"""
# 将 Pydantic 模型转换为模型预期的格式
# 这里做了简化;实际的预处理可能更复杂
features = [[input_data.feature1, input_data.feature2]]
prediction_result = model.predict(features)
return PredictionOutput(prediction=prediction_result[0])
# --- 根接口 (可选) ---
@app.get("/")
async def read_root():
return {"message": "Prediction service is running"}
# 运行方式 (使用 uvicorn): uvicorn main_before_refactor:app --reload
这适用于简单情况,但随着我们添加更多接口(例如,用于模型信息、批量预测、不同模型版本),这个单一文件将变得难以管理。
我们来使用APIRouter来组织项目。
创建项目结构: 像这样组织你的文件:
your_project/
├── app/
│ ├── __init__.py
│ ├── main.py # 主应用设置
│ ├── routers/
│ │ ├── __init__.py
│ │ └── predictions.py # 预测相关路由
│ ├── models/
│ │ ├── __init__.py
│ │ └── schemas.py # Pydantic 模型
│ └── core/
│ ├── __init__.py
│ └── config.py # 配置 (目前可选)
├── tests/
│ ├── __init__.py
│ └── test_predictions.py # 预测路由的测试
├── model.joblib # 你的序列化模型
└── requirements.txt # 项目依赖
定义 Pydantic 模型: 将 Pydantic 模型移至app/models/schemas.py:
# app/models/schemas.py
from pydantic import BaseModel
class InputFeatures(BaseModel):
feature1: float
feature2: float
# ... 其他特征
class PredictionOutput(BaseModel):
prediction: float # 或合适的类型
创建预测路由: 将预测逻辑移至app/routers/predictions.py。注意我们导入APIRouter并使用router而不是app作为装饰器。我们还调整了导入路径。
# app/routers/predictions.py
from fastapi import APIRouter
import joblib
from app.models.schemas import InputFeatures, PredictionOutput
# 假设模型路径已配置或已知
MODEL_PATH = "model.joblib"
model = joblib.load(MODEL_PATH)
router = APIRouter(
prefix="/predict", # 此路由中的所有路径都将以 /predict 开头
tags=["predictions"] # 在 API 文档中分组接口
)
@router.post("/", response_model=PredictionOutput) # 路径现在相对于前缀
async def make_prediction(input_data: InputFeatures):
"""
接收输入特征并返回预测结果。
(逻辑与之前相同)
"""
features = [[input_data.feature1, input_data.feature2]]
prediction_result = model.predict(features)
return PredictionOutput(prediction=prediction_result[0])
# 你稍后可以在这里添加其他与预测相关的接口,
# 例如,@router.post("/batch", ...)
更新主应用: 修改app/main.py来创建主FastAPI实例并引入路由。
# app/main.py
from fastapi import FastAPI
from app.routers import predictions # 导入路由模块
app = FastAPI(title="Refactored ML Prediction Service")
# 引入 predictions.py 中的路由
app.include_router(predictions.router)
@app.get("/")
async def read_root():
return {"message": "Prediction service is running"}
# 运行方式: uvicorn app.main:app --reload
简化图示,说明了重构后的项目结构以及
main.py、预测路由和Pydantic模型之间的交互。
现在,我们的预测逻辑整齐地存放于app/routers/predictions.py中,而main.py则更简洁,专注于应用设置和路由。
结构就绪后,我们来编写测试。我们将使用pytest和FastAPI的TestClient。
安装 Pytest: 如果你尚未安装,请安装pytest:
pip install pytest
创建测试文件: 创建tests/test_predictions.py。
编写测试:
# tests/test_predictions.py
from fastapi.testclient import TestClient
from app.main import app # 导入 FastAPI 应用实例
from app.models.schemas import InputFeatures # 如果需要,为类型提示导入
# 使用我们的 FastAPI 应用创建一个 TestClient 实例
client = TestClient(app)
def test_read_root():
"""测试根接口。"""
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Prediction service is running"}
def test_make_prediction_success():
"""使用有效输入测试预测接口。"""
# 定义符合 InputFeatures 模式的有效输入数据
valid_input = {"feature1": 5.1, "feature2": 3.5}
# 向 /predict/ 接口发送 POST 请求
response = client.post("/predict/", json=valid_input)
# 断言请求成功 (HTTP 200 OK)
assert response.status_code == 200
# 断言响应体结构符合 PredictionOutput
response_data = response.json()
assert "prediction" in response_data
# 可选:断言预测结果的类型
assert isinstance(response_data["prediction"], float)
# 注意:断言确切的预测值取决于你的模型
# 并且可能需要固定的测试数据集或模拟模型。
# 为简单起见,这里我们关注结构和状态。
def test_make_prediction_invalid_input_type():
"""使用不正确输入数据类型测试预测接口。"""
# 发送特征为字符串而非浮点数的数据
invalid_input = {"feature1": "wrong_type", "feature2": 3.5}
response = client.post("/predict/", json=invalid_input)
# FastAPI/Pydantic 自动处理验证错误
# 期望 HTTP 422 (不可处理实体)
assert response.status_code == 422
# 检查响应体是否包含验证错误详情
response_data = response.json()
assert "detail" in response_data
# 如果需要,你可以对错误消息添加更具体的检查
# 例如,assert "feature1" in str(response_data["detail"])
def test_make_prediction_missing_input_feature():
"""测试缺少输入数据的预测接口。"""
# 发送缺少 'feature2' 的数据
missing_input = {"feature1": 5.1}
response = client.post("/predict/", json=missing_input)
# 期望 HTTP 422 (不可处理实体)
assert response.status_code == 422
response_data = response.json()
assert "detail" in response_data
# 例如,assert "feature2" in str(response_data["detail"])
# 例如,assert "field required" in str(response_data["detail"])
运行测试: 在终端中导航到项目的根目录(your_project/),然后运行pytest:
pytest
Pytest 将发现并运行 tests 目录中的测试。你将看到指示测试通过或失败的输出。
这个实践练习展示了如何应用结构化原则,使用APIRouter来组织你的预测服务,以及如何使用TestClient编写有效的测试。这种方法大幅提高了可维护性,并确保你的API行为符合预期,即使在重构或添加新功能之后也是如此。随着你的应用规模扩大,这种分离和测试的价值会越来越大。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造