Now that we have reviewed the standard components that constitute a TFX pipeline, let's solidify this understanding through a practical example. In this section, we will construct a basic end-to-end machine learning pipeline using several core TFX components. This hands-on exercise will demonstrate how these components connect and pass data (artifacts) between each other, automating the workflow from data ingestion to model deployment.
We will use a simplified version of the widely used Chicago Taxi Trips dataset for this example. Our goal is not to build the most accurate prediction model, but rather to illustrate the mechanics of constructing and running a TFX pipeline.
Before starting, ensure you have TFX installed in your Python environment. You can typically install it using pip:
pip install tfx
You might also need specific runners like Apache Beam depending on how you orchestrate the pipeline. For this local example, the default runners should suffice. We also assume you have a local directory where you can store the pipeline artifacts and the input data.
Let's assume our raw taxi data (e.g., data.csv
) is located in a known directory, ./data
. Our pipeline definition and supporting code will reside in a Python script (e.g., taxi_pipeline.py
), and TFX will generate outputs in a specified pipeline root directory (e.g., ./pipeline_output
).
A TFX pipeline is defined programmatically in Python. You import the necessary components and link them together, specifying inputs and outputs. The pipeline definition describes the desired workflow graph.
Let's start by setting up the basic structure in our taxi_pipeline.py
script. We need to define paths for our data, pipeline outputs, and any module files required by components like Transform
and Trainer
.
# taxi_pipeline.py
import os
import tfx
from tfx.components import (
CsvExampleGen, StatisticsGen, SchemaGen, ExampleValidator, Transform, Trainer, Evaluator, Pusher
)
from tfx.proto import example_gen_pb2, trainer_pb2, pusher_pb2
from tfx.orchestration.local.local_dag_runner import LocalDagRunner
from tfx.dsl.components.common importresolver
from tfx.dsl.experimental.latest_artifacts_resolver import LatestArtifactsResolver
from tfx.types import Channel
from tfx.types.standard_artifacts import Model, ModelBlessing
# Define paths
_pipeline_name = 'taxi_simple'
_data_root = './data' # Directory containing data.csv
_module_file = './taxi_utils.py' # File for Transform and Trainer code
_pipeline_root = os.path.join('./pipeline_output', _pipeline_name)
_serving_model_dir = os.path.join(_pipeline_root, 'serving_model')
# Ensure output directories exist (optional, TFX often handles this)
os.makedirs(_pipeline_root, exist_ok=True)
os.makedirs(_serving_model_dir, exist_ok=True)
Now, let's instantiate and connect the TFX components one by one.
Data Ingestion (CsvExampleGen
): This component reads data from external sources. Here, we use CsvExampleGen
to read our CSV file. It converts the data into TFExample
format, which is standard for TensorFlow training.
# In taxi_pipeline.py, continued...
# Input specification for CsvExampleGen
output_config = example_gen_pb2.Output(
split_config=example_gen_pb2.SplitConfig(splits=[
example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=4),
example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
])
)
example_gen = CsvExampleGen(
input_base=_data_root,
output_config=output_config
)
We specify input location (input_base
) and configure output splits (output_config
) for training and evaluation data. CsvExampleGen
produces an examples
artifact.
Data Validation (StatisticsGen
, SchemaGen
, ExampleValidator
): These components analyze and validate the ingested data.
StatisticsGen
computes statistics over the data.SchemaGen
infers a data schema based on the statistics.ExampleValidator
looks for anomalies by comparing statistics against the schema.# In taxi_pipeline.py, continued...
statistics_gen = StatisticsGen(
examples=example_gen.outputs['examples']
)
schema_gen = SchemaGen(
statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=True
)
example_validator = ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_gen.outputs['schema']
)
Notice how the output artifact of one component (e.g., example_gen.outputs['examples']
) becomes the input for the next.
Feature Engineering (Transform
): This component performs feature engineering using Apache Beam. It requires a separate Python file (_module_file
) containing a preprocessing_fn
. This function defines the transformations (e.g., scaling, one-hot encoding) to be applied consistently during training and serving.
# In taxi_utils.py (separate file)
import tensorflow as tf
import tensorflow_transform as tft
# Define features and the label
_NUMERIC_FEATURES = ['trip_miles', 'trip_seconds']
_CATEGORICAL_FEATURES = ['pickup_community_area', 'dropoff_community_area']
_LABEL_KEY = 'tips'
def _transformed_name(key):
return key + '_xf'
def preprocessing_fn(inputs):
"""tf.transform's callback function for preprocessing"""
outputs = {}
# Scale numeric features
for key in _NUMERIC_FEATURES:
outputs[_transformed_name(key)] = tft.scale_to_z_score(inputs[key])
# Generate vocabularies and map categorical features to integers
for key in _CATEGORICAL_FEATURES:
outputs[_transformed_name(key)] = tft.compute_and_apply_vocabulary(
inputs[key], top_k=100 # Example: Use top 100 areas
)
# Keep the label as is (assuming it's numeric for regression/classification)
# Or apply transformation if needed (e.g., tft.bucketize for classification)
outputs[_transformed_name(_LABEL_KEY)] = inputs[_LABEL_KEY]
return outputs
# In taxi_pipeline.py, continued...
transform = Transform(
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
module_file=_module_file # Points to taxi_utils.py
)
Transform
consumes the raw examples and schema, applies the preprocessing_fn
, and produces transformed_examples
and a transform_graph
artifact used for consistent application during serving.
Model Training (Trainer
): The Trainer
component trains a TensorFlow model. Similar to Transform
, it often uses a module file (_module_file
) containing a run_fn
or trainer_fn
that defines the model architecture, optimizer, loss, and training logic.
# In taxi_utils.py (separate file), add the trainer function
import tensorflow as tf
from tfx.components.trainer.fn_args_utils import FnArgs
def _build_keras_model(tf_transform_output):
"""Creates a Keras model for training."""
feature_spec = tf_transform_output.transformed_feature_spec()
# Remove label from feature_spec for input layer
feature_spec.pop(_transformed_name(_LABEL_KEY))
inputs = {
key: tf.keras.layers.Input(shape=spec.shape, name=key, dtype=spec.dtype)
for key, spec in feature_spec.items()
}
# Simple example: Concatenate numeric and embedding layers for categoricals
numeric_inputs = [_transformed_name(key) for key in _NUMERIC_FEATURES]
categorical_inputs = [_transformed_name(key) for key in _CATEGORICAL_FEATURES]
# Create embeddings for categoricals (adjust embedding dim as needed)
embedded_cats = []
for key in categorical_inputs:
vocab_size = tf_transform_output.vocabulary_size_by_name(key.replace('_xf',''))
embedding = tf.keras.layers.Embedding(input_dim=vocab_size + 1, output_dim=8)(inputs[key])
embedded_cats.append(tf.keras.layers.Flatten()(embedding))
# Concatenate all processed features
features = tf.keras.layers.concatenate(
[inputs[key] for key in numeric_inputs] + embedded_cats
)
# Simple DNN
x = tf.keras.layers.Dense(64, activation='relu')(features)
x = tf.keras.layers.Dense(32, activation='relu')(x)
output = tf.keras.layers.Dense(1)(x) # Output for regression (predict tips)
model = tf.keras.Model(inputs=inputs, outputs=output)
return model
def run_fn(fn_args: FnArgs):
"""Train the model based on given args."""
tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)
train_dataset = # Logic to load training dataset from fn_args.train_files
eval_dataset = # Logic to load eval dataset from fn_args.eval_files
model = _build_keras_model(tf_transform_output)
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss='mean_squared_error', # Assuming regression
metrics=[tf.keras.metrics.RootMeanSquaredError()]
)
model.fit(
train_dataset,
steps_per_epoch=fn_args.train_steps,
validation_data=eval_dataset,
validation_steps=fn_args.eval_steps
)
# Save model in SavedModel format
model.save(fn_args.serving_model_dir, save_format='tf')
# Placeholder for dataset loading logic (replace with actual implementation)
# This usually involves tf.data.TFRecordDataset and applying the transform graph
def _input_fn(file_pattern, tf_transform_output, batch_size=64):
# Example placeholder - needs full implementation
# dataset = tf.data.TFRecordDataset(tf.io.gfile.glob(file_pattern))
# dataset = dataset.map(lambda x: tf_transform_output.transform_raw_features(tf.io.parse_example(x,...)))
# dataset = dataset.batch(batch_size).repeat()
# return dataset
return None # Replace with actual dataset loading
Note: The run_fn
and especially the dataset loading (_input_fn
) require careful implementation involving tf.data
and applying the transform_graph
. The code above provides the structure; a complete implementation depends on the specific data schema and TF version.
# In taxi_pipeline.py, continued...
trainer = Trainer(
module_file=_module_file, # Points to taxi_utils.py containing run_fn
transformed_examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
schema=schema_gen.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=1000), # Example steps
eval_args=trainer_pb2.EvalArgs(num_steps=500) # Example steps
)
Trainer
uses the transformed examples, the transform graph (to ensure consistency), and the schema to train the model defined in run_fn
. It outputs a trained model
artifact.
Model Evaluation (Evaluator
): This component performs a deep analysis of the trained model's performance on the evaluation dataset. It uses TensorFlow Model Analysis (TFMA).
# In taxi_pipeline.py, continued...
from tfx.proto import evaluator_pb2
import tensorflow_model_analysis as tfma
eval_config = tfma.EvalConfig(
model_specs=[tfma.ModelSpec(label_key='tips_xf')], # Use transformed label name
slicing_specs=[tfma.SlicingSpec()], # Evaluate on overall dataset
metrics_specs=[
tfma.MetricsSpec(metrics=[
tfma.MetricConfig(class_name='ExampleCount'),
tfma.MetricConfig(class_name='RootMeanSquaredError',
threshold=tfma.MetricThreshold(
value_threshold=tfma.GenericValueThreshold(upper_bound={'value': 15.0}), # Example threshold
change_threshold=tfma.GenericChangeThreshold(
direction=tfma.MetricDirection.LOWER_IS_BETTER,
absolute={'value': -1e-10}))) # Require improvement vs baseline
])
]
)
# Resolver to find the latest blessed model for comparison
model_resolver = resolver.Resolver(
strategy_class=LatestArtifactsResolver,
model=Channel(type=Model),
model_blessing=Channel(type=ModelBlessing)
).with_id('latest_blessed_model_resolver')
evaluator = Evaluator(
examples=example_gen.outputs['examples'], # Use original examples for slicing
model=trainer.outputs['model'],
baseline_model=model_resolver.outputs['model'], # Compare against previous model
eval_config=eval_config,
example_splits=['eval'] # Evaluate on the 'eval' split
)
We define an EvalConfig
specifying metrics (like RMSE) and thresholds. Evaluator
compares the current model against a baseline (often the previously blessed model, found using a Resolver
) and outputs evaluation
results and a blessing
artifact indicating if the model passed the thresholds.
Model Deployment (Pusher
): Based on the evaluation results, the Pusher
component conditionally deploys the validated model to a specified serving location.
# In taxi_pipeline.py, continued...
pusher = Pusher(
model=trainer.outputs['model'],
model_blessing=evaluator.outputs['blessing'], # Push only if blessed
push_destination=pusher_pb2.PushDestination(
filesystem=pusher_pb2.PushDestination.Filesystem(
base_directory=_serving_model_dir
)
)
)
Pusher
checks the model_blessing
artifact from Evaluator
. If the model is blessed, it copies the model artifact (the SavedModel) to the push_destination
. Here, we push to a local filesystem directory.
With all components defined, we assemble them into a TFX Pipeline
object and use an orchestrator to run it. For local execution, LocalDagRunner
is suitable.
# In taxi_pipeline.py, continued...
from tfx.orchestration import pipeline
# Define the pipeline
components = [
example_gen,
statistics_gen,
schema_gen,
example_validator,
transform,
trainer,
model_resolver, # Make sure resolver runs before Evaluator
evaluator,
pusher,
]
pipeline = pipeline.Pipeline(
pipeline_name=_pipeline_name,
pipeline_root=_pipeline_root,
components=components,
enable_cache=True, # Use caching for unchanged components
metadata_connection_config=tfx.orchestration.metadata.sqlite_metadata_connection_config(
os.path.join(_pipeline_root, 'metadata.sqlite')) # Store metadata locally
)
# Run the pipeline locally
LocalDagRunner().run(pipeline)
A typical workflow graph for the simple TFX pipeline described. Arrows indicate the flow of artifacts between components.
To execute this pipeline, simply run the Python script:
python taxi_pipeline.py
TFX, using the LocalDagRunner
, will execute each component in the correct order based on the defined dependencies. It will generate artifacts (data splits, statistics, schema, transformed data, model checkpoints, evaluation results, and the final pushed model) within the _pipeline_root
directory structure. The metadata.sqlite
file tracks all executions, components, and artifacts, providing lineage and enabling caching.
After the pipeline finishes, explore the pipeline_output
directory. You will find subdirectories for each component execution, containing their respective output artifacts. For example:
CsvExampleGen/examples/...
: Contains the ingested data in TFRecord format.StatisticsGen/statistics/...
: Contains visualizations (e.g., using Facets) of data statistics.SchemaGen/schema/...
: Contains the inferred schema protobuf file.Transform/transform_graph/...
: Contains the TensorFlow graph for preprocessing.Trainer/model/...
: Contains the trained model in SavedModel format.Evaluator/evaluation/...
: Contains TFMA results viewable in a browser.Pusher/pushed_model/...
: Contains the final model copied for serving, if blessed.Inspecting these artifacts helps understand what each component does and verify the pipeline's execution.
This hands-on example provides a concrete starting point for building TFX pipelines. While simple, it demonstrates the core principles of component definition, artifact flow, and local orchestration. Real-world pipelines often involve more complex data, custom components, different orchestrators (like Kubeflow Pipelines or Apache Airflow), and more sophisticated model architectures and evaluation strategies, building upon the foundation established here.
© 2025 ApX Machine Learning