STQS: Interpretable multi-modal S patial-T emporal-se Q uential model for automatic S leep scoring

Sleep scoring is an important step for the detection of sleep disorders and usually performed by visual analysis. Since manual sleep scoring is time consuming, machine-learning based approaches have been proposed. Though efficient, these algorithms are black-box in nature and difficult to interpret by clinicians. In this paper, we propose a deep learning architecture for multi-modal sleep scoring, investigate the model ’ s decision making process, and compare the model ’ s reasoning with the annotation guidelines in the AASM manual. Our architecture, called STQS, uses convolutional neural networks (CNN) to automatically extract spatio-temporal features from 3 modalities (EEG, EOG and EMG), a bidirectional long short-term memory (Bi-LSTM) to extract sequential information, and residual connections to combine spatio-temporal and sequential features. We evaluated our model on two large datasets, obtaining an accuracy of 85% and 77% and a macro F1 score of 79% and 73% on SHHS and an in-house dataset, respectively. We further quantify the contribution of various architectural components and conclude that adding LSTM layers improves performance over a spatio-temporal CNN, while adding residual connections does not. Our interpretability results show that the output of the model is well aligned with AASM guidelines, and therefore, the model ’ s decisions correspond to domain knowledge. We also compare multi-modal models and single-channel models and suggest that future research should focus on improving multi-modal models.

Deep learning approaches for sleep scoring apply architectures and techniques from deep learning on general domains, e.g., by directly applying convolutional neural network (CNN)-based architectures [10], incorporating sequential information via long short-term memory (LSTM) [11] networks, or adding residual connections (RC) [11]. However, the individual contribution of these architectural components have not yet been investigated. Pioneering deep learning approaches for sleep scoring base their prediction on single modality inputs, i.e., only EEG [11]. Subsequent work provides evidence that incorporating multiple modalities (e.g., EOG and EMG) improves performance [12]. Though very efficient in performance, a general downside of deep learning is its black-box nature, which hinders its adoption in clinical settings. Research in eXplainable Artificial Intelligence (XAI) [13] aims to make these black-boxes more transparent, e.g., by using post-hoc interpretability methods to explain the outcome of such models [14].
In this work, we propose a deep learning architecture, Spatio-Temporal-seQuential-Sleep-scoring (STQS), for multi-modal, multi-channel input data, designed to account for spatial, temporal and sequential information in the signals and evaluate the contribution of various architectural components. We apply post-hoc interpretability methods to investigate the alignment of the model with scoring guidelines [2], and evaluate the contribution of different modalities for the prediction. Specifically, our contributions are: 1. We show how to leverage spatio-temporal and sequential information from multi-modal, multi-channel input signals in a deep neural network and evaluate the effect of adding sequential information, information transfer via residual connections and various classimbalance techniques (results in Section 7.1). 2. We evaluate our model on a public benchmark dataset (SHHS, 5793 subjects) and an in-house dataset (1418 subjects) and compare it to multiple baselines (experiments in Section 7.1). We investigate the importance of multiple modalities using post-hoc interpretability methods (results in Section 7.2). 3. We show the model's alignment with AASM sleep scoring guidelines [2] by applying 3 different methods of post-hoc interpretability: frequency-domain occlusion, time-domain occlusion, and pattern visualization of temporal filters in the CNN (results in Section 8).
The remainder of this paper is organized as follows. Sections 2 and 3 introduce the state-of-the-art and datasets. Section 4 explains STQS, Section 5 discusses our post-hoc interpretability approach and Section 6 reports the experimental setup. Prediction results of our model are reported in Section 7, while Section 8 investigates the model's reasoning. We discuss implications of our results in Section 9 and conclude in Section 10.

Related work
In this section, we describe manual sleep scoring in detail, review traditional and deep-learning based approaches to automatic sleep scoring and review explainable AI methods.

PSGs and manual sleep scoring
A PSG is a sleep study, for which signals like EEG, EOG, EMG, electrocardiograms and leg movement are recorded from a patient. Humans experience 5 stages during sleep: Wake (W), Rapid Eye Movement (REM), Non-REM stage 1 (N1), Non-REM stage 2 (N2) and Non-REM stage 3 (N3). The analysis of sleep stages is crucial for the detection of sleep disorders, e.g., the periodic leg movement syndrome. The signals are collected for an 8 h period, i.e., a whole night of sleep and divided into 30 s epochs. Each epoch is annotated with a sleep stage according to the American Academy of Sleep Medicine (AASM) [2] or Rechtschaffen and Kales (R&K) [3] sleep manual. This annotation process is called sleep scoring and is usually based on EEG, EOG and EMG signals only. The sleep stages are annotated based on distinctive characteristics of the signals (cf. Table 1 for the AASM manual characteristics). The EEG signals per stage vary in amplitude, frequency and exhibit distinctive patterns, e.g., K-complex, or vertex waves. The EOG and EMG signals per stage mainly vary in amplitude. Currently, sleep stages are annotated based on visual inspection by sleep technologists. Annotating one PSG takes about 2-3 h.

Machine learning for sleep scoring
Traditional Machine Learning approaches rely on expert-defined features capturing temporal, frequency and non-linear properties of the data. Li et al. [5] combined random forests and rules developed from the R&K sleep manual achieving an accuracy of 0.86 and Cohen's kappa (κ) of 0.805 on a dataset with 198 subjects from Cleveland Sleep Study [24,25]. Koley et al. [6] and Lajnef et al. [8] used similar expert-defined features to train a support vector machine. [6] achieved κ of 0.86 on 28 subjects from Center of Sleep Disorder Diagnosis, India and [8] achieved an accuracy of 0.88 on 15 subjects from DyCog Lab, France respectively. Hassan et al. [26] exploited bootstrap aggregation to classify the sleep stages based on statistical moments extracted by tunable-Q wavelet transform, achieving an accuracy of 0.937 on Sleep-EDF13 dataset. Alickovic et al. [27] used discrete wavelet transform to extract features from the EEG channel and then trained an ensemble classifier called rotational support vector machine for sleep stage scoring achieving an accuracy of 0.91 on Sleep-EDF dataset. Experiments by Khalighi et al. [28] indicate that the best performance (accuracy of 0.92) for sleep scoring is obtained using 9 multi-modal EEG, EOG and EMG channels as input.
The usage of expert-defined features requires expert and/or domain knowledge. Additionally, the above approaches were evaluated on small datasets only. Malafeev et al. [9] compared traditional machine learning (a combination of random forest and Hidden Markov model), to deep neural networks (combination of convolutional neural networks and long short-term memory networks). They conclude that deep neural networks are superior in their generalization ability. Thus, we decided to focus on deep learning approaches for sleep scoring.

