Scalable gradients enable Hamiltonian Monte Carlo sampling for phylodynamic inference under episodic birth-death-sampling models

Birth-death models play a key role in phylodynamic analysis for their interpretation in terms of key epidemiological parameters. In particular, models with piecewise-constant rates varying at different epochs in time, to which we refer as episodic birth-death-sampling (EBDS) models, are valuable for their reflection of changing transmission dynamics over time. A challenge, however, that persists with current time-varying model inference procedures is their lack of computational efficiency. This limitation hinders the full utilization of these models in large-scale phylodynamic analyses, especially when dealing with high-dimensional parameter vectors that exhibit strong correlations. We present here a linear-time algorithm to compute the gradient of the birth-death model sampling density with respect to all time-varying parameters, and we implement this algorithm within a gradient-based Hamiltonian Monte Carlo (HMC) sampler to alleviate the computational burden of conducting inference under a wide variety of structures of, as well as priors for, EBDS processes. We assess this approach using three different real world data examples, including the HIV epidemic in Odesa, Ukraine, seasonal influenza A/H3N2 virus dynamics in New York state, America, and Ebola outbreak in West Africa. HMC sampling exhibits a substantial efficiency boost, delivering a 10- to 200-fold increase in minimum effective sample size per unit-time, in comparison to a Metropolis-Hastings-based approach. Additionally, we show the robustness of our implementation in both allowing for flexible prior choices and in modeling the transmission dynamics of various pathogens by accurately capturing the changing trend of viral effective reproductive number.


Implementation Algorithm:
We implement a recursive algorithm to compute the necessary gradient of the log-likelihood within our rate parameter space.Intermediate quantities are stored in between epochs to alleviate computational burden.Detailed algorithm is shown below based on the equations listed in 2.5 and previous sections in the supplement.

HIV dynamics in Odesa, Ukraine
We refer to the prior settings on the compound parameters from previous work [1], and try to roughly match their priors by adopting the following prior distributions on each of the rate parameters.Note that the sampling proportion was fixed to 0 before the first sampling date in their study, so we also set the sampling rate to 0 for the last two epochs for consistency.

Ebola epidemic in West Africa
We assume a constant death rate, µ for this data set, and we employ an empirical Bayes approach proposed by Magee et al. (2020) to set the prior on the first log-birth-rate and log-sampling-rate in our Bayesian bridge MRF models [3].The prior for the constant death rate is obtained from an estimation of the plausible duration of infectious period with 95% confidence intervals covering 8 to 40 days [4].The detailed prior distributions can be found in the table below:

Birth Rate
Death Rate

Computational complexity of the nodewise likelihood
The computational complexity of evaluating node-based representations of the likelihood is much less explicit.First, we need to write out an equivalent expression for the likelihood of Equation 1 node-wise.It will be helpful to distinguish different types of samples.In particular, let us denote serially-sampled tips ūψ with a particular serially-sampled tip being ūψi .With a slight abuse of notation, let us denote intensively-sampled tips ūρ , with ūρi denoting the vector of intensively-sampled tips at the ith intensive-sampling event.Then we can write The complexity here is not immediately apparent for a number of reasons.For one, the complexity appears to depend on the relative proportion of samples of different types, which affects the number of values of p k (t) and q k (t) which must be computed.Importantly, the complexity of computing those p k (t) and q k (t) is not immediately apparent either, and that these costs are somewhat hard to disentangle, as p k (t i ) builds recursively on p k−1 (t i ) and q k (t) depends on p k (t).

Node lookups
Regardless of such ambiguities, all nodes in the tree require an interval lookup.For births, the lookup is required to find the correct λ k term to use.For samples, the lookup is either to find the appropriate sampling rate, for serial samples, or to determine to which intensive-sampling event a sample belongs, for intensive samples.The time requirement here depends on the algorithm, for a binary search it is O(log(K)), making the total lookup cost O(N log(K)).
6.2 How many computations of q k (t) are required?
In the worst, but most common, case, there are no intensive-sampling events and q k (t) must be computed for the times of all samples, all births, and all epoch times (note that even when ρ i is 0, there is a term L(t i ) log(q i−1 (t i )) which must be computed in the final summation).
In the best case, all samples are at intensive-sampling events, and q k (t) only needs to be computed for the times of all births and all epoch times.These are both O(N + K), though there is a factor of two's worth of variation in front of the N depending on which side of this spectrum a tree falls in.Calling the cost of computing q k (t) Q, this makes the contribution to the complexity here O(Q(N + K)).

