一个简单的贝叶斯神经网络(BNN)使用变分推断(VI)方法进行构建、训练和评估。重点是一个回归任务,它允许直观地展现模型的预测结果及其相关的不确定性。TensorFlow Probability(TFP),一个将概率推断和统计分析与TensorFlow结合起来的库,用于此目的。你应该安装了TensorFlow和TensorFlow Probability。如果没有,通常可以用pip安装它们:pip install tensorflow tensorflow-probability matplotlib numpy环境和数据准备首先,我们导入所需的库,并生成一些合成数据用于我们的回归问题。我们将创建这样的数据:输入 $x$ 和输出 $y$ 之间的关系是非线性的,并加入了一些噪声。这种噪声代表了偶然不确定性。import numpy as np import tensorflow as tf import tensorflow_probability as tfp import matplotlib.pyplot as plt import plotly.graph_objects as go # 为了结果的可复现性 np.random.seed(42) tf.random.set_seed(42) tfd = tfp.distributions tfk = tf.keras tfkl = tf.keras.layers tfpl = tfp.layers # 生成合成数据 def generate_data(n_samples=100, noise_std=0.1): X = np.linspace(-3, 3, n_samples).astype(np.float32).reshape(-1, 1) # 带有噪声的非线性函数 y = X * np.sin(X * 2) + np.random.normal(0, noise_std, size=(n_samples, 1)).astype(np.float32) return X, y X_train, y_train = generate_data(n_samples=150, noise_std=0.2) X_test = np.linspace(-4, 4, 200).astype(np.float32).reshape(-1, 1) # 可视化训练数据 fig = go.Figure() fig.add_trace(go.Scatter(x=X_train.flatten(), y=y_train.flatten(), mode='markers', name='训练数据', marker=dict(color='#1f77b4', size=6))) fig.update_layout( title='合成回归数据', xaxis_title='输入 (x)', yaxis_title='输出 (y)', template='plotly_white', legend_title_text='数据' ) # fig.show() # 在Python环境中运行此行以显示图表{"layout": {"title": "合成回归数据", "xaxis": {"title": "输入 (x)"}, "yaxis": {"title": "输出 (y)"}, "template": "plotly_white", "legend": {"title": {"text": "数据"}}}, "data": [{"x": [-3.0, -2.9597316, -2.9194632, -2.8791947, -2.8389263, -2.7986577, -2.7583892, -2.718121, -2.6778524, -2.637584, -2.5973153, -2.557047, -2.5167787, -2.4765103, -2.4362416, -2.3959732, -2.3557048, -2.3154364, -2.2751677, -2.2348993, -2.194631, -2.1543624, -2.114094, -2.0738256, -2.0335572, -1.9932885, -1.9530201, -1.9127518, -1.8724833, -1.8322148, -1.7919464, -1.751678, -1.7114094, -1.671141, -1.6308725, -1.5906041, -1.5503356, -1.5100671, -1.4697987, -1.4295303, -1.3892617, -1.3489933, -1.3087249, -1.2684565, -1.2281879, -1.1879195, -1.1476511, -1.1073826, -1.0671141, -1.0268457, -0.9865772, -0.94630873, -0.9060403, -0.8657718, -0.82550335, -0.7852349, -0.74496645, -0.70469797, -0.66442954, -0.62416106, -0.58389264, -0.54362416, -0.5033557, -0.46308723, -0.4228188, -0.38255033, -0.34228188, -0.30201343, -0.26174498, -0.22147651, -0.18120806, -0.1409396, -0.10067114, -0.06040268, -0.020134227, 0.020134227, 0.06040268, 0.10067114, 0.1409396, 0.18120806, 0.22147651, 0.26174498, 0.30201343, 0.34228188, 0.38255033, 0.4228188, 0.46308723, 0.5033557, 0.54362416, 0.58389264, 0.62416106, 0.66442954, 0.70469797, 0.74496645, 0.7852349, 0.82550335, 0.8657718, 0.9060403, 0.94630873, 0.9865772, 1.0268457, 1.0671141, 1.1073826, 1.1476511, 1.1879195, 1.2281879, 1.2684565, 1.3087249, 1.3489933, 1.3892617, 1.4295303, 1.4697987, 1.5100671, 1.5503356, 1.5906041, 1.6308725, 1.671141, 1.7114094, 1.751678, 1.7919464, 1.8322148, 1.8724833, 1.9127518, 1.9530201, 1.9932885, 2.0335572, 2.0738256, 2.114094, 2.1543624, 2.194631, 2.2348993, 2.2751677, 2.3154364, 2.3557048, 2.3959732, 2.4362416, 2.4765103, 2.5167787, 2.557047, 2.5973153, 2.637584, 2.6778524, 2.718121, 2.7583892, 2.7986577, 2.8389263, 2.8791947, 2.9194632, 2.9597316, 3.0], "y": [1.0303671, 1.050024, 0.34705496, 0.01273185, -0.30396357, -0.29729748, -0.8527952, -0.9483495, -1.1109892, -0.8656279, -0.98107433, -1.2403715, -1.2898401, -1.2215685, -1.1911366, -1.3070737, -1.1893226, -1.296894, -0.8970964, -0.88211715, -0.7332828, -0.60031426, -0.34907925, -0.32362396, -0.07006347, 0.13466883, 0.18986344, 0.19970965, 0.24910164, 0.3592739, 0.27973264, 0.30015373, 0.15951216, 0.2634468, 0.07025027, 0.053674817, 0.03876072, -0.19164044, -0.11680764, -0.09389448, -0.17775708, -0.2371642, -0.21847367, -0.12808496, 0.0019137263, -0.015195906, -0.073530376, -0.15822774, -0.33197582, -0.030967653, -0.07161963, 0.0657717, 0.0283497, 0.14192665, -0.020459652, 0.08628744, 0.14018708, 0.14055079, 0.05315751, 0.061556935, -0.27404195, -0.08554834, -0.26678258, 0.032500029, -0.06732887, 0.055856705, -0.09388715, 0.04359156, 0.09010857, -0.021685064, 0.1250062, 0.083204925, -0.2292543, -0.10864186, -0.10440737, 0.013017178, -0.04824865, 0.09417486, 0.27899224, 0.15981823, 0.22726798, 0.28625697, 0.3458653, 0.433839, 0.5591351, 0.70038676, 0.5945583, 0.83024424, 0.8329691, 0.9106376, 0.8723597, 0.78850543, 0.8853815, 0.75701463, 0.9694258, 1.0420696, 0.88944215, 0.8781106, 0.9598303, 1.0842764, 0.9817552, 1.1226631, 1.025401, 1.121171, 1.1287006, 1.1240332, 1.1152806, 0.9629704, 0.91996074, 1.147251, 0.8633995, 0.88427883, 0.90551794, 0.9768483, 0.8712597, 0.5804621, 0.5553012, 0.65278876, 0.6639653, 0.3424074, 0.5827954, 0.23436713, 0.33870244, 0.37170887, 0.3356334, -0.016494572, 0.05951333, -0.03467077, 0.10633749, -0.23458183, -0.28548563, -0.42358667, -0.4177636, -0.53030837, -0.6848571, -0.93702555, -0.8196386, -1.1445826, -1.0346138, -1.2190297, -1.3237003, -1.3124739, -1.657484, -1.5409082, -1.4635613, -1.4730215], "type": "scatter", "mode": "markers", "name": "训练数据", "marker": {"color": "#1f77b4", "size": 6}}] }训练数据遵循模式 $y \approx x \sin(2x)$ 并添加了高斯噪声。定义贝叶斯神经网络现在,我们将使用Keras函数式API和TFP层来定义我们的贝叶斯神经网络。具体来说,我们使用tfp.layers.DenseVariational。该层表示一个全连接神经网络层,其权重和偏差是分布(我们的近似后验 $q(w)$),而不是点估计。在训练期间,该层会将一个KL散度项添加到模型的损失函数中。该项衡量了学到的近似后验 $q(w)$ 与先验 $p(w)$ 之间的差异。该层会自动处理前向传播所需的采样,并将该KL项的计算作为VI目标的一部分(ELBO最大化,或等效地,负ELBO最小化)。我们需要指定:先验分布: 表示我们对权重信念的分布 $p(w)$,在看到数据之前。一个常见的选择是,以零为中心的各向同性高斯(正态)分布。后验近似: 用于近似真实后验 $p(w|\mathcal{D})$ 的分布族 $q(w)$。一个常见选择是因子化的(均值场)高斯分布。KL散度计算函数: 如何计算 $KL[q(w) || p(w)]$。TFP为此提供了工具。# 定义权重和偏差的先验分布 def prior_fn(kernel_size, bias_size, dtype=None): n = kernel_size + bias_size prior_model = tfk.Sequential([ tfpl.VariableLayer(tfpl.IndependentNormal.params_size(n), dtype=dtype), tfpl.IndependentNormal(n, convert_to_tensor_fn=tfd.Distribution.sample) ]) return prior_model # 定义后验近似策略(均值场高斯) def posterior_fn(kernel_size, bias_size, dtype=None): n = kernel_size + bias_size posterior_model = tfk.Sequential([ tfpl.VariableLayer(tfpl.IndependentNormal.params_size(n), dtype=dtype), tfpl.IndependentNormal(n, convert_to_tensor_fn=tfd.Distribution.sample) ]) return posterior_model # 构建贝叶斯神经网络模型 def create_bnn_model(train_size): inputs = tfkl.Input(shape=(1,)) hidden = tfpl.DenseVariational( units=32, make_prior_fn=prior_fn, make_posterior_fn=posterior_fn, kl_weight=1/train_size, # 按数据集大小缩放KL散度 activation='relu' )(inputs) hidden = tfpl.DenseVariational( units=16, make_prior_fn=prior_fn, make_posterior_fn=posterior_fn, kl_weight=1/train_size, activation='relu' )(hidden) # 输出层:预测正态分布的均值 # 我们将输出 y 模型化为 y ~ Normal(loc=f(x), scale=sigma) # 这里,f(x) 是 DenseVariational 层的输出 # 为了简单起见,我们使用固定的标准差 (sigma), # 有效地使用均方误差作为负对数似然。 # 或者,另一个输出头可以预测sigma(偶然不确定性)。 output_mean = tfpl.DenseVariational( units=1, # 预测均值参数 make_prior_fn=prior_fn, make_posterior_fn=posterior_fn, kl_weight=1/train_size # 回归输出均值不需要激活函数 )(hidden) # 为了简单起见,我们使用MSE损失,这对应于一个固定的高斯似然标准差。 # 一个更完整的贝叶斯神经网络也可能预测标准差(尺度参数)。 # 例子:output_scale = tfpl.DenseVariational(...) -> tf.exp(output_scale_raw) # 然后使用 tfp.layers.IndependentNormal(1) 作为最终层。 model = tfk.Model(inputs=inputs, outputs=output_mean) return model bnn_model = create_bnn_model(train_size=len(X_train)) bnn_model.summary()我们将KL散度项乘以 1 / train_size。这是贝叶斯神经网络中变分推断的常见做法,用于平衡目标函数中的数据拟合(似然)项和正则化(KL散度)项。定义损失函数和训练对于变分推断,目标是最大化证据下界(ELBO),这等效于最小化负ELBO。负ELBO可以表示为:$$ -\text{ELBO} = -\mathbb{E}_{q(w)}[\log p(\mathcal{D}|w)] + KL[q(w) || p(w)] $$第一项是给定从近似后验中采样的参数时,数据的预期负对数似然。第二项是近似后验与先验之间的KL散度。当Keras使用DenseVariational时,KL散度项会自动添加到模型的损失函数中。我们只需指定负对数似然项作为我们的主要损失函数。对于假定高斯噪声(恒定方差)的回归任务,负对数似然与均方误差(MSE)成正比。# 定义负对数似然损失函数(高斯似然的均方误差) def nll_loss(y_true, y_pred_distribution): # 对于 DenseVariational,y_pred_distribution 在这里只是预测均值。 # 一个更完整的模型会输出一个 tfd.Distribution。 # return -y_pred_distribution.log_prob(y_true) # 如果输出层是 tfp.layers.IndependentNormal return tf.reduce_mean(tf.square(y_true - y_pred_distribution)) # 编译模型 optimizer = tfk.optimizers.Adam(learning_rate=0.01) bnn_model.compile(optimizer=optimizer, loss=nll_loss) # Keras 会自动添加 KL 散度 # 训练模型 print("开始训练...") history = bnn_model.fit(X_train, y_train, epochs=500, batch_size=32, verbose=0) print("训练完成。") # 你可以绘制损失曲线(总损失 = 负对数似然 + KL散度) # plt.plot(history.history['loss']) # plt.title('模型训练期间的损失') # plt.xlabel('周期') # plt.ylabel('总损失 (-ELBO)') # plt.show()进行预测和不确定性可视化贝叶斯神经网络的主要优势是它们量化不确定性的能力。使用变分推断,我们用 $q(w)$ 近似后验 $p(w|\mathcal{D})$。为了获得预测不确定性,我们通过网络执行多次前向传播,每次采样一组不同的权重 $w_i \sim q(w)$。输出中的变化反映了模型的认知不确定性(关于模型参数的不确定性)。# 通过多次采样进行预测 n_samples = 100 predictions_mc = np.stack([bnn_model(X_test).numpy() for _ in range(n_samples)], axis=0) # 移除不必要的维度 predictions_mc = np.squeeze(predictions_mc) # Shape: (n_samples, n_test_points) # 计算预测均值和标准差 pred_mean = np.mean(predictions_mc, axis=0) pred_std = np.std(predictions_mc, axis=0) # 可视化结果:均值预测和不确定性边界 fig = go.Figure() # 不确定性边界(例如,+/- 2个标准差) fig.add_trace(go.Scatter( x=np.concatenate([X_test.flatten(), X_test.flatten()[::-1]]), y=np.concatenate([pred_mean - 2 * pred_std, (pred_mean + 2 * pred_std)[::-1]]), fill='toself', fillcolor='rgba(250, 82, 82, 0.2)', # 淡红色 #fa5252 line=dict(color='rgba(255,255,255,0)'), hoverinfo="skip", showlegend=False, name='认知不确定性 (±2 std)' )) # 均值预测 fig.add_trace(go.Scatter( x=X_test.flatten(), y=pred_mean, mode='lines', name='预测均值', line=dict(color='#f03e3e') # 红色 #f03e3e )) # 原始训练数据 fig.add_trace(go.Scatter( x=X_train.flatten(), y=y_train.flatten(), mode='markers', name='训练数据', marker=dict(color='#1c7ed6', size=6) # 蓝色 #1c7ed6 )) fig.update_layout( title='带不确定性的贝叶斯神经网络回归', xaxis_title='输入 (x)', yaxis_title='输出 (y)', template='plotly_white', legend_title_text='组成部分' ) # fig.show() # 在Python环境中运行此行以显示图表{"layout": {"title": "带不确定性的贝叶斯神经网络回归", "xaxis": {"title": "输入 (x)"}, "yaxis": {"title": "输出 (y)"}, "template": "plotly_white", "legend": {"title": {"text": "组成部分"}}}, "data": [{"x": [-4.0, -3.959799, -3.919598, -3.879397, -3.839196, -3.798995, -3.758794, -3.718593, -3.678392, -3.638191, -3.59799, -3.557789, -3.517588, -3.477387, -3.437186, -3.396985, -3.356784, -3.316583, -3.276382, -3.236181, -3.19598, -3.155779, -3.115578, -3.075377, -3.035176, -2.994975, -2.954774, -2.914573, -2.874372, -2.834171, -2.79397, -2.753769, -2.713568, -2.673367, -2.633166, -2.592965, -2.552764, -2.512563, -2.472362, -2.432161, -2.39196, -2.351759, -2.311558, -2.271357, -2.231156, -2.190955, -2.150754, -2.110553, -2.070352, -2.030151, -1.98995, -1.949749, -1.909548, -1.869347, -1.829146, -1.788945, -1.748744, -1.708543, -1.668342, -1.628141, -1.58794, -1.547739, -1.507538, -1.467337, -1.427136, -1.386935, -1.346734, -1.306533, -1.266332, -1.226131, -1.185929, -1.145729, -1.105528, -1.065327, -1.025126, -0.9849246, -0.9447236, -0.9045226, -0.8643216, -0.8241206, -0.7839196, -0.7437186, -0.7035176, -0.6633166, -0.6231156, -0.5829146, -0.5427135, -0.5025125, -0.4623115, -0.4221105, -0.3819095, -0.3417085, -0.3015075, -0.2613065, -0.2211055, -0.1809045, -0.1407035, -0.1005025, -0.06030151, -0.0201005, 0.0201005, 0.06030151, 0.1005025, 0.1407035, 0.1809045, 0.2211055, 0.2613065, 0.3015075, 0.3417085, 0.3819095, 0.4221105, 0.4623115, 0.5025125, 0.5427135, 0.5829146, 0.6231156, 0.6633166, 0.7035176, 0.7437186, 0.7839196, 0.8241206, 0.8643216, 0.9045226, 0.9447236, 0.9849246, 1.025126, 1.065327, 1.105528, 1.145729, 1.185929, 1.226131, 1.266332, 1.306533, 1.346734, 1.386935, 1.427136, 1.467337, 1.507538, 1.547739, 1.58794, 1.628141, 1.668342, 1.708543, 1.748744, 1.788945, 1.829146, 1.869347, 1.909548, 1.949749, 1.98995, 2.030151, 2.070352, 2.110553, 2.150754, 2.190955, 2.231156, 2.271357, 2.311558, 2.351759, 2.39196, 2.432161, 2.472362, 2.512563, 2.552764, 2.592965, 2.633166, 2.673367, 2.713568, 2.753769, 2.79397, 2.834171, 2.874372, 2.914573, 2.954774, 2.994975, 3.035176, 3.075377, 3.115578, 3.155779, 3.19598, 3.236181, 3.276382, 3.316583, 3.356784, 3.396985, 3.437186, 3.477387, 3.517588, 3.557789, 3.59799, 3.638191, 3.678392, 3.718593, 3.758794, 3.798995, 3.839196, 3.879397, 3.919598, 3.959799, 4.0, 4.0, 3.959799, 3.919598, 3.879397, 3.839196, 3.798995, 3.758794, 3.718593, 3.678392, 3.638191, 3.59799, 3.557789, 3.517588, 3.477387, 3.437186, 3.396985, 3.356784, 3.316583, 3.276382, 3.236181, 3.19598, 3.155779, 3.115578, 3.075377, 3.035176, 2.994975, 2.954774, 2.914573, 2.874372, 2.834171, 2.79397, 2.753769, 2.713568, 2.673367, 2.633166, 2.592965, 2.552764, 2.512563, 2.472362, 2.432161, 2.39196, 2.351759, 2.311558, 2.271357, 2.231156, 2.190955, 2.150754, 2.110553, 2.070352, 2.030151, 1.98995, 1.949749, 1.909548, 1.869347, 1.829146, 1.788945, 1.748744, 1.708543, 1.668342, 1.628141, 1.58794, 1.547739, 1.507538, 1.467337, 1.427136, 1.386935, 1.346734, 1.306533, 1.266332, 1.226131, 1.185929, 1.145729, 1.105528, 1.065327, 1.025126, 0.9849246, 0.9447236, 0.9045226, 0.8643216, 0.8241206, 0.7839196, 0.7437186, 0.7035176, 0.6633166, 0.6231156, 0.5829146, 0.5427135, 0.5025125, 0.4623115, 0.4221105, 0.3819095, 0.3417085, 0.3015075, 0.2613065, 0.2211055, 0.1809045, 0.1407035, 0.1005025, 0.06030151, 0.0201005, -0.0201005, -0.06030151, -0.1005025, -0.1407035, -0.1809045, -0.2211055, -0.2613065, -0.3015075, -0.3417085, -0.3819095, -0.4221105, -0.4623115, -0.5025125, -0.5427135, -0.5829146, -0.6231156, -0.6633166, -0.7035176, -0.7437186, -0.7839196, -0.8241206, -0.8643216, -0.9045226, -0.9447236, -0.9849246, -1.025126, -1.065327, -1.105528, -1.145729, -1.185929, -1.226131, -1.266332, -1.306533, -1.346734, -1.386935, -1.427136, -1.467337, -1.507538, -1.547739, -1.58794, -1.628141, -1.668342, -1.708543, -1.748744, -1.788945, -1.829146, -1.869347, -1.909548, -1.949749, -1.98995, -2.030151, -2.070352, -2.110553, -2.150754, -2.190955, -2.231156, -2.271357, -2.311558, -2.351759, -2.39196, -2.432161, -2.472362, -2.512563, -2.552764, -2.592965, -2.633166, -2.673367, -2.713568, -2.753769, -2.79397, -2.834171, -2.874372, -2.914573, -2.954774, -2.994975, -3.035176, -3.075377, -3.115578, -3.155779, -3.19598, -3.236181, -3.276382, -3.316583, -3.356784, -3.396985, -3.437186, -3.477387, -3.517588, -3.557789, -3.59799, -3.638191, -3.678392, -3.718593, -3.758794, -3.798995, -3.839196, -3.879397, -3.919598, -3.959799, -4.0], "y": [1.5558792, 1.4707384, 1.3860499, 1.3021169, 1.2192476, 1.1377548, 1.0579485, 0.9801291, 0.9045843, 0.8315817, 0.7613709, 0.694178, 0.6302079, 0.56964254, 0.5126333, 0.45930666, 0.4097639, 0.3640771, 0.3222824, 0.28438628, 0.25036383, 0.22015977, 0.1937027, 0.17089564, 0.15161985, 0.1357478, 0.12313551, 0.11362845, 0.10706824, 0.10328835, 0.10212123, 0.10339582, 0.10693651, 0.11256474, 0.12010068, 0.1293652, 0.14018154, 0.15237463, 0.16577172, 0.18019569, 0.19547272, 0.21143138, 0.22790223, 0.24471974, 0.26172864, 0.27877855, 0.29572886, 0.31244898, 0.3288195, 0.34472966, 0.36007893, 0.37478316, 0.38876814, 0.40197015, 0.41432893, 0.42580247, 0.4363587, 0.44597697, 0.45464492, 0.4623528, 0.4690976, 0.47488165, 0.4797149, 0.48361695, 0.48661733, 0.48875248, 0.4900686, 0.4906193, 0.4904676, 0.48968637, 0.4883546, 0.4865593, 0.4843905, 0.48194027, 0.47929645, 0.47654784, 0.4737804, 0.47108173, 0.4685339, 0.46621192, 0.4641819, 0.4624994, 0.46120894, 0.46034658, 0.45994127, 0.4600122, 0.4605716, 0.46162653, 0.46317863, 0.4652231, 0.4677478, 0.4707367, 0.47416782, 0.47801065, 0.4822258, 0.4867677, 0.4915854, 0.49662447, 0.50182813, 0.5071391, 0.5125029, 0.5178666, 0.52317846, 0.52839124, 0.5334614, 0.5383487, 0.54299974, 0.5473941, 0.5514919, 0.5552646, 0.55869174, 0.5617566, 0.5644459, 0.566757, 0.568697, 0.57028466, 0.57154626, 0.5725181, 0.5732444, 0.5737744, 0.574156, 0.5744349, 0.57465154, 0.5748401, 0.5750288, 0.5752382, 0.57548165, 0.5757666, 0.57609534, 0.5764649, 0.57686996, 0.577299, 0.5777369, 0.5781662, 0.57856643, 0.57891703, 0.57919496, 0.5793792, 0.579448, 0.5793781, 0.57914865, 0.5787412, 0.5781387, 0.5773265, 0.57629156, 0.57501984, 0.57350016, 0.5717225, 0.56967986, 0.5673674, 0.5647818, 0.5619227, 0.5587924, 0.5553943, 0.5517324, 0.5478108, 0.5436327, 0.5392027, 0.5345243, 0.529599, 0.5244268, 0.51900566, 0.513331, 0.5073969, 0.5011958, 0.4947184, 0.4879544, 0.48089433, 0.4735294, 0.4658525, 0.45785832, 0.4495442, 0.44090992, 0.43195534, 0.42268085, 0.41308647, 0.40317273, 0.3929401, 0.3823881, 0.37151635, 0.36032444, 0.34881085, 0.33697385, 0.32481158, 0.31232214, 0.299499, 0.2863354, 0.27282453, 0.25895977, 0.24473524, 0.23014605, 0.21518803, 0.19985944, 0.18416238, 0.16810346, 0.15169382, 0.1349479, 0.11788535, 0.10052991, 0.082909346, 0.065056086, 0.046997488, 0.028766453, 0.010392785, -1.6399431, -1.6046672, -1.5692997, -1.5339613, -1.4987688, -1.4638346, -1.429266, -1.395162, -1.361607, -1.3286774, -1.2964393, -1.2649472, -1.2342469, -1.2043759, -1.1753628, -1.1472267, -1.1199758, -1.0936109, -1.0681275, -1.0435175, -1.01977, -0.99687195, -0.9748076, -0.9535601, -0.93311054, -0.91343755, -0.89451826, -0.8763285, -0.8588445, -0.8420407, -0.82589054, -0.8103684, -0.79544747, -0.78109974, -0.76729774, -0.7540133, -0.7412175, -0.7288818, -0.7169771, -0.7054744, -0.6943454, -0.6835624, -0.673098, -0.6629255, -0.6530194, -0.64335436, -0.6339059, -0.62465036, -0.61556464, -0.6066261, -0.5978129, -0.589099, -0.5804676, -0.57190156, -0.5633795, -0.5548816, -0.54639155, -0.53789115, -0.529368, -0.52081066, -0.5122084, -0.5035535, -0.49484003, -0.48606396, -0.47722244, -0.46831405, -0.45933855, -0.4502966, -0.4411902, -0.43202245, -0.42279732, -0.41352, -0.4041965, -0.39483374, -0.38543916, -0.37601984, -0.3665825, -0.35713363, -0.3476792, -0.33822513, -0.32877684, -0.31933945, -0.3099177, -0.3005162, -0.29113907, -0.2817902, -0.27247316, -0.26319116, -0.25394702, -0.24474365, -0.23558366, -0.22646952, -0.21740329, -0.20838726, -0.19942337, -0.19051355, -0.18165964, -0.17286313, -0.1641255, -0.15544814, -0.14683223, -0.13827872, -0.12978846, -0.12136215, -0.112999976, -0.10470247, -0.09646964, -0.08829999, -0.08019239, -0.0721454, -0.064157486, -0.05622715, -0.0483529, -0.040533245, -0.03276688, -0.025052547, -0.017388999, -0.009774864, -0.002209127, 0.0053089857, 0.012781143, 0.020208359, 0.027591646, 0.034932017, 0.04222995, 0.0494864, 0.056699872, 0.06387049, 0.07099718, 0.07807964, 0.085116684, 0.09210646, 0.09904748, 0.1059382, 0.1127764, 0.1195606, 0.12628889, 0.13295954, 0.13957083, 0.14612114, 0.15260875, 0.15903217, 0.16538984, 0.17168057, 0.17789996, 0.18404585, 0.19011617, 0.19610977, 0.20202488, 0.20786023, 0.21361464, 0.21928751, 0.22487795, 0.23038507, 0.23580778, 0.24114543, 0.24639702, 0.25156176, 0.25663924, 0.2616289, 0.26653028, 0.2713431, 0.27606708, 0.28069973, 0.28524095, 0.28968978, 0.29404485, 0.2983049, 0.30246836, 0.3065337, 0.3104997, 0.31436443, 0.31812614, 0.32178307, 0.3253337, 0.32877636, 0.33210957, 0.33533192, 0.33844197, 0.34143835, 0.34431982, 0.34708494, 0.34973246, 0.35226113, 0.3546699, 0.35695756, 0.359123, 0.36116493, 0.36308223, 0.36487377, 0.36653835, 0.36807525, 0.36948353, 0.3707623, 0.37191063, 0.37292743, 0.37381184, 0.37456316, 0.37518048, 0.37566328, 0.37601107, 0.3762234, 0.37629986, 0.37624007, 0.37604368, 0.37570995, 0.37523854, 0.37462914, 0.37388128, 0.37299478, 0.3719694, 0.37080485, 0.36949998, 0.36805415, 0.36646664, 0.3647369, 0.36286438, 0.36084843, 0.35868847, 0.3563838, 0.35393405, 0.35133845, 0.3485965, 0.34570783, 0.34267193, 0.33948803, 0.33615595, 0.33267504, 0.32904446, 0.32526356, 0.32133174, 0.3172484, 0.31301284, 0.30862457, 0.30408287, 0.29938734, 0.29453737, 0.28953242, 0.28437185, 0.27905518, 0.2735818, 0.26795125, 0.26216286, 0.2562163, 0.25011086, 0.24384636, 0.23742235, 0.23083866, 0.22409499, 0.21719116, 0.21012676, 0.20290178, 0.19551575, 0.18796843, 0.18026, 0.17238987, 0.16435796, 0.15616423, 0.14780855, 0.13929099, 0.13061154, 0.121769965, 0.112766385, 0.10360074, 0.09427309, 0.084783494, 0.075131774, 0.06531811, 0.055342376, 0.04520488, 0.034905314, 0.024443686, 0.013820052, 0.003034234, -0.007913649, -0.019023955, -0.030296743, -0.041731954, -0.053330243, -0.06509125, -0.07701564, -0.089099646, -0.10134214, -0.11374015, -0.1262902, -0.1389876, -0.15182745, -0.16480517, -0.17791456, -0.19114923, -0.20450258, -0.21796799, -0.23153865, -0.24520785, -0.25896806, -0.27281183, -0.28673154, -0.3007197, -0.31476867, -0.32887095, -0.34301883, -0.35720474, -0.37142116, -0.38565993, -0.39991355, -0.41417444, -0.42843503, -0.44268793, -0.45692557, -0.47114062, -0.48532557, -0.49947274, -0.5135747, -0.527624, -0.54161316, -0.5555347, -0.5693814, -0.5831459, -0.59682095, -0.6103991, -0.62387335, -0.6372363, -0.65048087, -0.6636001, -0.6765873, -0.68943584, -0.7021388, -0.71469, -0.7270827, -0.73931116, -0.7513696, -0.76325285, -0.7749555, -0.78647244, -0.7977984, -0.8089284, -0.81985736, -0.8305807, -0.8410939, -0.8513924, -0.8614719, -0.87132776, -0.88095576, -0.8903521, -0.8995125, -0.9084333, -0.9171106, -0.9255408, -0.93372023, -0.9416453, -0.9493124, -0.9567181, -0.9638587, -0.97073126, -0.97733265, -0.9836599, -0.98971033, -0.9954813, -1.0009704, -1.0061748, -1.0110923, -1.0157207, -1.0200578, -1.0241017, -1.0278505, -1.0313026, -1.0344566, -1.0373113, -1.0398656, -1.0421181, -1.0440676, -1.0457134, -1.0470543, -1.0480896, -1.0488186, -1.0492406, -1.049355, -1.0491617, -1.0486603, -1.0478506, -1.0467327, -1.0453062, -1.0435712, -1.0415276, -1.0391755, -1.0365151, -1.0335463, -1.0302693, -1.026684, -1.0227907, -1.0185888, -1.0140786, -1.0092603, -1.0041338, -0.99869967, -0.9929577, -0.98690844, -0.9805517, -0.9738879, -0.96691763, -0.9596412, -0.95205915, -0.9441719, -0.9359802, -0.9274845, -0.91868544, -0.9095833, -0.9001788, -0.8904724, -0.88046455, -0.8701557, -0.85954654, -0.84863746, -0.8374289, -0.82592154, -0.8141159, -0.79163957, -0.7791682, -0.7664019, -0.7533411, -0.7399862, -0.72633785, -0.71239674, -0.6981635, -0.6836389, -0.6688236, -0.65371823, -0.63832355, -0.6226399, -0.60666823, -0.59040916, -0.5738635, -0.557032, -0.5399153, -0.5225141, -0.5048291, -0.4868611, -0.46861088, -0.45007914, -0.43126673, -0.41217434, -0.39280295, -0.37315315, -0.35322583, -0.33299994, -0.31251645, -0.29175437, -0.27071452, -0.24939793, -0.22780538, -0.20593786, -0.18379617, -0.16138119, -0.13869375, -0.115734875, -0.092505455, -0.06900644, -0.045238852, -0.021203637], "fill": "toself", "fillcolor": "rgba(250, 82, 82, 0.2)", "line": {"color": "rgba(255,255,255,0)"}, "hoverinfo": "skip", "showlegend": false, "name": "认知不确定性 (±2 std)"}, {"x": [-4.0, -3.959799, -3.919598, -3.879397, -3.839196, -3.798995, -3.758794, -3.718593, -3.678392, -3.638191, -3.59799, -3.557789, -3.517588, -3.477387, -3.437186, -3.396985, -3.356784, -3.316583, -3.276382, -3.236181, -3.19598, -3.155779, -3.115578, -3.075377, -3.035176, -2.994975, -2.954774, -2.914573, -2.874372, -2.834171, -2.79397, -2.753769, -2.713568, -2.673367, -2.633166, -2.592965, -2.552764, -2.512563, -2.472362, -2.432161, -2.39196, -2.351759, -2.311558, -2.271357, -2.231156, -2.190955, -2.150754, -2.110553, -2.070352, -2.030151, -1.98995, -1.949749, -1.909548, -1.869347, -1.829146, -1.788945, -1.748744, -1.708543, -1.668342, -1.628141, -1.58794, -1.547739, -1.507538, -1.467337, -1.427136, -1.386935, -1.346734, -1.306533, -1.266332, -1.226131, -1.185929, -1.145729, -1.105528, -1.065327, -1.025126, -0.9849246, -0.9447236, -0.9045226, -0.8643216, -0.8241206, -0.7839196, -0.7437186, -0.7035176, -0.6633166, -0.6231156, -0.5829146, -0.5427135, -0.5025125, -0.4623115, -0.4221105, -0.3819095, -0.3417085, -0.3015075, -0.2613065, -0.2211055, -0.1809045, -0.1407035, -0.1005025, -0.06030151, -0.0201005, 0.0201005, 0.06030151, 0.1005025, 0.1407035, 0.1809045, 0.2211055, 0.2613065, 0.3015075, 0.3417085, 0.3819095, 0.4221105, 0.4623115, 0.5025125, 0.5427135, 0.5829146, 0.6231156, 0.6633166, 0.7035176, 0.7437186, 0.7839196, 0.8241206, 0.8643216, 0.9045226, 0.9447236, 0.9849246, 1.025126, 1.065327, 1.105528, 1.145729, 1.185929, 1.226131, 1.266332, 1.306533, 1.346734, 1.386935, 1.427136, 1.467337, 1.507538, 1.547739, 1.58794, 1.628141, 1.668342, 1.708543, 1.748744, 1.788945, 1.829146, 1.869347, 1.909548, 1.949749, 1.98995, 2.030151, 2.070352, 2.110553, 2.150754, 2.190955, 2.231156, 2.271357, 2.311558, 2.351759, 2.39196, 2.432161, 2.472362, 2.512563, 2.552764, 2.592965, 2.633166, 2.673367, 2.713568, 2.753769, 2.79397, 2.834171, 2.874372, 2.914573, 2.954774, 2.994975, 3.035176, 3.075377, 3.115578, 3.155779, 3.19598, 3.236181, 3.276382, 3.316583, 3.356784, 3.396985, 3.437186, 3.477387, 3.517588, 3.557789, 3.59799, 3.638191, 3.678392, 3.718593, 3.758794, 3.798995, 3.839196, 3.879397, 3.919598, 3.959799, 4.0], "y": [-0.819947, -0.7824755, -0.7452142, -0.70828056, -0.6717931, -0.6358687, -0.6006203, -0.56615573, -0.53257877, -0.4999895, -0.46848118, -0.43813938, -0.409042, -0.3812585, -0.35484767, -0.32986057, -0.30634046, -0.2843194, -0.2638188, -0.24485224, -0.22742105, -0.21151513, -0.19711399, -0.18418944, -0.17270374, -0.16261035, -0.15386218, -0.14640707, -0.14019126, -0.13515854, -0.13125122, -0.1284098, -0.12657481, -0.12568784, -0.12569124, -0.12652874, -0.12814456, -0.13048214, -0.13348514, -0.13709837, -0.14126772, -0.1459403, -0.15106297, -0.15658438, -0.16245484, -0.16862637, -0.17505348, -0.18169254, -0.18850183, -0.19544196, -0.20247585, -0.20956945, -0.2166897, -0.22380555, -0.2308873, -0.23790693, -0.24483848, -0.2516576, -0.2583406, -0.2648643, -0.2712071, -0.27734864, -0.28326976, -0.2889531, -0.2943819, -0.29954088, -0.30441558, -0.3089928, -0.3132605, -0.3172078, -0.32082504, -0.32410383, -0.3270371, -0.32961893, -0.3318449, -0.3337113, -0.3352151, -0.3363545, -0.33712757, -0.33753318, -0.33757138, -0.3372418, -0.33654577, -0.335484, -0.33405793, -0.3322693, -0.33012015, -0.3276133, -0.32475245, -0.3215419, -0.31798667, -0.31409186, -0.30986297, -0.30530614, -0.30042756, -0.2952339, -0.28973204, -0.2839295, -0.27783388, -0.27145308, -0.26479536, -0.2578689, -0.2506821, -0.2432437, -0.23556197, -0.2276457, -0.21950376, -0.2111448, -0.20257765, -0.19381118, -0.18485415, -0.17571527, -0.16640335, -0.15692705, -0.14729506, -0.13751614, -0.12759894, -0.11755216, -0.10738456, -0.09710485, -0.08672178, -0.076244056, -0.065680385, -0.055039465, -0.04433006, -0.03356093, -0.0227409, -0.011878669, -0.0009829998, 0.009936094, 0.02087003, 0.031809986, 0.04274684, 0.05367154, 0.06457501, 0.075448215, 0.08628219, 0.09706795, 0.10779655, 0.11845899, 0.12904644, 0.13955015, 0.14996147, 0.16027188, 0.17047286, 0.18055606, 0.19051313, 0.20033574, 0.21001554, 0.21954441, 0.22891426, 0.23811704, 0.24714482, 0.25599027, 0.26464623, 0.27310586, 0.28136247, 0.2894094, 0.29724014, 0.30484843, 0.31222796, 0.31937283, 0.32627738, 0.33293593, 0.33934295, 0.34549308, 0.35138118, 0.35700214, 0.36235118, 0.36742353, 0.37221467, 0.37672025, 0.38093603, 0.38485777, 0.38848144, 0.39179915, 0.39481115, 0.39751762, 0.3999188, 0.40201473, 0.40380597, 0.40529317, 0.40647686, 0.40735787, 0.40793687, 0.40821487, 0.40819287, 0.40787184, 0.40725285, 0.40633786], "type": "scatter", "mode": "lines", "name": "预测均值", "line": {"color": "#f03e3e"}}, {"x": [-3.0, -2.9597316, -2.9194632, -2.8791947, -2.8389263, -2.7986577, -2.7583892, -2.718121, -2.6778524, -2.637584, -2.5973153, -2.557047, -2.5167787, -2.4765103, -2.4362416, -2.3959732, -2.3557048, -2.3154364, -2.2751677, -2.2348993, -2.194631, -2.1543624, -2.114094, -2.0738256, -2.0335572, -1.9932885, -1.9530201, -1.9127518, -1.8724833, -1.8322148, -1.7919464, -1.751678, -1.7114094, -1.671141, -1.6308725, -1.5906041, -1.5503356, -1.5100671, -1.4697987, -1.4295303, -1.3892617, -1.3489933, -1.3087249, -1.2684565, -1.2281879, -1.1879195, -1.1476511, -1.1073826, -1.0671141, -1.0268457, -0.9865772, -0.94630873, -0.9060403, -0.8657718, -0.82550335, -0.7852349, -0.74496645, -0.70469797, -0.66442954, -0.62416106, -0.58389264, -0.54362416, -0.5033557, -0.46308723, -0.4228188, -0.38255033, -0.34228188, -0.30201343, -0.26174498, -0.22147651, -0.18120806, -0.1409396, -0.10067114, -0.06040268, -0.020134227, 0.020134227, 0.06040268, 0.10067114, 0.1409396, 0.18120806, 0.22147651, 0.26174498, 0.30201343, 0.34228188, 0.38255033, 0.4228188, 0.46308723, 0.5033557, 0.54362416, 0.58389264, 0.62416106, 0.66442954, 0.70469797, 0.74496645, 0.7852349, 0.82550335, 0.8657718, 0.9060403, 0.94630873, 0.9865772, 1.0268457, 1.0671141, 1.1073826, 1.1476511, 1.1879195, 1.2281879, 1.2684565, 1.3087249, 1.3489933, 1.3892617, 1.4295303, 1.4697987, 1.5100671, 1.5503356, 1.5906041, 1.6308725, 1.671141, 1.7114094, 1.751678, 1.7919464, 1.8322148, 1.8724833, 1.9127518, 1.9530201, 1.9932885, 2.0335572, 2.0738256, 2.114094, 2.1543624, 2.194631, 2.2348993, 2.2751677, 2.3154364, 2.3557048, 2.3959732, 2.4362416, 2.4765103, 2.5167787, 2.557047, 2.5973153, 2.637584, 2.6778524, 2.718121, 2.7583892, 2.7986577, 2.8389263, 2.8791947, 2.9194632, 2.9597316, 3.0], "y": [1.0303671, 1.050024, 0.34705496, 0.01273185, -0.30396357, -0.29729748, -0.8527952, -0.9483495, -1.1109892, -0.8656279, -0.98107433, -1.2403715, -1.2898401, -1.2215685, -1.1911366, -1.3070737, -1.1893226, -1.296894, -0.8970964, -0.88211715, -0.7332828, -0.60031426, -0.34907925, -0.32362396, -0.07006347, 0.13466883, 0.18986344, 0.19970965, 0.24910164, 0.3592739, 0.27973264, 0.30015373, 0.15951216, 0.2634468, 0.07025027, 0.053674817, 0.03876072, -0.19164044, -0.11680764, -0.09389448, -0.17775708, -0.2371642, -0.21847367, -0.12808496, 0.0019137263, -0.015195906, -0.073530376, -0.15822774, -0.33197582, -0.030967653, -0.07161963, 0.0657717, 0.0283497, 0.14192665, -0.020459652, 0.08628744, 0.14018708, 0.14055079, 0.05315751, 0.061556935, -0.27404195, -0.08554834, -0.26678258, 0.032500029, -0.06732887, 0.055856705, -0.09388715, 0.04359156, 0.09010857, -0.021685064, 0.1250062, 0.083204925, -0.2292543, -0.10864186, -0.10440737, 0.013017178, -0.04824865, 0.09417486, 0.27899224, 0.15981823, 0.22726798, 0.28625697, 0.3458653, 0.433839, 0.5591351, 0.70038676, 0.5945583, 0.83024424, 0.8329691, 0.9106376, 0.8723597, 0.78850543, 0.8853815, 0.75701463, 0.9694258, 1.0420696, 0.88944215, 0.8781106, 0.9598303, 1.0842764, 0.9817552, 1.1226631, 1.025401, 1.121171, 1.1287006, 1.1240332, 1.1152806, 0.9629704, 0.91996074, 1.147251, 0.8633995, 0.88427883, 0.90551794, 0.9768483, 0.8712597, 0.5804621, 0.5553012, 0.65278876, 0.6639653, 0.3424074, 0.5827954, 0.23436713, 0.33870244, 0.37170887, 0.3356334, -0.016494572, 0.05951333, -0.03467077, 0.10633749, -0.23458183, -0.28548563, -0.42358667, -0.4177636, -0.53030837, -0.6848571, -0.93702555, -0.8196386, -1.1445826, -1.0346138, -1.2190297, -1.3237003, -1.3124739, -1.657484, -1.5409082, -1.4635613, -1.4730215], "type": "scatter", "mode": "markers", "name": "训练数据", "marker": {"color": "#1c7ed6", "size": 6}}] }贝叶斯神经网络的预测均值(红线)反映了底层趋势,而阴影区域(距均值 $\pm 2$ 个标准差)代表了认知不确定性。请注意,在没有训练数据的区域(例如 $x < -3$ 或 $x > 3$)以及函数变化迅速的地方,不确定性会增加。替代方法:蒙特卡洛Dropout如前所述,蒙特卡洛(MC)Dropout提供了一种更简单的方法来近似现有标准神经网络中的贝叶斯推断。它包括:训练一个带有Dropout层的标准神经网络。在预测时,保持Dropout活跃并对相同输入执行多次前向传播。计算这些多次预测的均值和方差/标准差,以估计预测均值和不确定性。虽然在计算上更便宜且在标准框架中更容易实现,但蒙特卡洛Dropout是特定类型贝叶斯神经网络的一种近似(与高斯过程相关)。我们实现的变分推断方法通常被认为是一种更合理的方法来构建具有明确先验和后验的贝叶斯神经网络。总结与后续步骤在本实践部分,我们使用TensorFlow Probability的DenseVariational层构建了一个贝叶斯神经网络。我们使用变分推断对其进行训练,目标函数平衡了数据拟合(通过负对数似然/均方误差)和遵循先验信念(通过KL散度)。通过从学到的权重近似后验分布中采样,我们生成了预测结果以及可量化的认知不确定性估计。这个例子为应用贝叶斯神经网络提供了起始点。你 E可以通过以下方式扩展此内容:通过让网络预测输出分布的方差(尺度参数),明确地建模偶然不确定性。尝试不同的网络架构、先验或变分族。将贝叶斯神经网络应用于分类任务(需要不同的似然函数,例如Categorical)。研究MCMC方法,例如SGHMC,以获得可能更准确(但通常更慢)的后验采样。比较贝叶斯神经网络的性能和校准,并与标准神经网络和蒙特卡洛Dropout进行对比。构建贝叶斯神经网络提供了一个强大的框架,用于创建深度学习模型,这些模型不仅能预测,还能理解自身的置信度。