Deep learning for sleep scoring
Deep learning approaches can be distinguished based on their usage of input modalities. In the following section, we describe approaches on single modalities (usually EEG) and multi-modal approaches (usually EEG, EOG and EMG). Table 2 provides a comprehensive overview of the approaches, grouped by the type of modalities used.
Vilamala et al. [10] used spectrogram images of one EEG channel as input to a pre-trained VGGNet. Supratak et al. [11] developed a single channel EEG model using a CNN, a bidirectional LSTM (Bi-LSTM) and residual connections. Their CNN uses two different filter sizes for capturing both, frequency information and temporal patterns. Mousavi et al. [19] also used single EEG channel input and a similar architecture, with attention mechanism to learn the parts of the sequence to focus on, and a novel loss function to address class imbalance. Biswal et al. [20] used raw input and spectrogram input of multiple EEG channels on an architecture with CNN, Recurrent Neural Network (RNN) and residual connections.
Phan et al. [23] used time-frequency images from each modality (EEG, EOG, EMG) as input to a multi-task CNN model. Their experiments showed an increase of 4.1% in accuracy when combining EEG and EOG and an additional increase of 1% when adding EMG. Paisarnsrisomsuk et al. [21] used both, EEG and EOG and found that EOG increased the accuracy by 1%. A similar observation was made by Yildrim et al. [22]. Chambon et al. [12] also observed increasing performance when adding multiple modalities. They applied temporal and spatial filtering via CNN to extract features from multi-modal inputs (EEG, EOG and EMG) and also encoded temporal context from neighbouring epochs into their input to the CNN. Their extensive experiments showed that using additional modalities (EOG and EMG) with ≈6 EEG channels improved performance over only using EEG channels. However, there was no improvement in performance on further increasing the number of input EEG channels. Further sleep scoring approaches include an architecture based on depthwise separable convolutional layers [15] achieving an accuracy of 0.85 and a model to address database variability [29]. The latter paper reported the model's cross dataset performance on a private and several public sleep datasets. Their model achieved a kappa of 0.78 on SHHS visit 2 dataset. While spiking neural networks (SNNs), i.e., third generation neural networks, have been shown to learn prediction based on EEG signals [30], they have not been applied to sleep stage detection yet. We chose to focus on 2nd generation neural networks instead to make use of the large body of work on explainable AI for these types of networks. Since previous work found that multi-modal approaches outperform approaches based on single input modalities, we focus our work on multi-modal approaches. However, most of the previous multi-modal approaches rely on vanilla CNN architectures. Therefore, we investigate architectural components proposed for single-modality networks in a multi-modal settings. More specifically, we investigate (i) spatio-temporal feature extraction with CNNs [12], (ii) using Bi-LSTM to encode sequential information [11] and (iii) using residual connections to explicitly forward information from earlier to later layers in the network [11].

