BNNpriors: A library for Bayesian neural network inference with different prior distributions

Bayesian neural networks have shown great promise in many applications where calibrated uncertainty estimates are crucial and can often also lead to a higher predictive performance. However, it remains challenging to choose a good prior distribution over their weights. While isotropic Gaussian priors are often chosen in practice due to their simplicity, they do not reflect our true prior beliefs well and can lead to suboptimal performance. Our new library, BNNpriors, enables state-of-the-art Markov Chain Monte Carlo inference on Bayesian neural networks with a wide range of predefined priors, including heavy-tailed ones, hierarchical ones, and mixture priors. Moreover, it follows a modular approach that eases the design and implementation of new custom priors. It has facilitated foundational discoveries on the nature of the cold posterior effect in Bayesian neural networks and will hopefully catalyze future research as well as practical applications in this area.


Introduction
While standard neural networks fit a single point estimate of their weightŝ θ (typically using stochastic gradient descent), Bayesian neural networks (BNNs) instead infer a posterior distribution p(θ|D) over their weights given some training data D [1,2]. They perform this inference using Bayes' theorem, that is, p(θ|D) = Z −1 p(D|θ) p(θ) [3]. Using this posterior, BNNs can then offer a predictive distribution over outputs, which can help quantify their uncertainty with respect to specific predictions [4,5]. In safety-critical applications like medicine or autonomous driving, such calibrated uncertainty estimates can be crucial [6].
The distribution p(θ) in the equation above is called the prior distribution and has to be carefully chosen to encode our true prior beliefs in order to build a successful Bayesian model [7]. However, this is challenging in BNNs, since we often do not know how to properly formulate prior beliefs over the weight space. In practice, isotropic Gaussian priors are thus often used for their simplicity and computational appeal [e.g., 8,9].
Recently, it has been hypothesized that these isotropic Gaussian priors are the culprit for the cold posterior effect, that is, the fact that tempered posteriors at lower temperatures can perform better than the true Bayesian posterior [10]. This hypothesis has been partially confirmed and it has been shown that indeed the performance of BNNs can be improved and the cold posterior effect alleviated when using different priors [11]. Specifically, heavytailed priors seem to work well for fully-connected BNNs and correlated priors for convolutional BNNs. These priors are usually not available in standard BNN frameworks, which is why our novel BNN library, BNNpriors, introduces them and makes them easily usable and extendable for research and application purposes.

Description
The BNNpriors library is written in Python and uses PyTorch [12] for defining the neural networks models and performing automatic differentiation. It offers different inference schemes, including stochastic gradient Langevin dynamics (SGLD) [13] and Hamiltonian Monte Carlo (HMC) [14], but we would recommend using gradient-guided Monte Carlo [15] (which in the library is called VerletSGLDReject). Moreover, the inference allows for the use of cyclical learning rate schemes [16] in order to cover several posterior modes as well as the learning of a preconditioner matrix [10] that scales the momenta of the individual weights. To the best of our knowledge, this provides the most accurate scalable BNN posterior inference with stochastic gradients available today.
To perform inference on the standard Bayes posterior, the temperature parameter of the sampler has to be set to T = 1. However, one can also use cold posteriors [10], by setting the temperature to T < 1. In order to assess the accuracy of the inference in every single experiment, the sampler will  Figure 1: Example plot of kinetic temperature estimates for Markov chains sampled at different temperatures T . We see that the kinetic temperatures coincide with the true temperatures and that the inference is thus accurate. (Reproduced with permission from [11], best viewed in color) estimate diagnostics such as the kinetic temperature and the configurational temperature [10]. These should in expectation coincide with the temperature parameter T and can reveal inference problems, for instance, if the learning rate is set too high or the Markov chain is not converged. Some example temperature diagnostics of an accurate inference run are shown in Figure 1.
The BNNs in our framework are built from normal PyTorch modules (torch.nn.module), with the difference that their weights are not instances of the torch.Parameter class, but of our bnn priors.prior.Prior class. Our library includes a range of predefined priors within a modular taxonomy, such that new priors can be easily defined and can inherit from existing superclasses, such as location-scale distributions or multivariate distributions. Moreover, the hyperparameters of our priors can themselves be Prior objects, which allows for the definition of hierarchical prior models. Also, the Mixture class allows to define mixture priors from all the other existing priors. A short overview over some popular prior distributions included in our library is given in Table 1.

Usage
A basic BNN inference experiment would use the train bnn.py script in our library, which takes a number of arguments, including the choice of model, choice of prior, choice of dataset, and some training parameters, such as number of samples, number of training epochs, temperature, learning rate, and similar. This script will then create an output directory with weight samples from the BNN posterior as well as training curves of different performance metrics and the aforementioned inference diagnostics. The generated samples can be used with the test bnn.py script to create predictions on different evaluation datasets and compute different performance metrics, including accuracy, log likelihood, calibration error, and out-of-distribution Prior Density p(θ) Hierarchical detection. When run with different temperature parameters T , one can also plot these performance metrics against the temperature to create tempering curves that can show the cold posterior effect (or its absence). Some example tempering curves with different priors on different datasets are shown in Figure 2.

Impact
This library has enabled the study of the cold posterior effect in dependence of different priors [11]. It has led to the first observation that the cold posterior effect in BNNs can indeed be caused by the misspecification of the prior (e.g., by choosing an isotropic Gaussian) and that the performance of the BNN posterior on several metrics can be improved by choosing different priors. It has helped identify heavy-tailed priors for fully-connected BNNs and correlated priors for convolutional BNNs as better choices. We expect that these insights will not only spur a series of further studies into the role of priors in BNNs, but also the use of such priors in real-world applications. Both of these can be catalyzed by our BNNpriors library.
With respect to the inference, our library has been used to show that GGMC inference can use stochastic gradients and still yield nonzero Metropolis-Hastings acceptance probabilities in BNNs [15], while this is not true for the scheme known as stochastic gradient HMC (SGHMC). Based on these observations, and in combination with the cyclical learning rates, the preconditioning, and the temperature diagnostics, we thus believe that our library offers the most comprehensive state-of-the-art BNN inference, even without We can see that some priors perform much better than others and can also alleviate the cold posterior effect. (Reproduced with permission from [11], best viewed in color) considering the priors. We hope that his inference can lead to more accurate studies of BNN posteriors and a better performance in real-world applications.

Limitations
While our current library is of course limited in the number of predefined priors it contains, we hope that this should not pose a problem in practice, since new priors can easily be defined in our modular framework. However, a more serious limitation could be that for hierarchical priors, we currently only support joint inference of the hyperparameters with the BNN parameters. It would be interesting to extend our inference framework to also allow for Gibbs sampling or reversible jump Monte Carlo [17] in these models. Moreover, while our library should generally be usable with any kind of BNN model that is definable in PyTorch, we have not tested it for recurrent neural networks [18] nor attention-based ones [19]. Using our library on such models might require a certain amount of manual tuning. Finally, in order to aid truly Bayesian model selection of priors, it would be useful to be able to estimate marginal likelihoods from our Markov chains. This is generally a challenging problem, but there are a few promising solutions [20,21]. With these estimates, it could even be possible to learn useful BNN priors entirely from scratch [22]. We hope to add all these features to our library in the future, but also welcome open source contributions from the community.