Alright, let's translate the theory of tree structures into practical code. In this section, we'll implement a basic Binary Search Tree (BST) to solidify our understanding of its operations and then visualize the structure of a simple Decision Tree, connecting the abstract data structure to its application in a popular machine learning model.
A Binary Search Tree provides efficient average-case time complexity for search, insertion, and deletion, typically O(logn), making it a fundamental structure to understand. Let's build one from scratch in Python.
First, we need a Node
class to represent each element in the tree:
class Node:
"""Represents a node in a Binary Search Tree."""
def __init__(self, key):
self.left = None
self.right = None
self.val = key
def __str__(self):
# Helper for printing node values
return str(self.val)
Each node stores a value (val
) and references to its left and right children.
Now, let's create the BST
class itself, starting with an insert
method. Remember the insertion logic: if the tree is empty, the new node becomes the root. Otherwise, compare the new key with the current node's key. If it's smaller, go left; if it's larger, go right. Repeat until an empty spot (None
) is found where the node can be inserted.
class BST:
"""Represents a Binary Search Tree."""
def __init__(self):
self.root = None
def insert(self, key):
"""Inserts a key into the BST."""
if self.root is None:
self.root = Node(key)
else:
self._insert_recursive(self.root, key)
def _insert_recursive(self, current_node, key):
"""Recursive helper for insertion."""
if key < current_node.val:
if current_node.left is None:
current_node.left = Node(key)
else:
self._insert_recursive(current_node.left, key)
elif key > current_node.val: # Ignore duplicate keys for simplicity
if current_node.right is None:
current_node.right = Node(key)
else:
self._insert_recursive(current_node.right, key)
# If key == current_node.val, we can choose to ignore it,
# update the node, or handle duplicates differently.
# Here, we ignore duplicates.
# --- Search and Traversal Methods will go here ---
Next, let's implement the search
operation. The logic mirrors insertion: compare the target key with the current node's key and traverse left or right accordingly. If we find the key, return True
. If we reach a None
reference, the key isn't in the tree, so return False
.
# Add these methods inside the BST class
def search(self, key):
"""Searches for a key in the BST."""
return self._search_recursive(self.root, key)
def _search_recursive(self, current_node, key):
"""Recursive helper for search."""
if current_node is None:
return False # Reached end, key not found
if key == current_node.val:
return True # Key found
elif key < current_node.val:
return self._search_recursive(current_node.left, key)
else: # key > current_node.val
return self._search_recursive(current_node.right, key)
Finally, let's add an in_order_traversal
. This traversal visits the left subtree, then the current node, then the right subtree. For a BST, this conveniently prints the keys in sorted order.
# Add this method inside the BST class
def in_order_traversal(self):
"""Performs in-order traversal and returns a list of keys."""
elements = []
self._in_order_recursive(self.root, elements)
return elements
def _in_order_recursive(self, current_node, elements):
"""Recursive helper for in-order traversal."""
if current_node:
self._in_order_recursive(current_node.left, elements)
elements.append(current_node.val)
self._in_order_recursive(current_node.right, elements)
Let's put it all together and see it in action:
# --- Include the Node and BST class definitions from above here ---
# Example Usage
bst = BST()
keys_to_insert = [50, 30, 70, 20, 40, 60, 80]
for key in keys_to_insert:
bst.insert(key)
print(f"Searching for 40: {bst.search(40)}") # Output: True
print(f"Searching for 90: {bst.search(90)}") # Output: False
print(f"In-order traversal: {bst.in_order_traversal()}")
# Output: [20, 30, 40, 50, 60, 70, 80]
Visualizing the structure helps understand the relationships. Here's a diagram representing the BST we just created:
The Binary Search Tree created by inserting the keys
[50, 30, 70, 20, 40, 60, 80]
. Notice how smaller values go left and larger values go right at each node.
This simple implementation demonstrates the core mechanics. Remember the discussion on balanced trees; if we inserted keys in sorted order (e.g., [20, 30, 40, 50, 60, 70, 80]
), this simple BST would degrade into a linked list structure with O(n) search time. Libraries and production systems use self-balancing trees (like AVL or Red-Black trees) to guarantee O(logn) performance.
Decision Trees are a popular machine learning model that inherently uses a tree structure to make predictions. Each internal node represents a test on a feature, each branch represents the outcome of the test, and each leaf node represents a class label or a continuous value.
While implementing the full training algorithm (which involves selecting the best splits based on criteria like Gini impurity or information gain) is beyond this section's scope, visualizing a trained tree helps connect the data structure concept to the model's operation. We'll use scikit-learn
to train a simple classifier and visualize its structure.
Let's use a small, synthetic dataset for clarity:
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.tree import DecisionTreeClassifier, export_graphviz, plot_tree
import graphviz # Optional: For generating Graphviz dot file
# Generate a simple binary classification dataset
X, y = make_classification(n_samples=100, n_features=2, n_informative=2,
n_redundant=0, n_clusters_per_class=1,
random_state=42, class_sep=1.5)
# Train a Decision Tree Classifier
# Limit depth for better visualization
dt_classifier = DecisionTreeClassifier(max_depth=3, random_state=42)
dt_classifier.fit(X, y)
print("Decision Tree trained successfully.")
# Option 1: Using scikit-learn's plot_tree (requires matplotlib)
plt.figure(figsize=(12, 8))
plot_tree(dt_classifier,
filled=True,
rounded=True,
class_names=['Class 0', 'Class 1'], # Use appropriate class names
feature_names=['Feature 1', 'Feature 2']) # Use appropriate feature names
# plt.show() # Uncomment to display the plot directly
# Option 2: Exporting to Graphviz (more customizable)
# Generates a 'decision_tree.dot' file and optionally a PDF
dot_data = export_graphviz(dt_classifier, out_file=None,
feature_names=['Feature 1', 'Feature 2'],
class_names=['Class 0', 'Class 1'],
filled=True, rounded=True,
special_characters=True)
# You can render this dot_data using the graphviz library
# graph = graphviz.Source(dot_data)
# graph.render("decision_tree") # Saves decision_tree.pdf
# print("Graphviz DOT data generated (and optionally rendered to PDF).")
# Example Graphviz string (truncated for brevity, represents structure)
# Note: Actual output from export_graphviz will be more detailed
example_dot = """
digraph Tree {
node [shape=box, style="filled, rounded", color="black", fontname="helvetica"] ;
edge [fontname="helvetica"] ;
0 [label="Feature 2 <= 0.05\\ngini = 0.5\\nsamples = 100\\nvalue = [50, 50]\\nclass = Class 0", fillcolor="#ffffff"] ;
1 [label="Feature 1 <= -0.8\\ngini = 0.18\\nsamples = 55\\nvalue = [50, 5]\\nclass = Class 0", fillcolor="#e8f4fd"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="gini = 0.0\\nsamples = 35\\nvalue = [35, 0]\\nclass = Class 0", fillcolor="#e58139"] ;
1 -> 2 ;
3 [label="Feature 2 <= -0.5\\ngini = 0.375\\nsamples = 20\\nvalue = [15, 5]\\nclass = Class 0", fillcolor="#f2c2a1"] ;
1 -> 3 ;
4 [label="Feature 1 <= 1.2\\ngini = 0.48\\nsamples = 45\\nvalue = [10, 35]\\nclass = Class 1", fillcolor="#baddf7"] ;
0 -> 4 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
5 [label="Feature 2 <= 1.0\\ngini = 0.24\\nsamples = 25\\nvalue = [20, 5]\\nclass = Class 0", fillcolor="#ea9a5f"] ;
4 -> 5 ;
6 [label="gini = 0.0\\nsamples = 20\\nvalue = [0, 20]\\nclass = Class 1", fillcolor="#399de5"] ;
4 -> 6 ;
}
"""
# Displaying the example dot structure
print("\nExample Graphviz structure (illustrative):")
# Normally you'd use the graphviz library to render dot_data
# For this example, we just show a simplified DOT string representation.
# print(example_dot) # You can print or render the 'dot_data' variable
An illustrative structure of a trained Decision Tree. Each node shows the splitting condition, impurity (gini), number of samples reaching the node, distribution of samples per class (value), and the predicted class for that node. Colors often indicate the majority class and impurity level.
Interpreting the nodes:
Feature 2 <= 0.05
) The rule used to divide data at an internal node.[50, 5]
means 50 samples of Class 0 and 5 samples of Class 1).This hands-on exercise demonstrates how tree structures are implemented and utilized. Building a BST highlights search and insertion mechanics, while visualizing a decision tree shows how this structure directly forms the basis of a machine learning model, partitioning the feature space to make predictions. Keep in mind the performance implications of balanced versus unbalanced trees as you consider using them in your ML workflows.
© 2025 ApX Machine Learning