CoT-UNet++: A medical image segmentation method based on contextual transformer and dense connection

: Accurate depiction of individual teeth from CBCT images is a critical step in the diagnosis of oral diseases, and the traditional methods are very tedious and laborious, so automatic segmentation of individual teeth in CBCT images is important to assist physicians in diagnosis and treatment. TransUNet has achieved success in medical image segmentation tasks, which combines the advantages of Transformer and CNN. However, the skip connection taken by TransUNet leads to unnecessary restrictive fusion and also ignores the rich context between adjacent keys. To solve these problems, this paper proposes a context-transformed TransUNet++ (CoT-UNet++) architecture, which consists of a hybrid encoder, a dense connection, and a decoder. To be specific, a hybrid encoder is first used to obtain the contextual information between adjacent keys by CoTNet and the global context encoded by Transformer. Then the decoder upsamples the encoded features by cascading upsamplers to recover the original resolution. Finally, the multi-scale fusion between the encoded and decoded features at different levels is performed by dense concatenation to obtain more accurate location information. In addition, we employ a weighted loss function consisting of focal, dice, and cross-entropy to reduce the training error and achieve pixel-level optimization. Experimental results demonstrate that the proposed CoT-UNet++ method outperforms the baseline models and can obtain better performance in tooth segmentation.


Introduction
With the progress of society and the improvement of living standards, people are increasingly aware of oral health care.More people actively carry out dental treatment, such as orthodontics and dental implants, to ensure the normal function of teeth and improve facial appearance [1,2].Orthodontic treatment not only beautifies the alignment of teeth but also affects the related soft tissues, and improves the aesthetic appearance of the patient's face.In addition, severe endodontic and periapical diseases need to be treated with the help of dental implant technology.During dental treatment, 3D cone-beam computed tomography (CBCT) images are commonly used to assist in diagnosis.It can fulfill the dentist's need to observe the root alignment of the patient's teeth and can provide comprehensive 3D information on the complete tooth.Traditional dental treatment relies on the subjective experience of the doctor to analyze the spatial alignment and morphology of the patient's crowns, root canals, and alveolar bone through CBCT images.The process of manually measuring teeth in CBCT images can consume a lot of time and energy for the dentist.Therefore, a tooth segmentation algorithm is expected to be designed and used in the clinic.The input of a patient's CBCT image can automatically segment the tooth part precisely.
Nowadays, artificial intelligence and deep learning have gained breakthroughs and become the focus of public attention.Computer vision is the direction with the longest research history and the most accumulated technology in artificial intelligence, and it has shown strong applications [3][4][5][6][7][8][9][10][11][12] in various fields such as image classification, image segmentation, super-resolution, and human face, and many problems of medical image segmentation have been solved.Therefore, we applied computer vision to dental processing to achieve high-precision segmentation of CBCT dental images and to exploit its value in the clinic.
Among the segmentation methods for medical images, U-Net and its variants consisting of an encoder-decoder with skip connections have shown good segmentation performance [13][14][15].The skip connection is able to combine coarse-grained and fine-grained feature maps for better segmentation of images.Based on this approach, great success has been achieved in a wide range of medical applications, such as cardiac segmentation by magnetic resonance (MR) [16], organ segmentation by computed tomography (CT) [17][18][19], and polyp segmentation [15].Based on this structure, various methods such as UNet + + [20], UNet 3 + [21], DenseUNet [17], and KiU-Net [22] were subsequently proposed and have also been successful in the segmentation of various medical imaging datasets.UNet++ redesigns the fusion scheme of different scale features based on UNet to achieve highly flexible feature fusion.Although the CNN-based methods have shown excellent performance, UNet has difficulty in handling long-range and global semantic information, especially when there is low contrast between the organ and the environment, due to the intrinsic local nature of the convolution operation.Therefore, researchers have built a global context model by using a self-attention mechanism [23,24], but still cannot fundamentally solve the above problem.
Many traditional tooth segmentation algorithms have been proposed, which reflects the importance of these applications.Previous approaches used region growth [25], level sets boosted variants [26,27] and statistical shape models [28,29], but these methods suffered many failures.A new region candidate network [30] can effectively remove duplicate candidates and speed up the training, but also ends up with some erroneous segmentation.
In recent years, the Transformer architecture, which has achieved great success in the field of Natural Language Processing (NLP), has been gradually applied to the field of Computer Vision (CV).Designed for sequence-to-sequence prediction and relying entirely on self-attentive mechanisms, it has shown power in modeling global context, while conversely lacking attention to low-resolution information in images.TransUNet [31], as a classical hybrid network, combines the advantages of UNet and Transformer, not only has adopted global information, but also can extract low-level information to compensate for fine details, and has shown strong performance in medical image segmentation.After this, new architectures combining CNN and Transformer modules have been continuously developed [32][33][34][35][36][37][38].However, these traditional self-attention methods only use query-key pairs to compute the attention matrix without fully considering the rich context between keys, and Li et al. [39] proposed the CoTNet to address this problem.The CoTNet is obtained after replacing the 3 × 3 convolution in ResNet with the CoT block, which achieves good performance in different tasks such as target detection and instance segmentation.
Inspired by the CoT block of the CoTNet [39] and UNet++ [20], we propose CoT-UNet++, which fully uses contextual information between keys.Moreover, we adopt a densely connected fusion scheme.This learns not only local contextual information, but also adopts feature maps at all scales of the encoder, and obtains accurate localization information.In summary, the contributions of this paper are as follows:  A CoT block is utilized in the encoder section to make full use of the contextual information between keys, thus enhancing visual representations. A dense connection is introduced to leverage multi-scale features that combine low-level detail and high-level semantics in full-scale feature maps. A weighted loss function consisting of cross-entropy loss, Dice loss, and focus loss is used to reduce the training error of medical image segmentation and achieve pixel-level optimization.

