SYNAPTIC PLASTICITY WITH CORRELATED FEEDBACK knowing how much to learn

Learning synaptic weights is difficult. Connections receive global error signals, which are low dimensional and noisy, and local signals which lack information about relative contributions to the error. In this setting, it makes sense for connections to learn not just their ‘best guess’ of their weight, but also how confident they should be in this guess: i.e. to infer a distribution over their target weight. This idea was developed in (Aitchison & Latham, 2015) and (Aitchison, Pouget, & Latham, 2017). Similar concepts appear in (Hiratani & Fukai, 2018). In the aforementioned works the update equations are discrete in time and the likelihoods used in inference are Markov. This is not how things are in biology, where signals are continuous in time and temporally correlated. Here we consider a non-Markov setting, deriving coupled ODEs which describe how the parameters of the posterior evolve as more data is observed. We use a local temporal-smoothing method to deal with the continuous feedback and discrete presynaptic spike events. We show that the window of smoothing can be chosen in a principled way: maximising the per-spike decrease in uncertainty. We find our algorithm works better than the leaky delta rule with optimised learning rate. More importantly, for the simple model described, the method accurately predicts posterior variance.


Introduction
At a high level, our hypothesis is that connections aim to infer distributions over their target weights from sequentially observed data. To do this, they must update their estimates as new data is observed. The mathematically correct way to do this is by Bayes rule: p(w * (t)|D(0,t)) ∝ p(D(t − ∆,t)|w * (t), D(0,t − ∆))p(w * (t)|D(0,t − ∆)).
Here, w * (t) is a set of target connection strengths at time t and D(0,t) is a set of observed data signals on the time interval (0,t).
Such Bayesian approaches to synaptic learning rules have been described in (Aitchison & Latham, 2015) and (Aitchison et al., 2017). A key point in these works is that the resultant update rules for the expected target weight do not require an arbitrary learning rate, η, as is the case in most classical Hebbian rules. Rather the effective learning rate depends on the posterior variance, which comes from doing inference. This means more uncertain connections naturally learn faster than more certain ones. A different approach is taken in (Hiratani & Fukai, 2018)) to obtain an approximation to the posterior by means of a particle filter. In this model, each of the multiple synapses which constitute a connection between two neurons is considered a particle. During learning, the strength and location of individual synapses are varied such that their cumulative effect captures the desired statistics.

249
This work is licensed under the Creative Commons Attribution 3.0 Unported License. To view a copy of this license, visit http://creativecommons.org/licenses/by/3.0 In all the above mentioned works time is discretised, updates are instantaneous and likelihoods are Markov; they have the property p(D(t)|w * (t), D(0,t − 1)) = p(D(t)|w * (t)). (1) In biological systems, where fluctuations in potentials have decay times and target weight values may drift in time, this Markov assumption does not hold. We investigate this trickier setting below.

The model & the problem
We consider a simple model where one postsynaptic (point) neuron receives input from N independent pre-synaptic neurons. The i th pre-synaptic neuron fires Poisson spikes with rate λ i . A spike on neuron i at time t k i causes a transient post-synaptic potential with peak amplitude w(t k i ) and temporal profile α(t − t k i ) in the post-synaptic neuron. Let be a set of unknown target weight functions (functions of time). This defines an error signal: where S j (t) is the smoothed spike train from neuron j.
We are agnostic as to what overall task defines the w * i 's. In out model context, we can think of w * i 's as latent variable. Our goal is to get (variational) posteriors over these functions in an online way.
To derive update equations for the distribution over target weight w * i , we consider the problem from the point-of-view of connection i. For connection i, the contribution to the error from other connections can be considered "noise": Hence we will talk about learning from the point of view of a connection, rather than a neuron. The additive noise, consisting of the signals from other connections, has temporal correlations In general, inverting this kernel is not possible and in some cases, it's not even well defined! To overcome this problem, we consider an approximation which allows us to parameterise how much signal to extract from the feedback in a "windowed" manner after each spike. This avoids having to invert these kernels.

