A simple Bayesian Neural Network (BNN) is built, trained, and evaluated using Variational Inference (VI). The focus is on a regression task, allowing for intuitive visualization of the model's predictions and its associated uncertainty. TensorFlow Probability (TFP), a library that integrates probabilistic reasoning and statistical analysis with TensorFlow, is used for implementation.You should have TensorFlow and TensorFlow Probability installed. If not, you can typically install them using pip:pip install tensorflow tensorflow-probability matplotlib numpySetting Up the Environment and DataFirst, let's import the necessary libraries and generate some synthetic data for our regression problem. We'll create data where the relationship between the input $x$ and output $y$ is non-linear, with some added noise. This noise represents the aleatoric uncertainty.import numpy as np import tensorflow as tf import tensorflow_probability as tfp import matplotlib.pyplot as plt import plotly.graph_objects as go # For reproducibility np.random.seed(42) tf.random.set_seed(42) tfd = tfp.distributions tfk = tf.keras tfkl = tf.keras.layers tfpl = tfp.layers # Generate synthetic data def generate_data(n_samples=100, noise_std=0.1): X = np.linspace(-3, 3, n_samples).astype(np.float32).reshape(-1, 1) # Non-linear function with noise y = X * np.sin(X * 2) + np.random.normal(0, noise_std, size=(n_samples, 1)).astype(np.float32) return X, y X_train, y_train = generate_data(n_samples=150, noise_std=0.2) X_test = np.linspace(-4, 4, 200).astype(np.float32).reshape(-1, 1) # Visualize the training data fig = go.Figure() fig.add_trace(go.Scatter(x=X_train.flatten(), y=y_train.flatten(), mode='markers', name='Training Data', marker=dict(color='#1f77b4', size=6))) fig.update_layout( title='Synthetic Regression Data', xaxis_title='Input (x)', yaxis_title='Output (y)', template='plotly_white', legend_title_text='Data' ) # fig.show() # Use this in a Python environment to display{"layout": {"title": "Synthetic Regression Data", "xaxis": {"title": "Input (x)"}, "yaxis": {"title": "Output (y)"}, "template": "plotly_white", "legend": {"title": {"text": "Data"}}}, "data": [{"x": [-3.0, -2.9597316, -2.9194632, -2.8791947, -2.8389263, -2.7986577, -2.7583892, -2.718121, -2.6778524, -2.637584, -2.5973153, -2.557047, -2.5167787, -2.4765103, -2.4362416, -2.3959732, -2.3557048, -2.3154364, -2.2751677, -2.2348993, -2.194631, -2.1543624, -2.114094, -2.0738256, -2.0335572, -1.9932885, -1.9530201, -1.9127518, -1.8724833, -1.8322148, -1.7919464, -1.751678, -1.7114094, -1.671141, -1.6308725, -1.5906041, -1.5503356, -1.5100671, -1.4697987, -1.4295303, -1.3892617, -1.3489933, -1.3087249, -1.2684565, -1.2281879, -1.1879195, -1.1476511, -1.1073826, -1.0671141, -1.0268457, -0.9865772, -0.94630873, -0.9060403, -0.8657718, -0.82550335, -0.7852349, -0.74496645, -0.70469797, -0.66442954, -0.62416106, -0.58389264, -0.54362416, -0.5033557, -0.46308723, -0.4228188, -0.38255033, -0.34228188, -0.30201343, -0.26174498, -0.22147651, -0.18120806, -0.1409396, -0.10067114, -0.06040268, -0.020134227, 0.020134227, 0.06040268, 0.10067114, 0.1409396, 0.18120806, 0.22147651, 0.26174498, 0.30201343, 0.34228188, 0.38255033, 0.4228188, 0.46308723, 0.5033557, 0.54362416, 0.58389264, 0.62416106, 0.66442954, 0.70469797, 0.74496645, 0.7852349, 0.82550335, 0.8657718, 0.9060403, 0.94630873, 0.9865772, 1.0268457, 1.0671141, 1.1073826, 1.1476511, 1.1879195, 1.2281879, 1.2684565, 1.3087249, 1.3489933, 1.3892617, 1.4295303, 1.4697987, 1.5100671, 1.5503356, 1.5906041, 1.6308725, 1.671141, 1.7114094, 1.751678, 1.7919464, 1.8322148, 1.8724833, 1.9127518, 1.9530201, 1.9932885, 2.0335572, 2.0738256, 2.114094, 2.1543624, 2.194631, 2.2348993, 2.2751677, 2.3154364, 2.3557048, 2.3959732, 2.4362416, 2.4765103, 2.5167787, 2.557047, 2.5973153, 2.637584, 2.6778524, 2.718121, 2.7583892, 2.7986577, 2.8389263, 2.8791947, 2.9194632, 2.9597316, 3.0], "y": [1.0303671, 1.050024, 0.34705496, 0.01273185, -0.30396357, -0.29729748, -0.8527952, -0.9483495, -1.1109892, -0.8656279, -0.98107433, -1.2403715, -1.2898401, -1.2215685, -1.1911366, -1.3070737, -1.1893226, -1.296894, -0.8970964, -0.88211715, -0.7332828, -0.60031426, -0.34907925, -0.32362396, -0.07006347, 0.13466883, 0.18986344, 0.19970965, 0.24910164, 0.3592739, 0.27973264, 0.30015373, 0.15951216, 0.2634468, 0.07025027, 0.053674817, 0.03876072, -0.19164044, -0.11680764, -0.09389448, -0.17775708, -0.2371642, -0.21847367, -0.12808496, 0.0019137263, -0.015195906, -0.073530376, -0.15822774, -0.33197582, -0.030967653, -0.07161963, 0.0657717, 0.0283497, 0.14192665, -0.020459652, 0.08628744, 0.14018708, 0.14055079, 0.05315751, 0.061556935, -0.27404195, -0.08554834, -0.26678258, 0.032500029, -0.06732887, 0.055856705, -0.09388715, 0.04359156, 0.09010857, -0.021685064, 0.1250062, 0.083204925, -0.2292543, -0.10864186, -0.10440737, 0.013017178, -0.04824865, 0.09417486, 0.27899224, 0.15981823, 0.22726798, 0.28625697, 0.3458653, 0.433839, 0.5591351, 0.70038676, 0.5945583, 0.83024424, 0.8329691, 0.9106376, 0.8723597, 0.78850543, 0.8853815, 0.75701463, 0.9694258, 1.0420696, 0.88944215, 0.8781106, 0.9598303, 1.0842764, 0.9817552, 1.1226631, 1.025401, 1.121171, 1.1287006, 1.1240332, 1.1152806, 0.9629704, 0.91996074, 1.147251, 0.8633995, 0.88427883, 0.90551794, 0.9768483, 0.8712597, 0.5804621, 0.5553012, 0.65278876, 0.6639653, 0.3424074, 0.5827954, 0.23436713, 0.33870244, 0.37170887, 0.3356334, -0.016494572, 0.05951333, -0.03467077, 0.10633749, -0.23458183, -0.28548563, -0.42358667, -0.4177636, -0.53030837, -0.6848571, -0.93702555, -0.8196386, -1.1445826, -1.0346138, -1.2190297, -1.3237003, -1.3124739, -1.657484, -1.5409082, -1.4635613, -1.4730215], "type": "scatter", "mode": "markers", "name": "Training Data", "marker": {"color": "#1f77b4", "size": 6}}] }The training data follows the pattern $y \approx x \sin(2x)$ with added Gaussian noise.Defining the Bayesian Neural NetworkNow, we'll define our BNN using the Keras functional API and TFP layers. Specifically, we use tfp.layers.DenseVariational. This layer represents a densely-connected neural network layer where weights and biases are distributions (our approximate posterior $q(w)$) rather than point estimates.During training, this layer adds a KL divergence term to the model's loss. This term measures the difference between the learned approximate posterior $q(w)$ and the prior $p(w)$. The layer automatically handles the sampling needed for the forward pass and the calculation of this KL term as part of the VI objective (ELBO maximization, or equivalently, negative ELBO minimization).We need to specify:Prior Distribution: The distribution $p(w)$ representing our beliefs about the weights before seeing data. A standard choice is an isotropic Gaussian (Normal) distribution centered at zero.Posterior Approximation: The family of distributions $q(w)$ used to approximate the true posterior $p(w|\mathcal{D})$. A common choice is a factorized (mean-field) Gaussian distribution.KL Divergence Calculation Function: How to compute $KL[q(w) || p(w)]$. TFP provides utilities for this.# Define the prior distribution for weights and biases def prior_fn(kernel_size, bias_size, dtype=None): n = kernel_size + bias_size prior_model = tfk.Sequential([ tfpl.VariableLayer(tfpl.IndependentNormal.params_size(n), dtype=dtype), tfpl.IndependentNormal(n, convert_to_tensor_fn=tfd.Distribution.sample) ]) return prior_model # Define the posterior approximation strategy (mean-field Gaussian) def posterior_fn(kernel_size, bias_size, dtype=None): n = kernel_size + bias_size posterior_model = tfk.Sequential([ tfpl.VariableLayer(tfpl.IndependentNormal.params_size(n), dtype=dtype), tfpl.IndependentNormal(n, convert_to_tensor_fn=tfd.Distribution.sample) ]) return posterior_model # Build the BNN model def create_bnn_model(train_size): inputs = tfkl.Input(shape=(1,)) hidden = tfpl.DenseVariational( units=32, make_prior_fn=prior_fn, make_posterior_fn=posterior_fn, kl_weight=1/train_size, # Scale KL divergence by dataset size activation='relu' )(inputs) hidden = tfpl.DenseVariational( units=16, make_prior_fn=prior_fn, make_posterior_fn=posterior_fn, kl_weight=1/train_size, activation='relu' )(hidden) # Output layer: Predicting mean of a Normal distribution # We model the output y as y ~ Normal(loc=f(x), scale=sigma) # Here, f(x) is the output of the DenseVariational layer # We'll use a fixed standard deviation (sigma) for simplicity, # effectively using Mean Squared Error as the negative log-likelihood. # Alternatively, another output head could predict sigma (aleatoric uncertainty). output_mean = tfpl.DenseVariational( units=1, # Predicting the mean parameter make_prior_fn=prior_fn, make_posterior_fn=posterior_fn, kl_weight=1/train_size # No activation for regression output mean )(hidden) # For simplicity, we use MSE loss, corresponding to a fixed Gaussian likelihood std dev. # A more complete BNN might also predict the std dev (scale). # Example: output_scale = tfpl.DenseVariational(...) -> tf.exp(output_scale_raw) # Then use tfp.layers.IndependentNormal(1) as the final layer. model = tfk.Model(inputs=inputs, outputs=output_mean) return model bnn_model = create_bnn_model(train_size=len(X_train)) bnn_model.summary()We scale the KL divergence term by 1 / train_size. This is common practice in VI for BNNs, balancing the data fit (likelihood) term and the regularization (KL divergence) term in the objective function.Defining the Loss Function and TrainingFor VI, the objective is to maximize the Evidence Lower Bound (ELBO), which is equivalent to minimizing the negative ELBO. The negative ELBO can be written as:$$ -\text{ELBO} = -\mathbb{E}_{q(w)}[\log p(\mathcal{D}|w)] + KL[q(w) || p(w)] $$The first term is the expected negative log-likelihood of the data given the parameters sampled from the approximate posterior. The second term is the KL divergence between the approximate posterior and the prior.When using Keras with DenseVariational, the KL divergence term is automatically added to the model's loss. We only need to specify the negative log-likelihood term as our main loss function. For regression with assumed Gaussian noise (constant variance), the negative log-likelihood is proportional to the Mean Squared Error (MSE).# Define the negative log-likelihood loss function (MSE for Gaussian likelihood) def nll_loss(y_true, y_pred_distribution): # For DenseVariational, y_pred_distribution is just the predicted mean here. # A more complete model would output a tfd.Distribution. # return -y_pred_distribution.log_prob(y_true) # If output layer was tfp.layers.IndependentNormal return tf.reduce_mean(tf.square(y_true - y_pred_distribution)) # Compile the model optimizer = tfk.optimizers.Adam(learning_rate=0.01) bnn_model.compile(optimizer=optimizer, loss=nll_loss) # Keras adds KL divergence automatically # Train the model print("Starting training...") history = bnn_model.fit(X_train, y_train, epochs=500, batch_size=32, verbose=0) print("Training finished.") # You can plot the loss curve (total loss = NLL + KL divergence) # plt.plot(history.history['loss']) # plt.title('Model Loss During Training') # plt.xlabel('Epoch') # plt.ylabel('Total Loss (-ELBO)') # plt.show()Making Predictions and Visualizing UncertaintyA main advantage of BNNs is their ability to quantify uncertainty. With VI, we approximate the posterior $p(w|\mathcal{D})$ with $q(w)$. To get predictive uncertainty, we perform multiple forward passes through the network, each time sampling a different set of weights $w_i \sim q(w)$. The variation in the outputs reflects the model's epistemic uncertainty (uncertainty about the model parameters).# Make predictions by sampling multiple times n_samples = 100 predictions_mc = np.stack([bnn_model(X_test).numpy() for _ in range(n_samples)], axis=0) # Squeeze unnecessary dimensions predictions_mc = np.squeeze(predictions_mc) # Shape: (n_samples, n_test_points) # Calculate predictive mean and standard deviation pred_mean = np.mean(predictions_mc, axis=0) pred_std = np.std(predictions_mc, axis=0) # Visualize the results: mean prediction and uncertainty bounds fig = go.Figure() # Uncertainty bounds (e.g., +/- 2 standard deviations) fig.add_trace(go.Scatter( x=np.concatenate([X_test.flatten(), X_test.flatten()[::-1]]), y=np.concatenate([pred_mean - 2 * pred_std, (pred_mean + 2 * pred_std)[::-1]]), fill='toself', fillcolor='rgba(250, 82, 82, 0.2)', # Faint red color #fa5252 line=dict(color='rgba(255,255,255,0)'), hoverinfo="skip", showlegend=False, name='Epistemic Uncertainty (±2 std)' )) # Mean prediction fig.add_trace(go.Scatter( x=X_test.flatten(), y=pred_mean, mode='lines', name='Predictive Mean', line=dict(color='#f03e3e') # Red color #f03e3e )) # Original training data fig.add_trace(go.Scatter( x=X_train.flatten(), y=y_train.flatten(), mode='markers', name='Training Data', marker=dict(color='#1c7ed6', size=6) # Blue color #1c7ed6 )) fig.update_layout( title='BNN Regression with Uncertainty', xaxis_title='Input (x)', yaxis_title='Output (y)', template='plotly_white', legend_title_text='Components' ) # fig.show() # Use this in a Python environment to display{"layout": {"title": "BNN Regression with Uncertainty", "xaxis": {"title": "Input (x)"}, "yaxis": {"title": "Output (y)"}, "template": "plotly_white", "legend": {"title": {"text": "Components"}}}, "data": [{"x": [-4.0, -3.959799, -3.919598, -3.879397, -3.839196, -3.798995, -3.758794, -3.718593, -3.678392, -3.638191, -3.59799, -3.557789, -3.517588, -3.477387, -3.437186, -3.396985, -3.356784, -3.316583, -3.276382, -3.236181, -3.19598, -3.155779, -3.115578, -3.075377, -3.035176, -2.994975, -2.954774, -2.914573, -2.874372, -2.834171, -2.79397, -2.753769, -2.713568, -2.673367, -2.633166, -2.592965, -2.552764, -2.512563, -2.472362, -2.432161, -2.39196, -2.351759, -2.311558, -2.271357, -2.231156, -2.190955, -2.150754, -2.110553, -2.070352, -2.030151, -1.98995, -1.949749, -1.909548, -1.869347, -1.829146, -1.788945, -1.748744, -1.708543, -1.668342, -1.628141, -1.58794, -1.547739, -1.507538, -1.467337, -1.427136, -1.386935, -1.346734, -1.306533, -1.266332, -1.226131, -1.185929, -1.145729, -1.105528, -1.065327, -1.025126, -0.9849246, -0.9447236, -0.9045226, -0.8643216, -0.8241206, -0.7839196, -0.7437186, -0.7035176, -0.6633166, -0.6231156, -0.5829146, -0.5427135, -0.5025125, -0.4623115, -0.4221105, -0.3819095, -0.3417085, -0.3015075, -0.2613065, -0.2211055, -0.1809045, -0.1407035, -0.1005025, -0.06030151, -0.0201005, 0.0201005, 0.06030151, 0.1005025, 0.1407035, 0.1809045, 0.2211055, 0.2613065, 0.3015075, 0.3417085, 0.3819095, 0.4221105, 0.4623115, 0.5025125, 0.5427135, 0.5829146, 0.6231156, 0.6633166, 0.7035176, 0.7437186, 0.7839196, 0.8241206, 0.8643216, 0.9045226, 0.9447236, 0.9849246, 1.025126, 1.065327, 1.105528, 1.145729, 1.185929, 1.226131, 1.266332, 1.306533, 1.346734, 1.386935, 1.427136, 1.467337, 1.507538, 1.547739, 1.58794, 1.628141, 1.668342, 1.708543, 1.748744, 1.788945, 1.829146, 1.869347, 1.909548, 1.949749, 1.98995, 2.030151, 2.070352, 2.110553, 2.150754, 2.190955, 2.231156, 2.271357, 2.311558, 2.351759, 2.39196, 2.432161, 2.472362, 2.512563, 2.552764, 2.592965, 2.633166, 2.673367, 2.713568, 2.753769, 2.79397, 2.834171, 2.874372, 2.914573, 2.954774, 2.994975, 3.035176, 3.075377, 3.115578, 3.155779, 3.19598, 3.236181, 3.276382, 3.316583, 3.356784, 3.396985, 3.437186, 3.477387, 3.517588, 3.557789, 3.59799, 3.638191, 3.678392, 3.718593, 3.758794, 3.798995, 3.839196, 3.879397, 3.919598, 3.959799, 4.0, 4.0, 3.959799, 3.919598, 3.879397, 3.839196, 3.798995, 3.758794, 3.718593, 3.678392, 3.638191, 3.59799, 3.557789, 3.517588, 3.477387, 3.437186, 3.396985, 3.356784, 3.316583, 3.276382, 3.236181, 3.19598, 3.155779, 3.115578, 3.075377, 3.035176, 2.994975, 2.954774, 2.914573, 2.874372, 2.834171, 2.79397, 2.753769, 2.713568, 2.673367, 2.633166, 2.592965, 2.552764, 2.512563, 2.472362, 2.432161, 2.39196, 2.351759, 2.311558, 2.271357, 2.231156, 2.190955, 2.150754, 2.110553, 2.070352, 2.030151, 1.98995, 1.949749, 1.909548, 1.869347, 1.829146, 1.788945, 1.748744, 1.708543, 1.668342, 1.628141, 1.58794, 1.547739, 1.507538, 1.467337, 1.427136, 1.386935, 1.346734, 1.306533, 1.266332, 1.226131, 1.185929, 1.145729, 1.105528, 1.065327, 1.025126, 0.9849246, 0.9447236, 0.9045226, 0.8643216, 0.8241206, 0.7839196, 0.7437186, 0.7035176, 0.6633166, 0.6231156, 0.5829146, 0.5427135, 0.5025125, 0.4623115, 0.4221105, 0.3819095, 0.3417085, 0.3015075, 0.2613065, 0.2211055, 0.1809045, 0.1407035, 0.1005025, 0.06030151, 0.0201005, -0.0201005, -0.06030151, -0.1005025, -0.1407035, -0.1809045, -0.2211055, -0.2613065, -0.3015075, -0.3417085, -0.3819095, -0.4221105, -0.4623115, -0.5025125, -0.5427135, -0.5829146, -0.6231156, -0.6633166, -0.7035176, -0.7437186, -0.7839196, -0.8241206, -0.8643216, -0.9045226, -0.9447236, -0.9849246, -1.025126, -1.065327, -1.105528, -1.145729, -1.185929, -1.226131, -1.266332, -1.306533, -1.346734, -1.386935, -1.427136, -1.467337, -1.507538, -1.547739, -1.58794, -1.628141, -1.668342, -1.708543, -1.748744, -1.788945, -1.829146, -1.869347, -1.909548, -1.949749, -1.98995, -2.030151, -2.070352, -2.110553, -2.150754, -2.190955, -2.231156, -2.271357, -2.311558, -2.351759, -2.39196, -2.432161, -2.472362, -2.512563, -2.552764, -2.592965, -2.633166, -2.673367, -2.713568, -2.753769, -2.79397, -2.834171, -2.874372, -2.914573, -2.954774, -2.994975, -3.035176, -3.075377, -3.115578, -3.155779, -3.19598, -3.236181, -3.276382, -3.316583, -3.356784, -3.396985, -3.437186, -3.477387, -3.517588, -3.557789, -3.59799, -3.638191, -3.678392, -3.718593, -3.758794, -3.798995, -3.839196, -3.879397, -3.919598, -3.959799, -4.0], "y": [1.5558792, 1.4707384, 1.3860499, 1.3021169, 1.2192476, 1.1377548, 1.0579485, 0.9801291, 0.9045843, 0.8315817, 0.7613709, 0.694178, 0.6302079, 0.56964254, 0.5126333, 0.45930666, 0.4097639, 0.3640771, 0.3222824, 0.28438628, 0.25036383, 0.22015977, 0.1937027, 0.17089564, 0.15161985, 0.1357478, 0.12313551, 0.11362845, 0.10706824, 0.10328835, 0.10212123, 0.10339582, 0.10693651, 0.11256474, 0.12010068, 0.1293652, 0.14018154, 0.15237463, 0.16577172, 0.18019569, 0.19547272, 0.21143138, 0.22790223, 0.24471974, 0.26172864, 0.27877855, 0.29572886, 0.31244898, 0.3288195, 0.34472966, 0.36007893, 0.37478316, 0.38876814, 0.40197015, 0.41432893, 0.42580247, 0.4363587, 0.44597697, 0.45464492, 0.4623528, 0.4690976, 0.47488165, 0.4797149, 0.48361695, 0.48661733, 0.48875248, 0.4900686, 0.4906193, 0.4904676, 0.48968637, 0.4883546, 0.4865593, 0.4843905, 0.48194027, 0.47929645, 0.47654784, 0.4737804, 0.47108173, 0.4685339, 0.46621192, 0.4641819, 0.4624994, 0.46120894, 0.46034658, 0.45994127, 0.4600122, 0.4605716, 0.46162653, 0.46317863, 0.4652231, 0.4677478, 0.4707367, 0.47416782, 0.47801065, 0.4822258, 0.4867677, 0.4915854, 0.49662447, 0.50182813, 0.5071391, 0.5125029, 0.5178666, 0.52317846, 0.52839124, 0.5334614, 0.5383487, 0.54299974, 0.5473941, 0.5514919, 0.5552646, 0.55869174, 0.5617566, 0.5644459, 0.566757, 0.568697, 0.57028466, 0.57154626, 0.5725181, 0.5732444, 0.5737744, 0.574156, 0.5744349, 0.57465154, 0.5748401, 0.5750288, 0.5752382, 0.57548165, 0.5757666, 0.57609534, 0.5764649, 0.57686996, 0.577299, 0.5777369, 0.5781662, 0.57856643, 0.57891703, 0.57919496, 0.5793792, 0.579448, 0.5793781, 0.57914865, 0.5787412, 0.5781387, 0.5773265, 0.57629156, 0.57501984, 0.57350016, 0.5717225, 0.56967986, 0.5673674, 0.5647818, 0.5619227, 0.5587924, 0.5553943, 0.5517324, 0.5478108, 0.5436327, 0.5392027, 0.5345243, 0.529599, 0.5244268, 0.51900566, 0.513331, 0.5073969, 0.5011958, 0.4947184, 0.4879544, 0.48089433, 0.4735294, 0.4658525, 0.45785832, 0.4495442, 0.44090992, 0.43195534, 0.42268085, 0.41308647, 0.40317273, 0.3929401, 0.3823881, 0.37151635, 0.36032444, 0.34881085, 0.33697385, 0.32481158, 0.31232214, 0.299499, 0.2863354, 0.27282453, 0.25895977, 0.24473524, 0.23014605, 0.21518803, 0.19985944, 0.18416238, 0.16810346, 0.15169382, 0.1349479, 0.11788535, 0.10052991, 0.082909346, 0.065056086, 0.046997488, 0.028766453, 0.010392785, -1.6399431, -1.6046672, -1.5692997, -1.5339613, -1.4987688, -1.4638346, -1.429266, -1.395162, -1.361607, -1.3286774, -1.2964393, -1.2649472, -1.2342469, -1.2043759, -1.1753628, -1.1472267, -1.1199758, -1.0936109, -1.0681275, -1.0435175, -1.01977, -0.99687195, -0.9748076, -0.9535601, -0.93311054, -0.91343755, -0.89451826, -0.8763285, -0.8588445, -0.8420407, -0.82589054, -0.8103684, -0.79544747, -0.78109974, -0.76729774, -0.7540133, -0.7412175, -0.7288818, -0.7169771, -0.7054744, -0.6943454, -0.6835624, -0.673098, -0.6629255, -0.6530194, -0.64335436, -0.6339059, -0.62465036, -0.61556464, -0.6066261, -0.5978129, -0.589099, -0.5804676, -0.57190156, -0.5633795, -0.5548816, -0.54639155, -0.53789115, -0.529368, -0.52081066, -0.5122084, -0.5035535, -0.49484003, -0.48606396, -0.47722244, -0.46831405, -0.45933855, -0.4502966, -0.4411902, -0.43202245, -0.42279732, -0.41352, -0.4041965, -0.39483374, -0.38543916, -0.37601984, -0.3665825, -0.35713363, -0.3476792, -0.33822513, -0.32877684, -0.31933945, -0.3099177, -0.3005162, -0.29113907, -0.2817902, -0.27247316, -0.26319116, -0.25394702, -0.24474365, -0.23558366, -0.22646952, -0.21740329, -0.20838726, -0.19942337, -0.19051355, -0.18165964, -0.17286313, -0.1641255, -0.15544814, -0.14683223, -0.13827872, -0.12978846, -0.12136215, -0.112999976, -0.10470247, -0.09646964, -0.08829999, -0.08019239, -0.0721454, -0.064157486, -0.05622715, -0.0483529, -0.040533245, -0.03276688, -0.025052547, -0.017388999, -0.009774864, -0.002209127, 0.0053089857, 0.012781143, 0.020208359, 0.027591646, 0.034932017, 0.04222995, 0.0494864, 0.056699872, 0.06387049, 0.07099718, 0.07807964, 0.085116684, 0.09210646, 0.09904748, 0.1059382, 0.1127764, 0.1195606, 0.12628889, 0.13295954, 0.13957083, 0.14612114, 0.15260875, 0.15903217, 0.16538984, 0.17168057, 0.17789996, 0.18404585, 0.19011617, 0.19610977, 0.20202488, 0.20786023, 0.21361464, 0.21928751, 0.22487795, 0.23038507, 0.23580778, 0.24114543, 0.24639702, 0.25156176, 0.25663924, 0.2616289, 0.26653028, 0.2713431, 0.27606708, 0.28069973, 0.28524095, 0.28968978, 0.29404485, 0.2983049, 0.30246836, 0.3065337, 0.3104997, 0.31436443, 0.31812614, 0.32178307, 0.3253337, 0.32877636, 0.33210957, 0.33533192, 0.33844197, 0.34143835, 0.34431982, 0.34708494, 0.34973246, 0.35226113, 0.3546699, 0.35695756, 0.359123, 0.36116493, 0.36308223, 0.36487377, 0.36653835, 0.36807525, 0.36948353, 0.3707623, 0.37191063, 0.37292743, 0.37381184, 0.37456316, 0.37518048, 0.37566328, 0.37601107, 0.3762234, 0.37629986, 0.37624007, 0.37604368, 0.37570995, 0.37523854, 0.37462914, 0.37388128, 0.37299478, 0.3719694, 0.37080485, 0.36949998, 0.36805415, 0.36646664, 0.3647369, 0.36286438, 0.36084843, 0.35868847, 0.3563838, 0.35393405, 0.35133845, 0.3485965, 0.34570783, 0.34267193, 0.33948803, 0.33615595, 0.33267504, 0.32904446, 0.32526356, 0.32133174, 0.3172484, 0.31301284, 0.30862457, 0.30408287, 0.29938734, 0.29453737, 0.28953242, 0.28437185, 0.27905518, 0.2735818, 0.26795125, 0.26216286, 0.2562163, 0.25011086, 0.24384636, 0.23742235, 0.23083866, 0.22409499, 0.21719116, 0.21012676, 0.20290178, 0.19551575, 0.18796843, 0.18026, 0.17238987, 0.16435796, 0.15616423, 0.14780855, 0.13929099, 0.13061154, 0.121769965, 0.112766385, 0.10360074, 0.09427309, 0.084783494, 0.075131774, 0.06531811, 0.055342376, 0.04520488, 0.034905314, 0.024443686, 0.013820052, 0.003034234, -0.007913649, -0.019023955, -0.030296743, -0.041731954, -0.053330243, -0.06509125, -0.07701564, -0.089099646, -0.10134214, -0.11374015, -0.1262902, -0.1389876, -0.15182745, -0.16480517, -0.17791456, -0.19114923, -0.20450258, -0.21796799, -0.23153865, -0.24520785, -0.25896806, -0.27281183, -0.28673154, -0.3007197, -0.31476867, -0.32887095, -0.34301883, -0.35720474, -0.37142116, -0.38565993, -0.39991355, -0.41417444, -0.42843503, -0.44268793, -0.45692557, -0.47114062, -0.48532557, -0.49947274, -0.5135747, -0.527624, -0.54161316, -0.5555347, -0.5693814, -0.5831459, -0.59682095, -0.6103991, -0.62387335, -0.6372363, -0.65048087, -0.6636001, -0.6765873, -0.68943584, -0.7021388, -0.71469, -0.7270827, -0.73931116, -0.7513696, -0.76325285, -0.7749555, -0.78647244, -0.7977984, -0.8089284, -0.81985736, -0.8305807, -0.8410939, -0.8513924, -0.8614719, -0.87132776, -0.88095576, -0.8903521, -0.8995125, -0.9084333, -0.9171106, -0.9255408, -0.93372023, -0.9416453, -0.9493124, -0.9567181, -0.9638587, -0.97073126, -0.97733265, -0.9836599, -0.98971033, -0.9954813, -1.0009704, -1.0061748, -1.0110923, -1.0157207, -1.0200578, -1.0241017, -1.0278505, -1.0313026, -1.0344566, -1.0373113, -1.0398656, -1.0421181, -1.0440676, -1.0457134, -1.0470543, -1.0480896, -1.0488186, -1.0492406, -1.049355, -1.0491617, -1.0486603, -1.0478506, -1.0467327, -1.0453062, -1.0435712, -1.0415276, -1.0391755, -1.0365151, -1.0335463, -1.0302693, -1.026684, -1.0227907, -1.0185888, -1.0140786, -1.0092603, -1.0041338, -0.99869967, -0.9929577, -0.98690844, -0.9805517, -0.9738879, -0.96691763, -0.9596412, -0.95205915, -0.9441719, -0.9359802, -0.9274845, -0.91868544, -0.9095833, -0.9001788, -0.8904724, -0.88046455, -0.8701557, -0.85954654, -0.84863746, -0.8374289, -0.82592154, -0.8141159, -0.79163957, -0.7791682, -0.7664019, -0.7533411, -0.7399862, -0.72633785, -0.71239674, -0.6981635, -0.6836389, -0.6688236, -0.65371823, -0.63832355, -0.6226399, -0.60666823, -0.59040916, -0.5738635, -0.557032, -0.5399153, -0.5225141, -0.5048291, -0.4868611, -0.46861088, -0.45007914, -0.43126673, -0.41217434, -0.39280295, -0.37315315, -0.35322583, -0.33299994, -0.31251645, -0.29175437, -0.27071452, -0.24939793, -0.22780538, -0.20593786, -0.18379617, -0.16138119, -0.13869375, -0.115734875, -0.092505455, -0.06900644, -0.045238852, -0.021203637], "fill": "toself", "fillcolor": "rgba(250, 82, 82, 0.2)", "line": {"color": "rgba(255,255,255,0)"}, "hoverinfo": "skip", "showlegend": false, "name": "Epistemic Uncertainty (±2 std)"}, {"x": [-4.0, -3.959799, -3.919598, -3.879397, -3.839196, -3.798995, -3.758794, -3.718593, -3.678392, -3.638191, -3.59799, -3.557789, -3.517588, -3.477387, -3.437186, -3.396985, -3.356784, -3.316583, -3.276382, -3.236181, -3.19598, -3.155779, -3.115578, -3.075377, -3.035176, -2.994975, -2.954774, -2.914573, -2.874372, -2.834171, -2.79397, -2.753769, -2.713568, -2.673367, -2.633166, -2.592965, -2.552764, -2.512563, -2.472362, -2.432161, -2.39196, -2.351759, -2.311558, -2.271357, -2.231156, -2.190955, -2.150754, -2.110553, -2.070352, -2.030151, -1.98995, -1.949749, -1.909548, -1.869347, -1.829146, -1.788945, -1.748744, -1.708543, -1.668342, -1.628141, -1.58794, -1.547739, -1.507538, -1.467337, -1.427136, -1.386935, -1.346734, -1.306533, -1.266332, -1.226131, -1.185929, -1.145729, -1.105528, -1.065327, -1.025126, -0.9849246, -0.9447236, -0.9045226, -0.8643216, -0.8241206, -0.7839196, -0.7437186, -0.7035176, -0.6633166, -0.6231156, -0.5829146, -0.5427135, -0.5025125, -0.4623115, -0.4221105, -0.3819095, -0.3417085, -0.3015075, -0.2613065, -0.2211055, -0.1809045, -0.1407035, -0.1005025, -0.06030151, -0.0201005, 0.0201005, 0.06030151, 0.1005025, 0.1407035, 0.1809045, 0.2211055, 0.2613065, 0.3015075, 0.3417085, 0.3819095, 0.4221105, 0.4623115, 0.5025125, 0.5427135, 0.5829146, 0.6231156, 0.6633166, 0.7035176, 0.7437186, 0.7839196, 0.8241206, 0.8643216, 0.9045226, 0.9447236, 0.9849246, 1.025126, 1.065327, 1.105528, 1.145729, 1.185929, 1.226131, 1.266332, 1.306533, 1.346734, 1.386935, 1.427136, 1.467337, 1.507538, 1.547739, 1.58794, 1.628141, 1.668342, 1.708543, 1.748744, 1.788945, 1.829146, 1.869347, 1.909548, 1.949749, 1.98995, 2.030151, 2.070352, 2.110553, 2.150754, 2.190955, 2.231156, 2.271357, 2.311558, 2.351759, 2.39196, 2.432161, 2.472362, 2.512563, 2.552764, 2.592965, 2.633166, 2.673367, 2.713568, 2.753769, 2.79397, 2.834171, 2.874372, 2.914573, 2.954774, 2.994975, 3.035176, 3.075377, 3.115578, 3.155779, 3.19598, 3.236181, 3.276382, 3.316583, 3.356784, 3.396985, 3.437186, 3.477387, 3.517588, 3.557789, 3.59799, 3.638191, 3.678392, 3.718593, 3.758794, 3.798995, 3.839196, 3.879397, 3.919598, 3.959799, 4.0], "y": [-0.819947, -0.7824755, -0.7452142, -0.70828056, -0.6717931, -0.6358687, -0.6006203, -0.56615573, -0.53257877, -0.4999895, -0.46848118, -0.43813938, -0.409042, -0.3812585, -0.35484767, -0.32986057, -0.30634046, -0.2843194, -0.2638188, -0.24485224, -0.22742105, -0.21151513, -0.19711399, -0.18418944, -0.17270374, -0.16261035, -0.15386218, -0.14640707, -0.14019126, -0.13515854, -0.13125122, -0.1284098, -0.12657481, -0.12568784, -0.12569124, -0.12652874, -0.12814456, -0.13048214, -0.13348514, -0.13709837, -0.14126772, -0.1459403, -0.15106297, -0.15658438, -0.16245484, -0.16862637, -0.17505348, -0.18169254, -0.18850183, -0.19544196, -0.20247585, -0.20956945, -0.2166897, -0.22380555, -0.2308873, -0.23790693, -0.24483848, -0.2516576, -0.2583406, -0.2648643, -0.2712071, -0.27734864, -0.28326976, -0.2889531, -0.2943819, -0.29954088, -0.30441558, -0.3089928, -0.3132605, -0.3172078, -0.32082504, -0.32410383, -0.3270371, -0.32961893, -0.3318449, -0.3337113, -0.3352151, -0.3363545, -0.33712757, -0.33753318, -0.33757138, -0.3372418, -0.33654577, -0.335484, -0.33405793, -0.3322693, -0.33012015, -0.3276133, -0.32475245, -0.3215419, -0.31798667, -0.31409186, -0.30986297, -0.30530614, -0.30042756, -0.2952339, -0.28973204, -0.2839295, -0.27783388, -0.27145308, -0.26479536, -0.2578689, -0.2506821, -0.2432437, -0.23556197, -0.2276457, -0.21950376, -0.2111448, -0.20257765, -0.19381118, -0.18485415, -0.17571527, -0.16640335, -0.15692705, -0.14729506, -0.13751614, -0.12759894, -0.11755216, -0.10738456, -0.09710485, -0.08672178, -0.076244056, -0.065680385, -0.055039465, -0.04433006, -0.03356093, -0.0227409, -0.011878669, -0.0009829998, 0.009936094, 0.02087003, 0.031809986, 0.04274684, 0.05367154, 0.06457501, 0.075448215, 0.08628219, 0.09706795, 0.10779655, 0.11845899, 0.12904644, 0.13955015, 0.14996147, 0.16027188, 0.17047286, 0.18055606, 0.19051313, 0.20033574, 0.21001554, 0.21954441, 0.22891426, 0.23811704, 0.24714482, 0.25599027, 0.26464623, 0.27310586, 0.28136247, 0.2894094, 0.29724014, 0.30484843, 0.31222796, 0.31937283, 0.32627738, 0.33293593, 0.33934295, 0.34549308, 0.35138118, 0.35700214, 0.36235118, 0.36742353, 0.37221467, 0.37672025, 0.38093603, 0.38485777, 0.38848144, 0.39179915, 0.39481115, 0.39751762, 0.3999188, 0.40201473, 0.40380597, 0.40529317, 0.40647686, 0.40735787, 0.40793687, 0.40821487, 0.40819287, 0.40787184, 0.40725285, 0.40633786], "type": "scatter", "mode": "lines", "name": "Predictive Mean", "line": {"color": "#f03e3e"}}, {"x": [-3.0, -2.9597316, -2.9194632, -2.8791947, -2.8389263, -2.7986577, -2.7583892, -2.718121, -2.6778524, -2.637584, -2.5973153, -2.557047, -2.5167787, -2.4765103, -2.4362416, -2.3959732, -2.3557048, -2.3154364, -2.2751677, -2.2348993, -2.194631, -2.1543624, -2.114094, -2.0738256, -2.0335572, -1.9932885, -1.9530201, -1.9127518, -1.8724833, -1.8322148, -1.7919464, -1.751678, -1.7114094, -1.671141, -1.6308725, -1.5906041, -1.5503356, -1.5100671, -1.4697987, -1.4295303, -1.3892617, -1.3489933, -1.3087249, -1.2684565, -1.2281879, -1.1879195, -1.1476511, -1.1073826, -1.0671141, -1.0268457, -0.9865772, -0.94630873, -0.9060403, -0.8657718, -0.82550335, -0.7852349, -0.74496645, -0.70469797, -0.66442954, -0.62416106, -0.58389264, -0.54362416, -0.5033557, -0.46308723, -0.4228188, -0.38255033, -0.34228188, -0.30201343, -0.26174498, -0.22147651, -0.18120806, -0.1409396, -0.10067114, -0.06040268, -0.020134227, 0.020134227, 0.06040268, 0.10067114, 0.1409396, 0.18120806, 0.22147651, 0.26174498, 0.30201343, 0.34228188, 0.38255033, 0.4228188, 0.46308723, 0.5033557, 0.54362416, 0.58389264, 0.62416106, 0.66442954, 0.70469797, 0.74496645, 0.7852349, 0.82550335, 0.8657718, 0.9060403, 0.94630873, 0.9865772, 1.0268457, 1.0671141, 1.1073826, 1.1476511, 1.1879195, 1.2281879, 1.2684565, 1.3087249, 1.3489933, 1.3892617, 1.4295303, 1.4697987, 1.5100671, 1.5503356, 1.5906041, 1.6308725, 1.671141, 1.7114094, 1.751678, 1.7919464, 1.8322148, 1.8724833, 1.9127518, 1.9530201, 1.9932885, 2.0335572, 2.0738256, 2.114094, 2.1543624, 2.194631, 2.2348993, 2.2751677, 2.3154364, 2.3557048, 2.3959732, 2.4362416, 2.4765103, 2.5167787, 2.557047, 2.5973153, 2.637584, 2.6778524, 2.718121, 2.7583892, 2.7986577, 2.8389263, 2.8791947, 2.9194632, 2.9597316, 3.0], "y": [1.0303671, 1.050024, 0.34705496, 0.01273185, -0.30396357, -0.29729748, -0.8527952, -0.9483495, -1.1109892, -0.8656279, -0.98107433, -1.2403715, -1.2898401, -1.2215685, -1.1911366, -1.3070737, -1.1893226, -1.296894, -0.8970964, -0.88211715, -0.7332828, -0.60031426, -0.34907925, -0.32362396, -0.07006347, 0.13466883, 0.18986344, 0.19970965, 0.24910164, 0.3592739, 0.27973264, 0.30015373, 0.15951216, 0.2634468, 0.07025027, 0.053674817, 0.03876072, -0.19164044, -0.11680764, -0.09389448, -0.17775708, -0.2371642, -0.21847367, -0.12808496, 0.0019137263, -0.015195906, -0.073530376, -0.15822774, -0.33197582, -0.030967653, -0.07161963, 0.0657717, 0.0283497, 0.14192665, -0.020459652, 0.08628744, 0.14018708, 0.14055079, 0.05315751, 0.061556935, -0.27404195, -0.08554834, -0.26678258, 0.032500029, -0.06732887, 0.055856705, -0.09388715, 0.04359156, 0.09010857, -0.021685064, 0.1250062, 0.083204925, -0.2292543, -0.10864186, -0.10440737, 0.013017178, -0.04824865, 0.09417486, 0.27899224, 0.15981823, 0.22726798, 0.28625697, 0.3458653, 0.433839, 0.5591351, 0.70038676, 0.5945583, 0.83024424, 0.8329691, 0.9106376, 0.8723597, 0.78850543, 0.8853815, 0.75701463, 0.9694258, 1.0420696, 0.88944215, 0.8781106, 0.9598303, 1.0842764, 0.9817552, 1.1226631, 1.025401, 1.121171, 1.1287006, 1.1240332, 1.1152806, 0.9629704, 0.91996074, 1.147251, 0.8633995, 0.88427883, 0.90551794, 0.9768483, 0.8712597, 0.5804621, 0.5553012, 0.65278876, 0.6639653, 0.3424074, 0.5827954, 0.23436713, 0.33870244, 0.37170887, 0.3356334, -0.016494572, 0.05951333, -0.03467077, 0.10633749, -0.23458183, -0.28548563, -0.42358667, -0.4177636, -0.53030837, -0.6848571, -0.93702555, -0.8196386, -1.1445826, -1.0346138, -1.2190297, -1.3237003, -1.3124739, -1.657484, -1.5409082, -1.4635613, -1.4730215], "type": "scatter", "mode": "markers", "name": "Training Data", "marker": {"color": "#1c7ed6", "size": 6}}] }BNN predictive mean (red line) captures the underlying trend, while the shaded area (±2 standard deviations from the mean) represents epistemic uncertainty. Notice the uncertainty increases in regions with no training data (e.g., $x < -3$ or $x > 3$) and also where the function changes rapidly.Alternative: MC DropoutAs discussed previously, Monte Carlo (MC) Dropout offers a simpler way to approximate Bayesian inference in existing standard NNs. It involves:Training a standard neural network with dropout layers.At prediction time, keeping dropout active and performing multiple forward passes for the same input.Calculating the mean and variance/standard deviation of these multiple predictions to estimate the predictive mean and uncertainty.While computationally cheaper and easier to implement in standard frameworks, MC Dropout is an approximation to a specific type of BNN (related to Gaussian Processes). The VI approach we implemented is often considered a more principled way to construct BNNs with explicit priors and posteriors.Summary and Next StepsIn this practical section, we constructed a Bayesian Neural Network using TensorFlow Probability's DenseVariational layers. We trained it using Variational Inference, where the objective function balanced fitting the data (via negative log-likelihood/MSE) and adhering to prior beliefs (via KL divergence). By sampling from the learned approximate posterior distribution of weights, we generated predictions along with quantifiable epistemic uncertainty estimates.This example provides a foundation for applying BNNs. You could extend this by:Modeling aleatoric uncertainty explicitly by having the network predict the variance (scale parameter) of the output distribution.Trying different network architectures, priors, or variational families.Applying BNNs to classification tasks (requiring a different likelihood, like Categorical).Exploring MCMC methods like SGHMC for potentially more accurate (but often slower) posterior sampling.Comparing the performance and calibration of the BNN against a standard NN and MC Dropout.Building BNNs provides a powerful framework for creating deep learning models that not only predict but also understand their own confidence.