Explainable AI (XAI)
The purpose of XAI in our work is to justify the decisions of the blackbox models by comparing it to existing sleep scoring guidelines [2], such that they can be trusted by the clinicians and adopted in clinical workflows [31]. In this paper, we use the term interpretability instead of explainability for XAI methods [31]. XAI methods can be distinguished into methods that create intrinsically interpretable models, and post-hoc interpretability methods [31]. Examples for the former are bayesian rule lists [32] or generalized additive models [33]. However, such models usually provide less accurate predictions. Post-hoc interpretability methods aim at explaining black-box models (such as Deep Neural Networks), examples are sensitivity analysis [34], layer-wise relevance propagation (LRP) [35], and occlusion techniques [14]. These methods determine the importance of specific input features for a prediction. Occlusion techniques observe the sensitivity of the output on perturbing some features of the input, whereas LRP calculates the relevance score of the input features through backward propagation. Vilamala et al. [10] used sensitivity maps to highlight the input features deemed important by their model for the prediction of a sleep stage. In this work, we apply occlusion and LRP to explain our deep learning models.
From the Sleep Heart Health Study (SHHS) [36,37] we use the data from the first visit (SHHS-1). SHHS-1 contains 5793 2 PSG records (subjects' age ≥ 40). The PSGs consist of signals from 2 EEG sensors (C3-A2 and C4-A1) sampled at 125 Hz, 2 EOG sensors (left and right) sampled at 50 Hz and 1 EMG sensor sampled at 125 Hz. Sleep stages are annotated based on the R&K manual [3]: W, N1, N2, N3, N4, REM, Movement and Unscored. We unified the annotations of this data set to comply to AASM annotations [2] by combining N3 and N4 into a single stage N3 and removed the epochs annotated as Movement and Unscored.
The MST dataset was collected from Medisch Spectrum Twente,  [2]. The datasets were pre-processed as follows: All EEG, EOG and EMG channels of both datasets were low-pass filtered at 30 Hz [12], EEG and EOG channels were high-pass filtered at 0.16 Hz and EMG at 10 Hz [9].  All EEG, EOG and EMG channels of both datasets were further resampled to 125 Hz. In SHHS, EEG and EMG were already recorded at 125 Hz, therefore, only EOG channels needed to be upsampled from 50 Hz to 125 Hz using interpolation. In MST, all channels were downsampled from 250 Hz to 125 Hz using decimation. Resampling and filtering the signal are common preprocessing steps followed in existing sleep scoring literature [8,9,11,12]. Each PSG was divided into 30 s epochs (of length 3750 for 125 Hz sampling frequencies). The signal was cropped at the last timesteps, such that the signal length corresponds to multiples of 30 s. Each channel in each epoch was standardized to mean 0 and standard deviation 1.
For training and testing our models, we randomly assign 81% of PSGs for training, 9% for validation and 10% for testing in both the SHHS and MST dataset.

Approach for sleep stage prediction
We begin this section with a formal definition of the automatic sleep scoring task, and then describe our model architecture in detail.

Problem definition
Let X = (X 1 , …, X t , …, X T ) be a multi-modal PSG signal, with corresponding sleep stages Y = (y 1 , …, y t , …, y T ). T is the total number of epochs and X t is the t th 30 s epoch annotated with sleep stage y t . X t contains C m channels from each modality m ∈ {EEG, EOG, EMG}. The task is a multi-class classification problem: predict the sleep stage ∈S = {W, N1, N2, N3, REM} given X t . Thus, we aim to learn a prediction function f, such that ŷ t = arg maxf(X t ). Let y t be the true class and ŷ t be the prediction. y t is a 1-hot encoded vector of length |S|, with entry 1 for the true class and 0 otherwise. The prediction function f returns a probability distribution over the classes. Our model minimizes the categorical cross entropy loss between f(X) and Y, defined as

STQS architecture
For processing multi-modal PSG data, we first apply spatial and temporal filters on all modalities separately, and then combine all modalities. A sequential learning component is added to incorporate longer temporal contexts, i.e., the previous and subsequent epoch. The architecture of STQS is shown in Fig. 1 and described in more detail in the remainder of this section.

Input layer
Our model has a separate feature extraction pipeline for each input modality m. Each 30 s epoch has a shape of 1 × C m × t for each modality. 3

CNNs for spatial filtering
Spatial filtering transforms the C m raw channels into C m spatially combined representations of all channels. The spatial filtering component motivated from [12] consists of a convolutional block, with C m convolutional 2D filters each of shape C m × 1 that learns the spatial relationship among the channels of a modality. This is followed by batch normalization [38] and rectified linear unit activation (ReLU(x) = max (0, x)). Single-channel modalities (EMG, in Fig. 1) are directly passed to the temporal filtering stage. The generated feature vector for each modality after spatial filtering is of shape C m × 1 × t containing C m spatial representations and is reshaped into 1 × C m × t before passing to the temporal filtering component.

CNNs for temporal filtering
The feature vector from spatial filtering is further processed by two consecutive convolutional-max pooling blocks to extract the temporal features. The convolutional filters are of size 1 × 64 (i.e., ≈0.5 s). Each convolutional block has 8 filters, followed by a max pooling of stride 1 × 16 to reduce the width and retain only the most important features. The filter size was set such that it can capture patterns like the k-complex, which lasts for at least 0.5 s. The generated feature vector is of size 8 × H m × W m for each modality with H m , W m being the height and width of the last temporal filtering layer, respectively. This feature vector from each modality is further flattened and dropout [39] is applied with probability 0.5 to prevent overfitting. The resulting feature vectors are horizontally concatenated into a single vector of length

Bi-LSTMs for sequential learning
Sleep technologists use information from previous and subsequent epochs to annotate one epoch. An example of such an annotation rule is "Epochs without k-complex or spindles following an N2 stage, will be annotated with N2 if the previous epoch contained a k-complex without arousal or a sleep spindle" [2]. In order to capture such rules, we add a sequential learning component. We use a bidirectional LSTM (Bi-LSTM) [40] layer for learning sequential information from both the forward and backward direction of the sequence (containing eight 30 s epochs) with 20 hidden neurons per direction. Let LSTM f , shape 8 × 20 (seq_l en × features) and LSTM b , shape 8 × 20 be the feature vector of 8 epochs in the sequence in the forward direction and the backward direction, respectively. The output of the Bi-LSTM is a side-by-side concatenation of LSTM f and LSTM b (cf. Fig. 2), generating a feature vector of shape 8 × 40. The hidden and cell states of the Bi-LSTM layer are initialized with the hidden and cell state values from the last element of the previous sequence for each subject. This initialization, depicted in Fig. 2 with the dashed arrow, incorporates the global state of the signal into the prediction for each sequence. The hidden and cell states are initialized to zero for the first sequence of a new subject.

Residual connections
We use a residual connection block (RC) to add spatio-temporal and sequential features to the final prediction layer. The motivation for using RC is to improve predictions for stages for which temporal features are more important than sequential features. RCs add the multi-modal CNN features (spatio-temporal) to the Bi-LSTM (sequential) features elementwise. To add the feature vectors, the dimension of the CNN feature vector is reduced to the dimension of the Bi-LSTM feature vector using a fully connected layer of dimension 40, followed by batch normalization and ReLU activation.

Final prediction
A fully connected layer with a softmax activation function outputs the final prediction. Since there are 5 sleep stages, the output is of size 5.

Addressing class imbalance
The class imbalance in the dataset (cf. Table 3) is likely to decrease prediction performance for minority classes [41], i.e., stages N1, N3 and REM. We use two weighted cost functions w 1 (s) = 1 − Ns N and w 2 (s) = 1 Ns ⋅ N |S| , where w(s) is the weight of class s ∈ S = {W, N1, N2, N3, REM}, N s is the number of instances in s and N is the total number of instances. The weighted cost functions result in a higher error for misclassifications on rare classes. For oversampling, we randomly duplicated the instances of all but the majority class, such that the dataset becomes balanced. We apply class imbalance techniques only to the CNN component of our model, to not lose the sequential arrangement of the epochs in a PSG.

Approach for post-hoc interpretability
We applied post-hoc interpretability methods to understand how the black-box models perform their prediction. Specifically, we are interested in the following 4 questions: (i) To what extent do different modalities (EEG, EOG, EMG) contribute to the prediction? (ii) What are the prediction-relevant patterns in the EEG signal? (iii) Which frequency bands of an EEG signal are most important? (iv) How do different temporal filters in the CNN contribute to the final decision? Are there stage-specific filters? Additionally, we are interested, whether our findings align with the AASM guidelines (cf. Table 1).

Modality importance
We occlude different combinations of modalities (modality occlusion) by setting the amplitudes of all channels of those modalities to 0 while keeping the other modalities unchanged [31]. The occluded epoch in the test set is sent to the trained model for prediction. The influence of the modalities is analyzed by comparing the results with occlusion to the results without occlusion on the same epochs.

Predictive patterns in EEG
To identify the most important EEG patterns, we occlude parts of the 30 s EEG epoch in the time-domain (time-domain occlusion). We use a 5 s sliding window with a 1 s shift along the 30 s epoch. The signal within this occlusion window is set to 0. As we standardized the original epoch with mean 0, this choice represents an inactive signal without changing statistical properties. For each location of the occlusion window, we record the prediction of the model epoch and compare it with the prediction on the original epoch. The motivation is that the prediction of the epoch is likely to change if the occlusion window "hides" a pattern which is important for the prediction.
The idea is sketched in Fig. 3.
26} is shifted over the epoch. To consider all consecutive 5 s, we use overlapping occlusion windows. However, this resulted in any two consecutive occlusion windows (e.g., w 1 , w 2 ) having an overlap of 4s. This overlap makes it hard to uniquely quantify the contribution of the occlusion window towards prediction. Therefore, we divide each epoch into non-overlapping 1 s patterns, x p , and calculate the pattern-based where N p is the number of times x p was occluded while shifting the window and N pc p is the number of times the prediction ŷ o k changed from prediction ŷ when occluding Occluding only one channel might not change the prediction, even though the pattern is important for a sleep stage. Hence, we use timedomain occlusion on both EEG channels in the SHHS dataset.

