Student becomes teacher: training faster deep learning lightweight networks for automated identification of optical coherence tomography B-scans of interest using a student-teacher framework

: This work explores a student-teacher framework that leverages unlabeled images to train lightweight deep learning models with fewer parameters to perform fast automated detection of optical coherence tomography B-scans of interest. Twenty-seven lightweight models (LWMs) from four families of models were trained on expert-labeled B-scans ( ∼ 70 K) as either “abnormal” or “normal”, which established a baseline performance for the models. Then the LWMs were trained from random initialization using a student-teacher framework to incorporate a large number of unlabeled B-scans ( ∼ 500 K). A pre-trained ResNet50 model served as the teacher network. The ResNet50 teacher model achieved 96.0% validation accuracy and the validation accuracy achieved by the LWMs ranged from 89.6% to 95.1%. The best performing LWMs were 2.53 to 4.13 times faster than ResNet50 (0.109s to 0.178s vs. 0.452s). All LWMs benefitted from increasing the training set by including unlabeled B-scans in the student-teacher framework, with several models achieving validation accuracy of 96.0% or higher. The three best-performing models achieved comparable sensitivity and specificity in two hold-out test sets to the teacher network. We demonstrated the effectiveness of a student-teacher framework for training fast LWMs for automated B-scan of interest detection leveraging unlabeled, routinely-available data.

