趋近智
实现一个基本的二叉搜索树(BST)以演示其操作。然后可视化一个简单决策树的结构,呈现抽象数据结构在一个流行机器学习模型中的应用。
二叉搜索树为搜索、插入和删除操作提供了高效的平均时间复杂度,通常为 O(logn),使其成为一个需要掌握的基础结构。让我们用 Python 从头构建一个。
首先,我们需要一个 Node 类来表示树中的每个元素:
class Node:
"""表示二叉搜索树中的一个节点。"""
def __init__(self, key):
self.left = None
self.right = None
self.val = key
def __str__(self):
# 打印节点值的辅助函数
return str(self.val)
每个节点存储一个值 (val) 以及对其左右子节点的引用。
现在,让我们创建 BST 类本身,从 insert 方法开始。请记住插入逻辑:如果树为空,新节点成为根节点。否则,将新键与当前节点的键进行比较。如果新键较小,则向左;如果较大,则向右。重复此过程,直到找到可以插入节点的空位置 (None)。
class BST:
"""表示一个二叉搜索树。"""
def __init__(self):
self.root = None
def insert(self, key):
"""向BST中插入一个节点。"""
if self.root is None:
self.root = Node(key)
else:
self._insert_recursive(self.root, key)
def _insert_recursive(self, current_node, key):
"""插入的递归辅助函数。"""
if key < current_node.val:
if current_node.left is None:
current_node.left = Node(key)
else:
self._insert_recursive(current_node.left, key)
elif key > current_node.val: # 为简单起见忽略重复的键
if current_node.right is None:
current_node.right = Node(key)
else:
self._insert_recursive(current_node.right, key)
# 如果 current_node.val 相同,我们可以选择忽略它,
# 更新节点,或以不同方式处理重复项。
# 这里,我们忽略重复项。
# --- 搜索和遍历方法将放在这里 ---
接下来,让我们实现 search 操作。其逻辑与插入相似:将目标键与当前节点的键进行比较,并相应地向左或向右遍历。如果找到键,返回 True。如果达到 None 引用,则键不在树中,因此返回 False。
# 将这些方法添加到 BST 类中
def search(self, key):
"""在BST中搜索一个节点。"""
return self._search_recursive(self.root, key)
def _search_recursive(self, current_node, key):
"""搜索的递归辅助函数。"""
if current_node is None:
return False # 达到末尾,未找到
if key == current_node.val:
return True # 找到
elif key < current_node.val:
return self._search_recursive(current_node.left, key)
else: # current_node.val
return self._search_recursive(current_node.right, key)
最后,让我们添加一个 in_order_traversal(中序遍历)。这种遍历方式访问左子树,然后是当前节点,再是右子树。对于BST来说,这方便地按排序顺序打印键。
# 将此方法添加到 BST 类中
def in_order_traversal(self):
"""执行中序遍历并返回一个键列表。"""
elements = []
self._in_order_recursive(self.root, elements)
return elements
def _in_order_recursive(self, current_node, elements):
"""中序遍历的递归辅助函数。"""
if current_node:
self._in_order_recursive(current_node.left, elements)
elements.append(current_node.val)
self._in_order_recursive(current_node.right, elements)
让我们将所有代码整合起来并看它的运行效果:
# --- 在此处包含上面定义的 Node 和 BST 类 ---
# 示例用法
bst = BST()
keys_to_insert = [50, 30, 70, 20, 40, 60, 80]
for key in keys_to_insert:
bst.insert(key)
print(f"搜索 40:{bst.search(40)}") # 输出: True
print(f"搜索 90:{bst.search(90)}") # 输出: False
print(f"中序遍历:{bst.in_order_traversal()}")
# 输出: [20, 30, 40, 50, 60, 70, 80]
可视化结构有助于理解这些关系。这是表示我们刚刚创建的 BST 的图表:
通过插入键
[50, 30, 70, 20, 40, 60, 80]创建的二叉搜索树。请注意,在每个节点处,较小的值向左,较大的值向右。
这个简单的实现展示了核心机制。请记住关于平衡树的讨论;如果我们按排序顺序插入键(例如,[20, 30, 40, 50, 60, 70, 80]),这个简单的 BST 将退化为链表结构,搜索时间复杂度为 O(n)。库和生产系统使用自平衡树(如 AVL 树或红黑树)来保证 O(logn) 的性能。
决策树是一种流行的机器学习模型,它本身利用树结构进行预测。每个内部节点表示对一个特征的测试,每个分支表示测试结果,每个叶节点表示一个类别标签或一个连续值。
尽管实现完整的训练算法(涉及根据基尼不纯度或信息增益等标准选择最佳分裂)超出了本节的范围,但可视化训练好的树有助于将数据结构与模型的运作方式关联起来。我们将使用 scikit-learn 训练一个简单分类器并可视化其结构。
为了清晰起见,让我们使用一个小型的合成数据集:
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.tree import DecisionTreeClassifier, export_graphviz, plot_tree
import graphviz # 可选:用于生成 Graphviz dot 文件
# 生成一个简单的二元分类数据集
X, y = make_classification(n_samples=100, n_features=2, n_informative=2,
n_redundant=0, n_clusters_per_class=1,
random_state=42, class_sep=1.5)
# 训练一个决策树分类器
# 限制深度以便更好地可视化
dt_classifier = DecisionTreeClassifier(max_depth=3, random_state=42)
dt_classifier.fit(X, y)
print("决策树训练成功。")
# 选项 1: 使用 scikit-learn 的 plot_tree(需要 matplotlib)
plt.figure(figsize=(12, 8))
plot_tree(dt_classifier,
filled=True,
rounded=True,
class_names=['类别 0', '类别 1'], # 使用适当的类别名称
feature_names=['特征 1', '特征 2']) # 使用适当的特征名称
# plt.show() # 取消注释可直接显示图表
# 选项 2: 导出到 Graphviz(可定制性更高)
# 生成一个 'decision_tree.dot' 文件,可选地生成一个 PDF
dot_data = export_graphviz(dt_classifier, out_file=None,
feature_names=['特征 1', '特征 2'],
class_names=['类别 0', '类别 1'],
filled=True, rounded=True,
special_characters=True)
# 您可以使用 graphviz 库渲染此 dot_data
# graph = graphviz.Source(dot_data)
# graph.render("decision_tree") # 保存 decision_tree.pdf
# print("Graphviz DOT 数据已生成(可选地渲染为 PDF)。")
# Graphviz 字符串示例(为简洁起见已截断,表示结构)
# 注意: export_graphviz 的实际输出将更详细
example_dot = """
digraph Tree {
node [shape=box, style="filled, rounded", color="black", fontname="helvetica"] ;
edge [fontname="helvetica"] ;
0 [label="特征 2 <= 0.05\\ngini = 0.5\\n样本数 = 100\\n值 = [50, 50]\\n类别 = 类别 0", fillcolor="#ffffff"] ;
1 [label="特征 1 <= -0.8\\ngini = 0.18\\n样本数 = 55\\n值 = [50, 5]\\n类别 = 类别 0", fillcolor="#e8f4fd"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="真"] ;
2 [label="gini = 0.0\\n样本数 = 35\\n值 = [35, 0]\\n类别 = 类别 0", fillcolor="#e58139"] ;
1 -> 2 ;
3 [label="特征 2 <= -0.5\\ngini = 0.375\\n样本数 = 20\\n值 = [15, 5]\\n类别 = 类别 0", fillcolor="#f2c2a1"] ;
1 -> 3 ;
4 [label="特征 1 <= 1.2\\ngini = 0.48\\n样本数 = 45\\n值 = [10, 35]\\n类别 = 类别 1", fillcolor="#baddf7"] ;
0 -> 4 [labeldistance=2.5, labelangle=-45, headlabel="假"] ;
5 [label="特征 2 <= 1.0\\ngini = 0.24\\n样本数 = 25\\n值 = [20, 5]\\n类别 = 类别 0", fillcolor="#ea9a5f"] ;
4 -> 5 ;
6 [label="gini = 0.0\\n样本数 = 20\\n值 = [0, 20]\\n类别 = 类别 1", fillcolor="#399de5"] ;
4 -> 6 ;
}
"""
# 显示示例 dot 结构
print("\nGraphviz 结构示例(说明性):")
# 通常您会使用 graphviz 库来渲染 dot_data
# 对于此示例,我们只显示一个简化的 DOT 字符串表示。
# print(example_dot) # 您可以打印或渲染 'dot_data' 变量
训练好的决策树的说明性结构。每个节点显示分裂条件、不纯度(gini)、到达该节点的样本数量、每个类别样本的分布(值)以及该节点的预测类别。颜色通常表示多数类别和不纯度水平。
节点解读:
特征 2 <= 0.05) 用于在内部节点划分数据的规则。[50, 5] 表示类别 0 有 50 个样本,类别 1 有 5 个样本)。这个动手练习展示了树结构如何被实现和应用。构建 BST 突出了搜索和插入机制,而可视化决策树则说明了这种结构如何直接构成机器学习模型的基础,通过划分特征空间来进行预测。在考虑在您的机器学习工作流中使用树时,请记住平衡树与非平衡树的性能影响。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造