Cross-modal attention network for retinal disease classification based on multi-modal images

Multi-modal eye disease screening improves diagnostic accuracy by providing lesion information from different sources. However, existing multi-modal automatic diagnosis methods tend to focus on the specificity of modalities and ignore the spatial correlation of images. This paper proposes a novel cross-modal retinal disease diagnosis network (CRD-Net) that digs out the relevant features from modal images aided for multiple retinal disease diagnosis. Specifically, our model introduces a cross-modal attention (CMA) module to query and adaptively pay attention to the relevant features of the lesion in the different modal images. In addition, we also propose multiple loss functions to fuse features with modality correlation and train a multi-modal retinal image classification network to achieve a more accurate diagnosis. Experimental evaluation on three publicly available datasets shows that our CRD-Net outperforms existing single-modal and multi-modal methods, demonstrating its superior performance.


Introduction
Retinal diseases often impair a patient's vision or even lead to blindness [1][2][3].With the advancement of imaging technology, ophthalmologists often rely on different modal medical images for diagnosis, such as optical coherence tomography (OCT) [4] and color fundus photography (CFP) [5], which offer differential perspectives of retina.The large amount of images bring more references for diagnosis, but they also bring a significant burden to read and determine the results [6].Therefore, to aid disease diagnosis, researchers have tried to propose various automatic algorithms for high-accuracy disease classification.
Depending on the input images, the algorithms can be classified as single-model based and multiple-mode based methods.For the former, just one modal retinal image is used as input [5,[7][8][9][10].For example, Shen et al. [7] proposed a structure-oriented Transformer framework to further construct the relationship between lesions and age-related macular degeneration (AMD) on OCT images.Li et al. [9] developed a clinically feasible deep learning system for predicting and stratifying the risk of glaucoma onset and progression from CFP and performed clinical validation in an external population cohort.Peng et al. [8] used Inception V3 [11] and ResNet-50 [12] for CFP image classification for early diagnosis of retinal diseases.However, ophthalmologists often determine the final diagnosis based on multiple modal information [13], as single-modal data only provides limited information, which easily results in biased assistance.
The view of lesions provided by the different data helps ophthalmologists to build a comprehensive perception of the disease situation [13][14][15][16].For example, OCT and CFP are two commonly-applied [17], as shown in Fig. 1.The CFP image provides a 2D projection of the retina with wide field of view, and the OCT captures a 3D cross-sectional view.Several automatic diagnosis algorithms using multi-modal images as input have emerged [14,16,18], including early input concatenation [18] or late features fusion for classification [14,16].Hua et al. [18] proposed a network to predict the severity of diabetic retinopathy (DR) based on the concatenation of the CFP images and swept-source optical coherence tomography angiography images.Wang et al. [16] extracted features CFP and OCT based on two different branches and then classified using their late-fused features.Expanding on the work [16], He et al. [14] applied attention mechanisms to different branches of their model, which is opt to emphasize modality-specific features.Thus, these algorithms tend to treat images of different modalities as independent images, but in fact there is a correspondence between images of the same patient in different modalities.Referring to ophthalmologists' diagnosis based on multi-modal images, before giving the determination, they often repeatedly compare these images and observe their correlations with diseases.Therefore, we propose a new cross-modal retinal disease diagnosis network (CRD-Net), which mines the correlation between multi-modal images for retinal disease diagnosis.To achieve the purpose, we propose a cross modal attention (CMA) module to capture and emphasize the relationships between different modal images.CMA module extracts relevant information while suppressing irrelevant informative features, thereby improving the overall discriminative power of the model.Therefore, the contributions of the paper are listed as: 1. We propose a new cross-modal retinal disease diagnosis network (CRD-Net) for multiple retinal diseases diagnosis based on multiple modal images.
2. We propose a cross modal attention (CMA) module for CRD-Net to capture relationships between different modal images.
3. We propose multiple loss functions based on different modal inputs to provide suitable constraints for the network.

4.
Extensive experiments based on three publicly available retinal disease datasets are conducted to prove the effectiveness of our proposed CRD-Net.

Disease diagnosis based on multi-modal images
The success of deep learning based visual recognition in various applications has stimulated interest in solving retinal disease classification tasks from multi-modal images [14,15,[19][20][21][22][23].
Yoo et al. [20] employed random forests and VGG networks to extract features from OCT and CFP images, and then experimented with feature concatenation to aid in the multi-modal image diagnosis of AMD.Zou et al. [21] introduced a multi-modal evidence fusion pipeline for eye disease screening that provides a single-modal confidence measure and integrates multi-modal information from a multiple distributional fusion perspective.However, this simple concatenation method can result in the loss of relevant information between modalities.To reduce the loss of relevant features, researchers have explored different algorithms.Chen et al. [22] designed a vertical plane feature fusion method for multi-modal fusion to predict AMD disease use infrared reflectance and OCT images.Li et al. [15] proposed a multi-modal multi-instance Learning deep learning framework using CFP and OCT, selectively fusing CFP and OCT modalities feature.Song et al. [23] developed a multi-modal information bottleneck network (MMIB-Net) leveraging information bottleneck theory for feature representation in multiple modalities.Moreover, attention mechanisms are also used to retain modality-specific features for representation.

Attention mechanism
The concept of attention mechanism is inspired by the way humans selectively focus on relevant information while ignoring irrelevant details.The basic idea is to introduce a learnable mechanism in the model to automatically select and weigh important features [24].Most of the existing studies [14,[25][26][27][28][29][30][31] fall into the category of self-attention [14].The self-attention receives a feature map F and outputs the self-attention feature map F SA of the same shape as F. To produce the self-attention feature maps, the input feature maps F are fed into three learnable 1 × 1 convolution operations, respectively to produce three tensors (i.e., queries Q, keys K, and values V).The attention weights are generated by calculating the dot products of the Q with all K.The mathematical representation is as follows: where T denotes the transpose of a matrix, d k represent the feature dimensions.Chen et al. [31] incorporated trained single-modal self-attention into the CNN layers of their multi-modal network for feature extraction and fine-tuning, but this approach necessitates separate pre-training of single-modal models before their integration into the multi-modal framework.He et al. [14] introduced the self-attention mechanism into multi-modal retinal aided diagnosis.They proposed a modality-specific attention network (MSAN) to obtain the specificity of CFP and OCT by applying different self-attention mechanisms to CFP and OCT respectively.However, this method requires huge computational overhead and mainly focuses on the specificity of the retinal modality itself.It lacks the ability to leverage multi-modal relevant information like ophthalmologists do.To solve these limitations, our framework uses cross-modal attention to extract and fuse lesion-related features.

Framework structure
Clinical ophthalmologists typically consider multiple imaging modalities simultaneously and integrate relevant information between modalities to diagnose eye diseases.Our CRD-Net simulates the real-world diagnosis process, as shown in Fig. 2. Firstly, due to the noticeable difference in the visual appearance of CFP and OCT, as shown in Fig. 1 (OCT and CFP), the ophthalmologist separately identify the CFP and OCT images to look for lesion information.We simulate the step with two CNNs to extract instance-level features for the two modalities and map them to specific feature representations F x and F y .Secondly, when ophthalmologists identify a lesion on one of the modes, they carry the information to find the corresponding features on the other modes and integrate them.In the CRD-Net, this process is completed through the CMA module.Finally, the ophthalmologist makes a diagnosis based on the matching features of the lesions on CFP and OCT.Correspondingly, the combination of features from different modalities is fed into the classifier to predict the class of retina-related diseases in an end-to-end manner.
We provide a detailed introduction to each step of the proposed method in the following sections.

Cross modal attention
When clinicians read multimodal medical images, they search for lesion information from the global scale of the image, and then carry the lesion information to another modality to query the corresponding representation.In CRD-Net, the last layer of the backbone produces global-scale feature representations F x and F y .However, F x and F y tend to pay loose attention to important areas at the global scale.Attention mechanisms make the network to focus more on feature related to disease.Therefore, in the CMA module, we first pay attention to the lesion area of a specific modality through the attention mechanism, and then perform cross-modal related feature fusion.
We use multi-head self-attention (MHSA) to focus on disease-related features, while fine-tuning uses unify feedforward network (UFFN) and 1 × 1 convolutional layers.The output results are mapped to F ′ x and F ′ y .Considering the efficiency of this stage, we construct MHSA and UFFN according to [32].This process can be expressed as: where where y are the parameter matrices obtained after convolution operations on F ′ x and F ′ y respectively.The dimensions are d x and d y respectively.a x,y and a y,x represent the outputs of the CFP branch and OCT branch, respectively.

Classifier and objective function
After the lesion correlation features are generated, we concatenate them to obtain the multi-modal features and train a fully connected classifier to make the final prediction.We also predict retinal diseases from single OCT and CFP modality, respectively.Due to the diverse sources of features fed into the input classifier, employ independent loss functions to constrain features from different sources can be beneficial for learning modality-specific characteristics and enhance the synergy between different modalities.For training the three classification branches, we design a multiple loss module in our CRD-Net.The use of multiple loss functions provides comprehensive constraints on the overall optimization process of the network, thereby reduce the impact of a single branch on bias induction.We can represent the whole module as: where L OCT and L Fundus represent the losses of each single branch computed using cross-entropy.L Both represents the cross-entropy loss obtained by fusion both modalities.We add these three losses together and then optimize using stochastic gradient descent (SGD).During the training process, the goal of the entire model is to minimize the primary loss function Γ, which is the sum of the three losses.In the ablation experiment stage, we demonstrated the superiority of this constraint.

Datasets
We evaluate the effectiveness of our method based on three publicly available datasets (MMC-AMD [16,33], APTOS-2021 [34], and GAMMA [35]).Compared to acquiring volume 3D OCT data, obtaining single 2D OCT images is a simpler process.Therefore, the input for the CRD-Net consists of one 2D CFP image with one 2D OCT image.For the selection of 2D OCT images, we take different approaches depending on the datasets.For the MMC-AMD dataset, the 2D OCT images were selected by professional ophthalmologist, focusing on the diseased areas of the patients.This process is specifically described in Ref. [16,33].For the APTOS-2021 and GAMMA datasets, we selected the center frame of 3D OCT volume data as input.For APTOS-2021, the images are collected for some macular related diseases, which is also the focus center while captured, so the center of macular is the main concern of ophthalmologist.For GAMMA datasets, it is related to glaucoma mainly related to optics, but the ophthalmologists often check the condition of macular.Therefore, based on the knowledge discussed with doctors, we selected the center frame of 3D OCT volume as the input of OCT branch in our framework.
The details are listed in Table 1 and illustrated as follows:

MMC-AMD:
The MMC-AMD dataset [16,33] comprises two modalities: CFP and OCT.To adapt to the task of aid diagnosis using multimodal images of the same patient, we organized a total of 768 pairs of samples by associating CFP images with OCT images through matching patient identifiers.This dataset encompasses four categories: normal (195), dry AMD (57), Polypoidal Choroidal Vasculopathy (PCV) (185), and wet AMD (331).We divide the dataset into train and test sets in an 8:2 ratio use random partition.Specifically, 615 pairs of samples are allocated for train, while 153 pairs are designated for test.

APTOS-2021:
The APTOS-2021 dataset [34] is present by the asia-pacific tele-ophthalmology society (APTOS) during the 2021 asia-pacific ophthalmology society's big data competition.Its image modalities include CFP (fundus photography) and OCT images.To ensure that the primary research question is not affected by class imbalance, we select three diseases to validate the performance of our method, with a total of 1,497 pairs of samples.These diseases include wet AMD (596), PCV (406), and diabetic macular edema (DME) (495).Following the official data partition, 1,298 pairs of samples are allocated for training and 199 pairs for testing.
GAMMA: The GAMMA dataset [35] originates from the 2021 MICCAI Contest: GAMMA Challenge Task 1: Multi-modal Glaucoma Grading.The dataset comprises 200 pairs of clinical multi-modal images, of which 100 pairs for training and 100 pairs for testing.It includes CFP and OCT modalities, divided into three categories: Normal, Early Glaucoma, and Advanced Glaucoma.Due to the unavailability of labels for the test set in the challenge, this study solely utilized the training dataset.We reconfigure the dataset into train and test sets at an 8:2 ratio using a random split approach.Specifically, 80 pairs of samples are allocated for train, while 20 pairs are reserved for test purposes.

Implementation details
We implement the models with the PyTorch [36] framework and all experiments are conducted on a Geforce RTX 2080 Ti GPU.The optimizer, weight decay, momentum, epoch, and batch size are SGD, 1e−4, 0.9, 150, and 8, respectively.The initial learning rate is 1e−4 and decreases to 0.1× at epochs 50, and 100.We adopted the same input resolution as in the case of [16] to capture more intricate image details.During the training process, we perform data augmentation strategies on each CFP and OCT image using RandomHorizontalFlip with a probability of 0.5 and RandomRotation within a range of ±30 • .We refer [14] for standardisation and normalisation of the input image.Parameters for the comparative methods are configured based on the implementation details in the respective papers, and the best performance is reported.All experiments used pre-trained weights on IMAGENET1K.

Evaluation metrics
We comprehensively evaluate the performance of our method using six metrics: Recall, Precision (Pre), Specificity (Spe), F1 score, Kappa, and Accuracy (ACC).Sensitivity is equivalent to Recall.The Kappa coefficient and F1 score provide insights into the method's reliability.The notations for these evaluation metrics are as follows: Accuracy = TP + TN TP + TN + FP + FN (7) where TP represents true positives, FP stands for false positives, FN denotes false negatives, TN represents true negatives, a i indicates the actual number of samples in class i, b i refers to the predicted number of samples in class i, C represents the number of classes, and N stands for the total number of samples.Particularly, the multi-class precision, specificity, sensitivity, and F1 results in the following table are the weighted averages across all classes.

Ablation study
This section introduces our ablation experiments on the MMC-AMD dataset from the following four aspects: composition of the CMA module, the position of CMA, the number of loss branches and the different modules of the algorithm.We use multi-modal CNN (MM-CNN) [16] as the baseline model and compare the corresponding model performance by adding different modules at different positions.The "✓" in the following tables indicates that the network contains the corresponding module.
Ablation study of CMA module The CMA module simulates the operation of an ophthalmologist observing multi-modal images through two parts.To facilitate the study of the contributions of the different parts of the CMA module, we packaged the part before the UFFN as the specific modal attention (SMA) module, and the part after the UFFN as the cross modal fusion (CMF) module.Thus, the main components of CMA are SMA and our proposed CMF.The ablation studies about the two elements are analyzed in this section.To verify the validation of SMA and CMF in our CMA, we delete SMA and CMF respectively on CRD-Net and conducted experiments, as shown in Table 2.The first line is the experimental results of the baseline model.
Our SMA enables the network to get performance improvement as shown in the second row of Table 2, which means that our SMA can better focus on lesion features.Our CMF provides the highest precision, as shown in the third row of Table 2, which illustrates the effectiveness of our proposed fusion strategy for cross-modal querying of relevant lesion features.The last line of results shows that our model achieves the best performance on the five metrics with both SMA and CMF.Thus, our CMA module is helpful to dig the relevant features of multi-modal images, which improves the accuracy and stability performance on the aided diagnosis.Ablation study of CMA module locations The attention mechanism has limitations in its inductive bias capabilities.Its performance is affected by whether the input features are fully extracted.Therefore, we analyze the impact of the location of the CMA module on network performance.We conducted two sets of experiments, introducing the CMA module into the third and fourth layers of the backbone respectively.
The experimental results are shown in Table 3, showing the effect of CMA module position on performance.When the CMA module is integrated into the fourth layer (CMA-Layer4), it achieves the best results on the six evaluation metrics.In the same setup, the performance decreases when the CMA module is integrated into the third layer (CMA-Layer3).Due to the inherent limitations of the attention mechanism's inductive bias ability, input feature extraction is relatively underdeveloped, which results in a bias while capturing target features.Therefore, we place the CMA module at the end (Layer4) of the backbone.Ablation Study of multiple losses The objective function plays a key role in the optimization of the algorithm.Our objective function (Eq.( 6)) combines losses from three branches.To explore the best combination of different branch objective functions, we perform ablation experiments by different combinations of L Fundus , L OCT and L Both on the overall structure of the model.We run a total of 4 sets of experiments on single branch loss, double branch loss and triple branch loss.
The results of combining different branching objective functions are shown in Table 4.The first row shows the baseline results.We find that using a combination of L Fundus , L OCT and L Both effectively improves the performance of classification.When only L Both is used, all evaluation metrics exceed the baseline.However, adding either L Fundus or L OCT to L Both leads to a decline in network performance compared to L Both , as shown in the third and fourth row of Table 2. Considering the model structure, adding only L Fundus or L OCT to L Both implies that the network pays more attention to one of OCT or CFP, leading to a bias in network learning.We believe that adding only one loss branch to L Both may distort the learning.When combining L Fundus , L OCT and L Both the network gets the best results on the five metrics.It indicates that the combination of the three loss functions provides the most suitable constraints for the learning process of the network.Ablation Study of Modules Finally, we analyse the effectiveness of our proposed modules applied in our network.We conduct ablation study experiments by adding CMA and multiple loss functions to the baseline model, the results shown in Table 5.When CMA or multiple losses are independently added to the baseline model, the model performance declines.However, when the two modules act together, the model performance reaches its optimum.This suggests that the CMA module requires appropriate constraints for capturing cross-modal lesion-related features.On the other hand, when the objective function collaborates with the CMA module, all six metrics achieve their optimal scores, indicating the overall effectiveness of the model structure.
The experimental results based on three datasets are shown in Table 6, 7, 8, respectively.The advantages of different methods vary on different datasets, and our method has achieved the best results on the three datasets.Table 6 shows the comparison results on the MMC-AMD dataset, which shows that the same method of using OCT as input generally yields better classification results than using CFP.For AMD disease grading, OCT images are the gold standard for diagnosis.Our CRD-Net relies on capturing the relevant features of lesions on multi-modal images and achieves the best classification results: ACC of 85.62%, precision of 81.77%, sensitivity of 87.75%, specificity of 95.11%, F1 score is 84.66%, and kappa is 79.57%.It outperforms OCT-VIT with ACC score of approximately 83.01%.We use multi-modal inputs in our model, which can make more accurate diagnoses by focusing on the correlation of diseases under different modalities.The same situation also occurs in the APTOS-2021 dataset, as shown in Table 7.The situation changes on the GAMMA dataset, where the use of CFP as input generally obtains better classification results than the use of OCT as input, as the change in the optic disc from fundus images can more likely predict the onset of glaucoma than in OCT images.Even in such case, our method still achieves the best performance, proving the robustness of our method.By comparing the performance of different methods on GAMMA and MMC-AMD datasets, for ophthalmic disease diagnosis, different modal images hold their specific advantages aided for the diagnosis, which further proves the necessity of using the correlation of multi-modal images aiding for disease diagnosis.To further analyze the classification results, we also list the detailed confusion matrices and the class activation mapping (CAM) visualisation by the Grad-CAM method [43] of our algorithm and other comparison algorithms.As shown in Fig. 3, 4 and 5, the overall number of correctly classified images based on CRD-Net is higher than other methods.Our classification for each class superiors to most of other methods without bias.Figures also tell us that our model achieves more comprehensive and robust classification for all classes, as we adopt the correlation between multi-modal images.In the context of medical diagnosis, the interpretability of model decisions is crucial.We try to further understand the mechanism of the network through the result of Grad CAM.The reddish areas mean the highest contribution to the classification followed by yellowish pixels while the bluish ones contribute the lowest to the classification.As shown in Fig. 6, the model exhibits strong attention (red) towards the diseased areas.Additionally, the model attends to corresponding regions of the disease in both modalities.This shows that the multi-modal images inputted into the model contribute to network inference and the multi-modal clinical knowledge aids in the network's decision-making process.Therefore, our algorithm achieves a promising performance.
The results show that proposed method has the smallest sample error and achieves the best performance in classification compared to other methods.

Discussion
There are inherent advantages to using a single-modal to design an aided diagnostic algorithm.It is often easier to obtain single-modal images of the same patient than paired multi-modal images.
Although aided diagnosis methods based on single-modal medical images have achieved good performance, in clinical practice, doctors can be more confident in their diagnosis by referring to the patient's different modalities (e.g.CFP and ophthalmic OCT).Therefore, the design of aided diagnostic algorithms based on multi-modal images has attracted much attention.The same lesion is often correlated with different modal medical images with different expressions.It has become a new challenge to design an algorithm that focuses on the information related to retinal diseases in different modalities.This work focuses on using multi-modal images of patients, designing a cross-modal retinal disease diagnosis network that can focus on relevant information between different modalities, and using multiple losses to help the model learn better features.However, although our method achieves good classification results on multiple clinical datasets, the limitations of our method still exist.Doctors often fully consider whether patients suffer from multiple diseases simultaneously to avoid the missed diagnosis.However, the existing algorithm cannot consider this factor, as we cannot obtain such related data for validation, which may limit the application scenarios of our algorithm.Due to the use of publicly datasets in our research, patient demographic information was not disclosed; thus, we did not discuss patient demographics in the paper.We will continue to pay attention to the dynamics of these datasets, and if relevant information is disclosed in the future, we can conduct some experimental discussions on this.

Conclusion
We introduced an CRD-Net algorithm that integrates the CMA module with multiple loss functions to achieve comprehensive optimization.Extensive experiments proves the effectivity of our proposed module and loss function.Comparison experiments confirms that the relevant features from different modal images are helpful aided diseases diagnosis.

Fig. 1 .
Fig. 1.Examples of CFP and OCT images.CFP provides a two-dimensional projection view of the retinal, and OCT provides detailed images of changes in different layers of the retina.

Fig. 2 .
Fig. 2. CRD-Net architecture for multi-modal retinal image classification.Feature extraction is performed through two independent CNNs.Then cross-modal feature interactive fusion through CMA.Finally, the interactive features are connected and a fully connected layer is used to predict different categories.We use MHSA to denote multi-head self-attention and UFFN for unified feedforward networks.
denotes F x and F y , f ′ represents the intermediate feature between MHSA and UFFN, and F ′ represents F ′ x and F ′ y .After obtaining the modal-specific features F ′ x and F ′ y , we propose the modal interaction part of cross-modal attention, based on the operation of multimodal image observation by ophthalmologists in the previous article.Specifically, in the CFP branch, F ′ x is mapped to key and value, and the feature F ′ y from the OCT branch is mapped to query.Since F ′ x and F ′ y contain modality-specific lesion information, by calculating the attention of F ′ x on F ′ y , we can get the lesion features associated with OCT on the CFP branch.In the OCT branch, F ′ y is mapped to key and value, and the feature F ′ y from the CFP branch is mapped to query.Since F ′ y and F ′ x contain modality-specific lesion information, by calculating the attention of F ′ y on F ′ x , we can get the lesion features associated with CFP on the OCT branch.In addition, we use residual connections in the module to ensure the efficiency of network propagation.The representation of the entire block is:

Fig. 3 .
Fig. 3.The confusion matrices of results on the AMD-MMC dataset.The four diseases normal, dry_AMD, PCV, and wet_AMD are represented by I, II, III, and IV, respectively.

Fig. 4 .
Fig. 4. The confusion matrices of results on the APTOS-2021 dataset.The Wet_AMD, PCV, and DME are represented by I, II, and III, respectively.The results show that proposed method has the smallest sample error and the best classification performance compared to other methods.

Fig. 5 .
Fig. 5.The confusion matrices of results on the GAMMA dataset.The Normal, Early Glaucoma, and Advanced Glaucoma are indicated with I, II, and III, respectively.

Fig. 6 .
Fig. 6.Class activation map (CAM) visualization of the proposed CRD-Net method.The red area contributes the most to the classification, followed by the yellow pixels, while the blue area contributes the least to the classification.The results show that CRD-Net focuses well on the corresponding lesion locations of CFP and OCT.