General architecture of CoT-UNet++
The overall structure of CoT-UNet++ is shown in Figure 1, which mainly consists of a hybrid encoder, dense connection, and a decoder.
As can be seen in Figure 1(b), CoT-UNet++ is a tightly mixed CNN and Transformer model that first feeds the input image into the hybrid encoder and uses CoTNet-50 to extract features with different resolutions.Part of the structure in CoTNet-50 is shown in Figure 1(a), and the specific structure is described in Section 3.2.2.The CoT block makes full use of the contextual information between keys and can facilitate self-attention learning.In the final stage of CoTNet-50, the feature map is mapped into the form of a sequence and the embedded position information is fed into the Transformer encoder for multi-head self-attention learning, which achieves long-range dependency to obtain global encoded features.The transformed encoded features are then reshaped into feature maps for decoding, and a cascaded upsampler is used to decode the low-resolution features, with each upsampling block consisting of 3 × 3 convolution and ReLU layers.
To achieve flexible feature fusion, a dense connection is applied to fuse both the feature maps of the encoder output and intermediate nodes with the ones obtained by decoding, which can combine multiple low-level detailed feature maps from the encoding side with the feature maps containing high-level semantic information from the decoding side.The intermediate nodes use a multi-scale combination strategy of 1 × 1 and 3 × 3 convolution to capture more different local features.

Multi-head self-attention
The traditional self-attention mechanism first transforms the input 2D feature mapping graph X into three different matrices queries, keys and values, that is  =   ， =   ， =   , which are computed from the input vectors with their weight matrices   ，  ，  .As shown in Eq (1), the similarity matrix   is first obtained by multiplying the transpose of  with .The weight matrix is obtained by normalizing it with the softmax function.Finally, the attention matrix of the input vector  is obtained by multiplying the weight matrix with .
where   denotes the dimensionality of the  or  matrix.
The multi-head self-attention mechanism is an important part of the Transformer, which is a combination of  self-attention modules., ,  correspond to  weight matrices, where    ,    ,    denote the weight matrix of the ith self-attention respectively.They are multiplied with the input vector   to obtain the projection matrices   =    ,   =    and   =    on different spaces, and the attention matrix of each head is calculated as in Eq (2).Then all the output matrices are stitched together and multiplied with the learnable linear transformation matrix   to obtain the final multi-head self-attention output matrix , shown as Eq (3).

