趋近智
神经网络 (neural network)的结构常使用 Keras 的 Sequential 或 Functional API 来定义。检查这些已定义网络的结构通常会很有帮助。了解网络中的层、层间的连接、数据在网络中流动的形状以及参数 (parameter)数量,这对于调试、验证设计和估算计算复杂性很重要。Keras 为此提供了便利的工具。
获知模型概况最直接的方式是 summary() 方法。它提供模型逐层的文本形式的描述。我们来看一个简单的 Sequential 模型:
import keras
from keras import layers
# 定义一个简单的Sequential模型
model = keras.Sequential(
[
keras.Input(shape=(784,), name="input_layer"), # 指定形状的输入层
layers.Dense(128, activation="relu", name="hidden_layer_1"),
layers.Dense(64, activation="relu", name="hidden_layer_2"),
layers.Dense(10, activation="softmax", name="output_layer"),
],
name="simple_classifier",
)
# 打印模型概览
model.summary()
运行 model.summary() 会输出类似如下内容:
模型: "simple_classifier"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ 层 (类型) ┃ 输出形状 ┃ 参数数量 ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ hidden_layer_1 (Dense) │ (None, 128) │ 100,480 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ hidden_layer_2 (Dense) │ (None, 64) │ 8,256 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ output_layer (Dense) │ (None, 10) │ 650 │
└─────────────────────────────────┴───────────────────────────┴────────────┘
总参数: 109,386 (427.29 KB)
可训练参数: 109,386 (427.29 KB)
不可训练参数: 0 (0.00 B)
我们来分解一下这些输出内容:
simple_classifier)。Dense)。请注意,Input 对象本身不作为层列在概览表中,但其输入形状用于计算第一个连接层的参数 (parameter)。None 维度表示批次大小,它通常是可变的,直到训练时才确定。例如,(None, 128) 表示第一个 Dense 层输出的张量中,批次中的每个样本的形状都是 (128,)。Dense 层的这些参数是如何计算的:
hidden_layer_1:输入形状为 (784,)。该层有 128 个单元。权重数量为 。每个单元也有一个偏置项,所以总偏置数量 = 128。总参数 = 。hidden_layer_2:输入形状是前一层的输出形状,(128,)。该层有 64 个单元。权重 = 。偏置 = 64。总参数 = 。output_layer:输入形状为 (64,)。该层有 10 个单元。权重 = 。偏置 = 10。总参数 = 。summary() 方法非常有用,可用于快速检查各层连接是否符合预期、输出形状是否合理以及了解模型的大小。
尽管 summary() 很有用,但视觉图示通常能更清晰地展现模型结构,特别是对于使用函数式 API 构建的、涉及多输入、多输出或共享层的复杂结构。Keras 为此提供了 keras.utils.plot_model 函数。
要使用 plot_model,你可能需要安装额外的库:pydot 和 graphviz。通常可以使用 pip 安装它们:
pip install pydot graphviz
(注意:如果 Python 包不包含 Graphviz 二进制文件,你可能还需要在操作系统上单独安装它们。请查看 Graphviz 文档获取安装说明。)
依赖项准备好后,就可以绘制模型图:
# 假设变量 'model' 持有前面定义的 Keras 模型
keras.utils.plot_model(
model,
to_file="simple_classifier_model.png", # 将图保存到文件
show_shapes=True, # 显示形状信息
show_layer_names=True, # 显示层名称
show_layer_activations=True, # 显示激活函数
rankdir="TB" # 方向:TB=从上到下,LR=从左到右
)
这段代码会生成一个图像文件(simple_classifier_model.png),其中包含网络的图示。show_shapes、show_layer_names 和 show_layer_activations 参数 (parameter)会为图中的节点添加有用的信息。
以下是 plot_model 可能为我们的简单分类器生成的图示,使用 graphviz dot 语言表示:
一个表示简单分类器模型的图示。节点列出层名称、类型、输出形状和激活函数 (activation function)。箭头表示数据流向。
可视化模型尤其有益,在使用函数式 API 时,因为它清晰地显示了层之间的连接,这些连接可以是费线性的。这有助于确认你正确连接了各层,特别是在具有分支或多输入/输出的模型中。
summary() 和 plot_model 都是你 Keras 工具集中的重要工具,用于查看、理解和调试你构建的神经网络 (neural network)结构。
这部分内容有帮助吗?
summary() method, Keras Team, 2024 - Keras API Model.summary() 方法的官方参考文档,详细说明其输出以及如何解释网络架构和参数数量。plot_model() utility, Keras Team, 2024 - Keras API keras.utils.plot_model() 函数的官方参考文档,说明其用于可视化模型架构和依赖关系的使用方法。© 2026 ApX Machine LearningAI伦理与透明度•