How many computations of p k (t) are required?
The likelihood contains a number of explicit computations of p k (t) in the terms pertaining to (both serially-and intensively-)sampled tips.When all samples are serial samples, there are O(N ) direct computations of p k (t), while when all samples are intensive samples, there are O(K).Taking the cost of computing p k (t) to be P , the addition to the cost here is between O(P N ) and O(P K).
6.4 What is the cost of computing p k (t) and q k (t)?
We have thus far shown that the cost of computing the nodewise likelihood appears to be between O(N log(K) + Q(N + K) + P N ) and O(N log(K) + Q(N + K) + P K).But this is not particularly revealing without considering P and Q.
While q k (t) depends on p l:l<k (t) through A and B, once A k and B k have been computed, let us assume (as we did when evaluating the cost of the interval-wise likelihood) that the cost of q k (t) is O(1).In other words, let us assume that O(Q(N +K)) = O(P (N +K)).This makes the implied cost of the nodewise likelihood between O(N log(K) + P (N + K) + P N ) and O(N log(K) + P (N + K) + P K), which both simplify to O(N log(K) + P (N + K)).
Naïvely, we might choose to compute p k (t) recursively every time we need it, which is O(K 2 ).
In this case, the implied cost of the nodewise likelihood is O(N log(K) + N K + K 2 )).