CoT block
In the traditional self-attention mechanism, all query-key pairs are learned independently without using rich contextual information.For this purpose, the CoT block is constructed to combine the contextual information of adjacent keys and self-attention learning into a single structure to enhance the representational power of the learned feature graph.The specific structure of the CoT block is shown in Figure 1(a).It is assumed that in an input 2D feature mapping graph X, its queries, keys, and values are , ,  , respectively.The traditional self-attention mechanism directly multiplies the transpose of the query matrix and the key matrix by   to derive the similarity matrix between them.However, in the CoT block, to achieve the contextual representation of each key, the adjacent keys are first convolved with k × k to obtain  1 , a representation of static contextual information.As shown in Eq (4),  1 of the learned static context information is spliced with , and then two consecutive 1 × 1 convolutions ( 1 with ReLU activation function and  2 without activation function) are performed to obtain the attention matrix .
For each head, the local attention matrix for each spatial location is learned from  and the integrated contextual information of  instead of isolated query-key pairs, which enhances self-attention learning.The obtained attention matrix  is then multiplied with  to obtain the dynamic contextual representation  2 , shown as Eq (5).
Finally, the static and dynamic contextual information is combined as the final output of the CoT block, which is shown in Eq (6).
In summary, we obtained the above two kinds of contextual information by capturing information between adjacent keys.And visual representation learning is promoted by static contextual information obtained through 3 × 3 convolution and dynamic contextual information based on static context for self-attention learning.
We partially modified ResNet-50 by combining the res4 and res5 layers into one layer and replacing the 3 × 3 convolution of them with a CoT block to obtain CoTNet-50.The structure of CoTNet-50 is shown in Table 1.The output of this projection is patch embedding.Finally, the position embedding is added to the patch embedding to preserve the position information.The resulting sequence of embedding vectors is used as the input to the Transformer encoder.The above process can be expressed by Eq (7).

Vision Transformer block
where  ∈ is the embedding projection of patch,   ∈  × is the position embedding.The Transformer layer is composed of an L-layer Multi-head Self-Attention (MSA) and a Multilayer Perceptron (MLP).The multilayer perceptron mainly consists of a linear combination of two fully connected layers and a linear activation layer ReLU.The output of the Lth layer can be expressed by Eqs (8) and (9).
where (•) denotes the normalization operator,   is the encoded image representation, and  −1 is the output of a Transformer layer on   .

Dense connection
Let  , denote the output of node  , , where i is the index of the downsampling layer along the encoder and  is the number of convolutional layers of the intermediate nodes connected along the skip. , can be represented as follows: ,  ( +1,−1 )])  > 0 (10) where the function (•) is a convolution operation, each convolution operation is followed by an activation function, (•) and (•) denote the up-sampling and down-sampling layers, and ,•denotes the connection layer.As shown in Figure 1(b), for  = 0,  +1,0 only receives one input and is the previous layer node  ,0 of the encoder.When  = 1, the  +1,1 node receives two inputs  ,0 ,  +1,0 , which are two consecutive sub-networks from the encoder. > 1, the  +1, node receives  + 1 inputs, where the j inputs  +1,0 ,…, +1,−1 are all the outputs of the first j nodes of the same layer of the skip connection, and the last  + 1th input is the upsampled output of the skip connection from the lower layer.When  > 1, the  +1, node receives  + 1 inputs, where the  inputs  +1,0 ,…, +1,−1 are all the outputs of the first  nodes of the same layer of the skip connection, and the last  + 1th input is the upsampled output of the skip connection from the lower layer.We use a convolutional block to complete the skip connection between each of the above nodes.This enables all the previous feature maps can be accumulated and reach the current node, which includes not only the final aggregated feature maps, but also the maps of the intermediate nodes and the original feature maps of the same scale from the encoder.In this way, the deep fusion at different levels of features is completed, and the unnecessary restriction behavior of skip connections is further relaxed.

