Toward a More Neurally Plausible Neural Network Model of Latent Cause Inference

Humans spontaneously perceive a continuous stream of experience as discrete events. It has been hypothesized that this ability is supported by latent cause inference (LCI). We implemented this hypothesis using latent cause network (LCNet), a neural network model of LCI. LCNet interacts with a Bayesian LCI mechanism that activates a unique context vector for each inferred latent cause (LC). LCNet can also recall episodic memories of previously inferred LCs to avoid performing LCI all the time. These mechanisms make LCNet more neurally plausible and efficient than existing models. Across three simulations, we found that LCNet could 1) extract shared structure across LCs while avoiding catastrophic interference, 2) capture human data on curriculum effects on schema learning, and 3) infer the underlying event structure when processing naturalistic videos of daily activities. Our work provides a neurally plausible computational model that can operate in both laboratory experiment settings and naturalistic settings, opening up the possibility of providing a unified model of event cognition.

Humans spontaneously perceive a continuous stream of experience as discrete events (Kurby & Zacks, 2008;Zacks & Tversky, 2001;Radvansky & Zacks, 2017).This seems effortless for humans, but it is challenging for standard neural network models.When viewing each event as a task, event perception can be viewed as a continual learning problem (Parisi, Kemker, Part, Kanan, & Wermter, 2019;Flesch, Saxe, & Summerfield, 2023) with a blocked curriculum, where both the ongoing task identity and the number of tasks are unknown.

The Structured Event Memory (SEM) model
In the human brain, it has been hypothesized that the process of grouping observations into events is supported by latent cause inference (LCI) (Gershman, Norman, & Niv, 2015;Franklin, Norman, Ranganath, Zacks, & Gershman, 2020;Niv, 2019;Shin & DuBrow, 2020) -recently, the Structured Event Memory (SEM) model has been proposed to explain how humans assign observations to latent causes (LCs) to facilitate learning, prediction, memory, and generalization in a context-dependent manner (Franklin et al., 2020).SEM uses a Bayesian non-parametric mechanism to perform LCI and uses separate neural networks to represent different LCs.
Given the current observation x t , SEM runs LCI to either 1) assign x t to an existing LC and use the corresponding network to process x t , or 2) create a new LC for x t by generating a new network to process x t .Such fully-separated representations can convert a continual learning problem into multiple single-task problems, circumventing the catastrophic interference that would otherwise occur with blocked learning (French, 1999;McClelland, McNaughton, & O'Reilly, 1995;McCloskey & Cohen, 1989).SEM can explain a wide range of human data, such as how event structure affects memory reconstruction, and a variant of SEM ("SEM 2.0"; Bezdek, Nguyen, Gershman, et al., 2022) has recently been used to simulate data on how humans segment naturalistic videos of daily events (Bezdek, Nguyen, Hall, et al., 2022).However, using fully-separated representations is neurally implausible, and it impedes SEM's ability to extract the shared structure across LCs, since a given observation is only used to train one network.Additionally, SEM performs LCI all the time, which is also implausible due to the heavy computational demands.

A more neurally plausible model of LCI
Inference To address these limitations, we propose the Latent Cause Network (LCNet), a more biologically plausible neural network model of LCI.LCNet uses the same network to process all observations, and it uses different context vectors to achieve task/context-sensitivity (Figure 1A, G, Q), akin to the classic connectionist idea of context representation (Cohen, Dunbar, & McClelland, 1990;Rougier, Noelle, Braver, Cohen, & O'Reilly, 2005).Given the current observation x t , it still performs LCI using the same Bayesian non-parametric mechanism.If LCI assigns x t to an existing LC, the corresponding context vector is fed to the hidden layer of the network, or if LCI creates a new LC for x t , a new context vector is generated.Importantly, context vectors are simply random vectors sampled from a Gaussian distribution.Since high dimensional random vectors are approximately orthogonal, they can reduce catastrophic interference across LCs.
Episodic memory LCNet has an episodic memory (EM) mechanism that retrieves previously inferred LCs instead of performing LCI on each time step.Concretely, EM is implemented as a lookup table with a narrow generalization gradient (McClelland, 2013;McClelland et al., 1995;O'Reilly & Norman, 2002;Norman & O'Reilly, 2003;Norman, Detre, & Polyn, 2008).Every time LCNet performs a full LCI, EM encodes the current observation and the inferred LC in its buffer.
Later, given an observation x t , if LCNet can retrieve an LC previously associated with x t , it uses the retrieved LC instead of performing full inference (Figure 1G, Q).We found that having EM makes LCNet much more computationally efficient.

Function learning with shared structure
In Simulation 1, we compare the performance of a feedforward LCNet (Figure 1A), a feedforward version of SEM, and a regular feedforward network on learning four polynomial regression tasks.Given the input and a context-indicative signal (CIS) specifying which function is currently active, the model had to produce the corresponding output.Learning was fully blocked, and all models learned one function for every epoch.Unbeknownst to the model, each function (Figure 1C) is a sum of a component that is shared across all functions and an idiosyncratic component (Figure 1B).After learning, an LCIlesioned LCNet reconstructs the shared component (Figure 1D), suggesting that the shared structure is encoded in the network weights (blue pathway in Figure 1A).We observed strong catastrophic interference for the regular network,  2023).I and J) human behavioral data in the blocked (I) vs. interleaved (J) conditions.K and L) model behavioral data in the blocked (K) vs. interleaved (L) conditions.M and N) the purity of the inferred LCs of the model relative to the ground truth in the blocked (M) vs. interleaved (N) conditions.O and P) the same as panel M and panel N, except that the memory module of the model was lesioned.Q) Model architecture for Simulation 3. R) The hierarchical structure of the META dataset (Bezdek, Nguyen, Hall, et al., 2022).S) Each frame in the META video recordings is compressed as a 30D feature vector.T) The LCs inferred by the model and the ground truth event labels shared a significant level of mutual information.most no interference for LCNet, and no interference for SEM (Figure 1E).Moreover, learning one function sped up new learning for both the regular network and LCNet but not SEM (Figure 1F).These results show that LCNet can extract shared structure and reuse it when learning new, related tasks while avoiding catastrophic interference.

