1 Introduction

Large parts of contemporary research in machine learning presume the availability of massive datasets and focus on pure prediction. In problems where data is complex as well as scarce, however, and where it is of high importance to quantify uncertainties of predictions, rigorous probabilistic modeling is essential for achieving good generalization performance. The Bayesian formalism provides a natural setting for such modeling—by specifying a model \(p(\{y^n\}_n, y^*)\) jointly over observed data \(\{y^n\}_n \subset {\mathcal {Y}}\) and future data \(y^* \in {\mathcal {Y}}\), denoting by \({\mathcal {Y}}\) the space of outputs, obtaining the predictive distribution \(p(y^* \mid \{y^n\}_n)\) reduces to an integration problem. When using parametric models, prior assumptions on the parameters act as regularizations preventing overfitting, but the requirement of an explicit parameterization may be problematic when outputs depend on inputs through highly nonlinear relationships. Nonparametric alternatives, on the other hand, are associated with other issues—simple models such as Gaussian processes struggle with multivariate, skewed and multimodal outputs, and more sophisticated nonparametric Bayesian regression models based on Dirichlet processes require ad hoc constructions to incorporate input dependency and specially developed Markov chain Monte Carlo algorithms for posterior inference. To address these challenges, in this paper, we propose a novel nonparametric mixture-of-experts model designed to be able to model any predictive density while being robust against small datasets and overfitting, equipped with a mean-field variational Bayes algorithm for posterior inference.

For modeling irregular densities, Gaussian mixtures are a popular choice as this class of distributions is dense in the set of all probability distributions (Nguyen and McLachlan 2019), thus being able to approximate any target density arbitrarily well. Parameter inference for finite Gaussian mixtures is particularly tractable with algorithms such as expectation–maximization (EM), mean-field variational Bayes and Gibbs sampling (Murphy 2012). In a supervised learning context, if a discriminative approach is preferred, an extension from mixture models are mixture-of-experts models, first introduced by Jacobs et al. (1991), which allow the mixture components—the experts—as well as the mixture weights—the gate—to depend on inputs \(x \in {\mathcal {X}}\) from an input space \({\mathcal {X}}\). A simple example is the model

$$\begin{aligned}&z \mid x, \theta&\quad \sim {\text{ Cat }}\!\left( {\text{ softmax }}(\Phi x) \right) , \\ {}&y \mid x, z = c, \theta&\quad \sim {\text{ N }}(\mu _c, \Sigma _c), \end{aligned}$$

where \(\theta = (\{(\mu _c, \Sigma _c)\}_c, \Phi )\), \(\Phi \) is some transformation matrix and z is a categorical latent variable with probability masses given by the softmax of \(\Phi x\). Note that here, only the gate is input-dependent, which is sometimes known as a gating network mixture-of-experts model (Murphy and Murphy 2020)—however, as a more complex example, we could also let the experts be input-dependent by e.g. letting the \(\mu _c\) and \(\Sigma _c\) be defined through parameterized transformations of x, known as a full mixture-of-experts model. More advanced designs have been investigated in the literature—for example, Jordan and Jacobs (1994) and Bishop and Svensén (2003) use a hierarchical gating network for increased model complexity, and Xu et al. (1995), Ueda and Ghahramani (2002), Baldacchino et al. (2016) and Ingrassia et al. (2012) use a gate which itself is Gaussian or mixture-Gaussian. While some of these (Jordan and Jacobs 1994; Xu et al. 1995) use EM for maximum-likelihood parameter inference, others (Waterhouse et al. 1996; Ueda and Ghahramani 2002; Bishop and Svensén 2003; Baldacchino et al. 2016) have employed variational Bayesian methods to approximate a posterior over the parameters, thereby reducing the imminent risk of overfitting and collapsing variances associated with maximum-likelihood (Bishop 2006; Waterhouse et al. 1996). For an extensive review of various forms of mixture-of-experts models and methods for parameter inference, see Yuksel et al. (2012).

Common for the above models is that they are parametric, assuming conditional independence across the data pairs \(\{(x^n, y^n)\}_n \subset {\mathcal {X}} \times {\mathcal {Y}}\) given some finite-dimensional parameter \(\theta \). For exchangeable data, such a model is guaranteed to exist without loss of generality in the infinite-data limit by de Finetti’s theorem (Schervish 2012). However, due to the form of the posterior predictive distribution of the output \(y^* \in {\mathcal {Y}}\) corresponding to a new input \(x^*\in {\mathcal {X}}\), written as

$$\begin{aligned} p(y^* \mid x^*, \{(x^n, y^n)\}_n) = {{\mathbb {E}}}^{p(\theta \; \mid \; \{(x^n, y^n)\}_n)} p(y^* \mid x^*, \theta ), \end{aligned}$$

all knowledge learned from the training data must be encapsulated in the posterior \(p(\theta \mid \{(x^n, y^n)\}_n)\). While this does not present any practical limitations for most types of datasets, problems typically arise when the relationships between inputs and outputs are complex, when the dimension \({\text {dim}} {\mathcal {X}}\) of the inputs is medium-high or high, and when data is scarce—in general, the former two situations force \(\theta \) to also be of relatively high dimension which, in combination with the latter, renders inference especially hard. Nonparametric models, in contrast, characterized by their growth in complexity along with increasing data availability and often being based on comparing similarities between data points rather than learning a mapping from input to output, do not suffer from the same problem. Such similarity-based models are most easily understood from a predictive perspective—when inferring \(y^*\) given \(x^*\) and training data \(\{(x^n, y^n)\}_n\), we first determine similarities between \(x^*\) and each \(x^n\) and base our prediction on these similarities and the \(y^n\). The simplest example is the k-nearest neighbors algorithm, which has been extended to a probabilistic framework by Holmes and Adams (2002), who replace the pure averaging over the k-nearest neighbors by a softmax weighting and sample from the posterior of k and the softmax sharpness parameter using Markov chain Monte Carlo (MCMC) methods. The approach has since been revisited and refined by Cucala et al. (2009) and, subsequently, Friel and Pettitt (2011). For regression problems, we have similar models such as kernel machines (Bishop 2006) and the more probabilistically rigorous Gaussian processes (Rasmussen and Williams 2006). Although several extensions of Gaussian processes (Bonilla et al. 2008; Rasmussen and Ghahramani 2002) have been proposed to address, for example, their limitation to univariate outputs and Gaussian likelihoods, no combined method yet exists for handling arbitrary predictive distributions.

Other nonparametric Bayesian methods for conditional density estimation involve Dirichlet process mixture models (DPMMs) (Ferguson 1973; Sethuraman 1994; Antoniak 1974). Here, the Dirichlet process is used as a nonparametric prior in an infinite mixture model

$$\begin{aligned}&G&\quad \sim {\text{ DP }}(\alpha , G_0), \\ {}&\theta \mid G&\quad \sim G, \\ {}&y \mid \theta&\quad \sim p(y \mid \theta ), \end{aligned}$$

where \({\text {DP}}(\alpha , G_0)\) is the Dirichlet process with concentration parameter \(\alpha \) and base measure \(G_0\) and \(p(y \mid \theta )\) refers to some likelihood model. Several methods exist for using the DPMM in a regression setting, which has been reviewed by Müller et al. (2015). In particular, while simpler approaches are based on estimating the regression means and residuals separately, in a fully nonparametric regression method, the dependency on an input x may be introduced by replacing G with an x-indexed analogue \(G_x\) and instead placing the nonparametric prior over \(\{G_x\}_{x \in {\mathcal {X}}}\). Known as the dependent Dirichlet process (DDP) (MacEachern 1999), this model has been investigated with various choices of priors for \(\{G_x\}_{x \in {\mathcal {X}}}\)—for example, De Iorio et al. (2004) proposed the ANOVA-DDP model in which locations in the stick-breaking representation of the Dirichlet process (Sethuraman 1994) are linear in the inputs, which was further developed into the linear DDP model (De Iorio et al. 2009); Dunson and Park (2008) proposed a kernelized stick-breaking process in which the breaking probabilities are input-dependent and similarity-based; Jara and Hanson (2011) instead proposed to replace the usual beta priors in the stick-breaking representation of the Dirichlet process by logistic-transformed Gaussian processes, deviating from the DPMM. Common drawbacks of these models, however, are that they require the modeling of an input-to-parameter process \(G_x\), which is in general multivariate, that the construction of a prior over \(\{G_x\}_{x \in {\mathcal {X}}}\) essentially amounts to yet another multivariate regression problem, and that the available posterior MCMC inference algorithms exploit the respective specific forms of the priors. For instance, if using a normal likelihood \(p(y \mid \theta ) = {\text {N}}(y \mid \mu , \Sigma )\) with \(\theta = (\mu , \Sigma )\), it is not clear how one would go about specifying the \(({\text {dim}} {\mathcal {Y}} + ({\text {dim}} {\mathcal {Y}})^2)\)-dimensional process \(G_x\) and its prior in such a way that the input-to-parameter relation is adequately modeled and posterior inference is feasible. As an alternative perspective, known as the conditional DPMM (Müller et al. 1996; Cruz-Marcelo et al. 2013), one may model the data generatively and fit a DPMM to the set \(\{(x^n, y^n)\}_n \subset {\mathcal {X}} \oplus {\mathcal {Y}}\) of concatenated input–output pairs, from which conditional distributions are easily obtained due to the mixture-Gaussian form of the likelihood. While MCMC methods, which are known to scale poorly with model size and data dimensionality, are used for posterior inference in all previous cases, an alternative variational Bayes algorithm (Blei and Jordan 2006) is possible in the this model. On the other hand, the generative approach of the conditional DPMM relies on modeling both the input and output spaces using a mixture model, which may be unsuitable if the dimensionality of the inputs is large compared to that of the outputs. Moreover, the conditional DPMM relies on an approximate independence assumption across the data (Müller et al. 2015), and the model is therefore not similarity-based in the sense that \(x^*\) is compared with each \(x^n\) when predicting \(y^*\).