Precomputing A and B
One can instead choose to pre-compute A k , B k , as once these are computed the cost to compute p k (t) and q k (t) becomes O(1).Working backwards from the present allows recomputation to be avoided.As we did when we approximated the cost of the interval-wise likelihood, we will take the cost of the update (computing . Thus, the cost of the precomputation is O(K).This puts the implied cost of computing the nodewise likelihood between O(N log(K) + N + K).

Counting lineages at epoch times
Regardless of whether the model includes intensive-sampling (that is, regardless of whether ρ = 0), one must compute L(t i ) for all epoch times.This can be solved essentially the same way as the subintervals are obtained, at a cost of O(N + N log(N )).Alternately, it can be obtained by counting the number of births and sampled tips older (or younger) than each epoch time, at a cost of O(KN ).This makes the lower end of the computational cost once again a range, from In practice, the constants in front of all the sorting and node-lookup terms appear to be so small as to be unnoticeable in real-world computation.We demonstrate this in our timing experiments in the next section.Thus, for all practical purposes, the likelihood appears to be O(N + K) regardless of representation, as long as one avoids recursive computation of p k (t).

Timing Experiments
With the reformulation of the likelihood and derivation of the analytical gradients, our method notably gains in speed, as we highlight in this section.For a comprehensive assessment, we compare our approach with four other specialized packages for EBDS model inference concerning likelihood calculations.These include the BDSKY [5] package within BEAST2 [6], TreePar [5] package in R [7] and RevBayes [8].Furthermore, we present a benchmark comparing the gradient calculation efficiency of automatic differentiation implemented in VBSKY [9] package using JAX library [10] isolated from the variational inference procedure against our algorithm based analytical gradients implemented in BEAST.
To assess the scalability of the aforementioned methods in terms of likelihood/gradient calculation, we simulated a set of trees under the EBDS model with increasing number of tips.To investigate the scalability of different methods wrt the number of sequences, we fix the number of epochs to 5 for both likelihood and gradient calculation.
Regarding scalability with respect to the number of epochs, we adjust the model by progressively increasing the number of epochs.To keep other variables constant, we maintain the tree topology and set the number of tips at 12 (in scenarios where K >> N , this allows us to negate the effect of N in O(N +K)) for likelihood computation.For gradient calculations, we set the number of tips to 8198 (to minimize the impact of For methods that employ just-in-time (JIT) compilation, including BEAST, BEAST2 and VBSKY, we run a short MCMC chain or variational inference algorithm to compute likelihood or gradient across 100,000 iterations and take the average run time.In our analysis, we observe that for likelihood computations, the implementations in BEAST, BEAST2, and RevBayes offer similar speed performance when adjusting both the number of sequences and epochs.In contrast, the TreePar package consistently lags, being several hundred times slower than its counterparts across all tested scenarios.It is also the sole implementation that exhibits a quadratic scaling with the number of epochs.The algorithms of BEAST, BEAST2, and RevBayes seem to demonstrate approximately linear scaling relative to both tree size and model epochs.It's worth noting that RevBayes delivers the quickest calculation speed, which might be attributed to the inherent speed advantages of precompiled codes, particularly for quick likelihood calculations in our context.Result for TreePar with epochs exceeding than 512 is not not included as TreePar fail to process such large models.In terms of gradient calculations, our analytical gradients deployed within BEAST is remarkably faster than VBSKY approach using automatic differentiation.The gradient computation scales approximately linearly with the number of sequences for both BEAST and VBSKY.However, while this linearity persists for BEAST wrt an increasing number of epochs, VBSKY shows a departure from linear scaling.Notably, the run time for the VBSKY method escalates from 0.02 seconds with 400 epochs to 0.04 seconds with 450 epochs.We further confirm that the runtime slowness exhibited in VBSKY is not due to memory issues or JIT compilation difficulty.However, without the ability to modify or closely examine the automatic differentiation library employed by VBSKY, identifying the specific causes of this non-linear scaling remains out of reach.Therefore, our analysis demonstrates that analytically calculating the gradients of the EBDS likelihood is critical for improving the running time of gradient based methods.

XML specification for the EBDS model using HMC sampler
BEAST data, likelihood, prior and sampling specification relies on extensive markup language (XML) elements.Comprehensive instructions for incorporating XML elements to-gether to drive BEAST are provided in the How-to Guides on the BEAST community website https://beast.community/.In the instructions here, we address the construction and use of XML elements for the EBDS model and HMC transition kernels, which are the key components for the application of the results presented in this study.
The EBDS model XML elememt is specified as follows: < n e w B i r t h D e a t h S e r i a l S a m p l i n g id = " bdss " units = " years " hasFinal Sample = " false " Following the incorporation of XML elements for the selected prior distributions and models pertinent to joint phylogeny inference, we can shift our focus to the transition kernel or "operator" element.This block contains a unique operator, the hamiltonianMonteCar-loOperator, differing from standard operators in that it requires a jointGradient object instead of a typical parameter object.BEAST internally retrieves the parameter from its gradient object.As we are dealing with gradients with respect to several different parameters, we can define compoundGradient elements as needed in advance.Specifically, we define the compound gradient using two new elements: the gradient element for thr priors and the speciationLikelihoodGradient element for EBDS model parameters.A sample HMC operator element is shown below.
The implementation of HMC necessitates user specification of two critical parameters: the step size stepSize, and the number of steps nSteps.Notably, BEAST features intrinsic auto-tuning capabilities, facilitating parameter tuning during active analysis.We can enable this measure by specifying autoOptimize="true" and declaring a value for targe-tAcceptanceProbability.We also include a preconditioner element here which is not mandantorily required for HMC.However, as highlighted in the main text, preconditioning the mass matrix based on the Hessian of the log-prior significantly improves the efficiency of our HMC sampler.Removing this element leads to a standard HMC transition kernel,

1
= 1.26,SD = 0.58) Log-scale birth rate at present µ * k Normal (Mean = 3.02, SD = 0.41) Log-scale death rate for all epochs ψ * Normal (Mean = 1.27,SD = 0.58) Log-scale sampling rate at present t or Normal (Mean = 1.89,SD = 15.0)Age of phylogeny α Fixed to 0.25 Exponent of the MRF ϕ Gamma (Shape = 1.0,Scale = 1.0)Transformed global scale of the MRF ν k Exponentially tilted stable distributions Local scale of Bayesian bridge MRF ξ Fixed to 2.0 Slab width of Bayesian bridge MRF

Figure A :
Figure A: HIV virus: Median (solid line) and 95% credible intervals indicated by the shaded areas of the (a) birth rate, (b) death rate, and (c) sampling rate estimates through time.

Figure B :Figure C :
Figure B: Influenza virus: Median (solid line) and 95% credible intervals indicated by the shaded areas of the (a) birth rate, (b) death rate, and (c) sampling rate estimates through time.

Figure D :
Figure D: Speed of implementations for the likelihood calculations of increasing number of sequences (left plot) or number of epochs (right plot) for EBDS model.Note the time and number of sequences/epochs are laid out according to a logarithmic scale with base 2.

Figure E :
Figure E: Speed of implementations for and gradient calculations of increasing number of sequences (left plot) or number of epochs (right plot) for EBDS model.

c o n d
i t i o n O n S u r v i v a l = " false " > < birthRate > < parameter idref = " bdss .birthRate " / > </ birthRate > < deathRate gradientFlag = " false " > < parameter idref = " bdss .deathRate " / > </ deathRate > < samplingRate > < parameter idref = " bdss .samplingRate " / > </ samplingRate > < s a m p l i n g P r o b a b i l i t y gradientFlag = " false " > < c o m p o u n d P a r a m e t e r id = " bdss .s a m p l i n g P r o b a b i l i t y " > < parameter id = " s a m p l i n g A t P r e s e n t " value = " 0 " dimension = " 1 " lower = " 0.0 " upper = " 1.0 " / > < parameter id = " otherSampling " value = " 0 " dimension = " DIM -1 " lower = " 0.0 " upper = " 1.0 " / > </ c o m p o u n d P a r a m e t e r > </ sa m p l i n g P r o b a b i l i t y > < t r e a t m e n t P r o b a b i l i t y gradientFlag = " false " > < parameter id = " bdss .t r e a t m e n t P r o b a b i l i t y " value = " 1.0 " dimension = " DIM " lower = " 0.0 " upper = " 1.0 " / > </ t r e a t m e n t P r o b a b i l i t y > < origin > < parameter id = " bdss .origin " value = " ORIGIN TIME " lower = " 0.0 " / > </ origin > < cutOff > < parameter value = " CUT OFF TIME " / > </ cutOff > < numGridPoints > < parameter value = " DIM " / > </ numGridPoints ></ n e w B i r t h D e a t h S e r i a l S a m p l i n g >The option conditionOnSurvival controls whether the model conditions on the survival of at least one individual at the present time.Note that if we remove the compound parameter declarations, we need to set the initial parameter values for birthRate, deathRate, and samplingRate in this EBDS model block.Users can refer to the XML files for HIV examples in our provided Github repository to make the corresponding changes.For our analyses, we assume no intensive sampling events at the epoch switching times, so we set the samplingProbability at other times to be 0. Users can modify this parameter to value between 0 and 1 to incorporate these intensive sampling events.We need to input a starting value for parameter origin which matches the length of the starting phylogeny or the fixed tree.The parameter cutOff governs the end point of our last epoch and numGridPoints specifies the number of epochs.The current setup of this XML block assumes equidistant epoch switching times.Users can modify the grid points by adding additional parameter grids in this EBDS model block.To illustrate, in our HIV example, we have: s j+1 is a serial sampling event then 8Calculate p k+1 (s j+1 ) via Equation (21)end if j < m k+1 − 1 then 11 Calculate I k (E j ) via Equation (2)end end Calculate and store p k+1 (t k+1 ) via Equation (21) end /* Likelihood */ Calculate P[T | λ, µ, ψ, ρ, r, t] via Equation (1) 2 Gradient Derivation 2.1 For ∂ log P k (j) ∂θ k

Table B :
Prior specifications for the EBDS model in Influenza virus analysis