Weighted loss function
We use a weighted loss function in the model.ℒ  is the cross-entropy loss, and can be used to evaluate the loss incurred when classifying pixel points during the segmentation of image data, which can be defined as follows: where   denotes the label of sample , the positive class is 1 and the negative class is 0.   ̂ denotes the probability that sample  is predicted to be positive class after training.The base of  is  and ℒ  is a loss of Dice, which can be used to evaluate the similarity between the predicted segmented image and the real segmented image (label) whose value ranges from [0,1] and is calculated as follows: where | ∩ | denotes the intersection of the real picture and the predicted picture , and || and || denote the number of the corresponding elements.
The focal loss function [40] ℒ  focuses the training on hard negatives samples and is also able to alleviate the problem of unbalanced data samples, and it is calculated as follows: where  is the weighting factor and takes values in the range [0, 1].For positive samples, the weight is , and for negative samples, it is (1 − ). is a parameter that takes values in [0,5].
When   ̂ tends to 1, it indicates that the sample is easily distinguishable and the modulation factor .1 −   ̂/ tends to 0, which indicates a smaller contribution to the loss, thus reducing the loss contribution of the easily distinguishable sample.When   ̂ is small, it may be misclassified into positive samples.At this time the modulation factor .1 −   ̂/ converges to 1, which does not have much effect on the loss.By reducing the loss contribution of the easy-to-score samples, the loss obtained for simple samples becomes smaller, while the loss for samples with small prediction probability becomes large, thus enhancing the focus on difficult cases.We take a weighted combination of cross-entropy loss, Dice loss and focus loss, which can effectively segment the boundaries and overall contours of teeth in medical images.The weighted loss is defined as follows: ℒ = Aℒ  + ℒ  + ℒ  (14) where A, B and C are the weight parameters of the three loss functions respectively.

Dataset and pre-processing
There is no publicly available dataset of CBCT dental images for research, so we first constructed a dataset of CBCT dental images.We collected a large number of CBCT images from dental hospitals, and all dental CBCT images were obtained from patients under routine clinical care.Most of these patients need dental treatment such as dental implants, orthodontics, and restorations.To determine the true labels of the teeth for model training and evaluation, we perform point-by-point labeling of the tooth parts in the CBCT images.The labeling process is fully manually outlined and checked by an experienced physician to ensure the accuracy of the labeling.In total, there are 20 groups of 300 dental CBCT images in the dental dataset were constructed for tooth segmentation experiments.
We normalized the data to a uniform size before training, and then the data is enhanced by cropping, rotating, and mirroring.We selected three sets of different sizes of crop sizes, followed by different angular rotations, and also used up-down mirroring and left-right mirroring, where each transformation was randomly applied to the original image to complete the enhancement of the dataset.

Implementation
We performed experiments on the PyTorch platform.The hybrid encoder module used ResNet-50 [41] as the baseline model, where the 3 × 3 convolution was replaced with a CoT block, called CoTNet-50 for feature extraction.And we adopted ViT [34] with 12 Transformer layers and a multi-head self-attention mechanism with 12 heads.We combined CoTNet-50 and ViT, denoted as C50-ViT as a hybrid encoder.The input resolution size of the image was 224 × 224 and the patch was 16.The model was trained by the SGD optimizer with a learning-rate 0.001, momentum 0.9, weight decay 1e-4, epochs 200 and batch-size 4. The parameters ,  were set to 0.25 and 0.6.And the loss function weights A, B, and C were 0.4, 0.4 and 0.2 respectively.All experiments were conducted using the NVIDIA GeForce RTX 3060 GPU.

Evaluation metrics
Two commonly-used evaluation metrics in image segmentation are Dice Similarity Coefficient (DSC) and 95% Hausdorff Distance (95HD).
Dice Similarity Coefficient (DSC).In medical image segmentation, the Dice Similarity Coefficient is used to calculate the similarity between the model segmentation result and the real label, and the range of the value is [0, 1].The closer DSC value is to 1, the closer the segmentation result is to the ground truth, which indicates better segmentation performance.The calculation of DSC is as follows: where X denotes the set of true label pixels and Y denotes the output segmentation result of the model.

