一个机器学习预测服务将进行重构,以便使用APIRouter进行更好的组织管理。随后,将使用TestClient编写测试,以验证其功能和稳定性。假设我们最初的预测服务大致如下所示(很可能在一个单独的main.py文件中):# main_before_refactor.py (示例) from fastapi import FastAPI from pydantic import BaseModel import joblib # 或者你喜欢的模型加载库 # 假设模型已预训练并保存为 'model.joblib' # 假设必要的预处理步骤已在其他地方定义或很简单 # --- 数据模型 (来自第二章) --- class InputFeatures(BaseModel): feature1: float feature2: float # ... 其他特征 class PredictionOutput(BaseModel): prediction: float # 或合适的类型 # --- 应用设置 --- app = FastAPI(title="Simple ML Prediction Service") # --- 模型加载 (来自第三章) --- # 在实际应用中,处理潜在的加载错误 model = joblib.load("model.joblib") # --- 预测接口 (来自第三章) --- @app.post("/predict", response_model=PredictionOutput) async def make_prediction(input_data: InputFeatures): """ 接收输入特征并返回预测结果。 """ # 将 Pydantic 模型转换为模型预期的格式 # 这里做了简化;实际的预处理可能更复杂 features = [[input_data.feature1, input_data.feature2]] prediction_result = model.predict(features) return PredictionOutput(prediction=prediction_result[0]) # --- 根接口 (可选) --- @app.get("/") async def read_root(): return {"message": "Prediction service is running"} # 运行方式 (使用 uvicorn): uvicorn main_before_refactor:app --reload 这适用于简单情况,但随着我们添加更多接口(例如,用于模型信息、批量预测、不同模型版本),这个单一文件将变得难以管理。使用 APIRouter 重构我们来使用APIRouter来组织项目。创建项目结构: 像这样组织你的文件:your_project/ ├── app/ │ ├── __init__.py │ ├── main.py # 主应用设置 │ ├── routers/ │ │ ├── __init__.py │ │ └── predictions.py # 预测相关路由 │ ├── models/ │ │ ├── __init__.py │ │ └── schemas.py # Pydantic 模型 │ └── core/ │ ├── __init__.py │ └── config.py # 配置 (目前可选) ├── tests/ │ ├── __init__.py │ └── test_predictions.py # 预测路由的测试 ├── model.joblib # 你的序列化模型 └── requirements.txt # 项目依赖定义 Pydantic 模型: 将 Pydantic 模型移至app/models/schemas.py:# app/models/schemas.py from pydantic import BaseModel class InputFeatures(BaseModel): feature1: float feature2: float # ... 其他特征 class PredictionOutput(BaseModel): prediction: float # 或合适的类型创建预测路由: 将预测逻辑移至app/routers/predictions.py。注意我们导入APIRouter并使用router而不是app作为装饰器。我们还调整了导入路径。# app/routers/predictions.py from fastapi import APIRouter import joblib from app.models.schemas import InputFeatures, PredictionOutput # 假设模型路径已配置或已知 MODEL_PATH = "model.joblib" model = joblib.load(MODEL_PATH) router = APIRouter( prefix="/predict", # 此路由中的所有路径都将以 /predict 开头 tags=["predictions"] # 在 API 文档中分组接口 ) @router.post("/", response_model=PredictionOutput) # 路径现在相对于前缀 async def make_prediction(input_data: InputFeatures): """ 接收输入特征并返回预测结果。 (逻辑与之前相同) """ features = [[input_data.feature1, input_data.feature2]] prediction_result = model.predict(features) return PredictionOutput(prediction=prediction_result[0]) # 你稍后可以在这里添加其他与预测相关的接口, # 例如,@router.post("/batch", ...) 更新主应用: 修改app/main.py来创建主FastAPI实例并引入路由。# app/main.py from fastapi import FastAPI from app.routers import predictions # 导入路由模块 app = FastAPI(title="Refactored ML Prediction Service") # 引入 predictions.py 中的路由 app.include_router(predictions.router) @app.get("/") async def read_root(): return {"message": "Prediction service is running"} # 运行方式: uvicorn app.main:app --reloaddigraph ProjectStructure { rankdir=LR; node [shape=folder, style=filled, fillcolor="#e9ecef", fontname="Arial"]; edge [fontname="Arial"]; subgraph cluster_app { label = "app/"; style=filled; color="#dee2e6"; main [label="main.py", shape=note, fillcolor="#ffffff"]; subgraph cluster_routers { label = "routers/"; style=filled; color="#ced4da"; predictions_py [label="predictions.py", shape=note, fillcolor="#ffffff"]; } subgraph cluster_models { label = "models/"; style=filled; color="#ced4da"; schemas_py [label="schemas.py", shape=note, fillcolor="#ffffff"]; } main -> predictions_py [label=" 引入"]; predictions_py -> schemas_py [label=" 导入"]; } subgraph cluster_tests { label = "tests/"; test_predictions_py [label="test_predictions.py", shape=note, fillcolor="#ffffff"]; } main_app [label="FastAPI 应用", shape=component, fillcolor="#74c0fc"]; router_obj [label="APIRouter", shape=component, fillcolor="#91a7ff"]; pydantic_models [label="Pydantic 模型", shape=component, fillcolor="#a5d8ff"]; main -> main_app [style=invis]; // Position helper predictions_py -> router_obj [style=invis]; // Position helper schemas_py -> pydantic_models [style=invis]; // Position helper main_app -> router_obj [label=" 引入"]; router_obj -> pydantic_models [label=" 使用"]; test_predictions_py -> main_app [label=" 测试"]; }简化图示,说明了重构后的项目结构以及main.py、预测路由和Pydantic模型之间的交互。现在,我们的预测逻辑整齐地存放于app/routers/predictions.py中,而main.py则更简洁,专注于应用设置和路由。使用 TestClient 测试预测服务结构就绪后,我们来编写测试。我们将使用pytest和FastAPI的TestClient。安装 Pytest: 如果你尚未安装,请安装pytest:pip install pytest创建测试文件: 创建tests/test_predictions.py。编写测试:# tests/test_predictions.py from fastapi.testclient import TestClient from app.main import app # 导入 FastAPI 应用实例 from app.models.schemas import InputFeatures # 如果需要,为类型提示导入 # 使用我们的 FastAPI 应用创建一个 TestClient 实例 client = TestClient(app) def test_read_root(): """测试根接口。""" response = client.get("/") assert response.status_code == 200 assert response.json() == {"message": "Prediction service is running"} def test_make_prediction_success(): """使用有效输入测试预测接口。""" # 定义符合 InputFeatures 模式的有效输入数据 valid_input = {"feature1": 5.1, "feature2": 3.5} # 向 /predict/ 接口发送 POST 请求 response = client.post("/predict/", json=valid_input) # 断言请求成功 (HTTP 200 OK) assert response.status_code == 200 # 断言响应体结构符合 PredictionOutput response_data = response.json() assert "prediction" in response_data # 可选:断言预测结果的类型 assert isinstance(response_data["prediction"], float) # 注意:断言确切的预测值取决于你的模型 # 并且可能需要固定的测试数据集或模拟模型。 # 为简单起见,这里我们关注结构和状态。 def test_make_prediction_invalid_input_type(): """使用不正确输入数据类型测试预测接口。""" # 发送特征为字符串而非浮点数的数据 invalid_input = {"feature1": "wrong_type", "feature2": 3.5} response = client.post("/predict/", json=invalid_input) # FastAPI/Pydantic 自动处理验证错误 # 期望 HTTP 422 (不可处理实体) assert response.status_code == 422 # 检查响应体是否包含验证错误详情 response_data = response.json() assert "detail" in response_data # 如果需要,你可以对错误消息添加更具体的检查 # 例如,assert "feature1" in str(response_data["detail"]) def test_make_prediction_missing_input_feature(): """测试缺少输入数据的预测接口。""" # 发送缺少 'feature2' 的数据 missing_input = {"feature1": 5.1} response = client.post("/predict/", json=missing_input) # 期望 HTTP 422 (不可处理实体) assert response.status_code == 422 response_data = response.json() assert "detail" in response_data # 例如,assert "feature2" in str(response_data["detail"]) # 例如,assert "field required" in str(response_data["detail"])运行测试: 在终端中导航到项目的根目录(your_project/),然后运行pytest:pytestPytest 将发现并运行 tests 目录中的测试。你将看到指示测试通过或失败的输出。这个实践练习展示了如何应用结构化原则,使用APIRouter来组织你的预测服务,以及如何使用TestClient编写有效的测试。这种方法大幅提高了可维护性,并确保你的API行为符合预期,即使在重构或添加新功能之后也是如此。随着你的应用规模扩大,这种分离和测试的价值会越来越大。