## Load required libraries. library("rjags") ## Define the BNN model fitting function. bnn <- function(X, Y, num_hidden = 3, n_chains = 3, n_adapt = 500, n_iter = 5000, n_burnin = 1000) { ## Ensure X is a matrix. X <- as.matrix(X) ## Prepare data for JAGS. data_jags <- list( N = nrow(X), ## Number of data points. num_inputs = ncol(X), ## Number of input features. num_hidden = num_hidden, ## Number of neurons in the hidden layer. X = X, ## Input features matrix. Y = Y ## Output variable. ) ## JAGS model definition with two-parameter output (mean and precision). model_string <- " model { ## Priors for the input-to-hidden layer weights and biases. for(i in 1:num_inputs){ for(j in 1:num_hidden){ W1[i, j] ~ dnorm(0, 1) ## Weights from input to hidden layer. } } for(j in 1:num_hidden){ b1[j] ~ dnorm(0, 1) ## Biases for hidden layer. } ## Priors for the hidden-to-output layer weights and biases (mean and log-precision). for(j in 1:num_hidden){ W2_mean[j] ~ dnorm(0, 1) ## Weights from hidden to output mean layer. W2_prec[j] ~ dnorm(0, 1) ## Weights from hidden to output precision layer. } b2_mean ~ dnorm(0, 1) ## Bias for output mean layer. b2_prec ~ dnorm(0, 1) ## Bias for output precision layer. ## Likelihood: forward pass through the network. for(n in 1:N){ for(j in 1:num_hidden){ z1[n, j] <- 1 / (1 + exp(-(inprod(X[n, ], W1[, j]) + b1[j]))) ## Hidden layer (sigmoid activation). } ## Output layer for mean and log-precision. mean_y_hat[n] <- inprod(z1[n, ], W2_mean) + b2_mean ## Output layer for mean. log_prec_y_hat[n] <- inprod(z1[n, ], W2_prec) + b2_prec ## Output layer for log-precision. prec_y_hat[n] <- exp(log_prec_y_hat[n]) ## Exponentiate to get precision. ## Observation model: normal with individual mean and precision for each point. Y[n] ~ dnorm(mean_y_hat[n], prec_y_hat[n]) log_lik[n] <- logdensity.norm(Y[n], mean_y_hat[n], prec_y_hat[n]) ## Log-likelihood for each observation. } } " ## Save model string to file. writeLines(model_string, "bnn_model.jags") ## Compile the JAGS model. model <- jags.model("bnn_model.jags", data = data_jags, n.chains = n_chains, n.adapt = n_adapt) ## Burn-in period. update(model, n_burnin) ## Sample from the posterior. samples <- coda.samples(model, variable.names = c("W1", "b1", "W2_mean", "W2_prec", "b2_mean", "b2_prec", "log_lik"), n.iter = n_iter) ## Return the fitted model and samples as a list. rval <- list(model = model, samples = samples, data = data_jags) class(rval) <- "bnn" return(rval) } ## Prediction function for the bnn model. predict.bnn <- function(object, X_new) { ## Ensure X_new is a matrix. X_new <- as.matrix(X_new) ## Extract samples. samples_matrix <- as.matrix(object$samples) ## Extract weights and biases from the samples. W1_samples <- samples_matrix[, grep("W1", colnames(samples_matrix))] W2_mean_samples <- samples_matrix[, grep("W2_mean", colnames(samples_matrix))] W2_prec_samples <- samples_matrix[, grep("W2_prec", colnames(samples_matrix))] b1_samples <- samples_matrix[, grep("b1", colnames(samples_matrix))] b2_mean_samples <- samples_matrix[, grep("b2_mean", colnames(samples_matrix))] b2_prec_samples <- samples_matrix[, grep("b2_prec", colnames(samples_matrix))] ## Number of posterior samples. n_samples <- nrow(W1_samples) ## Function to compute mean and variance predictions using one posterior sample. predict_bnn <- function(W1, b1, W2_mean, W2_prec, b2_mean, b2_prec, X) { ## Compute hidden layer activations using the sigmoid function. z1 <- 1 / (1 + exp(-(X %*% matrix(W1, ncol = length(b1)) + matrix(b1, nrow = nrow(X), ncol = length(b1), byrow = TRUE)))) ## Compute output layer for mean and log-precision. mean_y_hat <- z1 %*% W2_mean + b2_mean log_prec_y_hat <- z1 %*% W2_prec + b2_prec prec_y_hat <- exp(log_prec_y_hat) ## Convert log-precision to precision (inverse variance). var_y_hat <- 1 / prec_y_hat ## Convert precision to variance. return(list(mean = mean_y_hat, variance = var_y_hat)) } ## Initialize matrices to store predictions for each sample. mean_predictions <- matrix(0, nrow = n_samples, ncol = nrow(X_new)) variance_predictions <- matrix(0, nrow = n_samples, ncol = nrow(X_new)) ## Compute predictions for each posterior sample. for(i in 1:n_samples){ W1 <- W1_samples[i, ] W2_mean <- W2_mean_samples[i, ] W2_prec <- W2_prec_samples[i, ] b1 <- b1_samples[i, ] b2_mean <- b2_mean_samples[i] b2_prec <- b2_prec_samples[i] preds <- predict_bnn(W1, b1, W2_mean, W2_prec, b2_mean, b2_prec, X_new) mean_predictions[i, ] <- preds$mean variance_predictions[i, ] <- preds$variance } ## Compute summary statistics for the predictions. mean_pred <- apply(mean_predictions, 2, mean) var_pred <- apply(variance_predictions, 2, mean) ## credible_intervals <- apply(mean_predictions, 2, quantile, probs = c(0.025, 0.975)) ## Return mean predictions, variances, and credible intervals. return(data.frame(mu = mean_pred, sigma = sqrt(var_pred))) } ## Simulate data. set.seed(123) n <- 300 x <- runif(n, -3, 3) y <- sin(x) + rnorm(n, sd = exp(-1 + cos(x))) ## Fit the BNN model. b <- bnn(x, y, num_hidden = 10) ## Generate predictions. p <- predict(b, x) fit <- cbind( qnorm(0.025, mean = p$mu, sd = p$sigma), qnorm(0.5, mean = p$mu, sd = p$sigma), qnorm(1 - 0.025, mean = p$mu, sd = p$sigma) ) ## Plot predictions. plot(x, y, main = "BNN Predictions") i <- order(x) matplot(x[i], fit[i, ], col = 4, type = "l", lwd = 2, lty = 1, add = TRUE) ## Log-likelihood. ll <- sapply(b$samples, function(x) colSums(x[, grep("log_lik", colnames(x), fix = TRUE)])) ll <- as.numeric(ll) plot(ll, type = "l", xlab = "Iteration", ylab = "Log-Likelihood")