Implement optimization techniques on a real model. A standard PyTorch-based image classification model will be converted into a highly optimized TensorRT engine. Both the baseline and optimized versions are deployed on Triton Inference Server, followed by careful benchmarking of their performance differences. This exercise solidifies the workflow from a trained artifact to a production-ready, high-performance inference service.PrerequisitesTo complete this exercise, you will need a system with an NVIDIA GPU, Docker, and the NVIDIA Container Toolkit installed. This allows Docker containers to access the GPU. All model optimization and serving will be performed within a containerized environment to ensure reproducibility.First, pull the latest Triton Inference Server container from the NVIDIA NGC repository. This container includes Triton, all necessary CUDA libraries, and a Python environment with common deep learning frameworks.docker pull nvcr.io/nvidia/tritonserver:24.05-py3Note: The version tag (e.g., 24.05-py3) changes over time. You can find the latest available tag on the NVIDIA NGC catalog page for Triton.Step 1: Prepare the Baseline Model and RepositoryTriton serves models from a specially structured directory called a model repository. Each model has its own subdirectory containing the model artifacts and a config.pbtxt file that tells Triton how to serve it.Let's start by creating a model repository for a baseline ResNet-18 model from PyTorch.Create the directory structure:mkdir -p triton_repo/resnet18_pytorch/1Save the PyTorch model: Create a Python script named export_model.py to download a pretrained ResNet-18 model and save it in the TorchScript format, which Triton can directly execute.# export_model.py import torch import torchvision.models as models # Load a pretrained ResNet-18 model model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) model.eval() model.cuda() # Move model to GPU # Create a dummy input tensor # The shape must match what the model expects: (batch_size, channels, height, width) dummy_input = torch.randn(1, 3, 224, 224, device="cuda") # Trace the model with TorchScript traced_model = torch.jit.trace(model, dummy_input) # Save the traced model traced_model.save("triton_repo/resnet18_pytorch/1/model.pt") print("PyTorch model saved to triton_repo/resnet18_pytorch/1/model.pt")Run the script: python export_model.py.Create the configuration file: Now, create the config.pbtxt file inside triton_repo/resnet18_pytorch/. This file defines the model's metadata.# triton_repo/resnet18_pytorch/config.pbtxt name: "resnet18_pytorch" backend: "pytorch" max_batch_size: 64 input [ { name: "INPUT__0" data_type: TYPE_FP32 dims: [ 3, 224, 224 ] } ] output [ { name: "OUTPUT__0" data_type: TYPE_FP32 dims: [ 1000 ] } ] instance_group [ { count: 1 kind: KIND_GPU } ]Your model repository should now have the following structure:triton_repo/ └── resnet18_pytorch/ ├── 1/ │ └── model.pt └── config.pbtxtStep 2: Benchmark the Baseline PyTorch ModelWith the repository ready, launch the Triton server and point it to your model repository.docker run --rm --gpus all -p 8000:8000 -p 8001:8001 -p 8002:8002 \ -v $(pwd)/triton_repo:/models \ nvcr.io/nvidia/tritonserver:24.05-py3 tritonserver --model-repository=/modelsTriton will start and load the resnet18_pytorch model. To benchmark it, we use Triton's perf_analyzer tool, which is also included in the container. Open a new terminal and run the following command to execute perf_analyzer inside a new container connected to the server's network.docker run --rm --net=host nvcr.io/nvidia/tritonserver:24.05-py3 \ perf_analyzer -m resnet18_pytorch --concurrency-range 1:16 -u localhost:8001This command tests the model with increasing client-side concurrency levels from 1 to 16. After a few moments, you will see a summary table. Note the throughput (infer/sec) and p99 latency values. For an unoptimized ResNet-18, you might see something like this:***HINT*** Request concurrency: 16 Client: Request count: 2174 Throughput: 271.5 infer/sec p99 latency: 62105 usecThis is our performance baseline. Let's see how much we can improve it.Step 3: Optimize the Model with TensorRTNow we will convert the PyTorch model into a TensorRT engine. This process applies numerous optimizations, including layer fusion, precision calibration, and kernel auto-tuning for our specific GPU.Create a new Python script optimize_model.py to perform the conversion. You will need to install the torch-tensorrt library first: pip install torch-tensorrt.# optimize_model.py import torch import torch_tensorrt # Load the saved TorchScript model model = torch.jit.load("triton_repo/resnet18_pytorch/1/model.pt") model.eval().cuda() # Compile the model with TensorRT # We enable FP16 precision for a significant speedup trt_model = torch_tensorrt.compile(model, inputs=[ torch_tensorrt.Input( min_shape=(1, 3, 224, 224), opt_shape=(8, 3, 224, 224), # Typical batch size max_shape=(64, 3, 224, 224), # Matches max_batch_size in config dtype=torch.float32) ], enabled_precisions={torch.float16} # Enable FP16 ) # Save the TensorRT engine torch.jit.save(trt_model, "triton_repo/resnet18_trt/1/model.plan") print("TensorRT engine saved to triton_repo/resnet18_trt/1/model.plan")Before running the script, create the directory for the new model: mkdir -p triton_repo/resnet18_trt/1. Then, execute the script: python optimize_model.py.Step 4: Configure and Deploy the TensorRT ModelThe optimized model needs its own configuration file. The main difference is changing the backend to tensorrt and the model artifact filename to model.plan.Create the TensorRT config.pbtxt: Create a file at triton_repo/resnet18_trt/config.pbtxt.# triton_repo/resnet18_trt/config.pbtxt name: "resnet18_trt" backend: "tensorrt" max_batch_size: 64 input [ { name: "INPUT__0" data_type: TYPE_FP32 dims: [ 3, 224, 224 ] } ] output [ { name: "OUTPUT__0" data_type: TYPE_FP32 dims: [ 1000 ] } ] instance_group [ { count: 1 kind: KIND_GPU } ] # Enable dynamic batching for better GPU utilization dynamic_batching { preferred_batch_size: [8, 16] max_queue_delay_microseconds: 100 }Notice the addition of the dynamic_batching block. This Triton feature groups individual inference requests together into a larger batch to better saturate the GPU, a common strategy for improving throughput.Verify the repository structure: Your repository should now contain both models.digraph G { rankdir=TB; node [shape=folder, style=rounded, fontname="Arial"]; edge [arrowhead=none]; "triton_repo" -> {"resnet18_pytorch" "resnet18_trt"}; "resnet18_pytorch" -> {"config.pbtxt_py" "v1_py"}; "config.pbtxt_py" [label="config.pbtxt"]; "v1_py" [label="1"]; "v1_py" -> "model.pt"; "resnet18_trt" -> {"config.pbtxt_trt" "v1_trt"}; "config.pbtxt_trt" [label="config.pbtxt"]; "v1_trt" [label="1"]; "v1_trt" -> "model.plan"; } The deployment repository now holds both the baseline PyTorch model and the optimized TensorRT engine, allowing for direct comparison.Relaunch Triton: If your Triton server from Step 2 is still running, stop it (Ctrl+C). Then, relaunch it with the same command. It will now detect and load both the resnet18_pytorch and resnet18_trt models.Step 5: Benchmark the Optimized Model and CompareFinally, run perf_analyzer again, but this time targeting the new resnet18_trt model.docker run --rm --net=host nvcr.io/nvidia/tritonserver:24.05-py3 \ perf_analyzer -m resnet18_trt --concurrency-range 1:16 -u localhost:8001You should observe a dramatic improvement in performance. The output might look like this:***HINT*** Request concurrency: 16 Client: Request count: 13010 Throughput: 1625.3 infer/sec p99 latency: 10015 usecLet's visualize the results. The throughput has increased significantly, and the p99 latency has been drastically reduced.{"data":[{"x":["Baseline (PyTorch)","Optimized (TensorRT)"],"y":[271,1625],"type":"bar","name":"Throughput","marker":{"color":"#339af0"}},{"x":["Baseline (PyTorch)","Optimized (TensorRT)"],"y":[62,10],"type":"bar","name":"p99 Latency","yaxis":"y2","marker":{"color":"#ff922b"}}],"layout":{"title":{"text":"Performance Comparison: Baseline vs. TensorRT"},"xaxis":{"title":{"text":"Model Version"}},"yaxis":{"title":{"text":"Throughput (inferences/sec)","color":"#339af0"}},"yaxis2":{"title":{"text":"p99 Latency (ms)","color":"#ff922b"},"overlaying":"y","side":"right"},"legend":{"x":0.05,"y":0.95},"barmode":"group"}}Comparison of throughput and p99 latency between the original PyTorch model and the TensorRT-optimized version. Latency is measured in milliseconds (ms).This hands-on exercise demonstrates a standard and effective workflow for production model deployment. By converting a standard framework model to a specialized inference engine like TensorRT and serving it with an advanced server like Triton, you can achieve order-of-magnitude performance gains. This process is fundamental to building cost-effective and responsive AI services that meet demanding SLOs.