Predictive frequency bands in EEG
According to the AASM guidelines, some frequency bands of EEG signals are highly indicative for certain stages (cf. Table 1). To investigate if our model corresponds to this domain knowledge, we occlude the EEG signal in the frequency-domain (frequency-domain occlusion) [31] and investigate whether the prediction changes on occlusion. Removing the most prominent frequency bands should result in maximum misclassification for that stage and keeping only those frequency bands will lead to maximum misclassification for other stages. More concretely, we occluded δ (delta), θ (theta), α (alpha), σ (sigma) and β (beta) (cf. Table 1) frequency bands one-at-a-time. To occlude a frequency band, we (i) kept frequencies only in the specific range and removed the rest (band-pass filter in the frequency band), or (ii) removed all frequencies in a specific range (band-stop filter in the frequency band) from the signal. We performed our experiments on both EEG channels in SHHS.

Filter importance and visualization
We would like to know how different temporal filters contribute to the final decision and whether specific filters are learnt for certain sleep stages. To this end, we applied LRP [35], consisting of a forward activation and a specific backward calculation, to calculate an importance score of a filter per test instance. We focused on the filters in the first convolutional layer in temporal filtering (i.e., the 5th layer in the model), because feature complexity will increase in deeper layers [42] and patterns in the AASM guidelines are also basic features. Filter relevancy is averaged per sleep stage, resulting in an importance score per filter per stage. We selected 20 patients randomly from SHHS 4 to calculate the importance scores. To identify the frequency patterns learned by a specific filter, we performed power spectrum analysis on the activations of a filter. The dominant frequency components of boththe raw data and the activations of a filter, are compared. We also generated white noise with a uniform distribution of all effective frequencies of input signals to test the filter reactions to all frequency bands. Because EEG channels contain important frequency information, while EOGs and EMGs are mainly distinguished by amplitude (cf. Table 1), we analysed the 8 EEG filters on SHHS.

Experimental setup
In this section, we describe the model variants, the training process and evaluation metrics.

STQS architectures and baseline
We tested various combinations of the architectural components outlined in Section 4.2. The ST model performs spatio-temporal filtering and combines the modalities for the final prediction as described in Section 4.2.3 and 4.2.6. In Fig. 1, ST corresponds to the top part (input, spatial, and temporal filtering, modality concatenation), the dashed line and the prediction layer. Q denotes the sequential learning component described in Section 4.2.4 corresponding to dashed line . RC denotes the residual connection block described in Section 4.2.5, corresponding to the dashed line in Fig. 1. The three different imbalance techniques introduced in Section 4.3 are denoted with superscript (oversampling O , and W1 and W2 are the two weighted cost functions). We further compared our STQS models to a baseline model, MLP (Multi-Layer Perceptron), made up of Input-3 FC blocks-Output (18,750, 10,000, 5000, 1000, 5 neurons respectively). Input is a concatenation of features of all channels, each FC block consists of a FC layer followed by BatchNorm and ReLU and the output layer consists of a FC layer with a softmax activation function. We evaluate the MLP only on SHHS due to the huge size of the input layer in the MST dataset. 5 We tested  [43] with learning rate λ = 10 − 4 , β 1 = 0.9 and β 2 = 0.999. 6 We did not use weight decay for regularization, as it resulted in worse performance than the class imbalance techniques. Weights for batch normalization, convolutional and fully connected layer were initialized following a normal distribution (0, 0.02). 7 For the Bi-LSTM layer, we used orthogonal weight initialization [44] and a sequence length of 8 epochs (4 min) as suggested by previous work [9]. 8 Our batch size of 192 epochs corresponds to 96 min, i.e., approx. one sleep cycle. Thus, on average each batch contains most sleep stages. For the ST-Q variants, a batch (of size 192 epochs) is reshaped into 24 × 8 before passing it to Bi-LSTM layer. This means the Bi-LSTM layer is trained with a batch size of 24 sequences of length 8. Before passing the Bi-LSTM output to the prediction layer, it is reshaped back to the batch size of 192 to consider all the epochs in the sequence for prediction. The last input epoch of each PSG was filled to sequence length 8 by copying the epochs from the beginning of that PSG. Thus, both the first and the last sequence from a PSG contains the first few epochs. For evaluation, we only consider the Bi-LSTM output from the first sequence for these common epochs. We applied the interpretability techniques on the ST O model, since ST-Q O learns sequential information from the neighbouring stages for prediction, and we cannot explain the predictions of this model solely based on stage-specific rules.

Training and testing
We train the spatio-temporal filters and the sequential filters successively. In Stage 1 we train the spatio-temporal filters. We shuffle the input data, and if oversampling is used, we additionally augment the data accordingly. Then, we train only the ST part of the architecture. In Stage 2 we train the sequential parts of the model. Each input PSG is divided into 8 non-overlapping sequences of 30 s epochs (cf. Fig. 2). Models are initialized with the weights learned in Stage 1 (except the prediction layer) and all layers are trained. No class imbalance technique is used in this training step.
We used early stopping on validation loss with a patience of 7, i.e., if the validation loss did not decrease for 7 training iterations, the training was stopped. The PyTorch implementation and trained models are available online. 9