Hausdorff distance (HD).
Hausdorff distance is used to calculate the distance between any point set  = *1, 2, . . .+ in space.Another point set = *1, 2, . . . .+ , which can be used in image segmentation tasks to measure the maximum mismatch between two contours or shapes.95HD represents the 95% quantile of the maximum distance, which is slightly more stable for small outliers compared to the Hausdorff distance and is commonly used for biomedical segmentation challenges.It is calculated as follows: (, ) = max{max ‖ − ‖, max  ‖ − ‖} (16) The Hausdorff distance represents the maximum value of the difference between the point set  and the point set .Therefore, a smaller value represents a better match between the two contours or shapes, and a better segmentation of the model.We also used the following additional evaluation metrics to assess the segmentation results of the model.
Pixel Accuracy (PA).PA is the simplest metric and is the ratio of correctly marked pixels to the total pixels.The calculated formula is as follows: Mean Pixel Accuracy (MPA).MPA is a simple enhancement of PA, which calculates the proportion of pixels within each class that are correctly classified.The calculation formula is as follows: Mean Intersection over Union (mIoU).This metric provides a balanced evaluation of the segmentation results for all categories.The segmentation accuracy is measured by calculating the ratio of intersection and union between the segmentation results and the true labels.The calculation formula is as follows: Average System Surface Distance (ASD).It calculates the average of the surface distances between the segmentation result and all points in the real label, which is one of the evaluation criteria for the medical image segmentation competition CHAOS.Let the point set  = *1, 2, … + be the surface point set of the model segmentation result, and the point set  = *1, 2, … + is the surface point set of the real label.The calculation of ASD can be expressed as Formula (20).The unit of ASD is millimeter, and its smaller value means the better segmentation result of the model.)) + ∑ (  , ()

The selection of loss function weights
We combine three different loss functions to obtain our weighted loss function.The optimal weights of A, B and C are derived from multiple sets of experimental data, as shown in Table 3.We took four different sets of weights, and the experimental results show that the best results can be achieved when A, B and C are 0.4, 0.4 and 0.2.Therefore, this set of parameters is taken as the actual weight values of our network.

Overall performance comparison
In order to verify the segmentation effect of the proposed model, we have quantitatively compared the performance of our proposed CoT-UNet++ with U-Net, UNet++, and TransUNet in terms of Dice, 95HD, MPA, mIoU, TPR, and ASD metrics on our constructed dataset.Three of these comparisons were obtained by running the source code provided by the authors of the published literature, while simply replacing their datasets with our dental dataset.The results of the comparison of the four methods are shown in Table 3, and the best performance is highlighted in bold.From Table 3, our method has the best performance on Dice and 95HD with 92.06% and 1.06 respectively, which can be attributed to the appropriate combination of the CoT block, dense connection, and loss weighting.Compared with the baseline model TransUNet, the Dice evaluation index was improved by about 5.5%.There was a significant decrease in 95HD, and other indexes were also improved.

Ablation experiment
To verify the effectiveness of each component of the proposed method, we conducted the ablation study in this section.Performance comparison of 50-layer networks.To qualitatively compare the above segmentation performance and the effectiveness of the weighted loss function, we further compared the performance of our proposed network using different encoders, with or without dense connectivity, with (w) or without (w/o) weighted loss function ℒ.The results are all presented in Table 3.The comparison between TransUNet and TransUNet++ shows that the latter has better values of all evaluation metrics than the former.It shows that the dense connection effectively captures different local features by multi-scale fusion of different features and obtains more accurate location information.CoT-UNet also shows more improvement in segmentation metrics than TransUNet.It is confirmed that unifying context mining and self-attention learning between keys into a single model is an effective way to enhance representational learning and thus facilitate visual recognition.The proposed CoT-UNet++ model obtains the best performance, demonstrating the effectiveness of combining dense connectivity with an efficient encoder.In addition, comparing the data with and without the weighted loss function in the table, it can be seen that the metrics with the weighted loss function are all better than those without, which also shows the necessity of using the weighted loss function.Figure 3 shows the loss curves of TransUNet and our proposed CoT-UNet++.We can see that the loss value of CoT-UNet++ decreases faster than that of TransUNet, and the final stabilization value is smaller, which can indicate that CoT-UNet++ training is more effective.
Performance comparison of 101-layer networks.To verify the general applicability of our network, we also used ResNet-101 as the baseline model for the encoder.The 3 × 3 convolution in ResNet-101 is replaced by the CoT block to obtain CoTNet-101.Table 5 compares the performance of several different networks based on ResNet-101 and CoTNet-101.The table shows that the metrics of dental segmentation using 101-layer networks are generally better than those using 50 layers.Moreover, the use of CoT block, dense connections, and weighted loss functions can all improve the accuracy of segmentation, further validating the effectiveness of our proposed network.In summary, the segmentation results of our proposed model CoT-UNet++ for dental CBCT images are better than the existing state-of-the-art models and are much closer to the real manual segmentation results from the dentist and physician.However, our model has a large number of parameters and a long training time.Moreover, the segmentation result is not satisfactory for poor-quality CBCT images.Therefore, the network needs to be improved to reduce the number of parameters and to focus more on boundary information.

Conclusions
In this paper, we propose a CoT-UNet++ algorithm for dental image segmentation.The CoT block is introduced to enhance visual characterization in the encoder.In order to obtain more accurate localization of teeth in CBCT images, CoTNet-50 and ViT are used as hybrid encoders and the dense connection is utilized to fuse all the same scale feature mappings of the encoder.Moreover, an effective weighted combination of loss functions can capture the tooth structure at the pixel level, making the boundaries of segmented teeth more accurate.Experimental results show that our proposed method achieves better performance in terms of teeth segmentation accuracy over other related methods.However, there is still inaccurate segmentation of individual tooth boundaries, and more attention should be paid to the boundary information in future studies.In addition, tooth classification and 3D tooth reconstruction are worthy of further investigation.

Figure 1 .
Figure 1.The overall structure of CoT-UNet++.(a) The specific structure of the CoT block.(b) Overall architecture of CoT-Unet.

For
an image  ∈  H×W× ,  ×  is the resolution of the image and  is the number of channels.The standard Transformer receives a 1D sequence as input.So the image is first reconstructed as a 2D patch ,   ∈   2  |  = 1, … , -. ×  is the resolution of each image patch,  =   2 is the number of patches, which also serves as the input sequence length for the Transformer.Then the vectorized    is mapped into the N-dimensional space by linear projection.
First, we qualitatively compared the effectiveness of the CoT block and dense connections under the same loss function.TransUNet, TransUNet++, CoT-UNet, and CoT-UNet++ are selected separately for comparison, as shown in Figure 2. TransUNet++ only uses dense connection on the basis of TransUNet without taking CoT blocks, and the hybrid encoder uses a combination of ResNet-50 and ViT.CoT-UNet does not use dense connections while the CoT block is used, and the hybrid encoder is a combination of CoTNet-50 and ViT.Our proposed CoT-UNet++ utilizes a dense connection, hybrid encoder using a combination of CoTNet-50 and ViT with a weighted loss function.According to Figure 2, we can observe that teeth segmented by CoT-UNet++ can obtain more accurate segmentation results, while several other methods tend to produce under-segment or over-segment results.For the segmentation results of TransUNet and TransUNet++, there is a certain degree of over-segmentation in the molar part.It identifies parts of the alveolar bone as teeth, as well as two adjacent molars that are not well separated from each other.In contrast, the CoT-UNet segmentation results showed incomplete molar segmentation and cavities in the middle of the teeth.Our proposed CoT-UNet++ overcomes these limitations and has better performance.

Figure 2 .
Figure 2. Segmentation results of different CBCT images of teeth with labels.From left to right are the segmentation results using TransUNet, TransUNet++, CoT-UNet and our proposed method respectively.

Table 2 .
The results of the loss function using different weighted parameter values.The best performance is shown in bold.

Table 3 .
Quantitative comparison of segmentation performance using different methods on the CBCT dental dataset.(Dice, 95HD, MPA, mIoU, TPR and ASD)

Table 4 .
Ablation experiments with our network CoTUNet++ (50 layers) on the CBCT dental dataset.The best performance is shown in bold.

Table 5 .
Ablation experiments with our network CoTUNet++ (101 layers) on the CBCT dental dataset.The best performance is shown in bold.