Once your model is packaged as a SavedModel, the next step is making it accessible for inference. TensorFlow Serving is a dedicated, high-performance system specifically designed for deploying machine learning models in production environments. It takes your SavedModels and exposes them through network endpoints, allowing client applications to easily send inference requests and receive predictions.
TensorFlow Serving primarily offers two protocols for communication: REST (Representational State Transfer) and gRPC (gRPC Remote Procedure Calls). Understanding how to interact with these protocols is fundamental to integrating your deployed models into larger applications.
Before interacting with the APIs, you need TensorFlow Serving running with your model. While detailed setup is beyond this section's scope, a common approach involves using Docker. Assuming you have Docker installed and your SavedModel is located at /path/to/your/model/1
(where 1
represents the version number), you might start the server like this:
docker run -p 8501:8501 --mount type=bind,source=/path/to/your/model/,target=/models/my_model \
-e MODEL_NAME=my_model -t tensorflow/serving
This command:
/models/my_model
.MODEL_NAME
to my_model
, which is how you'll refer to the model in API calls.Note: For gRPC access, you would typically map port 8500 as well (
-p 8500:8500
). Ensure your model directory structure follows the TF Serving convention (model_name/version_number/saved_model.pb
and assets/variables).
The RESTful API is often the simpler way to start interacting with TF Serving, leveraging standard HTTP methods and JSON for data exchange. It's widely compatible with various programming languages and tools.
TensorFlow Serving exposes models via predictable URLs. The most common endpoint for prediction is:
POST http://<host>:<port>/v1/models/<model_name>[:predict]
Or, if you want to target a specific version:
POST http://<host>:<port>/v1/models/<model_name>/versions/<version_number>[:predict]
<host>
: The hostname or IP address where TF Serving is running (e.g., localhost
).<port>
: The port mapped for the REST API (default is 8501).<model_name>
: The name assigned to your model (e.g., my_model
from the Docker command).<version_number>
: Optional specific version to query. If omitted, TF Serving typically uses the latest version available.The body of the POST request must be a JSON object. The structure depends on your model's signature, but a common format uses the instances
key. The value associated with instances
is a list, where each element represents a single input instance (or a batch of instances, depending on how your model expects input).
Suppose your model expects a single input tensor named input_features
with a shape of (batch_size, 784)
(like a flattened MNIST image):
{
"instances": [
[0.0, 0.1, ..., 0.9],
[0.5, 0.2, ..., 0.0],
...
]
}
If your model has multiple named inputs defined in its signature (e.g., image_input
and metadata_input
), you use a different format, providing a dictionary for each instance:
{
"instances": [
{
"image_input": [[0.0, ...], [0.1, ...]],
"metadata_input": [1.0, 2.5]
},
{
"image_input": [[0.5, ...], [0.3, ...]],
"metadata_input": [0.5, 1.2]
}
]
}
You can inspect your SavedModel's signature using the saved_model_cli
tool to determine the expected input names and formats:
saved_model_cli show --dir /path/to/your/model/1 --tag_set serve --signature_def serving_default
Here's how you might send a request using Python's requests
library:
import requests
import json
import numpy as np
# Assume TF Serving is running on localhost:8501 with model 'my_model'
url = "http://localhost:8501/v1/models/my_model:predict"
# Example: Create two dummy input instances (e.g., flattened 28x28 images)
# Replace with your actual data preprocessing
input_data = np.random.rand(2, 784).tolist()
# Construct the request payload
request_payload = json.dumps({"instances": input_data})
# Set headers
headers = {"content-type": "application/json"}
try:
# Send the POST request
response = requests.post(url, data=request_payload, headers=headers)
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
# Parse the JSON response
predictions = response.json()['predictions']
print("Predictions received:")
print(predictions)
except requests.exceptions.RequestException as e:
print(f"Error making request: {e}")
except json.JSONDecodeError:
print(f"Error decoding JSON response: {response.text}")
except KeyError:
print(f"Key 'predictions' not found in response: {response.json()}")
curl
or web browsers), wide compatibility.gRPC is a modern, high-performance RPC framework developed by Google. It uses HTTP/2 for transport and Protocol Buffers (Protobufs) as its interface definition language and message interchange format. gRPC generally offers lower latency and higher throughput compared to REST, making it suitable for performance-sensitive applications.
To use gRPC, you typically need:
.proto
files)..proto
files.Fortunately, the tensorflow-serving-api
Python package provides the necessary libraries and pre-generated stubs.
pip install grpcio tensorflow-serving-api
Interacting via gRPC involves creating Protobuf request objects, establishing a channel to the server, creating a client stub, and making the remote procedure call.
import grpc
import numpy as np
import tensorflow as tf
# Import generated gRPC classes
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
# Assume TF Serving is running on localhost:8500 with model 'my_model'
server_address = 'localhost:8500'
model_name = 'my_model'
# Specify the signature name if not 'serving_default'
# signature_name = 'serving_default'
try:
# Create a gRPC channel
channel = grpc.insecure_channel(server_address)
# Create a stub (client)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
# Example: Create two dummy input instances
# Replace with your actual data preprocessing
input_data = np.random.rand(2, 784).astype(np.float32)
# Create a PredictRequest Protobuf object
request = predict_pb2.PredictRequest()
request.model_spec.name = model_name
# request.model_spec.signature_name = signature_name # Uncomment if needed
# Map the input data to the correct input tensor name in the signature
# Use 'tf.make_tensor_proto' to convert NumPy array to TensorProto
# Ensure the dtype matches the model's expected input type
request.inputs['input_features'].CopyFrom(
tf.make_tensor_proto(input_data, shape=input_data.shape, dtype=tf.float32)
)
# Set a timeout for the request (e.g., 10 seconds)
timeout_seconds = 10.0
result_future = stub.Predict.future(request, timeout_seconds)
result = result_future.result() # Wait for the response
# Parse the response (which is a PredictResponse Protobuf object)
# Access the output tensor by its name (e.g., 'output_scores')
# Use 'tf.make_ndarray' to convert TensorProto back to NumPy array
# Replace 'output_scores' with your actual output tensor name
predictions = tf.make_ndarray(result.outputs['output_scores'])
print("Predictions received:")
print(predictions)
except grpc.RpcError as e:
print(f"gRPC error: {e.status()}")
print(f"Details: {e.details()}")
except Exception as e:
print(f"An unexpected error occurred: {e}")
finally:
# Ensure the channel is closed if it was opened
if 'channel' in locals() and channel:
channel.close()
Important: The input and output tensor names (
'input_features'
,'output_scores'
in the example) must exactly match the names defined in your SavedModel's signature definition. Usesaved_model_cli
to find these names.
tensorflow-serving-api
).The choice depends on your specific requirements:
Client applications can connect to TensorFlow Serving using either REST over HTTP/1.1 with JSON payloads (typically on port 8501) or the more performant gRPC over HTTP/2 with Protobuf payloads (typically on port 8500).
By mastering both REST and gRPC interactions with TensorFlow Serving, you gain the flexibility to deploy your models effectively, choosing the protocol that best balances performance needs with development simplicity for your specific use case. This sets the stage for building robust and scalable machine learning inference services.
© 2025 ApX Machine Learning