Curriculum effect on schema learning
In Simulation 2, LCNet (Figure 1G) was trained on a contextdependent sequential prediction task (Figure 1H).Given the current state, the model had to predict the upcoming state.Unbeknownst to the model/human, the first observation is context-indicative and determines the transition structure of the event graph (Figure 1H).Empirical results show that humans learned much better under the blocked curriculum (Figure 1I) than the interleaved curriculum (Figure 1J; Beukers et al., 2023).LCNet qualitatively replicated this pattern (Figure 1K vs. 1L) as LCI was much more accurate in the blocked condition (Figure 1M vs. 1N), quantified by cluster purity.Importantly, although LCI was more accurate when full inference was performed all the time (Figure 1O vs. Figure 1M), LCNet with EM saved 94% of LCI while still being able to capture human data.

Scene prediction on naturalistic video stimuli
In Simulation 3, we test LCNet (Figure 1Q) on naturalistic video recordings of daily activities (Figure 1R, 1S) (Bezdek, Nguyen, Hall, et al., 2022).Since this dataset was generated in a controlled manner, the ground truth event labels are known (Figure 1R).We found that LCs inferred by LCNet share a significant amount of mutual information with the event label (Figure 1T) even though 35% of LCI was saved by using EM.Being able to process naturalistic data demonstrates the generality of our framework.In the future, we will test LCNet's ability to predict human event segmentation (comparing its performance to SEM 2.0; Bezdek, Nguyen, Gershman, et al., 2022), and we will use uncertainty to guide when to perform LCI (as in the CCN 2023 submission by Nguyen et al.).The model also generates time series predictions of when episodic retrieval happens, which we will test via human fMRI.

Figure 1 :
Figure 1: A) Model architecture for Simulation 1. B) The target functions that the model had to learn.Each function is a sum of a shared term and an idiosyncratic term shown in panel C. D) An LCI-lesioned LCNet reconstructs the shared component, even though the shared component was never directly observed.E) Learning curves for each polynomial plotted separately over epochs -the model was only trained on the i-th polynomial at epoch i. F) Learning curves for each polynomial plotted separately over the number of samples.G) Model architecture for Simulation 2. H) The state-transition graph used in Beukers et al. (2023).I and J) human behavioral data in the blocked (I) vs. interleaved (J) conditions.K and L) model behavioral data in the blocked (K) vs. interleaved (L) conditions.M and N) the purity of the inferred LCs of the model relative to the ground truth in the blocked (M) vs. interleaved (N) conditions.O and P) the same as panel M and panel N, except that the memory module of the model was lesioned.Q) Model architecture for Simulation 3. R) The hierarchical structure of the META dataset(Bezdek, Nguyen, Hall, et al., 2022).S) Each frame in the META video recordings is compressed as a 30D feature vector.T) The LCs inferred by the model and the ground truth event labels shared a significant level of mutual information.