Despite these challenges, researchers are continuing with their efforts in applying artificial intelligence approaches in automated disease detection. Deep learning is particularly effective for this task, as these models can automatically identify features associated with a given diagnosis without requiring expert interpretation. Indeed, deep learning models have already been developed to identify pathologies such as age-related macular degeneration [6], macular edema [7], glaucomatous optic neuropathy [8], and diabetic retinopathy [9]. Recent studies have demonstrated the potential feasibility of automated multiple disease classification of OCT images using deep learning, as well as the challenges associated with this task [10][11][12]. Additionally, a deep learning system has been developed to identify OCT B-scans of interest, using a binary classification approach where the abnormal class was composed of multiple retinal pathologies [13,14].
Transfer learning, in which an existing classification model pre-trained on a very large, generic image dataset is fine-tuned for a specific medical imaging task, is the usual approach for developing these models. Recently, however, the utility of transfer learning for medical imaging has been called into question given the differing image statistics of medical images and the relatively few number of classes compared to typical object recognition tasks [15]. The state of the art deep learning models developed for object recognition typically have large numbers of parameters and require significant training time [12,16], and may be overparameterized for the task of identifying one or more pathologies on retinal OCT images [15]. Many clinical applications demand a model that performs quickly and is compatible with portable devices, as is the case for the ever-growing field of mobile health (mHealth) [17].
Lightweight models (LWMs), those with fewer parameters and operations, can provide potential advantages over large deep learning models and transfer learning approaches [15,18,19]. While many lightweight models have been shown to perform comparably to much larger models [15,[19][20][21], training these models may require specialized approaches to achieve comparable accuracy to large, state-of-the-art models. Knowledge distillation enables improved training of small (student) models by distilling knowledge from a larger (teacher) network, also known as "student-teacher framework [22]." There are various techniques employed to achieve knowledge distillation. One method achieves distillation through semi-supervised learning where a larger network generates labels for unlabeled data that are used to train a smaller network [22,23]. Another method is soft target training where the teacher [24] generates soft target labels for the training data and these soft targets are used in combination with one-hot encoded labels (hard targets). Combining the concept of semi-supervised learning and soft-target learning, unlabeled images can be used to distill knowledge from the teacher network to the student networks through soft-target learning. In the case of multiclass classification, multiple specialist networks can be ensembled and their knowledge distilled to a student network [24].
In this study, we explore a student-teacher framework to perform semi-supervised learning with the goal of training LWMs to perform fast automated detection of OCT B-scans of interest. We take advantage of widely available unlabeled images, with the goal of developing a model that can provide fast and accurate results, suitable for a clinical setting.

Optical coherence tomography dataset
A total of 598 OCT macular volume cube scans (76,396 B-scans) were retrospectively obtained from nine clinical sites in the United States, Germany, Portugal and Singapore. All macular volume cube scans (512 × 128) were acquired using the CIRRUS 4000 and 5000 (ZEISS, Dublin, CA). Five hundred ninety-eight unique patients were included with a variety of retinal conditions. Two retina specialists first labeled each B-scan for image quality as "gradable" or "ungradable", resulting in the exclusion of 148 B-scans. Then the specialists annotated the remaining 76,396 B-scan images based on 8 different pathologies: intraretinal fluid, subretinal fluid, disruption of inner retinal layers, disruption of the vitreoretinal interface (VRI), retinal pigment epithelium (RPE) atrophy, RPE elevation, inner segment (IS)/outer segment (OS) disruption, or other retinal pathologies. A B-scan was labeled abnormal if at least one specialist annotated at least one of the 8 pathological categories for the image. The images were then split at the subject level into training (80%) and validation (20%) data sets. This resulted in a training set with 36,223 normal OCT images and 24,835 images of interest from 478 patients, and a validation set of 9,475 normal and 5,863 images of interest from 120 patients. In addition to the labeled images, a set of 478,588 unlabeled images from 3,148 patients collected using the CIRRUS 4000 and 5000 devices from multiple sites was available. See Table 1 for a summary of the number of patients and images. Two hold-out test sets were set aside for the final performance testing of the top-performing lightweight models; both the sets were collected from clinical sites that were not used for gathering training or validation data. Hold-out test set-1 contained 200 OCT macular volume cubes (25,600 B-scan images) retrospectively obtained from three clinical sites in the United States and Austria. All the OCT cubes in this hold-out data were acquired using the CIRRUS 4000 and 5000 (ZEISS, Dublin, CA) from three sites. Two retina specialists labeled each B-scan for the same 8 different retinal pathologies and a B-scan was labeled abnormal if at least one specialist annotated at least one of the pathologies in that image. The analysis of the annotations showed excellent agreement between the two specialists (Cohen's kappa = 0.8922). Hold-out test set-2 contained 225 OCT macular volume cubes (28,800 B-scan images) retrospectively obtained from six clinical sites in the United States and Austria. All the OCT cubes in this hold-out data were acquired using the CIRRUS 6000 (ZEISS, Dublin, CA) from seven sites. Three optometrists labeled each B-scans for the same 8 different retinal pathologies and a B-scan was labeled abnormal if at least 2 optometrists labeled at least one of the pathologies in that image. The analysis of the annotations showed moderate to substantial agreement between the three optometrists (Cohen's kappa between optometrist 1&2 = 0.4859, Cohen's kappa between optometrist 2&3 = 0.7992 and Cohen's kappa between optometrist 1&3 = 0.5406).

Teacher network training
The deep learning architecture for the automatic detection of B-scans of interest was developed using a ResNet-50 network. A 3-channel ResNet-50 [25] neural network architecture was modified by adding inverted drop-out followed by Softmax activation. The training set was resized to 224 x 224 and augmented using rotation, horizontal flip, and vertical shift to form a set of B-scans for training. The modified 3-channel Resnet-50 was pre-trained on ImageNet images [26] and was transfer trained by unfreezing all the layers with resized B-scans. Binary cross entropy was used as loss function and stochastic gradient descent with initial learning rate ranging from 1e-2 to 1e-6 with momentum 0.9 was used as optimizer. The model was trained with a batch size of 64 and for 31 epochs.

Student network training
We selected 4 general families of lightweight deep learning architectures for our student networks: SqueezeNet [27], SqueezeNext [28], MobileNet [29][30][31], and ShuffleNet [32,33]. SqueezeNet was shown to obtain AlexNet-level accuracy on ImageNet with 1/50 the parameters, achieved by replacing 3 × 3 convolutions with 1 × 1 convolutions, decreasing the channels input to the 3 × 3 convolutions in "squeeze layers", and downsampling late in the network so that convolution layers have large activation maps. SqueezeNet can be implemented with residual connections; this modification yields the SqueezeResNet. SqueezeNext achieved AlexNet's accuracy with only 0.5 million model parameters (approximately 1/112 the parameters of AlexNet). SqueezeNext uses a SqueezeNet architecture as a baseline with the following changes: significantly reducing the total number of parameters used with the 3×3 convolutions, using separable 3 × 3 convolutions to further reduce the model size and removing the additional 1×1 branch after the squeeze module, and using an element-wise addition skip connection similar to that of ResNet architecture. The MobileNet model was shown to perform comparably to larger models, GoogleNet and VGG16, and outperform SqueezeNet in ImageNet classification. It utilizes depth-wise separable convolutions by factorizing a standard convolution into a depthwise convolution and a 1×1 "pointwise" convolution. The original MobileNet model can include faster down-sampling yielding the Fast MobileNet model. ShuffleNet utilizes two new operations, pointwise group convolution and channel shuffle, to greatly reduce computation cost while maintaining accuracy on ImageNet. ShuffleNet was shown to be approximately 13 times faster than AlexNet while maintaining similar accuracy, and performs comparably to MobileNet with a moderate speedup. A total of 27 LWMs from these four general architecture families were selected as the student networks. Details of the individual student models, including hyperparameters, can be found in Supplementary Table 1.
For all models, binary cross entropy was used as the loss function. Adam [34] was used as an optimizer and the batch size was set to 64 and models were trained for 20 epochs. A grid search was performed for learning rate, initial learning rates ranged from 0.1 to 1e-6. The maximum validation accuracy was calculated over epochs and learning rates for each model. The average inference time across all the full 128 B-scan volumes (n=208) was computed on an NVIDIA Tesla P100 graphics processing unit.

Student-teacher framework -semi-supervised learning
The pretrained ResNet50 model (described in Section 2.2) was used as a teacher network for the student (lightweight) networks. Figure 1 describes the student-teacher framework. Inference with the teacher network was performed on the 478,588 unlabeled images and a label was assigned to each image as the class with the highest value from the softmax in the ResNet50 model, generating hard-targets for the unlabeled images. These images and the inferred labels were then pooled with the expert-labeled images and this pooled dataset was used to train the student models in the student-teacher framework. Additionally we explored the effect of the number of unlabeled images available for student-teacher training on the validation accuracy by subsampling the unlabeled image set at 25%, 50%, 75%, and 100%, where 0% is the baseline performance with only labeled images. The same training hyperparameters and grid search for optimal learning rate were used as in Section 2.3. The maximum validation accuracy was calculated over epochs and learning rates.

Alternative student-teacher framework -soft-target training
In this framework, the pretrained ResNet50 model was also used as a teacher network for the student (lightweight) networks. But rather than learning with hard targets alone, a combined loss function was used for soft and hard targets [24]. The soft targets were generated by performing inference on all the images (labeled or unlabeled) with a temperature raised softmax [24]. The The lightweight models were first trained on only the unlabeled images used to train the teacher network (ResNet50). The lightweight models were then trained with labeled and unlabeled images in the studentteacher framework and the three best performing models were selected. After model architecture search, six models were evaluated on the two hold-out test sets. B. Flow diagram of student-teacher framework. Labels for the unlabeled images are inferred (y_pseudo_unlabeled) by the teacher network (top network). The unlabeled images are combined with the expert-labeled images to train the lightweight student networks (bottom network). The inferred labels for the unlabeled images (y_pseudo_unlabeled) and the human-graded labels (y_true_labeled) are used in a binary cross entropy loss to train the lightweight networks (bottom network). The yellow arrows denote training with labeled images and the purple arrows denote the training for unlabeled images. temperature (T) and alpha are hyperparameters that must be tuned. We selected the model that benefited most from the student-teacher learning described in Section 2.4 and performed the soft-target training for a number of temperature and alpha values. After selecting the temperature and alpha that yielded the highest validation accuracy, we performed the soft-target learning with the labeled and 100% of the unlabeled images. The maximum validation accuracy over epochs was reported. We refer to this framework as the soft-target student-teacher framework.

Evaluation on the hold-out test sets
The three top-performing models were selected based on the maximum validation accuracy achieved with the student-teacher framework. These three models with and without studentteacher training (n=6) were evaluated on the two hold-out test sets. Sensitivity and specificity were calculated on the validation set and both hold-out test sets and 95% confidence intervals were calculated with clustered bootstraps.

Technical details
All analyses were performed using Python (v). Deep learning models were developed using Keras (v), Tensorflow (v), accelerated using NVIDIA CUDA (9.0.333), and trained on a server with dual Xeon 3.4 GHz processors, 256 GB of random access memory, and 8 x NVIDIA P100 GPUs. All LWMs were implemented in Keras adapted from code available from https://github.com/osmr/imgclsmob/tree/master/keras.

Student networks
The inference time versus baseline accuracy tradeoff for the teacher and the student models is presented in Fig. 2. The teacher network (ResNet50) achieves 96.0% validation accuracy with an average inference time of 0.452 seconds. The student models range from 89.6% to 95.1% validation accuracy and average inference time ranges from 0.066 to 0.261 seconds. The best performing (optimizing for inference time and validation accuracy) lightweight model is one of the SqueezeNet models with residual connections (SRN.1). SRN.1 has 95.1% validation accuracy and is 4.13 times faster than ResNet50 with only 1/32th of the parameters.  Table 1 for more details on abbreviations used.

Student-teacher framework
All student models benefit from training with the additional data provided by the unlabeled images and their inferred labels generated by the teacher network. Figure 3 demonstrates the gradual increase in validation accuracy observed when training with the labeled images and 0%, 25%, 50%, 75%, and 100% of the unlabeled images. Using the student-teacher framework images, the validation accuracies are boosted to between 94.8% and 96.3% looking across all student models. SRN.1, SN.1, SQN.1.3, and SQN.2.3 achieve a validation accuracy of 96.1%, slightly exceeding the validation accuracy of the teacher network. One of the MobileNet models, M.2.1, also narrowly beats the teacher network (ResNet50) with a validation accuracy of 96.3%. SRN.1, SN.1, and M.2.1 were selected as the best performing models as they balance low inference time with high validation accuracy. Figure 4 provides a different view of the results in Fig. 3 by plotting the percentage of unlabeled images used versus validation accuracy. This perspective allows the appreciation of the increase in validation accuracy as more unlabeled images are used. The models, in general, see a more pronounced increase at 25% of the unlabeled images and then a gradual increase after. The 95% confidence intervals demonstrate that the lightweight models benefit from the addition of unlabeled images at all percentages. The soft-target student-teacher framework did not yield a similar increase in validation accuracy as compared to the semi-supervised learning student-teach framework (Supplementary Table 2) for the MobileNet model (M.3.4) tested. This model was selected because it had the largest gain in validation accuracy with student-teacher training. The baseline validation accuracy using only the labeled data for training is 89.6%, and when using the soft-target training the validation accuracy rises to 90.3% (T=3, alpha=1.5), far less than the increase observed with the student teacher framework (to 95.1%). Likewise, the soft-target student-teacher framework (T=3, alpha=1.5) with labeled and unlabeled data achieves a lower validation accuracy of 94.4%.

Performance on hold-out test sets
The ResNet50 teacher model and the three best-performing lightweight models (SRN.1, SN.1, and M.2.1) trained with and without the student-teacher framework were evaluated on two hold-out test sets (

Discussion
In this work, we have demonstrated that lightweight deep learning networks designed for use on mobile devices can be applied to a medical imaging classification task. In particular, we have shown that four general families of LWMs, SqueezeNet, SqueezeNext, ShuffleNet, and MobileNet can be used to perform B-scan of interest identification with retina OCT images. Using a student-teacher framework, we demonstrate that all lightweight, student networks have improved validation accuracy with the addition of teacher-generated labeled data. Most notably, several models are able to perform comparably to the state-of-the-art, large teacher model (ResNet50) and run in a fraction of the time, making the LWMs ideal candidates to run in clinical settings where runtime is important and computing resources are often limited. The student-teacher framework is an appealing paradigm for semi-supervised learning in medical image classification as the generation of expertly-labeled image datasets is resource intensive.
All the lightweight models benefited from the student-teacher training, with some models gaining as much as 4-5% in validation accuracy. The best-performing lightweight models, two SqueezeNet models (SRN.1 and SN.1) and one MobileNet model (M.2.1), and the ResNet50 teacher network were evaluated in two hold-out test sets. The results from both hold-out test sets demonstrate that the lightweight models when trained with the studentteacher framework perform comparably to the ResNet50 model on the hold-out test sets. In both hold-out test sets, the lightweight models trained without the student-teacher framework

Discussion
In this work, we have demonstrated that lightweight deep learning networks designed for use on mobile devices can be applied to a medical imaging classification task. In particular, we have shown that four general families of LWMs, SqueezeNet, SqueezeNext, ShuffleNet, and MobileNet can be used to perform B-scan of interest identification with retina OCT images. Using a student-teacher framework, we demonstrate that all lightweight, student networks have improved validation accuracy with the addition of teacher-generated labeled data. Most notably, several models are able to perform comparably to the state-of-the-art, large teacher model (ResNet50) and run in a fraction of the time, making the LWMs ideal candidates to run in clinical settings where runtime is important and computing resources are often limited. The student-teacher framework is an appealing paradigm for semi-supervised learning in medical image classification as the generation of expertly-labeled image datasets is resource intensive.
All of the lightweight models benefited from the student-teacher training, with some models gaining as much as 4-5% in validation accuracy. The best-performing lightweight models, two SqueezeNet models (SRN.1 and SN.1) and one MobileNet model (M.2.1), and the ResNet50 teacher network were evaluated in two hold-out test sets. The results from both hold-out test sets demonstrate that the lightweight models, when trained with the student-teacher framework, perform comparably to the ResNet50 model on the hold-out test sets. In both hold-out test sets, the lightweight models trained without the student-teacher framework perform similarly in terms of specificity to the other models, but have lower sensitivities, demonstrating the benefit of harnessing the unlabeled image for training the lightweight models. The sensitivity and specificity were lower for all models for the CIRRUS 6000 hold-out test set as compared to the CIRRUS 4000/5000 hold-out test set. This is not surprising as the models were trained with data from CIRRUS 4000/5000 instruments, but does demonstrate that deep learning models do have difficulty generalizing across imaging platforms. Additionally, the Cohen's kappa between raters was lower for the CIRRUS 6000 hold-out test set, suggesting that the ground-truth labels were noisier than the CIRRUS 4000/5000 hold-out test set. We plan to investigate increasing generalization by utilizing unlabeled B-scans from the CIRRUS 6000 in our student-teacher framework in future work.
We found the soft-target student-teacher framework was less effective than the semi-supervised approach for this binary classification problem. This is likely due to the fact that with only two classes there is little information added by encouraging the student model to mimic the teacher's behavior on both the soft and hard targets. It could be argued there are no "off-target" labels in the binary case, as binary cross entropy yields a "probability" p for one class and (1-p) for the other class. The student-teacher framework with unlabeled images presents a way to utilize the many unlabeled images available in medical imaging. We did not explore the size of the labeled image dataset, as we treated the teacher network as a fixed model. In future studies, it would be illuminating to explore the minimal requirements for the labeled imaging set size.
There are relatively few examples of LWMs being applied to the classification of medical images, but in general they have been shown to perform comparably to or better than larger, state-of-the-art models. Previously a lightweight deep learning model was trained to detect the 12 views of transthoracic echocardiography with a student-teacher framework with soft-target training [20]. The LWM had comparable performance to the best available deep learning architectures, but with 1% of the number of parameters and a 6X faster inference time. In addition, a lightweight model has been shown to be effective in detecting fundus image quality and was fast enough to run on mobile devices [18]. Two LWMs were shown to perform comparably to both a ResNet50 and an Inceptionv3 network on disease classification using chest x-ray and fundus photo images [15]. The authors suggest that many state-of-the-art algorithms designed for object recognition may be overparameterized for medical imaging classification tasks. A LWM (with only 6.9% of parameters compared to the ResNet-50 model) outperformed the ResNet-50 model on a multiclass classification task on OCT images [35], and a LWM trained to predict brain age from magnetic resonance images outperformed state-of-the-art, heavyweight models [19]. Most recently, lightweight models have been able to detect COVID-19 infections from chest X-ray images with high accuracy [21,36,37].
A popular approach to increasing the performance of deep learning classifiers is to use transfer learning [12,38], in which a large multiclass image classifier is pretrained on a natural image dataset and then retrained on the smaller medical imaging dataset of interest. The resulting model has many fewer parameters that need to be trained than the original model and is "fine-tuned" to the medical imaging task. There may be drawbacks to this approach, as large image classifier models such as ResNet are designed for many more classes than are typical for medical imaging tasks, and medical imaging pathologies often contain unique features that differ from standard natural images [15]. Additionally, most, if not all, natural image datasets are RGB images, while many medical images are grayscale, thus requiring the medical images to be duplicated across channels. The student-teacher approach avoids these drawbacks. The large teacher model, although similar to those used in transfer learning, is not retrained. Instead, the teacher model passes its classification knowledge to the lightweight student model by providing "examples'' in the form of labeled training data. The student model is only trained on images that are specific to the task. Additionally, the expansion of the training examples available to the student networks should prevent overfitting and result in more generalized models.
There have been recent publications exploring knowledge distillation for classification, semantic segmentation, and instance segmentation with medical images. Knowledge distillation was demonstrated to outperform transfer learning in a diabetic retinopathy classification task using color fundus photography [39]. Multiple teacher networks were used to train a single student network to perform a multi-task classification with color fundus photographs [40]. A Multiple-Instance Learning (MIL) model combining pseudo-labeling via a mean-teacher was effective at detecting intraretinal fluid, subretinal fluid and pigment epithelial detachments while cutting the requirement for expensive labels by 94.22% [41]. Likewise, using limited labeled data, retinal layers were segmented from OCT images using a student-teacher framework with performance comparable to a model trained in a fully-supervised fashion [42]. A chain of student networks, each distilling knowledge to the next, was shown to need only 0.5% of the labeled data available to accurately detect colorectal cancer from histology slides [43]. Additionally enforcing consistency across perturbations with unlabeled data has been shown to improve multi-class classification results with chest X-ray [44] and instance segmentation results from microscopy and computed tomography images [45].
As discussed above, one limitation of our study is that we did not investigate reducing the number of labeled images because the labeled images went into training the teacher ResNet50 model as well as training the LWMs. A follow-up study with two sets of labeled images, one used for teacher training and one used for student training, would be an interesting paradigm in which to investigate the number of labeled images needed to achieve comparable accuracy. It is difficult to compare the models as they each have their own architectures and were only tested on one binary classification task; more research is warranted before making a more generalized statement about model superiority. A drawback to knowledge distillation is that there is the potential that the student network can learn to make systematic errors from the teacher model, so-called 'confirmation bias'. A recent learning framework was proposed to reduce confirmation bias by doing away with a teacher model and rather distilling knowledge between student networks [46]. Future directions of this work are to explore ensembling multiple teachers, deep supervision, and teacher-free knowledge distillation.
This work represents a divergent perspective on training lightweight deep learning models for medical imaging classification tasks. Expertly-labeled images are expensive to obtain; we exploited a large number of unlabeled images using a student-teacher network to perform semi-supervised learning. In the end, we were able to achieve comparable performance for fast LWMs as compared to the relatively slow and heavyweight ResNet50 model. The framework presented here could expand the horizons for running medical imaging classification algorithms on imaging instruments or mobile devices.