Evaluation metrics
We calculated the overall accuracy (a) at 95% confidence interval (CI), the balanced accuracy (a b ) to account for class imbalance, the macro-averaged F1 score (F1 M ) and Cohen's kappa (κ). All values are reported in percentage for better readability in Section 7.
where TP s , F1 s and N s are the number of true positives, F1 score and number of epochs of class s respectively. |S| denotes the number of classes and N is the total number of epochs in the test set. p o is the relative agreement among ground truth and prediction, p e is the hypothetical probability of a chance agreement, N tr s is the number of epochs in the true class s and N pr s is the number of epochs in the predicted class s. We also report predicted hypnograms for two test PSGs, with highest and lowest distance (among the test set) from their ground truth, calculated using hypnogram distance, HD. HD is based on the assumption that stages further away during sleep, should have more influence on the distance. Therefore, we encode the sleep stages as numbers (W as 0, N1 as 1, N2 as 2, N3 as 3, and REM as 4). The difference between N1-N2 (or N2-N1) and W-REM (or REM-W) is then 1 unit and 4 units respectively. 4 We only used 20 patients due to limited RAM. 5 12 channels × 3750 features. 6 λ = 10 − 3 decreased performance. 7 (0, 1) took longer to converge. 8 [9] did not test sequence length <8. We tested with a value of 5, i.e., 1 min before and after the to-be-classified epoch, as suggested by sleep technologists, but this did not improve performance. 9 Code is available on Github: https://github.com/ShreyasiPathak/STQS.
where y n ,ŷ n are the true and predicted sleep stage of an epoch n. N PSG denotes the total number of epochs in a PSG. The factor 1 4 normalises similarity (sim PSG ) to the interval [0, 1].

Sleep stage prediction results
In this section, we first report predictive performance of various architectural choices. Then we show the importance of modalities for sleep stage prediction and its conformance to AASM guidelines.

