Having optimized your models for inference using techniques like TorchScript, quantization, or pruning, the final step is making them available to end-users or downstream applications. Manually building and managing a serving infrastructure can be complex, involving API development, request handling, scaling, and monitoring. TorchServe is a dedicated tool developed to simplify this process specifically for PyTorch models. It provides a standardized way to package, deploy, manage, and serve your trained models in production environments.
TorchServe acts as a bridge between your optimized PyTorch models and the applications that need to consume their predictions. It handles the operational aspects of model serving, allowing you to focus on model development and integration.
TorchServe is designed with flexibility and performance in mind. Its main components work together to provide a robust serving solution:
Model Archiver (torch-model-archiver
): This command-line utility is the first step in preparing your model for TorchServe. It bundles all necessary artifacts into a single archive file with a .mar
extension. These artifacts typically include:
.pt
file or a standard state_dict
).TorchServe Runtime: This is the core server process. It listens for incoming requests on predefined network ports. It manages the lifecycle of deployed models, including loading them into memory, scaling the number of worker processes per model, and routing inference requests to the appropriate workers.
Handlers: Handlers are Python scripts or classes that define how TorchServe interacts with your specific model. They encapsulate the logic for:
initialize(context)
: Called once when the model is loaded. Used for loading the model into memory and any one-time setup.preprocess(data)
: Transforms incoming request data into the format expected by the model's forward
method (e.g., decoding images, tokenizing text).inference(model_input)
: Runs the actual inference by calling the model's prediction function.postprocess(inference_output)
: Transforms the model's raw output into a user-friendly format (e.g., mapping class indices to labels, formatting JSON).
TorchServe provides several built-in handlers for common tasks (image classification, object detection, text classification), but you can easily create custom handlers for specialized model inputs/outputs or complex workflows.APIs: TorchServe exposes two primary REST APIs:
gRPC support is also available for lower-latency communication, particularly useful in microservice architectures.
Deploying a model with TorchServe generally follows these steps:
torch.jit.save()
for TorchScript or torch.save()
for state_dict
) and gather any necessary auxiliary files..py
file).torch-model-archiver
to package the model, handler, and other files into a .mar
archive..mar
files will be located or registered..mar
file and prepare it for serving. You can specify the initial number of workers.High-level overview of the TorchServe deployment workflow, showing the preparation steps and the runtime request handling.
The torch-model-archiver
tool is central to packaging your model. Here's a typical command structure:
torch-model-archiver \
--model-name my_transformer_model \
--version 1.0 \
--serialized-file traced_transformer.pt \
--handler transformer_handler.py \
--extra-files "vocab.txt,config.json" \
--export-path /path/to/model-store \
--force
Let's break down the arguments:
--model-name
: A logical name for your model used in API calls (e.g., my_transformer_model
).--version
: A version string for the model (e.g., 1.0
). TorchServe can manage multiple versions of the same model.--serialized-file
: Path to your saved model file (e.g., the output of torch.jit.save
). If using state_dict
, you also need --model-file
pointing to your model definition Python file.--handler
: Path to your handler script (.py
). This can be one of the built-in handlers (e.g., image_classifier
, text_classifier
) or your custom script.--extra-files
: Comma-separated list of additional files needed by your model or handler (e.g., tokenizer configurations, vocabulary files, label mappings). These files will be accessible within the handler's initialize
method via the context object.--export-path
: The directory where the generated .mar
file will be saved. This is often set to the TorchServe model store directory.--force
: Overwrite the .mar
file if it already exists.Executing this command creates /path/to/model-store/my_transformer_model.mar
.
While TorchServe's built-in handlers cover many common use cases, you'll often need custom logic. A custom handler is a Python script containing a class (typically inheriting from BaseHandler
) that implements some or all of the following methods:
# custom_handler.py
import torch
import json
from ts.torch_handler.base_handler import BaseHandler
import logging
import os
logger = logging.getLogger(__name__)
class MyCustomHandler(BaseHandler):
"""
Custom handler for processing specific input/output formats.
"""
def __init__(self):
super().__init__()
self.initialized = False
self.model = None
self.mapping = None
def initialize(self, context):
"""
Load model and extra files. Called once when the model is loaded.
"""
self.manifest = context.manifest
properties = context.system_properties
model_dir = properties.get("model_dir") # Directory containing unpacked MAR contents
# Determine device
self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu")
logger.info(f"Handler initialized on device: {self.device}")
# Load the model (example assuming TorchScript)
serialized_file = self.manifest['model']['serializedFile']
model_pt_path = os.path.join(model_dir, serialized_file)
if not os.path.isfile(model_pt_path):
raise RuntimeError("Missing the model.pt file")
self.model = torch.jit.load(model_pt_path, map_location=self.device)
self.model.eval()
logger.info(f"Model {self.manifest['model']['modelName']} loaded successfully.")
# Load extra files (e.g., label mapping)
mapping_file_path = os.path.join(model_dir, "index_to_name.json") # Assuming passed via --extra-files
if os.path.isfile(mapping_file_path):
with open(mapping_file_path) as f:
self.mapping = json.load(f)
logger.info("Label mapping loaded successfully.")
else:
logger.warning("Mapping file not found.")
self.initialized = True
def preprocess(self, data):
"""
Transform raw input data into model input tensor(s).
'data' is a list of dictionaries, each containing raw request data.
"""
# Example: Assuming input is JSON like {'text': 'some input string'}
# This depends heavily on your specific application
processed_inputs = []
for row in data:
request_body = row.get("data") or row.get("body") # Handle different input sources
if isinstance(request_body, (bytes, bytearray)):
request_body = request_body.decode('utf-8')
# Add your specific preprocessing logic here (tokenization, tensor creation, etc.)
# For simplicity, let's assume request_body is the direct input needed
# In reality, you'd tokenize text, decode/resize images, etc.
# This part MUST return data in the format expected by self.model()
logger.info(f"Received input: {request_body}")
# Dummy preprocessing: just pass it through (replace with real logic)
processed_inputs.append(request_body)
# Example: Convert processed inputs to a batch tensor if needed by model
# input_tensor = self.tokenizer(processed_inputs, return_tensors="pt", padding=True, truncation=True).to(self.device)
# return input_tensor
return processed_inputs # Returning list for dummy example
def inference(self, model_input):
"""
Run inference using the model.
'model_input' is the output of preprocess().
"""
# Example: Run the loaded TorchScript model
# Output depends on your model structure
# Ensure model_input is on the correct device
# output = self.model(model_input.to(self.device))
# Dummy inference: just echo input (replace with real model call)
logger.info(f"Running dummy inference on: {model_input}")
with torch.no_grad(): # Essential for inference
# Replace this with: output = self.model(model_input)
output = [f"Processed: {item}" for item in model_input]
return output
def postprocess(self, inference_output):
"""
Transform model output into user-friendly format.
'inference_output' is the output of inference().
Returns a list of prediction results, one for each input request.
"""
# Example: Convert raw model output (e.g., logits) to labels/scores
# predictions = torch.softmax(inference_output, dim=1).argmax(dim=1).cpu().tolist()
# result = [self.mapping[str(pred)] if self.mapping else str(pred) for pred in predictions]
# Dummy postprocessing: just return the inference output
logger.info(f"Postprocessing: {inference_output}")
return inference_output # Should be a list
# Note: The handler file name must match the --handler argument value (without .py)
# If the handler class name is different from the capitalized file name,
# specify it in the MANIFEST.json or via the archiver.
This example shows the basic structure. You would replace the dummy logic with your actual preprocessing (e.g., image transformations, text tokenization), model invocation, and postprocessing (e.g., mapping output indices to class names, formatting JSON responses). The context
object provides access to system properties (like GPU availability), the model directory, and the manifest details.
Once you have your .mar
file in the model store, you start TorchServe:
# Create model store directory if it doesn't exist
mkdir /path/to/model-store
# Start TorchServe, pointing to the model store
torchserve --start \
--model-store /path/to/model-store \
--models my_model=/path/to/model-store/my_transformer_model.mar \
--ts-config /path/to/config.properties
--start
: Starts the TorchServe server in the background. Use torchserve --stop
to stop it.--model-store
: Specifies the directory containing .mar
files.--models
: (Optional) Pre-load and register specific models at startup. The format is model_name=model_archive.mar
or simply model_name.mar
if the file is in the model store. You can specify multiple models.--ts-config
: (Optional) Path to a configuration file (.properties
) for customizing ports, JVM arguments, logging, etc.After starting, you interact with the APIs, typically using curl
or a client library like requests
in Python.
Registering a Model (Management API):
curl -X POST "http://localhost:8081/models?url=my_transformer_model.mar&model_name=transformer&initial_workers=1&synchronous=true"
This command tells TorchServe to load my_transformer_model.mar
(assuming it's in the model store), register it with the logical name transformer
, start 1 worker process for it, and wait for the registration to complete.
Sending an Inference Request (Inference API):
The exact format depends on your handler's preprocess
method. If it expects raw image bytes:
curl -X POST http://localhost:8080/predictions/transformer -T image.jpg
If it expects JSON:
curl -X POST http://localhost:8080/predictions/transformer -H "Content-Type: application/json" -d '{"text": "This is an example sentence."}'
The response will contain the output from your handler's postprocess
method.
Scaling Workers (Management API):
If you need more throughput for a specific model, you can increase the number of worker processes:
curl -X PUT "http://localhost:8081/models/transformer?min_worker=4"
This scales the number of workers for the transformer
model to 4 (TorchServe handles load balancing across them).
Checking Status and Metrics:
curl http://localhost:8081/models
curl http://localhost:8081/models/transformer
curl http://localhost:8082/metrics
TorchServe is designed for production loads. Key features contributing to performance include:
preprocess
and inference
to process multiple requests simultaneously, improving GPU utilization. TorchServe also has built-in dynamic batching capabilities configurable via the Management API or config file.By leveraging TorchServe, you significantly reduce the engineering effort required to deploy and manage PyTorch models reliably and efficiently. It provides a standardized, feature-rich platform that integrates well with the PyTorch ecosystem and common MLOps practices, making it an essential tool for putting your advanced models into production.
© 2025 ApX Machine Learning