Even with a solid understanding of TensorFlow's execution models, tracking down errors in complex models or within tf.function
-decorated code can be challenging. Standard Python debugging techniques sometimes fall short when dealing with compiled graphs. This section introduces specific tools and strategies for debugging TensorFlow programs, particularly focusing on issues arising from graph execution, AutoGraph transformations, and gradient computations.
Your debugging approach often depends on whether the problematic code is running eagerly or within a tf.function
graph.
Eager Execution: When running eagerly (the default in TensorFlow 2), your code executes line by line, much like standard Python. This means you can often use familiar Python debugging tools:
print()
function: Use standard Python print()
statements to inspect the values of Python variables or the results of TensorFlow operations (which will be eager tensors).pdb
or IDE debuggers): Set breakpoints, step through code, and inspect variables interactively. This is highly effective for understanding logic flow and variable states before graph tracing occurs.Graph Execution (tf.function
): Once code is wrapped in tf.function
, TensorFlow traces it to create a static computational graph. Standard Python print()
or pdb
breakpoints placed inside the decorated function will typically only execute during the initial tracing phase, not during subsequent graph executions. This behavior can be misleading. New strategies are needed for graph-mode debugging.
tf.function
When you need to inspect values or control flow during the execution of a TensorFlow graph, use the following techniques:
tf.print
The tf.print
function is the graph-aware equivalent of Python's print
. It inserts a print operation directly into the TensorFlow graph. This ensures that the values of tensors are printed whenever the graph is executed, not just during tracing.
import tensorflow as tf
@tf.function
def problematic_function(x):
# Use tf.print to inspect tensor values inside the graph
tf.print("Inside tf.function, x =", x)
y = x * 2
tf.print("Intermediate value y =", y)
# Potential issue: Integer division might truncate unexpectedly
z = y // 3
tf.print("Final value z =", z)
return z
# Call the function
input_tensor = tf.constant([1, 5, 10], dtype=tf.int32)
result = problematic_function(input_tensor)
print("Result outside tf.function:", result)
# Example Output:
# Inside tf.function, x = [1 5 10]
# Intermediate value y = [2 10 20]
# Final value z = [0 3 6]
# Result outside tf.function: tf.Tensor([0 3 6], shape=(3,), dtype=int32)
tf.print
is invaluable for observing tensor values at different stages within a graph's execution path. Remember that tf.print
operations execute on the device where the computation is placed (CPU/GPU/TPU) and output might appear in different places (e.g., logs) depending on the execution environment, especially in distributed settings.
tf.function
For complex debugging scenarios, it can be helpful to force the function to run eagerly, allowing you to use standard Python debuggers. You can achieve this globally:
import tensorflow as tf
# Disable tf.function globally
tf.config.run_functions_eagerly(True)
@tf.function
def my_complex_logic(a, b):
# Now you can use pdb or print effectively here
# import pdb; pdb.set_trace()
print("Running eagerly, a:", a)
c = tf.matmul(a, b)
print("Running eagerly, c:", c)
return c
# Calls will now execute eagerly
matrix_a = tf.constant([[1.0, 2.0], [3.0, 4.0]])
matrix_b = tf.constant([[5.0], [6.0]])
result = my_complex_logic(matrix_a, matrix_b)
# Remember to turn it back off when done debugging
tf.config.run_functions_eagerly(False)
Running eagerly simplifies debugging but comes at a cost:
Use this technique judiciously to isolate problems, but always verify the fix works correctly with tf.function
enabled.
Understanding how tf.function
and AutoGraph translate your Python code into a computational graph can reveal unexpected structures or operations. TensorBoard provides a graph visualizer.
To use it, create a summary file writer within your tf.function
context and trace the function:
import tensorflow as tf
import datetime
@tf.function
def simple_graph(x, y):
return tf.add(x, y)
# Set up logging
stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = f'./logs/func/{stamp}'
writer = tf.summary.create_file_writer(logdir)
# Trace the function
input_tensor = tf.constant(1.0)
tf.summary.trace_on(graph=True, profiler=False) # Start tracing for graph
result = simple_graph(input_tensor, input_tensor) # Call the function to trace it
with writer.as_default():
tf.summary.trace_export(name="simple_graph_trace", step=0) # Export the graph
tf.summary.trace_off() # Stop tracing
print(f"Graph trace written to {logdir}")
# Now run: tensorboard --logdir ./logs/func
Launching TensorBoard (tensorboard --logdir ./logs/func
) and navigating to the "Graphs" tab will show the structure TensorFlow created. This helps identify if control flow (tf.cond
, tf.while_loop
) was generated as expected, or if specific operations are present.
A simplified representation of the
simple_graph
function's computational graph, showing inputs flowing into the addition operation.
tf.debugging
ModuleTensorFlow provides a dedicated debugging module, tf.debugging
, containing assertions that operate within the graph. These are useful for checking conditions during graph execution and raising errors if checks fail.
tf.debugging.check_numerics(tensor, message)
: Checks if a tensor contains any NaN
(Not a Number) or Inf
(Infinity) values. This is extremely useful for detecting numerical instability during training (e.g., exploding gradients).tf.debugging.assert_equal(x, y)
: Asserts that two tensors x
and y
have the same values element-wise.tf.debugging.assert_shapes(shapes, data=None)
: Asserts that the shapes of tensors match a specified list of shapes. This helps catch shape mismatch errors early. Example: tf.debugging.assert_shapes([(tensor1, (None, 10)), (tensor2, (5, None, 3))])
checks if tensor1
has shape (batch, 10)
and tensor2
has shape (5, time, 3)
.tf.debugging.Assert(condition, data)
: A general assertion that raises an error if the boolean condition
tensor evaluates to False
.These assertions add operations to the graph that perform the checks during execution.
import tensorflow as tf
@tf.function
def safe_divide(numerator, denominator):
tf.debugging.assert_greater(tf.abs(denominator), 1e-6,
message="Denominator close to zero!")
result = numerator / denominator
# Check for NaNs/Infs which might result from near-zero division
tf.debugging.check_numerics(result, "Numerical issues in result")
return result
# This will execute fine
print(safe_divide(tf.constant(10.0), tf.constant(2.0)))
# This will raise an InvalidArgumentError due to the assert_greater check
# try:
# print(safe_divide(tf.constant(1.0), tf.constant(1e-8)))
# except tf.errors.InvalidArgumentError as e:
# print(f"Caught expected error: {e}")
tf.GradientTape
Issues related to automatic differentiation are common, especially when building custom training loops or complex models.
None
GradientsOne frequent problem is receiving None
when requesting a gradient from tf.GradientTape
. This usually happens for one of these reasons:
tf.GradientTape
only watches trainable tf.Variable
objects. If you need gradients with respect to a tf.Tensor
, you must explicitly tell the tape to watch it using tape.watch(tensor)
.tf.cast
to an integer type, tf.round
, boolean operations, indexing with non-constant tensors).float32
, float64
).tf.Variable
s inside the gradient tape's context is generally discouraged and can lead to unexpected behavior or None
gradients. Initialize variables outside the tape scope.You can directly inspect the computed gradients to check if they are reasonable (e.g., not all zeros, not excessively large).
import tensorflow as tf
x = tf.Variable(3.0)
with tf.GradientTape() as tape:
y = x * x
# Compute the gradient dy/dx
grad = tape.gradient(y, x)
# Check if the gradient is None or inspect its value
if grad is not None:
print(f"Gradient dy/dx at x={x.numpy()}: {grad.numpy()}") # Should be 6.0
tf.debugging.check_numerics(grad, "Checking gradient for NaN/Inf")
else:
print("Gradient is None. Check tape watching and differentiability.")
Using tf.debugging.check_numerics
on gradients themselves is a good practice to catch exploding gradients early in the training process.
tf.print(tf.shape(tensor))
or tf.debugging.assert_shapes
liberally. Remember that shapes might be less defined during graph construction (None
dimensions) than during execution.dtype
): Ensure tensors being combined in operations have compatible and appropriate data types (usually float32
). Use tf.cast
explicitly if needed, but be aware it can break gradient flow if casting to non-float types.tf.function
and debug any graph-specific issues.tf.function
, convert smaller pieces first and verify they work before combining them. This helps pinpoint where AutoGraph might be struggling.tf.autograph.set_verbosity
level can provide detailed logs about the conversion process, potentially highlighting problematic Python constructs.Debugging TensorFlow code, especially within tf.function
graphs, requires adapting your strategies. Leveraging tf.print
, tf.debugging
, TensorBoard visualization, and the ability to toggle eager execution provides a powerful toolkit for identifying and resolving issues in advanced TensorFlow applications. These techniques form an important basis for building and maintaining the complex, high-performance models discussed in subsequent chapters.
© 2025 ApX Machine Learning