Approximate Version
Consider time-averaging the error signal for a short interval after spike k arrives: Where we have assumed that the change in weights is negligible over this period, ε. We have now introduced a free model parameter, ε, which we will set in a principled way later. For inference we assume the only global information available to all connections is the error, δ. Each connection also receives local information: their own spike times (t k i ), their own implemented weights (w i (t k i )) and the PSP shape, α(t). We also assume they know the prior over target weight distributions.
Hence we can write D(0 : t) = {δ(0 : t), w(t k i ),t k i } for all t k i < t. Given a prior distribution p(w * i (t k i )|D i (0 : t k i )), how should connection i update this distribution when spike k arrives and it observes {δ(t k i ), w(t k i ),t k i }? We propose the connection integrates the error signal for the period ε after the spike.
From equation (7), we have This is our likelihood distribution and it can be shown that, for our model problem, it is approximately where Avg j =i [·] indicates the empirical average over the other connections and C[ε, α] is shorthand to denote a constant scale which depends only on ε and α. If we assume that the prior is also Gaussian, ) (the second subscript indicates how many spikes have been seen) we can use standard results to obtain discrete increments of the mean and variance for each spike (omitted due to space constraints, see 1 ).
Choosing ε : Equation (11) is the discrete change in the posterior variance after a spike arrives and is clearly always negative. It can be shown that this quantity is zero in the limits that ε → 0 and ε → ∞. Hence we can set ε to the value for which this is most-negative, i.e. for which each spike maximally decrease the variance (see supplementary material 1 ). Now that we have "fixed" ε, we will drop it from our notation for simplicity.
Drift and biophysical approximations: So far we know how to update the posterior after each spike but what happens in between spikes? Far from a spike, the error signal is not telling the connection much about its own contribution to the error. Rather, since the connection is not receiving information, we should expect its uncertainty (i.e. variance) to increase and the posterior to regress to the original prior distribution, p(w * i ). This effect is captured by the drift terms in equations (15) and (16).
Equation (11) and its analogy for µ give us discrete, per-spike updates. So, ignoring drift for now, we could write the value of µ i,k (t) for t >> t k i as where σ(t) is a sigmoidal function satisfying 0 ≤ σ(t) < 1. The integral of the (normalised) PSP shape, α, is just such a sigmoid, so we can "smooth" our updates out in time by multiplying them by the PSPs.
In the end, we obtain equations (15) and (16) as our final learning rules. Equation (15) is made of two parts. The first two terms are the drft component: the prior pulls the mean back towards the prior mean, µ i,0 . The Third term looks Hebbian: a product of a learning rate, the error and the PSP. We can see that each neuron's learning rate depends on the ratio of its variance and the variance of the smoothed noise. This means connections with high uncertainty (large variance) learn faster. The learning rule for the precision also has a drift component, pulling it up to the prior variance σ 2 0 (which is typically large) and a data-driven component, which is always negative but approaches zero for highly con- S i (t). (16)

Results
A natural benchmark is testing if this algorithm against the leaky delta rule (LDR: equation (17)). The learning rate (η) is numerically optimised in mean-square error of the output for each set of simulation parameters.
This rule only tracks a point-estimate for the expected weight, not the variance. For our simulations we used a log-normal distribution of presynaptic firing rates, to mimic distributions commonly found in biology (Shafi et al., 2007).
Due to space limits, we present just the takehome point: Figure 1 shows that the Bayesian rule does significantly better than the optimised Hebbian rule in mean squared error of the output (a), as well as achieving a smaller error between the posterior mean and the ground-truth target values (b).
If the Bayesian rule is doing a reasonable job of inference, we find that true target value to be within 1 − σ(t) of µ(t) about 68% of the time. In the simulation presented here, we find that out means are within this interval 72.26 ± 3.64% of the time.
Inference LDR Figure 1: (a) Mean square error in output in units of the inference algorithm mean square error. The leaky-delta rule is approximately 50% worse than the inference method. Both methods are much better than the control, achieving approximately 10 −2 the control error. (b) Square-difference between the posterior means and the target weights, averaged over the 100 neurons. Shaded area is one standard deviation computed over neurons. The inference method achieves a lower average and much tighter variance.