In this paper, we present a new similarity-based mixture-of-experts model combining the advantages of similarity-based nonparametric models with the flexibility of Gaussian mixtures, all while avoiding generative modeling and MCMC posterior inference. Specifically, the model uses multivariate Gaussian experts and a gating network comprising two layers of transitions when making predictions: a first layer where similarities between \(x^*\) and each \(x^n\) are used to compute transition probabilities, and a second layer where the conditional probabilities of belonging to each expert c given each \(y^n\) are computed. As such, in contrast to aforementioned DPMM-based nonparametric Bayesian regression methods, in which input-to-parameter relations must be explicitly modeled by \(G_x\) and its prior, all input–output mappings are handled through the observed data pairs \(\{(x^n, y^n)\}_n\). The model entails a predictive likelihood on the form of a multivariate Gaussian mixture, capable of modeling marginals as well as dependencies across the components of the output variable. Parameters are inferred using a mean-field variational Bayes algorithm, where local approximations of the likelihood are introduced to make variational posteriors analytically tractable and where the corresponding variational parameters may be updated using EM iterations. A gradient-based way of optimizing the similarity metric of the first transition is presented, which uses the reparameterization trick (Kingma and Welling 2014) to make stochastic gradients available through Monte Carlo sampling. The result is an algorithm generally suitable for datasets for which input–output relationships are intricate, the number of observations is relatively small, the output variables are multivariate, the predictive distributions may be multimodal as well as skewed, and for which it is desirable to estimate the predictive distributions as exactly as possible. The proposed method was tested on two artificially generated datasets, a dataset from medical physics containing dose statistics from radiation therapy treatments, the California housing dataset (Pace and Barry 1997) comprising geographic location and features of housing districts in California, and a dataset (Vågen et al. 2010) of soil functional properties and infrared spectroscopy measurements in Africa. In terms of both visual inspection and validation metrics, our model outperformed a conditional DPMM on all but the radiation therapy dataset, where it gave similar results, and performed consistently better than a Gaussian process baseline model. In particular, the experiments serve to illustrate the advantages of the proposed mixture-of-experts model on data with relatively high-dimensional inputs and irregularly shaped output distributions.

The sequel of this paper is organized as follows: the assumptions of the model are accounted for in Sect. 2, details of the mean-field variational Bayes algorithm are given in Sect. 3, results from the computational study are presented in Sect. 4 and discussed in Sect. 5, and derivations of some key facts and other algorithm details are given in Appendices A, B,C,D and E.

2 Model

Let \(\{(x^n, y^n)\}_{n = 1}^N \subset {\mathcal {X}} \times {\mathcal {Y}}\) be exchangeable pairs of random variables consisting of inputs \(x^n\) and outputs \(y^n\), where \({\mathcal {X}}\) and \({\mathcal {Y}}\) are some appropriate vector spaces of random variables. Given observations of these data pairs, which are referred to as the training data, our main task is to be able to, for each new input \(x^* \in {\mathcal {X}}\), obtain the predictive distribution \(p(y^* \mid x^*, \{(x^n, y^n)\}_n)\) of the corresponding output \(y^* \in {\mathcal {Y}}\). All random variables are modeled on a common probability space with probability measure \({{\mathbb {P}}}\), using p for densities or masses induced by \({\mathbb {P}}\). We will frequently use \(p(\{z_i\}_{i=1}^I)\) as shorthand for the joint density \(p(z_1, z_2, \dots , z_I)\) and \({{\mathbb {E}}}^{p(z)}\) for the marginalization over a random variable \(z \sim p(z)\).

Before introducing our model, it will be instructive to first formulate a conventional kernel regression model using latent variables:

Example 1

(Nadaraya–Watson estimator) Consider a model where prediction of \(y^*\) given \(x^*\) may be written as a noisy linear combination

$$\begin{aligned} y^* = \sum _n \frac{k(x^*, x^n)}{\sum _{n'} k(x^*, x^{n'})} y^n + \varepsilon ^* \end{aligned}$$
(1)

of the training outputs \(y^n\), where \(\varepsilon ^* \sim {\text {N}}(0, \epsilon ^2)\) is the regression error and \(k: {\mathcal {X}}^2 \rightarrow {\mathbb {R}}\) is some positive definite kernel—this is sometimes known as the Nadaraya–Watson estimator (Bishop 2006). For simplicity, we will use the radial basis function kernel \(k(x, x') = {{\text {exp}}(-\Vert x - x' \Vert ^2 / (2\ell ^2))}\) with standard Euclidean norm \(\Vert \cdot \Vert \) and lengthscale parameter \(\ell \). We can now reformulate this using latent variables. In particular, let \(u^*\) be a random variable supported on \(\{1, \dots , N\}\) such that