Model performance
An overall comparison of model variants (cf. Section 6) on both data sets (cf. Section 3) can be found in Table 4. We report accuracy values at 95% CI (cf. Appendix Table 3  Overall, all sequential models (denoted with -Q-) outperform Biswal et al. [20] in terms of a and κ, while the amount of improvement depends on the class balancing technique. However, the overall performance of our best model is ≈2% lower than Sors et al. [18], except for F1 score of W and REM stage, which is 1.5% and 3.7% higher respectively. This may be due to the fact that Sors et al. is a single channel model whereas STQS is a multi-modal, multi-channel model. Therefore, we investigate the importance of various modalities for sleep stage prediction in Section 7.2.
Comparing the ST model with its class imbalance counterparts ST O , ST W1 and ST W2 , we see that class imbalance techniques improve balanced accuracy, but not necessarily overall accuracy. This shows that class imbalance techniques improve the performance on rare classes, but possibly at the expense of misclassifying more epochs in total. There is, however, no clear picture for the choice of class imbalance technique of non-sequential models (ST O , ST W1 , ST W2 ). Taking balanced accuracy for comparison, oversampling outperforms weighted cost functions on SHHS, whereas the opposite is true for MST. The sequential model ST-Q (without class imbalance) performs worse (1.5% decrease in accuracy on SHHS) over its class imbalance counterparts (ST-Q O , ST-Q W1 , ST-Q W2 ), among which oversampling shows better results than weighted cost functions. ST-Q-RC O and ST-Q O show similar performance: all performance metrics are equal or differ by a maximum of 1%, indicating that combining information with residual connections, while adding more trainable parameters, does not improve the overall performance. In summary, we observe performance improvements by adding sequential information and accounting for class imbalance. Adding residual connections to a spatio-temporal model does not show any further improvement.
The  Fig. 5. On comparing the confusion matrices, we see that all the 3 models make misclassification between N1-N2 and N2-N3 and ST O (Fig. 5a) additionally misclassifies between REM-N1 and REM-N2. ST O has the highest true positive rate (TPR) among all confusion matrices for N1 and N3, ST-Q O (Fig. 5b) has Table 4 Performance comparison of different STQS model variants. Reporting aggregated values and F1 scores per class. Best values among our model are marked in bold and best values when compared with related work are marked in bold and italics,"-" indicates that the values are not available in the respective publication. We show the model of Sors et al. [18] for completeness, however, the results are not directly comparable ( [18] excludes subjects from the dataset (cf. Section 9.5).

79.7
Biswal et al. [20] 77.9  (Fig. 5c) has the highest TPR for REM. Though ST O has higher TPR for N1 and N3, the false positives for the same classes are higher as well. N1 and N3 being predicted more often suggests to be an effect of oversampling the dataset in order to account for rare classes.
The results reported so far show the performance per epoch. In order to investigate the misclassifications on the PSG level, we identified the two PSGs in SHHS for which the predictions are most similar and least similar to the ground truth (cf. Section 6.3). Fig. 4 shows the true (top) and predicted hypnogram (bottom) of those PSGs. It can be seen that the most similar hypnogram has many W stages, leading to the high similarity score. The N2-REM, W-N2, N1-N2 transitions have been predicted correctly, whereas the predicted hypnogram could not identify all the N2-N3 transitions. The most dissimilar hypnogram has more stage transitions presumably making it harder to predict. Not all W-N2, N2-N3 and W-REM transitions could be correctly predicted.

Modality importance
We investigated the contribution of each modality using the confusion matrices of the model generated on occluding combinations of modalities (cf. Fig. 6) in SHHS dataset. For instance, the contribution of EEG can be inferred by either occluding EEG or keeping only EEG (while occluding EOG and EMG). EEG is very important for the model and results in high TPR for all stages except N1, which is misclassified as REM. EOG alone cannot classify stages correctly. However, removing EOG lowers the TPR of all stages, except for REM. In fact, without EOG, REM achieves the best TPR, along with other stages also getting misclassified with REM, especially N1. Moreover, we find that occluding EMG results in lowest TPR for REM, while TPR of other stages remains almost unchanged. Further, EMG alone can only identify REM and W. If all channels are occluded, the model predicts W, which is the nonsleep stage and the class with the highest number of training instances.
In conclusion, the results show that (i) EMG and EEG are sufficient to correctly predict REM, (ii) EEG and EOG are sufficient to correctly classify N1 and N3, (iii) EEG alone is sufficient to classify W and N2, and (iv) EOG is necessary to reduce the misclassification of other stages with REM. These observations can be justified by the following facts in AASM (cf. Table 1): (i) The lowest amplitude of EMG in REM may make EMG important to identify REM, (ii) W and N2 generally have a very characteristic EEG signalhigh frequency signal in W and unique timedomain patterns in N2, making these 2 stages easily identifiable using EEG alone and (iii) The misclassification of other stages with REM, most prominently N1 (EEG signal of N1 and REM are quite similar leading to such high misclassification), is reduced by adding EOG due to the rapid eye movement patterns in REM.

Model interpretability results
In this section, we analyse important patterns in the time and frequency domain of EEG signals for prediction and investigate activation patterns of the temporal filters in the CNN.

Predictive patterns in EEG
To analyse which EEG patterns are most relevant for the prediction, we occluded the EEG signal in the time domain. Four examples of correctly classified epochs from the SHHS test set are shown in Fig. 7, lighter colors denote more important patterns. This means, if a yellow 1 s pattern is occluded by the 5 s occlusion window, the prediction changed more often than for patterns in darker colours. Purple means no change of prediction on occlusion. The annotations at the top of each figure indicate the following: The bottom number shows the PI p of a pattern p and the annotation above it shows the corresponding prediction on occlusion, ŷ o p . The example of stage N1 shows multiple consecutive 1 s patterns with vertex waves as most important for the prediction. The example of stage N2 highlights the k-complex in green, showing that it is among the important patterns for prediction, but not the most important. This suggests that the epoch contains other important information which leads to predicting N2, even if the kcomplex is occluded. For stage N3 high amplitude patterns are important, while the prediction for stage REM is based on saw-tooth waves and high-amplitude patterns. These findings conform to the characteristic patterns according to AASM guidelines (cf. Table 1).
While the previous examples show the importance of patterns for example epochs, we also investigated the change of prediction on an aggregated level. The Sankey diagram in Fig. 8 shows how many epochs per stage change prediction on occlusion. The figure is based on 10 randomly selected PSGs. We show the number of epochs in the ground truth s T , the prediction s P , and the prediction on occlusion s PO . For instance, 94,380 epochs in the ground truth are N2. From those, 77,142 are predicted as N2. On occlusion, 94,797 epochs are classified as N2.
Nearly all epochs predicted as N2 are also predicted as N2 on occlusion. This means, N2 mostly does not contain 5 s patterns that solely identify N2. If N3 epochs are occluded, a considerable amount is misclassified as N2, indicating 5 s windows in the epochs that hold the information for distinguishing N2 from N3. Further, many true N2, which are misclassified as N3, are correctly predicted as N2 on occlusion. This shows some similar patterns between N2 and N3 which results in these misclassifications and on occluding these patterns, the epoch is correctly predicted.

Predictive frequency bands in EEG
To identify the most relevant EEG frequency bands for the    prediction, we occluded the EEG signal in the frequency domain. Fig. 9 shows the confusion matrix on occluding different frequency bands. With only δ frequencies present, most epochs are (correctly and incorrectly) predicted as N3, while in absence of those frequencies N3 is nearly never predicted. This shows that δ is the characteristic frequency band of N3, which also conforms with the characteristic frequency of N3 according to AASM (cf. Table 1). If only θ frequencies are present, mostly N1 and W are (correctly and incorrectly) predicted, while omitting θ, decreases the TPR of N1, but increases the TPR of W. This shows that θ is a characteristic frequency of N1 (conforms with the AASM guidelines). N2 and REM also show considerable TPR when only θ is present, however, on removing θ, they still show a considerable TPR. This shows that θ is one of the characteristic frequency bands of N2 and REM, also in accordance with the AASM guidelines (cf. Table 1). If only α frequencies are present, stages are classified as W, whereas its absence results in increase of TPR in all other stages. This shows that α is the characteristic frequency band for W. A similar argument holds for β frequencies, suggesting that β is a characteristic frequency band for W,  which aligns with the AASM guidelines (cf. Table 1). If only σ frequencies are kept, all stages are classified as W; except N2. The absence of σ frequencies results in higher TPR for all stages, but W. A considerable number of TP in N2 for the case when only σ is present, is attributable to the sleep spindles in N2, whose frequency lies in the σ band. This shows σ is the characteristic band of W and also present in N2, again confirming the domain knowledge from the AASM guidelines (cf. Table 1).

Filter importance and visualization
The LRP [35] importance scores for the 8 CNN's temporal filtering layers are shown in Fig. 10. Most filters have relatively small absolute importance scores, i.e., single patterns are not very discriminative for a sleep stage. Single filters with high importance are found for W, N2 and N3. Few filters react significantly to N1 and REM, illustrating the hardness of predicting these stages. For identifying and visualising the frequency patterns learnt by EEG filters, we selected the filters with high absolute importance scores and analysed the dominant frequency components of their activations. Fig. 11 shows Power spectral densities of the raw signal and filter activations from typical epochs of filters 1, 2, 7 for W, filter 2, 3, 4 for N2 and filter 2, 3, 6 for REM. Filter reactions to the white noise of 0.5-30 Hz (effective frequency components of EEG inputs) are plotted in Fig. 12 to visualise the frequency patterns learnt by all 8 EEG filters. If we compare the filter reactions of the white noise to corresponding filter reactions in Fig. 11, we can see that the same filter always extracts the same frequency patterns independent of the input, which verifies that every EEG filter has its invariant unique function in feature extraction. Moreover, given all frequencies in the white noise have the same amplitude, the actual contributing value of a filter to a frequency band can be conferred via the corresponding amplitude in Fig. 12.
In addition, frequency patterns learnt by the EEG filters can be specifically explained when compared to the AASM guidelines and the importance scores. Fig. 11a shows that filter 1 and 7 extract the frequency components around 13 Hz, 0-2 Hz and 8-10 Hz from the EEGs of W, and filter 2 extracts 0-6 Hz. In the AASM guidelines, EEGs of W mainly contain α and σ frequencies (8)(9)(10)(11)(12)(13)(14)(15)(16). Therefore, filters 1 and 7 react positively, and filter 2 reacts negatively when predicting W. Similar observations can be made for Non-Wake stages. For N2 (cf. Fig. 11b), filters 2 and 3 react positively as they mainly recognize the frequencies between 0-14 Hz (θ frequencies, k-complex and sleep spindles), and filter 4 reacts negatively as it detects the frequencies between 15-25 Hz. For REM (cf. Fig. 11c), filters 2 and 3 extract the frequency components in 0-12 Hz and filter 6 mainly extracts 20-30 Hz. Filters 2 and 3 are highly important to REM, as the main components of REM are θ and α frequencies (4)(5)(6)(7)(8)(9)(10)(11)(12). Additionally, if we compare the importance scores of the same filter in the prediction of different stages, the quantitative importance scores can exactly show the contribution of a filter in predicting a particular sleep stage. For example, the frequency pattern learnt by filter 2 is 0-6 Hz which matches N2 better than REM, so filter 2 has a higher importance score in N2 than REM.

Discussion
In this section, we discuss the contribution of the architectural components, reasons for misclassifications and put our work in context to related work.

Model analysis
We developed a multi-modal sleep scoring model which can learn from EEG, EOG and EMG. This was motivated from the fact that the AASM manual recommends the use of all three modalities for sleep scoring, as the three modalities together have unique characteristics to differentiate various stages. Sleep technologists also usually consider all 3 modalities for scoring. We show that automatic sleep scoring on raw multi-modal input signals can be performed with 85% accuracy.
The confusion matrix from the ST O model (cf. Fig. 5a) shows a higher TPR for N1 and N3 than ST-Q O (cf. Fig. 5b). Adding residual connections to combine spatio-temporal with sequential features (ST-Q-RC O model), however, did not increase the overall performance, but increased the TPR of N3 (cf. Fig. 5c). This indicates that N3 can be predicted quite well solely based on temporal features, whereas sequential information seems to decrease performance of N3. This observation is in line with the AASM manual, which does not mention sequential rules for N3. The confusion matrix of ST-Q O (cf. Fig. 5b) shows that misclassifications generally occur between contiguous sleep stages (W-N1, N1-N2 and N2-N3). The confusion matrix of ST O (cf. Fig. 5a) also shows misclassifications based on feature similarities, e.g., REM-N1 and REM-N2 misclassifications. REM and N1 lie in similar frequency bands for EEG and EOG signals: EEG lies in θ, α (cf. Table 1) and EOG lies in 0.1-0.4 Hz [45]. Similarly, REM and N2 have EEG frequencies in θ and the amplitude of EMG in N2 can be as low as the amplitude in REM (cf. Table 1). The AASM manual also indicates that a stage is scored as REM when the majority of the epoch has REM characteristics, even though it contains a k-complex suggesting it to be N2, which can make REM and N2 hard to distinguish. Fig. 10. Filter importance of the 8 EEG filters at the first temporal filtering layer (layer 5 in Fig. 1) for SHHS.

Additionally tested experimental setup
We performed some additional experiments to decide on the design of our Bi-LSTM component, STQS architecture and training process.
Bi-LSTM Component: We compared initialization of hidden and cell state of a sequence with information from the previous sequence to initialization of the same with zeros. We found better performance for the former and therefore only report those results. More specifically, the performance of REM and N2 increased, as for those stages, more information from contiguous epochs is necessary to decide the sleep stage of the current epoch [2] and some of these contiguous epochs can be present in neighbouring sequences. We also experimented with passing overlapping sequences to the Bi-LSTM such that the sequence window would be shifted by 1 input epoch for a new sequence, instead of shifting by the sequence length. The output taken into consideration for each sequence was the hidden state of the middle 30 s epoch in the sequence. The motivation for overlapping sequences was to provide more context for learning i.e., providing an equal amount of past and future epochs for the prediction of an input epoch. This resulted in more training time and similar results when compared to non-overlapping sequences. The similarity in results was mainly due to the fact that in non-overlapping sequences, hidden and cell state value of a sequence are passed on to the next sequence, generating a similar effect as overlapping sequences.
STQS Architecture: We compared the addition of (i) a common Bi-LSTM after concatenating the CNN features from all modalities against (ii) separate Bi-LSTM layers after the CNN layer in each modality pipeline, concatenating these Bi-LSTM feature maps from all modalities and then passing it for prediction. Both versions resulted in similar overall accuracy (85%), while the separate Bi-LSTM performed slightly better in predicting N1 (44.1% vs. 41.3% F1 on SHHS), but also required more training time due to an increased number of parameters. We therefore chose the model with a common Bi-LSTM layer.
Training Process: We explored 3 different 2-step training processes -(i) training the ST architecture, then retraining the ST-Q architecture with the same parameters, (ii) using different learning rates for both training steps [11], and (iii) training the ST architecture, then freezing the ST weights and using ST to calculate the feature maps and training only the Q. Process (i) seemed to result in the best performance and so, we used that for training.

Effect of dataset and class imbalance techniques
Our model was trained and tested on 2 different datasets -SHHS and MST. Our model performed better on SHHS (a : 0.85) as compared to MST (a : 0.77), which might be due to higher number of training data in SHHS. However, we think that though all modalities are important for sleep scoring, the amount of channels in each modality may contribute negatively to the performance. MST dataset has 12 channels and SHHS has 5. Having to combine information from so many channels in each modality pipeline may confuse the model (information within the channels of a modality may not support each other) rather than supporting it with additional information. More evidence on this can be Fig. 11. Frequency spectrum of the raw data and corresponding EEG activation patterns for W, N2 and REM for SHHS. found in Section 9.5.
We experimented with 3 class imbalance techniques and found class imbalance handling to be helpful in general, however, there seems to be an interaction with the dataset. Oversampling was found to be the best for SHHS whereas a weighted cost function was best for MST. We think this may be related to the amount of channels in the dataset. As the MST dataset contains a lot of channels for each modality, duplicating data may have added to the confusion and made it hard for the model to learn from this data. On the other hand, the weighted cost function increased the error for rare classes making the model update its weight more for rare classes. We used 2 weighted cost function techniques and found contradicting observation for both dataset, so we cannot conclude which one is the best weighted cost function.

Post-hoc interpretability
Our post-hoc interpretability results showed that our model's prediction conforms to the AASM guidelines by giving importance to the unique characteristics mentioned for each stage.
However, we found that the time-domain occlusion is not a reliable interpretability method as for some epochs, it did not find any important pattern. This suggests the absence of 5 s patterns, which are solely responsible for the prediction. This can be explained from the fact that sleep stages can either contain multiple non-localized patterns throughout the epoch like vertex or saw-tooth waves in N1 and REM or single localized pattern like k-complex in N2. In the former case, the model can either find all the multiple patterns to be important ( Fig. 7 (N1)) or no particular pattern to be specifically important due to multiple occurrences of the pattern. Further, localized patterns are not always the only reason for the decision (Fig. 7(N2)), again suggesting the possibility of no particular pattern being solely important.
We experimented with 5 s, 10 s and 15 s occlusion window sizes for time-domain occlusion. Results indicated that there is a trade-off between finding the most important patterns and occluding enough information for the prediction to change. We chose 5 s, as the larger window sizes found more patterns to be important, losing on our objective to find only the most important patterns. Limitations of our method are, for example, how much time-domain, frequency-domain or modality information our model uses for prediction. Also, our modality occlusion does not calculate the importance of the channels in each modality. We have developed methods for interpreting our ST models, but extensions for the sequential part of the architecture are left for future work.

Comparison with state-of-the-art
The performance of ST-Q O is on par with the state-of-the-art multimodal deep learning approaches on sleep scoring (cf. Table 2). Comparing to other papers is not straight-forward, due to differences in datasets and its preprocessing. We compare to other works that are closely related to our model and use similar data. Our CNN architecture was motivated by [12] and therefore we did not experiment with various filter sizes or other parameters of our CNN model.
Comparing to models evaluated on the same dataset (SHHS), we found that we outperform single modal model, Biswal et al. [20] and are at-par with single modal model, Fernandez-Blanco et al. [15] (however, their dataset has more PSGs than ours (cf. Table 2)), but we perform 2% lower than single channel model, Sors et al. [18]. Due to different data preprocessing, dataset splits and class imbalance techniques, results are not directly comparable. Our ST-Q model without any class imbalance technique has 1.5% lower accuracy than ST-Q O on SHHS and can be used to compare to state-of-the-art models, which do not use any class imbalance technique, like [18]. Still, the models are not directly comparable because of the other variations. To make our model directly comparable to [18], 10 we trained and tested their model on our SHHS dataset split (81-9-10%, train-validation-test), including all PSGs and wake stages (unlike [18]) and did not filter the EEG signals (like [18]). To select the best model checkpoint, we used the same criterion as our model (loss value) instead of accuracy. This resulted in 85.9 (a), 78.0 (F1 M ) and 80.2 (κ) on input of C4-A1 EEG channel (around 1% higher than ST-Q O ) and 85.1 (a), 76.2 (F1 M ) and 78.8 (κ) on C3-A2 EEG channel (almost the same as ST-Q O ). This shows that [18] performs slightly better than our model using only one EEG channel. We replaced their training method (model validation after training on some batches) with our training method (model validation after iterating through the whole training set), but found no significant influence on the performance.
We hypothesize that the slight difference in performance may be due to our multi-modal input versus their single channel input. We can see that the performance on using multi-modal, multi-channel input is almost equal to single channel input with only 1% difference. This would mean that one channel has enough information to identify stages uniquely. However, our multi-modal XAI experiments (cf. Section 7.2) show that the model has learnt the stage-specific modality importance, which conforms to the AASM guidelines. We hypothesize that the content of the information condensed from all the channels may have led to better prediction for some of the epochs (wrongly classified by a single channel model) and at the same time, to more confusion for some other epochs (correctly classified by a single channel model), resulting in an overall similar accuracy. To verify this hypothesis, we calculated the number of prediction mismatches between Sors et al. [18] (model trained on our dataset split) and ST-Q O on our test set (containing 580 PSGs). We found that 7.4% of the epochs were incorrectly classified by ST-Q O , but correctly by Sors' model, 7.0% were incorrectly classified by Sors' model, but correctly by ST-Q O , and for 85.6% of the epochs, both models agreed. This indicates, that model ensembling and/or multi-modal models with attention mechanisms could improve prediction.

Comparison to human performance
We compared the agreement between our predicted class and ground-truth with previously reported human inter-annotator agreement scores [47,46]. The agreement was reported for correctly classified instances and therefore, we compare their scores to the TPR from our confusion matrix of ST-Q O (Fig. 5b) and report the comparison in Table 5. The results from Whitney et al. [46] are more comparable to our model than Rosenberg et al. [47] as the former reported their inter-annotator agreement for SHHS dataset. However, please note the difference in the number of epochs among the 3 studies (cf. Table 5). From the scores, we can conclude that our model's overall agreement with the ground truth is comparable to the agreement between humans. Our model has a better agreement than humans from both [47,46] for W and N3, comparable to [47] and better than [46] for N2 and REM and less than [47], but better than [46] for N1.  [46] are averaged over results from 3 pairs of annotators. 10 We chose [18] over [20], due to [18]'s code availability.

Conclusion and future work
Our model can predict sleep stages with an overall accuracy of 85.5% and stages W, N1, N2, N3 and REM with an F1 score of 92.5%, 41.3%, 84.8%, 76.3% and 89.1% respectively on SHHS. We created a multimodal model which can learn from EEG, EOG and EMG inputs. We evaluated various architecture choices and found sequential learning (Bi-LSTM) improved predictive performance over spatio-temporal filtering (CNN), while residual connections did not. Through various post-hoc interpretability techniques, we found that our model conforms to the AASM guidelines. Thus, our model can be used to support sleep technologists in annotating sleep stages and explaining the reason for the automated annotation.
Through our multi-modal versus single channel experiments, we found that the single-channel model by Sors et al. [18] slightly outperforms our multi-modal model, while approx. 7% of the epochs are correctly classified by one, but not the other model. Moreover, through modality occlusion, we found that specific modalities are important for predicting specific stages. Therefore, we suggest that future work could investigate automatic channel selection for multi-modal sleep scoring models.
In a clinical setting, sleep scoring is used for diagnosing sleep disorders. Misclassifications that do not change this diagnosis, should be considered less serious. Evaluating our model based on predictive performance for diagnosing disorders is left for future work.