Variational inference for continuous time causal learning

We use Variational Bayes to recover causal structures and model human learning in a context where variables are represented by continuous-times series. This approach allows us to model learning about hyperparam-eters which have been assumed to be known by participants in previous research. We test two factorisa-tions for the approximate posterior: (1) a normative one, which keeps the true joint distribution and (2) a mean-field approximation. Variational models remains effective at recovering the correct causal structure and also manage to estimate hyperparameters. Both factorisations reduce the model’s ability to disentangle direct and indirect effects, a well-documented human error which we here replicate with 420 participants. This initial agreement of Variational Bayes with participants’ behaviour encourages further modelling attempt, specifically those considering how interventions shape learning under its constraints. 1


Introduction
Active causal learning is concerned with settings where learning about a causal structure happens through the feedback between the agent's interventions and the resulting observations.From a young age, humans actively learn anytime they manipulate objects to understand their inner workings, such as learning to ride a bike or play a video game.To be successful, people must do inference about the system's causal structure and policy evaluation and selection.Here we focus on the former.
Bayesian methods are a standard framework for modelling human learning and reasoning under uncertainty (Oaksford & Chater, 2007).As exact inference is often intractable, two families of approximation methods are generally used: sampling methods such a Markov Chain Monte Carlo (MCMC) and variational methods which seek to approximate the posterior by making simplifying assumptions about its factorisation (Sanborn, 2017).In discrete time active learning, MCMC methods have already yielded significant insights in how people learn (Bramley, Dayan, Griffiths, & Lagnado, 2017;Davis & Rehder, 2020;Fr änken, Theodoropoulos, & Bramley, 2022).Meanwhile, variational inference has primarily been studied in the context of active inference, a general theory of perception and action (Friston et al., 2016;Parr & Friston, 2019).While Variational Bayes has been criticised for neglecting core causal inference principles (Bruineberg, Dolega, Dewhurst, & Baltieri, 2021;Btesh, Bramley, & Lagnado, 2022), it provides a strong formal framework for modelling, and there is a paucity of direct applications to human learning.
In continuous time, where data are abundant, modelling using such approximations has yet to emerge.Davis and Rehder (2020) and Rehder, Davis, and Bramley (2022) have shown that participants are better fit by computational models which, like variational methods, make simplifying assumptions about the factorisation of the joint distribution such as the local computations (LC) model.It assumes independence between all parameters in the generative model at the cost of struggling to disambiguate direct from indirect effects.They have circumvented the issue of intractability by assuming that participants hold fixed estimates of the parameters which do not pertain directly to the causal structure.In this paper, we want to relax this assumption and put variational inference to the test.We use existing data from 4 experiments conducted in the Ornstein-Uhlenbeck (OU) network setting.420 participants were tasked with actively learning about five different benchmark causal structures: collider, common cause, standard, confounded and damped chains (see Figure 1).We compare the accuracy of the variational version of the normative and LC, i.e. mean-field approximation, models and then discuss specific errors patterns observed in graphs with indirect effects.

Models
Ornstein-Ulhenbeck (OU) networks are collections of OU processes which are continuous time stationary Gauss-Markov processes.Beyond their usual θ and σ parameters, respectively representing the strength of the drift term and 151 This work is licensed under the Creative Commons Attribution 3.0 Unported License.To view a copy of this license, visit http://creativecommons.org/licenses/by/3.0 Figure 2: Mean proximity to participants judgements and mean accuracy for standard and variational version of the normative and LC models.Both measures use a normalised euclidean distance measure between the non-diagonal entries of, respectively, participants' judgements for the first row and the ground truth for the second, and the maximum a posteriori graph for each tested model.
the variance of the processes, the linear relationships between the OU processes are parametrised by a causal graph and its causality matrix Γ.It is an augmented version of its adjacency matrix where the non-diagonal entries represent the strength of the pairwise causal links between variables, i.e.Γ i j = x i → x j and x i → x j ∈ {−1, −0.5, 0, 0.5, 1}.If x t represents the state of each variable at time t, their distributions at time t + dt is a multivariate Gaus- To avoid the marginalisation of θ and σ, previous work has assumed them to be known.Using variational Bayes, i.e. approximating the joint posterior by a simpler distribution over factors, we allow for estimation of θ and σ as well as Γ's non-diagonal entries.
Variational Bayes casts inference as an optimisation problem of minimising the free energy F, a functional of possible factors Q: , and a local computations, i.e. mean field like, factorisation: As factors, or beliefs, Q are assumed to be independent, given a fixed dataset, any sequence of updates to one or several of them always increases the free energy (Fox & Roberts, 2012).All Qs are categorical distributions, i.e. θ ∈ {0.1, 0.5, 1, 1.5} and σ ∈ {1, 3, 6, 12}.
This allows the use of a closed form expression for the updates.All factors are updated at each new observed datapoint.This approach generalises to continuous distributions when using gradient methods.We compare the performances of our variational normative and LC versions to equivalent Bayesian models which know the true values of θ and σ.