$$\begin{aligned} p(u^* = n \mid x^*, \{(x^{n'}, y^{n'})\}_{n'}) \propto {\text {N}}(x^* \mid x^n, \ell ^2 I), \end{aligned}$$

I being the identity matrix, which may be interpreted as \(u^*\) “choosing” one of the observed data points n and setting up a normal distribution centered around \(x^n\). If, conditional upon the choice \(u^* = n\), the prediction of \(y^*\) is \(y^n\) with some normal uncertainty of variance \(\epsilon ^2\)—that is, \(p(y^* \mid u^* = n, x^*, \{(x^{n'}, y^{n'})\}_{n'}) = {\text {N}}(y^* \mid y^n, \epsilon ^2)\)—we obtain

$$\begin{aligned}{} & {} p(y^* \mid x^*, \{(x^n, y^n)\}_n) \\ {}{} & {} \quad = \sum _{n} \frac{{\text{ N }}(x^* \mid x^n, \ell ^2 I)}{\sum _{n'} {\text{ N }}(x^* \mid x^{n'}, \ell ^2 I)} {\text{ N }}(y^* \mid y^n, \epsilon ^2). \end{aligned}$$

From the radial basis form of k, it is easy to see that this is equivalent to the conventional kernel regression formulation (1).

2.1 Main setup

To assemble our model, we will start from a predictive perspective. We use C normal distributions \(\{{\text {N}}(\mu _c, \Sigma _c)\}_{c=1}^C\) as experts, where the means \(\mu _c\) and covariances \(\Sigma _c\) are parameters not depending on the inputs. The gate then consists of two layers of transitions—one between the new input \(x^*\) and each training input \(x^n\), and one between each training output \(y^n\) and each expert c—which are represented by, respectively, latent variables \(u^*\) and \(z^*\) supported on \(\{1, \dots , N\}\) and \(\{1, \dots , C\}\). In particular, we let

$$\begin{aligned} p(u^* = n \mid x^*, \{(x^{n'}, y^{n'})\}_{n'}, \theta ) \propto {\text {N}}(x^* \mid x^n, \Lambda ^{-1}), \end{aligned}$$

where \(\Lambda \) is a precision matrix (as a generalization from \(\ell ^{-2} I\) in Example 1), and

$$\begin{aligned} p(z^* = c \mid u^* = n, x^*, \{(x^{n'}, y^{n'})\}_{n'}, \theta ) \propto {\text {N}}(y^n \mid \mu _c, \Sigma _c), \end{aligned}$$

using \(\theta = (\Lambda , \{(\mu _c, \Sigma _c)\}_c)\) for the collection of all parameters. Furthermore, we let

$$\begin{aligned} p(y^* \mid z^* = c, u^* = n, x^*, \{(x^{n'}, y^{n'})\}_{n'}, \theta ) = {\text {N}}(y^* \mid \mu _c, \Sigma _c) \end{aligned}$$

for all n and c. The complete-data predictive likelihood is then written as

$$\begin{aligned} \begin{aligned}&p(y^*, u^*, \, z^* \mid x^*, \{(x^n, y^n)\}_n, \theta ) \\&\quad = p(y^* \mid z^*, \theta ) p(z^* \mid u^*, \{y^{n}\}_{n}, \theta ) p(u^* \mid x^*, \{x^n\}_n, \theta ) \\&\quad = \prod _n \prod _c \Bigg ( \frac{{\text {N}}(x^* \mid x^n, \Lambda ^{-1})}{\sum _{n'} {\text {N}}(x^* \mid x^{n'}, \Lambda ^{-1})} \\&\qquad \times \frac{{\text {N}}(y^n \mid \mu _c, \Sigma _c)}{\sum _{c'} {\text {N}}(y^n \mid \mu _{c'}, \Sigma _{c'})} {\text {N}}(y^* \mid \mu _c, \Sigma _c) \Bigg )^{1_{z^* \; = \; c} 1_{u^* \; = \; n}}, \end{aligned} \end{aligned}$$
(2)

leading to the observed-data predictive likelihood

$$\begin{aligned}{} & {} p( y^* \mid x^*, \{(x^n, y^n)\}_n, \theta ) \\ {}{} & {} \quad = \sum _c \left( \sum _n \frac{{\text{ N }}(x^* \mid x^n, \Lambda ^{-1})}{\sum _{n'} {\text{ N }}(x^* \mid x^{n'}, \Lambda ^{-1})} \frac{{\text{ N }}(y^n \mid \mu _c, \Sigma _c)}{\sum _{c'} {\text{ N }}(y^n \mid \mu _{c'}, \Sigma _{c'})} \right) \\ {}{} & {} \qquad \times {\text{ N }}(y^* \mid \mu _c, \Sigma _c), \end{aligned}$$

which is a mixture of Gaussians. The first and second fraction within the parentheses of the above display are recognized as probabilities associated with first and second transitions, respectively. The precision matrix \(\Lambda \) naturally induces a Mahalanobis-form distance metric \(d_{{\mathcal {X}}}(x, x') = (x - x')^{{\text {T}}} \Lambda (x - x')\) on \({\mathcal {X}}\), whereby the first transitions are obtained as a softmax transformation of the vector \((-d_{{\mathcal {X}}}(x^*, x^n) / 2)_n\). Note, in particular, that this is a gating network mixture-of-experts model as the mean–covariance pairs \((\mu _c, \Sigma _c)\) are not input-dependent. This is a deliberate choice in light of our preference to avoid stipulating an input-to-parameter map, which is in general a high-dimensional and potentially complex transformation, as discussed in Sect. 1—instead, we let all input–output transitions occur through the observed data pairs \(\{(x^n, y^n)\}_n\).

Having established the predictive pipeline, we now face a similar problem as Holmes and Adams (2002) originally did in the probabilistic k-nearest neighbors context of constructing a joint complete-data likelihood \(p(\{(y^n, u^n, z^n)\}_n \mid \{x^n\}_n, \theta )\) given the training inputs. It is intuitively reasonable to require that the full conditionals \(p(y^n, u^n, z^n \mid x^n, \{(x^{n'}, y^{n'})\}_{n' \ne n}, \theta )\) have forms analogous to the predictive likelihood (2). As pointed out by Cucala et al. (2009), however, directly defining the joint likelihood as the product of all full conditionals as in Holmes and Adams (2002) will lead to an improperly normalized density. Noting that the asymmetry of the neighborhood relationships is the main cause of difficulties in specifying a well-defined model, Cucala et al. (2009) instead use a symmetrized version and view the training data as Markov random field (Murphy 2012), defining a joint distribution up to a normalizing constant corresponding to a Boltzmann-type model. The idea is reused in Friel and Pettitt (2011), who replace the arguably superficial symmetrized neighborhood relationships with distance metrics. Both papers present a pseudolikelihood approximation (Besag 1974) of the joint distribution as a possible approach, although focus is reserved for other methods of handling the intractable normalizing constant.

Historically, there has been considerable debate regarding the use of a proper joint likelihood versus a conditionally specified model, for which a joint likelihood may not even exist (Besag and Kooperberg 1995; Cucala et al. 2009). As argued by Besag (1974), the appeal of the latter choice is mainly due to the increased interpretability of specifying relationships directly through full conditionals, whereby the pseudolikelihood is always available as a viable surrogate for the joint likelihood. Thus, despite a certain lack of statistical coherence, this methodology allows the translation of any predictive pipeline, such as that of the k-nearest neighbor algorithm, into a probabilistic framework. Subsequent investigations of the model by Holmes and Adams (2002) gave different motivations for using this approach—Manocha and Girolami (2007) noticed the resemblance of parameter inference in such a model with leave-one-out cross-validation, whereas Yoon and Friel (2015) explained that the Boltzmann-type model by Cucala et al. (2009) would lead to difficulties due to the intractable normalizing constant. The interpretation of using such a pseudolikelihood is that dependencies between the data points beyond first order are ignored (Friel and Pettitt 2011), which may be a reasonable simplification in many applications.

Given this, we settle on the pseudolikelihood approach and proceed to define the full conditionals of our model as

$$\begin{aligned}&p(y^n, u^n, \, z^n \mid x^n, \{(x^{n'}, y^{n'})\}_{n' \ne n}, \theta ) \\&\quad = \prod _{n' \ne n} \prod _c \Bigg ( \frac{{\text {N}}(x^n \mid x^{n'}, \Lambda ^{-1})}{\sum _{n''} {\text {N}}(x^n \mid x^{n''}, \Lambda ^{-1})} \\&\qquad \times \frac{{\text {N}}(y^{n'} \mid \mu _c, \Sigma _c)}{\sum _{c'} {\text {N}}(y^{n'} \mid \mu _{c'}, \Sigma _{c'})} {\text {N}}(y^n \mid \mu _c, \Sigma _c) \Bigg )^{1_{z^n \; = \; c} 1_{u^n \; = \; n'}}, \end{aligned}$$

recognized as the analogue of (2), where the latent variables \(u^n\), \(z^n\) are now supported on \(\{1, \dots , N\} {\setminus } \{n\}\) and \(\{1, \dots , C\}\), respectively. The joint complete-data pseudolikelihood is then written as

$$\begin{aligned}{} & {} p_{\textrm{pseudo}}(\{(y^n, u^n, z^n)\}_n \mid \{x^n\}_n, \theta )\nonumber \\{} & {} \quad =\prod _n p(y^n, u^n, z^n \mid x^n, \{(x^{n'}, y^{n'})\}_{n' \ne n}, \theta ) \end{aligned}$$
(3)

—while it is important to remember that this is not a proper joint likelihood, we will drop the subscript in the following for convenience of notation. To complete our Bayesian model, we assume that the precision matrix \(\Lambda \) and the mean–covariance pairs \((\mu _c, \Sigma _c)\) are a priori independent with Wishart and normal–inverse-Wishart distributions, respectively. This is written as

$$\begin{aligned}{} & {} p(\theta ) = {\text{ W }}(\Lambda \mid \Lambda _0, \eta _0) \\ {}{} & {} \qquad \times \prod _c {\text{ N }}\!\left( \mu _c \; \bigg | \bigg | \; \mu _0, \frac{1}{\kappa _0} \Sigma _c\right) {\text{ IW }}(\Sigma _c \mid \Sigma _0, \nu _0), \end{aligned}$$

where \(\Lambda _0\), \(\eta _0\), \(\mu _0\), \(\kappa _0\), \(\Sigma _0\) and \(\nu _0\) are hyperparameters. In order to ensure that the distributions are well-defined, we must require that \(\Lambda _0\) and \(\Sigma _0\) be positive definite, that \(\eta _0 > {\text {dim}} {\mathcal {X}} - 1\) and \(\nu _0 > {\text {dim}} {\mathcal {Y}} - 1\), and that \(\kappa _0 > 0\) (Anderson 1984). The discussion of general-purpose default methods to set hyperparameters is deferred to Appendix E—in particular, the number C of experts should generally be high to allow flexibility in the set of possible predictive distributions. Using (3), this leads to the log-pseudojoint

$$\begin{aligned} \begin{aligned}&\log \, p(\{(y^n, u^n, z^n)\}_n, \theta \mid \{x^n\}_n) \\&\quad = \log {\text {W}}(\Lambda \mid \Lambda _0, \eta _0) + \sum _c \left( \log {\text {N}}\!\left( \mu _c \; \bigg | \; \mu _0, \frac{1}{\kappa _0} \Sigma _c\right) \right. \\&\qquad \left. + \log {\text {IW}}(\Sigma _c \mid \Sigma _0, \nu _0) \right) \\&\qquad + \sum _n \sum _{n' \ne n} \sum _c 1_{z^n \; = \; c} 1_{u^n \; = \; n'} \Bigg ( \log {\text {N}}(y^n \mid \mu _c, \Sigma _c) \\&\qquad + \log {\text {N}}(y^{n'} \mid \mu _c, \Sigma _c) - \log \!\Bigg ( \sum _{c'} {\text {N}}(y^{n'} \mid \mu _{c'}, \Sigma _{c'}) \Bigg ) \\&\qquad + \log {\text {N}}(x^n \mid x^{n'}, \Lambda ^{-1}) - \log \!\Bigg ( \sum _{n'' \ne n} {\text {N}}(x^n \mid x^{n''}, \Lambda ^{-1}) \Bigg ) \Bigg ). \end{aligned} \end{aligned}$$
(4)

3 Parameter inference

3.1 Mean-field variational Bayes

Motivated by the Gaussian mixture form of the full conditionals, we will base our parameter inference on a mean-field variational Bayes algorithm. In general, given some observed variables y and hidden variables z, the main idea of variational inference is to approximate an intractable posterior \(p(z \mid y)\) by a variational posterior q(z) from some family \({\mathcal {Q}}\) of distributions. The optimal \(q^*(z) \in {\mathcal {Q}}\) is found by minimizing the Kullback–Leibler (KL) divergence \(d_{{\text {KL}}}(q \; \Vert \; p(\cdot \mid y))\) according to

$$\begin{aligned} q^*(z)= & {} \; \mathop { \text{ arg } \text{ min } }\limits _{q(z) \in {\mathcal {Q}}} \; -{{\mathbb {E}}}^{q(z)} \log \!\left( \frac{p(z \mid y)}{q(z)} \right) \\ {}= & {} \; \mathop { \text{ arg } \text{ max } }\limits _{q(z) \in {\mathcal {Q}}} \; {\text{ ELBO }}(q), \end{aligned}$$

where the evidence lower bound (ELBO) is defined as

$$\begin{aligned} {\text {ELBO}}(q) = {{\mathbb {E}}}^{q(z)} \log \!\left( \frac{p(y, z)}{q(z)} \right) , \end{aligned}$$

its name originating from the fact that the bound \(\log p(y) \ge {\text {ELBO}}(q)\) always holds. By mean-field variational inference, we mean using as \({\mathcal {Q}}\) the family of all distributions on the fully factorized form

$$\begin{aligned} q(z) = \prod _i q(z_i), \end{aligned}$$

where we have block-decomposed z as \(z = (z_i)_i\). In this case, one can show that the optimal distribution \(q^*(z_i)\) for each component i has log-density equal up to an additive constant to the log-joint with \(\{z_{i'}\}_{i' \ne i}\) marginalized out—that is,

$$\begin{aligned} \log q^*(z_i) {\mathop {=}\limits ^{\textrm{c}}} {{\mathbb {E}}}^{q(\{z_{i'}\}_{i' \ne i})} \log p(y, z). \end{aligned}$$

This is used to construct a coordinate ascent algorithm, where a sequence \(\{q^l\}_{l \ge 0}\) of distributions are devised to approximate q increasingly better and updates are performed according to the above by, at each l, iterating through the components of z and taking expectations always with respect to the latest versions of the \(q(z_i)\). For a detailed review of variational inference methods, see Blei et al. (2017).

In our case, counting both the latent variables \(u^n\), \(z^n\) and the parameters \(\theta \) as hidden variables, we propose the factorization

$$\begin{aligned} q(\theta , \{(u^n, z^n)\}_n) = q(\Lambda ) q(\{(\mu _c, \Sigma _c)\}_c) q(\{(u^n, z^n)\}_n). \end{aligned}$$

An issue with the terms containing logarithms of summed normal densities in (4) is that they do not lead to tractable variational posteriors. Xu et al. (1995) proposed a generative model, which was also used by Baldacchino et al. (2016), where each \(p(x^n)\) was on the form of a Gaussian mixture, leading to analytically solvable variational posteriors at the cost of introducing additional parameters to the model and thus also accompanying hyperparameters. Trying to avoid this, we settle for a slightly different approach and instead approximate the log-sum-exp function \({\text {LSE}}\) by the linearization

$$\begin{aligned} \begin{aligned} {\text{ LSE }}(\xi ){}&{} = \log \!\left( \sum _i e^{\xi _i} \right) \\ {}&\approx \log \!\left( \sum _i e^{\eta _i} \right) + \sum _i \frac{e^{\eta _i}}{\sum _{i'} e^{\eta _{i'}}} (\xi _i - \eta _i), \end{aligned} \end{aligned}$$
(5)

where \(\xi = (\xi _i)_i\) and \(\eta = (\eta _i)_i\). In particular, we let \(s = (s_{nc})_{n, c}\) and \(t = (t_{nn'})_{n, n' \ne n}\) be the softmax-transformed locations of the linearizations—that is, for each n, we have

$$\begin{aligned} \log \!\left( \sum _c {\text {N}}(y^n \mid \mu _c, \Sigma _c) \right) {\mathop {\approx }\limits ^{\textrm{c}}} \sum _c s_{nc} \log {\text {N}}(y^n \mid \mu _c, \Sigma _c) \end{aligned}$$
(6)

and

$$\begin{aligned} \begin{aligned} \log \!\left( \sum _{n' \ne n} {\text{ N }}(x^n \mid x^{n'}, \Lambda ^{-1}) \right) {\mathop {\approx }\limits ^{\text {c}}} \sum _{n' \ne n} t_{nn'} \log {\text{ N }}(x^n \mid x^{n'}, \Lambda ^{-1}). \end{aligned}\nonumber \\ \end{aligned}$$
(7)

Here, \(s_n = (s_{nc})_c\) and \(t_n = (t_{nn'})_{n' \ne n}\), and \({\mathop {\approx }\limits ^{\textrm{c}}}\) denotes approximate equality up to a constant not depending on \(\theta \). We regard s and t as variational parameters—the presentation of how they may be optimized is deferred to Sect. 3.2. It will be shown that, provided that the values of s and t are appropriate, our model will become conditionally conjugate under the assumptions above, resulting in variational posteriors from the same parametric families as our priors.

Letting \(\{q^l\}_{l \ge 0}\) be a sequence of variational posterior approximations, we use the assumed factorization to obtain update equations according to

$$\begin{aligned}{} & {} \log q^{l+1}(\{(u^n, z^n)\}_n) \\ {}{} & {} \quad {\mathop {=}\limits ^{\text {c}}} {{\mathbb {E}}}^{q^l(\Lambda ) q^l(\{(\mu _c, \Sigma _c)\}_c)} \log p(\{(y^n, u^n, z^n)\}_n, \theta \mid \{x^n\}_n),\\ {}{} & {} \log q^{l+1}(\{(\mu _c, \Sigma _c)\}_c) \\ {}{} & {} \quad {\mathop {=}\limits ^{\text {c}}} {{\mathbb {E}}}^{q^{l+1}(\{(u^n, z^n)\}_n) q^l(\Lambda )} \log p(\{(y^n, u^n, z^n)\}_n, \theta \mid \{x^n\}_n) \end{aligned}$$

and

$$\begin{aligned}{} & {} \log q^{l+1}(\Lambda ) {\mathop {=}\limits ^{\textrm{c}}} {{\mathbb {E}}}^{q^{l+1}(\{(u^n, z^n)\}_n) q^{l+1}\{(\mu _c, \Sigma _c)\}_c)} \\{} & {} \quad \log p(\{(y^n, u^n, z^n)\}_n, \theta \mid \{x^n\}_n). \end{aligned}$$

These will be referred to as the E-step, the M-step I and the M-step II, respectively, due to the strong resemblance with the standard EM algorithm for fitting Gaussian mixtures (Murphy 2012) and the expectation–conditional maximization algorithm (Meng and Rubin 1993). Let also \(s^l\), \(t^l\) be the values of the variational parameters s and t at iteration l, and let

$$\begin{aligned} \omega _{c, nn'}^l = q^l(z^n = c, u^n = n'), \quad \Omega _{nn'}^l = \sum _c \omega _{c, nn'}^l, \end{aligned}$$

for each c, n and \(n'\). We can then summarize the details of the updates in Propositions 1, 2 and 3 below, for which derivations are given in Appendix A. The overall algorithm, including the steps described in Sect. 3.2, is outlined in Algorithm 1.

Proposition 1

(Variational posterior, E-step) For all l, we have the decomposition

$$\begin{aligned} q^l(\{(u^n, z^n)\}_n)= & {} \prod _n q^l(u^n, z^n) \\= & {} \prod _n \prod _{n' \ne n} \prod _c \left( \omega _{c, nn'}^l \right) ^{1_{z^n \; = \; c} 1_{u^n \; = \; n'}}. \end{aligned}$$

Updates are given by

$$\begin{aligned}&\log \, \omega _{c, nn'}^{l+1} {\mathop {=}\limits ^{\textrm{c}}} \sum _{i=1}^{\dim {\mathcal {Y}}} \left( \psi \!\left( \frac{\nu _c^l + 1 - i}{2} \right) \right. \\&\qquad \left. -\frac{1}{2}\sum _{c'} s_{n'c'}^l \psi \!\left( \frac{\nu _{c'}^l + 1 - i}{2} \right) \right) \\&\qquad -\log \det \Sigma _c^l + \frac{1}{2} \sum _{c'} s_{n'c'}^l \log \det \Sigma _{c'}^l \\&\qquad + \dim {\mathcal {Y}} \left( -\frac{1}{\kappa _c^l} + \frac{1}{2} \sum _{c'} s_{n'c'}^l \frac{1}{\kappa _{c'}^l} \right) \\&\qquad -\frac{1}{2} \nu _c^l (y^n - \mu _c^l)^{{\text {T}}} \Sigma _c^{-l} (y^n - \mu _c^l) \\&\qquad -\frac{1}{2} \nu _c^l (y^{n'} - \mu _c^l)^{{\text {T}}} \Sigma _c^{-l} (y^{n'} - \mu _c^l) \\&\qquad + \frac{1}{2} \sum _{c'} s_{n'c'}^l (y^{n'} - \mu _{c'}^l)^{{\text {T}}} \Sigma _{c'}^{-l} (y^{n'} - \mu _{c'}^l) \\&\qquad -\frac{1}{2} \eta ^l (x^n - x^{n'})^{{\text {T}}} \Lambda ^l (x^n - x^{n'}) \\&\qquad + \frac{1}{2} \eta ^l \sum _{n'' \ne n} t_{nn''}^l (x^n - x^{n''})^{{\text {T}}} \Lambda ^l (x^n - x^{n''}), \end{aligned}$$

where \(\psi \) is the digamma function.

Remark

The values of \(\omega _{c, nn'}^l\) may be found using the fact that \(\sum _c \sum _{n' \ne n} \omega _{c, nn'}^l = 1\).

Proposition 2

(Variational posterior, M-step I) For all l, we have the decomposition

$$\begin{aligned} q^l(\{(\mu _c, \Sigma _c)\}_c) = \prod _c q^l(\mu _c \mid \Sigma _c) q^l(\Sigma _c), \end{aligned}$$

with \(q^l(\mu _c \mid \Sigma _c) = {\text {N}}\!\left( \mu _c \mid \mu _c^l, \kappa _c^{-l} \Sigma _c \right) \) and \(q^l(\Sigma _c) = {\text {IW}}(\Sigma _c \mid \Sigma _c^l, \nu _c^l)\). Defining

$$\begin{aligned} r_{nc}^l = \sum _{n' \ne n} \left( \omega _{c, nn'}^l + \omega _{c, n'n}^l - s_{nc}^l \Omega _{n'n}^l \right) , \quad R_c^l = \sum _n r_{nc}^l, \end{aligned}$$

updates are given by

$$\begin{aligned} \kappa _c^{l+1} = \kappa _0 + R_c^{l+1}, \quad \mu _c^{l+1}=\frac{\kappa _0 \mu _0 + \sum _{n} r_{nc}^{l+1} y^n }{\kappa _c^{l+1}} \end{aligned}$$

and

$$\begin{aligned} \nu _c^{l+1}= & {} \nu _0 + R_c^{l+1}, \\ \Sigma _c^{l+1}= & {} \Sigma _0 + \kappa _0 \mu _0 \mu _0^{{\text{ T }}} - \kappa _c^{l+1} \mu _c^{l+1} \mu _c^{(l+1)\text {T}} + \sum _n r_{nc}^{l+1} y^n y^{n\text {T}}. \end{aligned}$$

Moreover, each \(\Sigma _c^l\) is positive definite if \(r_{nc}^l > 0\) for all n.

Proposition 3

(Variational posterior, M-step II) For all l, we have

$$\begin{aligned} q^l(\Lambda ) = {\text {W}}(\Lambda \mid \Lambda ^l, \eta _0),\end{aligned}$$

and updates are given by

$$\begin{aligned} \Lambda ^{-(l+1)} = \Lambda _0^{-1} + \sum _n \sum _{n' \ne n} (\Omega _{nn'}^{l+1} - t_{nn'}^{l+1})(x^n - x^{n'})(x^n - x^{n'})^{{\text {T}}}. \end{aligned}$$

3.2 Optimization of variational parameters

3.2.1 Linear programming for M-step I

Proposition 2 provides a sufficient condition for all scale matrices \(\Sigma _c^l\) to be positive definite, which is required for the inverse-Wishart distribution to be well-defined. We shall combine this with the local variational approximation method outlined in Bishop (2006) and Watanabe et al. (2011) to set up an optimization problem to be solved at each iteration for updating each row \(s_n\) in the second transition variational parameters s.

Recall, in particular, that the purpose of introducing s and the associated linearization (6) is to render variational posteriors analytically tractable through a mean-field approximation. We would like to optimize the accuracy of the approximation by maximizing the linearized complete-data pseudolikelihood \(p(\{(y^n, u^n, z^n)\}_n \mid \{x^n\}_n, \theta )\), which depends on s. Although the exact solution is trivial with \(\theta \) fixed, since \(\theta \) is uncertain, we instead resort to marginalizing out \(\theta \) according to the latest variational posterior \(q^l(\theta )\) and maximize the corresponding log-pseudoevidence

$$\begin{aligned} \log {{\mathbb {E}}}^{q^l(\theta )} p(\{(y^n, u^n, z^n)\}_n \mid \{x^n\}_n, \theta ). \end{aligned}$$

The form of this objective motivates the use of an EM algorithm. These EM updates will be interleaved with the variational Bayes updates, similarly to the scheme in Bishop and Svensén (2003), using the previous iterate as initialization and performing a single iteration at each l. In addition, we will constrain each update so that the M-step is taken constrained such that the condition \(r_{nc}^{l+1} > 0\) from Proposition 2 holds. Removing terms not depending on s and noting that the problem separates in such a way that each \(s_n\) can be optimized separately, we can devise an update equation where \(s^{l+1}\) is set to the solution of the linear program

$$\begin{aligned}{} & {} \underset{s_n \;: \; s_n \ge 0, \; 1^{{\text{ T }}} s_n = 1}{ \text{ minimize } } \quad \sum _c s_{nc} {{\mathbb {E}}}^{q^l(\mu _c, \Sigma _c)} \log {\text{ N }}(y^n \mid \mu _c, \Sigma _c), \nonumber \\ {}{} & {} \;\; \text{ subject } \text{ to } \quad s_{nc} {\le } \frac{\sum _{n' \ne n}(\omega _{c, nn'}^l {+} \omega _{c, n'n}^l)}{\sum _{n' \ne n} \Omega _{n'n}^l} \quad \text{ for } \text{ each } c, \end{aligned}$$
(8)

with

$$\begin{aligned} \begin{aligned}&{{\mathbb {E}}}^{q^l(\mu _c, \Sigma _c)} \log {\text {N}}(y^n \mid \mu _c, \Sigma _c) \\&\quad = \frac{1}{2} \sum _{i=1}^{\dim {\mathcal {Y}}} \psi \!\left( \frac{\nu _c^l + 1 - i}{2} \right) - \frac{1}{2} \log \det \Sigma _c^l\\&\qquad - \frac{1}{2} \frac{\dim {\mathcal {Y}}}{\kappa _c^l} - \frac{1}{2} \nu _c^l (y^n - \mu _c^l)^{{\text {T}}} \Sigma _c^{-l} (y^n - \mu _c^l). \end{aligned} \end{aligned}$$

A detailed motivation for (8) is given in Appendix B. Note that since

$$\begin{aligned} \sum _c \frac{\sum _{n' \ne n}(\omega _{c, nn'}^l + \omega _{c, n'n}^l)}{\sum _{n' \ne n} \Omega _{n'n}^l} = \frac{1 + \sum _{n' \ne n} \Omega _{n'n}^l}{\sum _{n' \ne n} \Omega _{n'n}^l} > 1, \end{aligned}$$

the constraints are always feasible.

3.2.2 Stochastic gradient method for M-step II

For the first transition variational parameters t, note that unlike Propositions 2, 3 does not come with a guarantee that the double sum added to \(\Lambda _0^{-1}\) is positive definite—thus, the Wishart distribution constituting the variational posterior of \(\Lambda \) may not be well-defined. A solution is to use the unconstrained analogue

$$\begin{aligned} \underset{t_n \;: \; t_n \ge 0, \; 1^{{\text{ T }}} t_n = 1}{\text{ minimize }} \quad \sum _{n' \ne n} t_{nn'} {{\mathbb {E}}}^{q^l(\Lambda )} \log {\text{ N }}(x^n \mid x^{n'}, \Lambda ^{-1}) \end{aligned}$$
(9)

of the optimization problem (8), where, if the resulting inverse scale matrix \(\Lambda ^{-(l+1)}\) is not positive definite, its eigenvalues will be thresholded to be at least some small positive number. This is, however, somewhat unsatisfactory in that the eigenvalue thresholding step would become another layer of indirection besides from introducing t. An alternative method would be to instead specify beforehand, motivated by Proposition 3, that \(q^l(\Lambda ) = {\text {W}}(\Lambda \mid \Lambda ^l, \eta _0)\) and find the variational parameter \(\Lambda ^l\) by maximizing \({\text {ELBO}}(q^l)\) without the linearization of the log-sum-exp expression associated with the first transition, according to the generic approach described in Blei et al. (2017). While it is always possible to approximate analytically intractable expectations by Monte Carlo methods, since the ELBO is an expectation taken with respect to \(q^l\) itself, one will in general need additional manipulations to obtain the corresponding estimate of the gradient. Various methods for such situations exist, examples including the score function gradient, which works for very general cases and can be used in conjunction with variance reduction techniques such as Rao–Blackwellization and control variates (Ranganath et al. 2014), and the reparameterization trick (Kingma and Welling 2014).

Since \(q^l(\Lambda )\) depends on \(\Lambda ^l\) continuously, we will use the reparameterization trick, which is understood to be the preferred method in such cases due to its lower variance compared to that of the score function gradient, even when using variance reduction (Kucukelbir et al. 2017). With the Cholesky factorization \(\Lambda ^l = L^l L^{l\textrm{T}}\), we will reparameterize according to the Bartlett decomposition (Anderson 1984), which states that \(\Lambda \sim {\text {W}}(\Lambda ^l, \eta _0)\) may be written as

$$\begin{aligned} \Lambda = L^l A A^{{\text {T}}} L^{l \textrm{T}}, \end{aligned}$$

where \(A = (A_{ii'})_{i, i' = 1}^{{\text {dim}} {\mathcal {X}}}\) is a lower triangular random matrix with all independent entries such that \(A_{ii}^2 \sim \chi ^2_{\eta _0 + 1 - i}\) and \(A_{ii'} \sim {\text {N}}(0, 1)\) for \(i > i'\). Extracting all terms depending on \(\Lambda ^l\) in the ELBO (see Appendix C) and using the Bartlett decomposition, we get the following result:

Proposition 4

(Stochastic gate optimization) If \(q^l(\Lambda ) = {\text {W}}(\Lambda \mid \Lambda ^l, \eta _0)\), where \(\Lambda ^l = L^l L^{l\textrm{T}}\) is the Cholesky factorization of \(\Lambda ^l\), updates are given by

$$\begin{aligned} L^{l+1}= & {} \; \mathop {{\text {arg\,min }}}\limits _{L} \Bigg \{ -\eta _0 \log \det L \\ {}{} & {} +\sum _n {{\mathbb {E}}}^{p(A)} \log \Bigg ( \sum _{n' \ne n} \exp \Bigg ( -\frac{1}{2}(x^n {-} x^{n'})^{{\text{ T }}} L A A^{{\text{ T }}} L^{\text {T}} (x^n {-} x^{n'}) \Bigg ) \Bigg ) \\ {}{} & {} + \frac{\eta _0}{2} {\text{ tr }}\!\Bigg ( L^{{\text{ T }}}\Bigg ( \Lambda _0^{-1} + \sum _n \sum _{n' \ne n} \Omega _{nn'}^l (x^n - x^{n'})(x^n - x^{n'})^{{\text{ T }}} \Bigg ) L \Bigg ) \Bigg \}. \end{aligned}$$

We will refer to the matrix in the parentheses between \(L^{{\text {T}}}\) and L as the center matrix. Note that since p(A) does not depend on L, a standard Monte Carlo approximation allows both the function value and its gradient to be estimated. This information may then be used in any gradient-based optimization algorithm to find the update \(L^{l+1}\) and thus \(\Lambda ^{l+1}\), serving as an alternative to Proposition 3. In this case, it is straightforward to derive an analogue to Proposition 1 by replacing the terms corresponding to the first transitions with their equivalents from the Monte Carlo estimation. Optionally, as evaluation of the involved quantities is relatively computationally expensive, we may reduce the variance of the standard Monte Carlo estimator by replacing it with an unbiased estimator with control variates introduced—for details, see Appendix D.1. Moreover, the center matrix, containing a double sum of outer vector products, may be evaluated in the alternative form displayed in Appendix D.2.

figure a

3.3 Posterior predictive distribution

Having obtained the variational posterior approximation \(q(\theta )\), prediction of the corresponding output \(y^*\) given an out-of-sample input \(x^*\) is determined by the posterior predictive distribution

$$\begin{aligned} \begin{aligned}&p(y^* \mid x^*, \{(x^n, y^n)\}_{n}) \\&\quad = {{\mathbb {E}}}^{p(\theta \; \mid \; \{(x^n, y^n)\}_{n})} p(y^* \mid x^*, \{(x^n, y^n)\}_{n}, \theta ) \\&\quad \approx {{\mathbb {E}}}^{q(\theta )} p(y^* \mid x^*, \{(x^n, y^n)\}_{n}, \theta ) \\&\quad = \sum _c \sum _{n} {{\mathbb {E}}}^{q(\Lambda )} \frac{{\text {N}}(x^* \mid x^n, \Lambda ^{-1})}{\sum _{n'} {\text {N}}(x^* \mid x^{n'}, \Lambda ^{-1})} \\&\qquad \times {{\mathbb {E}}}^{q(\mu _c, \Sigma _c)} \frac{{\text {N}}(y^n \mid \mu _c, \Sigma _c)}{\sum _{c'} {\text {N}}(y^n \mid \mu _{c'}, \Sigma _{c'})} {\text {N}}(y^* \mid \mu _c, \Sigma _c). \end{aligned} \end{aligned}$$

As the marginalization integrals with respect to \(q(\Lambda )\) and \(q(\mu _c, \Sigma _c)\) are intractable, we will resort to Monte Carlo approximations. Note, in particular, that one may sample from \(q(\Lambda )\) and each \(q(\mu _c, \Sigma _c)\) independently due to the assumed factorization. Letting \(\{\Lambda ^k\}_{k=1}^{K_{\textrm{g}}}\) and \(\{(\mu _c^k, \Sigma _c^k)\}_{k=1}^{K_{\textrm{e}}}\) be independent samples from \(q(\Lambda )\) and each of the \(q(\mu _c, \Sigma _c)\), respectively, we can write

$$\begin{aligned}&\quad p(y^* \mid x^*, \{(x^n, y^n)\}_{n}) \approx \frac{1}{K_{\text {e}} K_{\text {g}}}\\ {}&\times \sum _{k=1}^{K_{\text {e}}} \sum _{k'=1}^{K_{\text {g}}} \sum _c \sum _n \frac{{\text{ N }}(x^* \mid x^n, \Lambda ^{-k'})}{\sum _{n'} {\text{ N }}(x^* \mid x^n, \Lambda ^{-k'})} \\ {}&\times \frac{{\text{ N }}(y^n \mid \mu _c^k, \Sigma _c^k)}{\sum _{c'} {\text{ N }}(y^n \mid \mu _{c'}^k, \Sigma _{c'}^k)} {\text{ N }}(y^* \mid \mu _c^k, \Sigma _c^k). \end{aligned}$$

This results in a \(K_{\textrm{e}} C\)-component Gaussian mixture as an approximate posterior predictive distribution.

Table 1 Comparison of evaluated performance metrics between the proposed similarity-based mixture-of-experts (SBMoE) model, the conditional Dirichlet process mixture model (C-DPMM) and the baseline Gaussian process (GP) model. For the synthetic datasets, the KL divergence, the Hellinger distance and the total variation distance were used, whereas the mean negative log-likelihood was used for the other datasets due to the true predictive distribution being unknown

4 Computational study

To illustrate the advantages of the proposed model, we evaluate its performance when applied on five datasets. The first two are artificially generated, intended to demonstrate the ability of the mixture-of-experts method to adapt to complex dependencies between the inputs and their corresponding predictive distributions and capture characteristics such as skewness and multimodality. The third is a dataset of dose statistics of radiation therapy treatment plans for postoperative prostate cancer, which is small, noisy and relatively high-dimensional, intended to showcase the stability of the estimated uncertainties even in such cases. The fourth dataset is the California housing dataset (Pace and Barry 1997), consisting of features for California housing districts along with geographic location—here, focusing on multimodal outputs, the task is to predict the latitude–longitude pairs for each district given its other features. The fifth dataset consists of infrared spectroscopy measurements for prediction of soil functional properties at different locations in Africa, aimed at illustrating the case of probabilistically predicting more than two output variables.

In the following computational study, attention was restricted to the stochastic gradient method described by Proposition 4 in Sect. 3.2.2 for fitting the variational parameters associated with the gate, as the projection method described in Sect. 3.2.2 was found to be insufficiently accurate. The method was implemented in Python—in particular, the stochastic gradient optimization used the automatic differentiation features and the Adam optimizer (Kingma and Ba 2020) in Tensorflow 2.3.1. The hyperparameters of the priors were set and the variational parameters were initialized according to the default implementation outlined in Appendix E—in particular, an agglomerative clustering algorithm (Pedregosa et al. 2011) run on the standardized outputs \(\{y^n\}_n\) was used to initialize the expert means. For all experiments, the maximum number of iterations in the mean-field variational Bayes algorithm was set to 20, each outer iteration comprising 50 stochastic gradient sub-iterations for the gate optimization. An early-stop mechanism was employed, terminating the algorithm if the inner optimization leaves the gate scale matrix non-significantly changed for three consecutive iterations—here, a non-significant change is defined as a run for which the inner iteration number and objective value has positive estimated correlation and for which non-correlation cannot be rejected at significance level 0.01. All of the reported execution times are with respect to a computational service setup with a 12-core Intel Cascade Lake CPU platform, 85 GB of RAM and an NVIDIA A100 GPU with 40 GB of memory.

A comparison was made between the results obtained from the mixture-of-experts model and those from a conditional Gaussian DPMM model (Cruz-Marcelo et al. 2013), with posterior inference performed using a variational Bayes algorithm (Blei and Jordan 2006; Pedregosa et al. 2011), as well as a Gaussian process baseline model (Rasmussen and Williams 2006; GPy 2012). For the DPMM, we used a Gaussian likelihood \(p(y \mid \theta ) = {\text {N}}(y \mid \mu , \Sigma )\) with \(\theta = (\mu , \Sigma )\), \(\theta \mid G \sim G\) and \(G \sim {\text {DP}}(\alpha , G_0)\), where \(\alpha = 1\) and \(G_0\) was set to a conjugate normal–inverse-Wishart distribution, with hyperparameters selected analogously to the priors in our mixture-of-experts model. Using the stick-breaking representation of the Dirichlet process and a mean-field variational Bayes algorithm, the posterior predictive distribution may be Monte Carlo–approximated by a truncated Gaussian mixture (Blei and Jordan 2006). In particular, regarding each input–output concatenation \((x^n, y^n)\) in our dataset \(\{(x^n, y^n)\}_n\) as a data point in the DPMM, the posterior predictive distribution \(p(x^*, y^* \mid \{(x^n, y^n)\}_n)\) of a new pair \((x^*, y^*)\) is mixture-Gaussian, and the associated conditional distribution \(p(y^* \mid x^*, \{(x^n, y^n)\}_n)\) is thus again mixture-Gaussian (Cruz-Marcelo et al. 2013). On the other hand, the Gaussian process baseline model used a radial basis function covariance \(k: {\mathcal {X}}^2 \rightarrow {\mathbb {R}}\) with automatic relevance detection (Neal 1994), written as

$$\begin{aligned} k(x, x') = \sigma ^2 {\text {exp}}\!\left( -\sum _i w_i (x_i - x'_i)^2 \right) , \end{aligned}$$

where the hyperparameters \(\{w_i\}_i\) and \(\sigma ^2\) were fitted by evidence maximization. For the second and third datasets, where the output is multidimensional, Gaussian processes were fitted for the marginal distributions of the components separately, assuming independence between the components—this is done purely for simplicity, as the Gaussian process model mainly serves as a baseline method in our experiments. Apart from the KL divergence, the Hellinger distance \(d_{{\text {H}}}(p, q) = \Vert \sqrt{p} - \sqrt{q} \Vert _{L^2} / \sqrt{2}\) and the total variation distance \(d_{{\text {TV}}}(p, q) = \Vert p - q \Vert _{L^1} / 2\) were used as evaluation metrics for the test cases using synthetic data. The results of the comparison are shown in Table 1.

4.1 One-dimensional synthetic data

In this example, the observed data \(\{(x^n, y^n)\}_{n=1}^N\) consists of \(N = 2000\) exchangeable copies of \((x, y) = ((x_1, x_2), y)\) following the generative model

$$\begin{aligned} x&\sim {\text{ LN }}\!\left( \begin{bmatrix} 0 \\ 0 \end{bmatrix}, \begin{bmatrix} 1 &{}{} 0.5 \\ 0.5 &{}{} 1 \end{bmatrix} \right) , \\ \tau&\sim {\text{ Be }}(0.3), \\ \zeta \mid x&\sim \Gamma (x_1, x_2), \\ y&= \log (\zeta + 0.4\tau + 0.1), \end{aligned}$$

with \(\tau \) independent of x and \(\zeta \), where we use \({\text {LN}}\), \({\text {Be}}\) and \(\Gamma \) to denote the log-normal, the Bernoulli and the gamma distribution (with shape–rate parameterization), respectively. As such, the predictive distributions are both skewed and possibly bimodal. The intent of choosing a relatively large data size was to demonstrate the capability of the proposed mixture-of-experts model to approach the complex predictive distributions with growing data—for smaller datasets, while the methods still work well, we found it hard to evaluate and compare the methods in any meaningful way. Using \(C = 32\) experts for our proposed mixture-of-experts model (note that the expert locations are fixed with respect to x, thus the need for a relatively high number), approximate posteriors were computed using the proposed mean-field variational Bayes algorithm. The training of the mixture-of-experts, Gaussian process and conditional DPMM models took 553, 14 and 75 s on the aforementioned computational setup. A comparison of the estimated predictive distributions is shown in Fig. 1 alongside a kernel density estimate of the true predictive density using 5000 samples from the true predictive distribution, a Gaussian kernel and Scott’s rule (Scott 1992) for bandwidth selection. The evaluated performance metrics are summarized in Table 1. It is apparent that our mixture-of-experts model performed better than the conditional DPMM—which was, in turn, better than the Gaussian process model—in terms of performance metrics as well as visual inspection. Specifically, the flexibility of our model allows for much better fit to the bimodal and skewed distributions even compared to the conditional DPMM, which was often not able to capture bimodal behavior despite also having mixture-Gaussian predictive distributions.

Fig. 1
figure 1

Comparison between a kernel density estimate of the true predictive density \(p(y^* \mid x^*, \{(x^n, y^n)\}_n)\) and the estimated counterparts from the mixture-of-experts model, the conditional DPMM and the Gaussian process, shown for six out-of-sample inputs \(x^*\) of the one-dimensional synthetic data. The gray bars below the graph represent the training data points \(\{y^n\}_n\), and the blue bars represent samples from the true predictive distribution

4.2 Two-dimensional synthetic data

To illustrate the advantages of our model in estimating predictive distributions with inter-component dependencies, we now use a 32-dimensional input \(x = (x_i)_{i=1}^{32}\) and a two-dimensional output \(y = (y_1, y_2)\). Define the transformation \(\varphi (\xi ) = ({\text {N}}(0.7 i \mid \xi , 4))_{i=0}^7\), and denote by \({\mathcal {F}}\) the discrete Fourier transform. The observed data consists of \(N = 2000\) data points drawn from the generative model

$$\begin{aligned} \begin{aligned}&\eta = (\eta _1, \eta _2) \sim {\text {LN}}\!\left( \begin{bmatrix} 0 \\ 0 \end{bmatrix}, \begin{bmatrix} 1 &{} 0.5 \\ 0.5 &{} 1 \end{bmatrix} \right) , \\&x = \left( {\text {Re}} {\mathcal {F}}(\varphi (\eta _1)), {\text {Im}} {\mathcal {F}} (\varphi (\eta _1)), {\text {Re}} {\mathcal {F}} (\varphi (\eta _2)), {\text {Im}} {\mathcal {F}} (\varphi (\eta _2)) \right) , \\&\tau \sim {\text {Be}}(0.5), \\&\zeta = (\zeta _{ii'})_{i, i'=1}^2 \mid \eta \sim {\text {IW}}\!\left( \begin{bmatrix} \eta _1 &{} 0 \\ 0.5 &{} \eta _2 \end{bmatrix} \begin{bmatrix} \eta _1 &{} 0 \\ 0.5 &{} \eta _2 \end{bmatrix}^{{\text {T}}}\!, 3 \right) , \\&y = (\log \zeta _{11}, (2\tau - 1) \log \zeta _{22}), \end{aligned} \end{aligned}$$

where \(\tau \) is independent of \(\eta \) and \(\zeta \). In particular, the log-normal generator \(\eta \) is observed through its obsfucated counterpart x—the map \(\varphi \) creates spikes centered at its inputs \(\eta _1\), \(\eta _2\) represented as eight-dimensional vectors, whose real and imaginary Fourier components are concatenated into a 32-dimensional x. Given \(\eta \), we then draw \(\zeta \) as an inverse-Wishart matrix whose diagonal components are log-transformed and put into y, with the sign of the second component flipped with probability 0.5. Again, although all methods remained stable for smaller datasets, a relatively large data size was found necessary for the comparison to clearly showcase key qualities of the methods. Using \(C = 64\) classes and running the mean-field variational Bayes algorithm, we obtain the results visualized in Fig. 2, where estimated densities for our mixture-of-experts model, the conditional DPMM and the per-component Gaussian process model are shown in comparison to a kernel density estimate of the true predictive density, using 10000 samples from the true predictive distribution and again a Gaussian kernel and Scott’s rule for bandwidth selection. The training of the mixture-of-experts, Gaussian process and conditional DPMM models took 1116, 311 and 245 s. The evaluated performance metrics are summarized in Table 1. Again, besides from having better validation metrics, our model is clearly more successful than the Gaussian process model in capturing the characteristics of the bimodal shape of the true predictive distribution. More surprisingly, the conditional DPMM, while also being able to recognize bimodal distributions, yielded much worse performance metrics than both our and the Gaussian process model. This may be explained by the conditional DPMM’s generative modeling using mixture models, which presents difficulties in settings in which inputs are of relatively high dimension but sparsely distributed—note that each 32-dimensional x is uniquely determined by the associated two-dimensional \(\eta \). In turn, this may lead to the underestimation of uncertainty we observe for the conditional DPMM, causing the considerably worse performance metrics. In contrast, the mixture-of-experts and Gaussian process models are both discriminative and similarity-based and therefore relatively insensitive to the complexities of the input space.

Fig. 2
figure 2

Comparison between a kernel density estimate of the true predictive density \(p(y^* \mid x^*, \{(x^n, y^n)\}_n)\) and the estimated counterparts from the mixture-of-experts model, the conditional DPMM and the Gaussian process, shown for six out-of-sample inputs \(x^*\) of the two-dimensional synthetic data. The gray dots represent the training data points \(\{y^n\}_n\), and the blue dots represent samples from the true predictive distribution

4.3 Radiation therapy data

We also demonstrate our method on a dataset of historically delivered radiation therapy treatment plans for postoperative prostate cancer patients from the Iridium Cancer Network in Antwerp, Belgium. Prediction of radiation dose statistics based on patient geometry is of particular interest in fields such as automated treatment planning and quality assurance (Ge and Wu 2019). The dataset consists of 94 treatment plans on which features have been extracted from patient images and dose statistics have been computed for the treatment plans, of which 84 were used as training data. For our purposes, attention was restricted to the bladder region—the features we used are signed distance transforms computed from the bladder region to itself, the prostate planning target volume, the seminal vesicles planning target volume and the rectum region, binned into a total of 44 features; the dose statistics we used are the dose-at-volume values at 40 and 20 percent in the bladder region, corresponding to the 0.6- and 0.8-level quantiles of the dose delivered to the region. Using \(C = 32\) mixture classes, for the mixture-of-experts, Gaussian process and conditional DPMM models, the training took 393, 2 and 1 s, respectively. The results of the comparison to the conditional DPMM and the Gaussian process model are presented in Table 1 and Figs. 3 and 4. In terms of mean negative log-likelihood over the test dataset, the mixture-of-experts model had comparable and only slightly better performance to the conditional DPMM the Gaussian process model, respectively. More importantly, however, it managed to follow the S-shaped distribution of the data points in its predictions, whereas the Gaussian process model sometimes predicted far outside the range of reasonable values. From Fig. 4, we also see that the learned distance metric \(d_{{\mathcal {X}}}\) produced a varying number of near neighbors, implying an accordingly varying degree of uncertainty. It is especially interesting to note that this was possible using such a small training dataset with relatively high input dimensions, which is seen as a sign of robustness against data scarcity.

Fig. 3
figure 3

Comparison between the estimated predictive densities \(p(y^* \mid x^*, \{(x^n, y^n)\}_n)\) from the mixture-of-experts model, the conditional DPMM and the Gaussian process, shown for six random out-of-sample inputs \(x^*\) of the radiation therapy data. The gray dots represent the training data points \(\{y^n\}_n\), and the blue diamond shows the ground truth \(y^*\). Axes represent fraction of the prescription dose

Fig. 4
figure 4

Visualization of the computed distances \(d_{{\mathcal {X}}}(x^*, x^n)\) to each training input \(x^n\), shown for the same out-of-sample inputs \(x^*\) of the radiation therapy data as in Fig. 3. A darker purple color corresponds to a smaller distance. Axes represent fraction of the prescription dose

4.4 California housing data

Here, we consider the California housing price dataset (Pace and Barry 1997), which comprises variables for 20640 housing districts in California, United States. In particular, for each district, we have the longitude–latitude pair of its geographical location, the median age, the total number of rooms, the total number of bedrooms, the population, the number of households, the median income, the median house value and the ocean proximity. Removing irregular data points, the categorical ocean proximity variable and replacing the total number of rooms and bedrooms by the corresponding mean per household, 18299 data points remain, with seven input and two output dimensions. Furthermore, subsampling with uniform probability, we use 3000 training and 100 test data points for the numerical experiments. For our purposes, we will try to predict the latitude–longitude pair of each out-of-sample district given the other seven features. This is a non-straightforward task for several reasons: the signal-to-noise ratio is relatively small with the available input features, the map from features to latitude–longitude pair is likely rather complex, and the predictive distributions are in general multimodal as similar housing districts may occur in areas far away from each other. Using \(C = 32\), training of the mixture-of-experts model took 194 s, compared to 196 s for the Gaussian process and 229 s for the conditional DPMM. Table 1 contains performance metrics for the different methods, and Fig. 5 shows example predictions for six random out-of-sample test inputs. We see that the mixture-of-experts model and the conditional DPMM, in contrast to the Gaussian process, are able to produce multimodal predictive distributions following the geographical shape of the data, successfully identifying housing district clusters such as cities. Moreover, the mixture-of-experts model performs better in terms of mean negative log-likelihood than the other models. In particular, as is the case for the two-dimensional synthetic dataset, the conditional DPMM tends to slightly underestimate the predictive uncertainties.

Fig. 5
figure 5

Predictive distributions of geographic location for the SBMoE, C-DPMM and GP models versus ground truth for the California housing data

4.5 Africa soil data

In this last experiment, we consider the prediction of soil functional properties from infrared spectroscopy measurements as a means of rapid, low-cost analysis of soil samples. Especially in data-sparse regions such as Africa, such tools are important for planning sustainable agricultural intensification and natural resource management. The dataset (Vågen et al. 2010) contains 1157 data points, each having 3578 dimensions of diffuse-reflectance infrared spectroscopy measurements at different wavelengths, 16 features related to the geography of the location at which the sample is collected, and five output variables comprising the soil organic carbon, the pH value and the calcium, phosphorous and sand content—for more details, see Vågen et al. (2010). Using 1000 data points for training and excluding the rest for testing, the 3578-dimensional spectroscopy measurements are dimensionality-reduced by a multi-level discrete wavelet transform (Mallat 2008) using a Daubechies 4 wavelet into 74-dimensional representations—thus, the problem considered amounts to having \({\text {dim}} {\mathcal {X}} = 90\) and \({\text {dim}} {\mathcal {Y}} = 5\). With \(C = 64\) classes, the trainings for the mixture-of-experts model, the Gaussian process and the conditional DPMM ran for 803, 1279 and 24 s, respectively. Table 1 shows the evaluated performance metrics, and Fig. 6 shows the single and pairwise marginal predictive distributions for a randomly chosen test input. Again, the mixture-of-experts model performs better than the competitor models, with the conditional DPMM obtaining a particularly bad score. We can see from the example prediction in Fig. 6 that our mixture-of-experts model is able to follow the data distribution well and predict the ground truth accurately with varying degrees of uncertainty for the different output dimensions—in comparison, the Gaussian process is only able to output an ill-fitted Gaussian predictive distribution, and the predictions from the conditional DPMM are mostly far away from the ground truth. Indeed, it is reasonable to believe that the particularly bad performance of the conditional DPMM is due to the relatively high dimension of the data, where fitting a mixture model to \({\text {dim}} {\mathcal {X}} + {\text {dim}} {\mathcal {Y}} = 95\) dimensions and taking the conditional given a 90-dimensional input is likely to be extremely sensitive to calibration errors.

Fig. 6
figure 6

Predictive distributions for the SBMoE, C-DPMM and GP models versus ground truth for the Africa soil data

5 Discussion

Motivated by the disadvantages of commonly used parametric models when applied to data with complex input–output relationships, in this paper, we have built upon previous work on probabilistic k-nearest neighbors by Holmes and Adams (2002), Cucala et al. (2009) and Friel and Pettitt (2011) by extending the ideas to a regression setting, using Gaussian mixtures as a flexible means of representing predictive distributions. We use a conditionally specified model in which full conditionals are defined and the joint likelihood is replaced by a pseudolikelihood in order to render further computations tractable. In particular, in contrast to Friel and Pettitt (2011), we also regard the precision matrix \(\Lambda \) in the distance metric \(d_{{\mathcal {X}}}\) as a model parameter on which posterior inference may be performed. With a nonparametric approach, we avoid the need to explicitly learn the mapping between input and output as, for example, in conventional mixture-of-experts models such as those described in Jacobs et al. (1991), Bishop and Svensén (2003) and Xu et al. (1995), all while maintaining a Bayesian approach. Also in contrast to, e.g., the probabilistic k-nearest neighbor models by Holmes and Adams (2002), Cucala et al. (2009) and Friel and Pettitt (2011) as well as the DPMM-based models by, e.g., De Iorio et al. (2009), Dunson and Park (2008), Jara and Hanson (2011) and Cruz-Marcelo et al. (2013), we exploit the mixture-Gaussian structure of our model and base our parameter inference on a mean-field variational Bayes algorithm, where local variational approximations are introduced and associated variational parameters are optimized by linear programming and stochastic gradient descent methods. Thus, we avoid the need for input-to-parameter modeling present in DPMM-based models, which in our case is a multivariate regression problem from dimension \({\text {dim}} {\mathcal {X}}\) to \({\text {dim }}{\mathcal {Y}} + {\text {dim}} {\mathcal {Y}}^2\)—in principle, potentially even harder than the original problem from dimension \({\text {dim}} {\mathcal {X}}\) to \({\mathcal {Y}}\). Furthermore, we avoid the use of MCMC methods, which scale poorly with model size and data dimensionality. In contrast to the conditional DPMM, which models the distribution of the inputs generatively by a mixture model on \({\mathcal {X}} \oplus {\mathcal {Y}}\), our method relies on discriminative modeling by comparing similarities between inputs. The computational study demonstrates on several synthetic as well as real-world datasets the ability to model irregular, multivariate predictive densities, while being able to quantify uncertainties accurately for medium-sized to very small datasets. Specifically, we showed that our mixture-of-experts model outperforms the conditional DPMM and the Gaussian process baseline in certain settings—for example, those in which the input space is relatively high-dimensional yet sparse and thus unsuitable for generative modeling using mixture models. While generally being more computationally demanding than the conditional DPMM and Gaussian process, the execution times for training the mixture-of-experts model on the given computational setup were certainly short for a high-dimensional Bayesian model.

Based on the method we describe in this paper, there are ample opportunities for interesting future work. A drawback of the method is its lack of scalability to massive data, an issue common for all nonparametric models where the training data is explicitly used in the predictive pipeline. In our case, the main difficulties stem from the need to compute the posterior probabilities \(\omega _{c, nn'}^l\) at each iteration l, which requires storage in an array of \(C N^2\) elements. One possible approach to addressing this is to use Bayesian coresets (Huggins et al. 2016), which is a small weighted subset of the data constructed so as to optimally summarize the data; another is to take inspiration from the sparse pseudo-input Gaussian process literature (Quiñonero-Candela and Rasmussen 2005). A further drawback is the subtle but important assumption that the input space \({\mathcal {X}}\) is a vector space—while this may be true for all datasets considered in our computational study, when handling e.g. image data, one would need some form of preprocessing by a feature-extracting transformation \(\phi : {\mathcal {I}} \rightarrow {\mathcal {X}}\) from the image space \({\mathcal {I}}\) to the vector space \({\mathcal {X}}\). If, for example, \(\phi \) is a neural network, one can optimize the network weights using the same optimization problem as in Proposition 4 by regarding the weights as hyperparameters. One may also investigate sparse representations of the scale matrix \(\Lambda ^l\) for the case of \({\text {dim}} {\mathcal {X}}\) being too large to store the full Cholesky factorization of \(\Lambda ^l\). Lastly, to address the problem of choosing an appropriate number of experts, yet another direction of future work would be to explore the possibility of combining the present model with Dirichlet process mixtures, further conforming to the philosophy of Bayesian nonparametrics.