趋近智
测试执行机器学习预测的端点,需要除标准API逻辑验证之外的特定策略。TestClient是通用端点测试的常用工具。然而,预测端点涉及到外部依赖,如加载的模型和可能复杂的推理逻辑,这使得直接测试有时不切实际或速度慢。因此,提供有效进行ML预测端点单元测试和集成测试的方法。
测试预测端点的主要目的,通常不是重新验证模型的准确性(这通常在ML训练和评估阶段处理),而是确保模型周围的API能正常工作。我们希望验证以下几点:
与其他端点一样,FastAPI的TestClient是测试预测端点的主要工具。您模拟HTTP请求(例如,带有预测输入的POST请求),并断言预期的HTTP状态码和响应体。
# 示例结构(假设使用pytest和TestClient夹具'client')
from fastapi import status
from pydantic import BaseModel
# 假设这些在其他地方定义
# from my_app.schemas import PredictionInput, PredictionOutput
# from my_app.main import app
class PredictionInput(BaseModel):
feature1: float
feature2: str
class PredictionOutput(BaseModel):
prediction: float
probability: float | None = None
# 在您的测试文件(例如,test_predictions.py)中
def test_predict_endpoint_success(client):
# 根据PredictionInput模式定义有效输入数据
input_data = {"feature1": 10.5, "feature2": "categoryA"}
response = client.post("/predict", json=input_data)
assert response.status_code == status.HTTP_200_OK
# 假设端点返回与PredictionOutput匹配的数据
response_data = response.json()
assert "prediction" in response_data
# 基于预期输出结构的进一步断言...
# 例如,检查类型:
assert isinstance(response_data["prediction"], float)
def test_predict_endpoint_invalid_input(client):
# 缺少必需字段的输入数据
invalid_input_data = {"feature1": 10.5}
response = client.post("/predict", json=invalid_input_data)
# FastAPI/Pydantic自动处理验证错误
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
主要的挑战来自ML模型本身。在单元测试期间加载和运行真实模型可能很慢、消耗资源,并引入外部因素(如文件路径),从而使测试复杂化。我们需要方法将API逻辑与实际模型推理分离,以实现更快、更可靠的单元测试。
模拟是指在测试运行时,用一个替代品(“模拟对象”)替换实际的预测函数或模型对象。这个模拟对象可以被编程为返回预定义输出,这样您就可以测试API的行为,而无需执行真实的模型推理。
FastAPI的依赖注入系统提供了一种巧妙的方式,可以使用app.dependency_overrides来实现这一点。您可以覆盖提供模型或预测函数本身的依赖。
让我们假设您的预测端点使用依赖来获取预测结果:
# 在您的应用中(例如,my_app/predictor.py)
def get_model():
# 加载实际ML模型的逻辑
# ...
# return loaded_model
pass # 占位符
def perform_prediction(data: PredictionInput, model = Depends(get_model)):
# 使用加载模型的实际预测逻辑
# prediction_result = model.predict(processed_data)
# return {"prediction": prediction_result, "probability": 0.95} # 示例
# 为了演示,返回一个固定结构
return {"prediction": data.feature1 * 2, "probability": 0.9}
# 在您的FastAPI主文件(例如,my_app/main.py)中
from fastapi import FastAPI, Depends
from .schemas import PredictionInput, PredictionOutput
from .predictor import perform_prediction
app = FastAPI()
@app.post("/predict", response_model=PredictionOutput)
async def predict(data: PredictionInput, result: dict = Depends(perform_prediction)):
# 依赖注入处理对perform_prediction的调用
# 注意:为了简化,这里将perform_prediction更改为直接返回一个字典
# 更优的方法可能涉及基于类的依赖
return result
现在,在您的测试文件中,您可以覆盖perform_prediction依赖:
# 在您的测试文件(例如,test_predictions.py)中
from fastapi.testclient import TestClient
from my_app.main import app # 导入您的FastAPI应用实例
from my_app.predictor import perform_prediction # 导入原始依赖
from my_app.schemas import PredictionInput, PredictionOutput
# 创建一个用于测试的模拟预测函数
async def mock_perform_prediction(data: PredictionInput):
# 模拟预测逻辑 - 返回一个固定的、已知的结果
# 如果需要,可以在这里添加关于'data'的断言
print(f"模拟预测调用,参数为: {data}")
return {"prediction": 123.45, "probability": 0.88}
# 使用FastAPI的依赖覆盖功能
app.dependency_overrides[perform_prediction] = mock_perform_prediction
client = TestClient(app) # 在覆盖之后创建TestClient
def test_predict_with_mock(client):
input_data = {"feature1": 10.5, "feature2": "categoryA"}
response = client.post("/predict", json=input_data)
assert response.status_code == 200
response_data = response.json()
# 对模拟函数中定义的输出进行断言
assert response_data["prediction"] == 123.45
assert response_data["probability"] == 0.88
# 请记住,如果其他测试需要原始依赖,则清除此覆盖
# 这通常通过pytest夹具处理得更好(见下文)
def teardown_function(): # 使用pytest teardown的示例
app.dependency_overrides.clear()
使用pytest夹具实现更清晰的覆盖:
pytest夹具提供了一种更清晰的方式来管理设置和清理,包括依赖覆盖:
# 在conftest.py或您的测试文件中
import pytest
from fastapi.testclient import TestClient
from my_app.main import app
from my_app.predictor import perform_prediction
from my_app.schemas import PredictionInput
@pytest.fixture(scope="function") # 范围可以调整
def client_with_mock_predictor():
# 在夹具内部定义模拟函数
async def mock_perform_prediction_fixture(data: PredictionInput):
return {"prediction": 123.45, "probability": 0.88}
# 应用覆盖
app.dependency_overrides[perform_prediction] = mock_perform_prediction_fixture
# 返回TestClient
yield TestClient(app)
# 清理:在使用此夹具的测试完成后清除覆盖
app.dependency_overrides.clear()
# 在您的测试文件(例如,test_predictions.py)中
def test_predict_with_fixture(client_with_mock_predictor): # 使用夹具
input_data = {"feature1": 10.5, "feature2": "categoryA"}
# 使用夹具提供的客户端
response = client_with_mock_predictor.post("/predict", json=input_data)
assert response.status_code == 200
response_data = response.json()
assert response_data["prediction"] == 123.45
assert response_data["probability"] == 0.88
模拟的优点:
模拟的缺点:
模拟的另一种方法是在测试期间使用一个非常简单、快速的“哑”模型。这个哑模型应该模仿您真实模型的接口(例如,predict方法),但执行一个微不足道的操作。
您可以使用依赖注入来实现这一点,类似于模拟,但不是模拟预测函数,而是覆盖加载模型的依赖(在我们之前的示例中是get_model)。
# 在您的测试设置中(例如,conftest.py或测试文件)
class DummyModel:
"""用于测试的简单模型替代品。"""
def predict(self, input_data):
# 简单逻辑,例如,返回固定值或根据输入长度
print("DummyModel predict 被调用")
return [sum(input_data.values())] # 示例哑预测
def predict_proba(self, input_data):
# 哑概率
return [[0.1, 0.9]] # 示例
def get_dummy_model():
print("提供哑模型")
return DummyModel()
# 在您的pytest夹具或测试设置中
@pytest.fixture(scope="function")
def client_with_dummy_model():
# 假设get_model是用于加载模型的依赖
# from my_app.predictor import get_model # 导入原始依赖
# app.dependency_overrides[get_model] = get_dummy_model # 覆盖它
# 如果perform_prediction直接使用模型:
# 您可能需要以不同方式组织依赖关系,
# 例如,让perform_prediction通过Depends接受模型
# async def perform_prediction(data: Input, model = Depends(get_model)): ...
# 为了演示,我们假设调整perform_prediction以使用来自get_model的模型
# 这部分需要根据您的实际应用结构进行调整
# app.dependency_overrides[get_model] = get_dummy_model # 示例覆盖
yield TestClient(app) # 假设覆盖已正确应用
app.dependency_overrides.clear() # 清理
# 在您的测试文件中
# def test_predict_with_dummy_model(client_with_dummy_model):
# # ... 使用客户端执行测试 ...
# # 断言将取决于哑模型的行为
注意:具体的实现方式很大程度上取决于您的模型如何在端点逻辑中加载和访问。请确保您的依赖关系结构允许覆盖模型提供者。
哑模型的优点:
哑模型的缺点:
无论采用何种策略(模拟或哑模型),请将测试设计为覆盖各种场景:
response_model匹配。有效地测试预测端点,需要将API逻辑与ML模型本身的复杂性在单元测试期间进行分离。将TestClient与依赖覆盖相结合以注入模拟对象或哑模型,提供了一种方法来确保您的API行为正确,按预期处理输入和输出,并与预测机制恰当集成,从而带来更可靠的ML部署。请记住,除了这些单元/集成测试之外,还要辅以针对模型在ML工作流中性能和准确性的独立专用测试。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造