Results
Unsurprisingly, the increased complexity of having to estimate θ and σ led the mean accuracy across graphs of the variational normative (.90 ± .09) and local computations (.84 ± .09)models to be lower than their standard normative (.98 ± .04) and LC (.89 ± .07)counterparts, which know the true values for θ and σ.However, they still were effective at recovering causal structures compared to a chance level of .4. Figure 2 shows that while the dip in accuracy is sharp, variational models remain at a similar distance to participants' judgement on average, specifically for chain and damped structures.This is due to the variational models losing accuracy on indirect dynamics even in the normative case, a mistake often made by humans (A → C in Figure 3).Errors made by variational agents appear to stem from underestimating the value of θ, where the mean estimate was .29 ± .22 instead of .5.The LC version recovered θ only in 34% of trials and failure to recovery caused an average accuracy loss of .10 against 52% recovery and a mean accuracy loss of .12 for the normative version.It further caused the normative version to report the non-existent A → C link in chain structures 82% (42% for participants) of the time, up from 23%, and fail to report it in damped structures in 92% (52% for participants) of trials, up from 20%, driving the trend observed in Figure 3.

Discussion
Evidently, while still being effective at recovering causal structures, simply switching to variational Bayes in continuous time is not sufficient to better capture participants judgements.Errors in estimating hyperparameters led to errors disambiguating indirect causal effects, a common human mistake which we replicated here.However, the variational models were overall as effective at matching participants' judgements as models which assumed knowledge about hyperparameters.A parsimonious conclusion is thus that participants are good at estimating the impact of latent parameters in this context.However, variational inference encourages us to think of agents as holding sets of independent beliefs which can be updated independently.Here, we have assumed that all beliefs are updated at each new video frame.Evidence suggests that humans do not learn in this way.They tend to decompose learning into sub goals which they then sequentially solve (Correa, Ho, Callaway, Daw, & Griffiths, 2022) or choose careful interventions to focus on specific parameters (Lagnado & Sloman, 2004;Bramley et al., 2017;Bramley, Gerstenberg, Tenenbaum, & Gureckis, 2018).Variational inference, by formalising that updating factors independently guarantees increases in free energy, encourages future research to use models which focus on specific parts of the inference process, as humans have been shown to prefer.

Figure 1 :
Figure 1: Generic benchmarks causal graphs with their respective number of observations.Green arrows represent positive causal links and red arrows negative ones.Example of a interventional path of a participant in a Chain trial where θ = 0.5, σ = 3 and dt = 0.2.They were kept constant across trials.Shaded areas represent participant interventions on the node of the corresponding colour.Not all participants did all structures, hence the different sample sizes.

Figure 3 :
Figure 3: Chain (first row) and Damped (second row) causal graphs estimates for each model and participants.The x-axis is the link and the y-axis its mean estimate.The links confounded by a indirect effect is always A → C.shows that while the dip in accuracy is sharp, variational models remain at a similar distance to participants' judgement on average, specifically for chain and damped structures.This is due to the variational models losing accuracy on indirect dynamics even in the normative case, a mistake often made