Factorizer: A Scalable Interpretable Approach to Context Modeling for Medical Image Segmentation

Convolutional Neural Networks (CNNs) with U-shaped architectures have dominated medical image segmentation, which is crucial for various clinical purposes. However, the inherent locality of convolution makes CNNs fail to fully exploit global context, essential for better recognition of some structures, e.g., brain lesions. Transformers have recently proven promising performance on vision tasks, including semantic segmentation, mainly due to their capability of modeling long-range dependencies. Nevertheless, the quadratic complexity of attention makes existing Transformer-based models use self-attention layers only after somehow reducing the image resolution, which limits the ability to capture global contexts present at higher resolutions. Therefore, this work introduces a family of models, dubbed Factorizer, which leverages the power of low-rank matrix factorization for constructing an end-to-end segmentation model. Specifically, we propose a linearly scalable approach to context modeling, formulating Nonnegative Matrix Factorization (NMF) as a differentiable layer integrated into a U-shaped architecture. The shifted window technique is also utilized in combination with NMF to effectively aggregate local information. Factorizers compete favorably with CNNs and Transformers in terms of accuracy, scalability, and interpretability, achieving state-of-the-art results on the BraTS dataset for brain tumor segmentation and ISLES'22 dataset for stroke lesion segmentation. Highly meaningful NMF components give an additional interpretability advantage to Factorizers over CNNs and Transformers. Moreover, our ablation studies reveal a distinctive feature of Factorizers that enables a significant speed-up in inference for a trained Factorizer without any extra steps and without sacrificing much accuracy. The code and models are publicly available at https://github.com/pashtari/factorizer.


Introduction
Medical image segmentation is an essential prerequisite for the analysis of anatomical structures for various clinical purposes, including diagnosis and treatment planning. In recent years, the vast majority of effective segmentation models are based on Convolutional Neural Networks (CNNs), particularly, U-Net [1], consisting of encoder and decoder parts with skip connections in between. In a typical U-Net [1,2], the encoder learns a low-resolution contextual representation consisting of progressively downsampled feature maps, while the decoder progressively upsamples the low-resolution feature maps to propagate contextual information to the higher-resolution layers. Moreover, skip connections between encoder and decoder layers of equal resolution help to recover spatial information lost during downsampling.
Currently, U-Net models mostly rely on convolution operations with small receptive fields, which are capable of exploiting only local context at each resolution. Hence, they generally fail to effectively model long-range spatial dependencies, often necessary, for example, for better recognition of brain lesions, which can be very infiltrative, extensive, and thus dramatically vary in shape and size. Moreover, capturing even small focal tumors within the receptive field is extremely difficult without any notion about the global context of normal brain anatomy since such tumors can occur anywhere in the brain. Several works [3,4] have employed dilated convolution for expanding the receptive fields. Nevertheless, the learning capabilities of convolutional layers are still limited due to their inherent locality. As a solution, integrating self-attention modules into CNNs [5,6] has been proposed to enhance the capability of modeling non-local context. Transformers [7] have achieved state-of-the-art performance on various natural language processing tasks. The attention mechanism enables Transformers to effectively model the pairwise interactions between the words in a sentence. Recently, Transformerbased models have been applied to vision tasks and demonstrated promising results. Specifically, Vision Transformer (ViT) [8] outperformed state-of-the-art CNNs on image recognition by large-scale pre-training and fine-tuning a pure Transformer. Unlike CNNs, ViT encodes images as a sequence of 1D patch embeddings (known as tokens) and dynamically highlights the important tokens using self-attention layers, which in turn increases the capability of learning long-range dependencies. Due to lack of locality inductive bias, ViT is data-hungry and generally requires a larger dataset to perform as effectively as its CNN counterparts, leading to poor performance when trained on insufficient data, which is usually the case in medical imaging. Furthermore, the quadratic complexity of self-attention makes Transformers computationally intractable on long sequences of patches. Therefore, existing models use self-attention layers only after somehow reducing the image resolution, thereby failing to fully exploit the global context at the higher resolution.
This work proposes a family of architectures, dubbed as Factorizer, which leverages the power of low-rank matrix approximation (LRMA) to construct an end-to-end medical image segmentation model. Among LRMA methods, Nonnegative Matrix Factorization (NMF) has demonstrated a remarkable ability to compress data and automatically extract easy-to-interpret sparse factors [9,10]. Hence, we propose a linearly scalable alternative to self-attention by formulating an NMF algorithm as a differentiable layer. Moreover, a series of matricization operations is introduced, which enables NMF to effectively exploit both global and local contexts. The Factorizer block is constructed by replacing the selfattention layer of a ViT block with our NMF-based modules and then integrated into a U-shaped architecture with skip connections.
We evaluated the effectiveness of our approach for the segmentation of brain tumors and stroke lesions in MRI data. Factorizers achieved competitive results on the BraTS [11,12] and ISLES'22 [13] datasets, having outperformed state-of-the-art methods based on CNN and Transformer. Our experiments showed that NMF components are highly meaningful, which gives a great advantage to Factorizers over CNNs and Transformers in terms of interpretability. Furthermore, our ablation studies revealed a distinctive interesting feature of Factorizers that enables us to easily speed up the inference for a trained Factorizer model with no extra steps and without sacrificing much accuracy.
Contributions. The main contributions of this work are as follows: • To the best of our knowledge, this work presents the first end-to-end deep model with matrix factorization layers for medical image segmentation.
• A differentiable NMF layer is constructed using a block coordinate descent solver to efficiently model contextual information.
• Shifted Window (SW) Matricize operation is introduced and combined with NMF to fully exploit local contexts.
• Scalable interpretable U-shaped segmentation models based on NMF are proposed.
• The proposed models achieve state-of-the-art results on the BraTS and ISLES'22 datasets.
Notation. We denote vectors by boldface lower-case letters, e.g., x, matrices by boldface upper-case letters, e.g., X, and tensors by boldface calligraphic letters, e.g., X . Elements in a matrix (tensor) are denoted by X[i, j] (X [i 1 , . . . , i N ]). The ith row and jth column of a matrix is denoted by X[i, :] and X[:, j], respectively. A sequence of N vectors (a.k.a. tokens) is denoted by (x n ) N n=1 . We use [x 1 | . . . |x N ] to denote a matrix X created by stacking x i s along the columns. We show the inner product between matrices by X, Y = i,j X[i, j]Y[i, j] and 2 -norm of a matrix by X = X, X .

Related Work
CNN-based Segmentation Models. Convolutional neural networks (CNNs) have dominated medical image segmentation. Particularly, following an encoder-decoder architecture with skip connections, U-Net [1] has achieved state-of-the-art on various medical image datasets. The simplicity and effectiveness of a U-shaped architecture have led to the emergence of numerous U-Net variants in the field. Ç içek et al. [2] extended U-Net by replacing all 2D operations with their 3D counterparts. UNet++ [14] follows a deeplysupervised encoder-decoder network consisting of sub-networks connected through a series of nested, dense skip connections. nnU-Net [15] proved effective in various medical image segmentation tasks by only making minor modifications to the standard 3D U-Net [2] and defining a recipe to automatically configure key design choices. Myronenko [16] proposed a U-Net-like architecture with ResNet blocks, a.k.a. ResSegNet, which ranked first in the Brain Tumor Segmentation Challenge (BraTS) 2018. Ashtari et al. [17] proposed a lightweight CNN for glioma segmentation, with low-rank constraints being imposed on the kernel weights of the convolutional layers in order to reduce overfitting. Despite their success, these networks generally fail to effectively model long-range spatial dependencies, often necessary for better recognition of some region semantics such as tumors, since they rely on convolution operations with small kernel sizes, aggregating only local information in an image.
Visual Transformers. Transformers with attention mechanisms [7], introduced originally for language modeling, have recently proven promising on computer vision tasks. Particularly, the pioneering Vision Transformer (ViT) model [8] outperformed state-ofthe-art CNNs on image recognition by large-scale pre-training and fine-tuning a pure Transformer applied to sequences of image patches. In contrast to CNNs, ViT lacks any inductive bias such as locality, and therefore, generally shows poorer performance than its CNN counterparts (e.g., ResNets [18]) when trained from scratch on small-size or mid-size datasets, which is usually the case for medical imaging.
Efforts have been made to mitigate this limitation. For example, Tokens-to-Token ViT (T2T-ViT) [19] introduces a hierarchical architecture to ViT by progressively combining neighboring tokens into a single token to reduce the sequence length and aggregate local context. Liu et al. [20] proposed a hierarchical Transformer, called Swin Transformer, adopting the shifted windowing scheme, which brings more efficiency by limiting selfattention computation to non-overlapping local windows while also allowing for crosswindow connection. As another exemplification of a hierarchical Transformer, Pyramid vision Transformer (PvT) [21] significantly reduces computational and memory overhead by reducing the sequence length at each stage through non-overlapping patch embedding and learning low-resolution key-value pairs via spatial-reduction attention (SRA) in each block. Convolutional vision Transformer (CvT) [22] incorporates depthwise convolutions into self-attention layers and uses strided convolution for simultaneously tokenizing and downsampling the image, exploiting the excellent capability of convolution at capturing low-level local features.
Transformer-based methods were recently proposed to deal with the task of 2D image segmentation. Segmentation Transformer (SETR) [23] uses a ViT encoder and a decoder with progressive upsampling (which alternates Conv layers and upsampling operations) and multi-level feature Aggregation. SegFormer [24] consists of a PvT-based encoder and a lightweight Multilayer Perceptron (MLP) decoder with upsampling operations. Chen et al. [25] proposed a model for multi-organ segmentation by incorporating ViT into the bridge of a 2D convolutional U-Net architecture. Zhang et al. [26] proposed to combine a shallow CNN with a Transformer in a parallel style. Valanarasu et al. [27] proposed a Transformer model with an axial attention mechanism for the segmentation of 2D medical images. Cao et al. [28] proposed a U-shaped architecture based purely on Swin Transformer.
For 3D medical image segmentation, Xie et al. [29] proposed a model comprising a CNN backbone to extract features, a Transformer to model long-range dependencies, and a CNN decoder to construct the segmentation map. More recently, Hatamizadeh et al. [30] proposed UNETR, which utilizes ViT as the main encoder but directly connects it to the convolutional decoder via skip connections, as opposed to using a Transformer only in the bridge. nnFormer [31] uses an initial convolutional tokenizer and interleaves local and global self-attention blocks with convolutional downsamplers. Since self-attention is prohibitively expensive on long sequences, all these models apply Transformer on a low-resolution stage after either patch embedding or a CNN backbone, making them fail to fully exploit the global context at the higher resolutions. In contrast, our proposed approach based on NMF offers a scalable alternative to the attention mechanism, which enables the exploitation of the global context at the highest-resolution stage of a 3D network.
Matrix Factorization Models. In the context of machine learning, low-rank matrix factorization (MF) methods have proven extremely useful for representation learning, dimensionality reduction, and collaborative filtering. NMF has been used for unsupervised and semi-automated segmentation of brain tumors on multiparametric MRI data [32,33]. However, only a few works have incorporated MF into an end-to-end deep model to perform a computer vision task. Most notably, Geng et al. [34] proposed a framework, called Hamburger, where the global context is modeled as solving a low-rank matrix completion problem by suitable optimization algorithms that guide the design of layers able to capture global information. They demonstrated the effectiveness of a Hamburger model based on NMF with a multiplicative update (MU) solver for semantic segmentation. Our approach to context modeling is based on NMF with a Hierarchical Alternating Least Squares (HALS) solver and introduces a series of matricization operations which enable NMF to effectively exploit both global and local contexts. Moreover, our proposed block is inspired by the overall design of the ViT block and incorporated into a U-shaped architecture.

Matrix Factorization for Context Modeling
Here we provide the motivation behind incorporating matrix factorization into deep learning by first presenting an alternative view of the attention mechanism and then showing how it relates to the matrix factorization approach to modeling contextual information.
Revisiting Attention Mechanism. The attention mechanism is the key component that enables Transformers to model complex dependencies between the elements of a sequence. Consider an input sequence of C-dimensional tokens (x n ) N n=1 , stacked into the rows of matrix X = [x 1 | . . . |x N ] T . In a self-attention layer, the input is first projected onto three learnable weight matrices W Q , W K , W V ∈ R C×E to get three different matrices: The output is then defined by where Softmax is taken row-wise. Note that attention takes O(N 2 E) time and needs O(N 2 ) memory to store the attention map, scaling quadratically with the sequence length, which is prohibitively expensive for large inputs, such as high-resolution or 3D images. Taking a closer look at equation (1), we notice that attention can be viewed as a special case of nonparametric regression. To reveal this, let's consider key-value pairs {(k n , v n )} N n=1 as a training set and queries {q n } N n=1 as a test set; where k n s and q n s are feature vectors, and v n s are vectors of target variables. The predictions for the queries using a Nadaraya-Watson kernel regression [35] are given bŷ where kernel ϕ(., .) is, in general, a similarity measure. Setting ϕ(a, b) = exp(a T b/ √ E) simply yields the attention formula.  Figure 1 The overall architecture of Factorizer.
keys, and values to model interactions between the sequence elements. Notably, one can take a matrix completion approach to regress on the queries via low-rank approximation of a block matrix of queries, keys, and values with missing entries, that is where B, F ∈ R N ×R and G, H ∈ R E×R ; and R ≤ min(N, E) is the rank. The reconstructed matrix FH T then gives an estimation of the missing block as the output. This can be viewed as a joint matrix factorization of Q, K, and V, i.e., K ≈ BG T , V ≈ BH T , Q ≈ FG T . In this work, we further simplify the procedure by considering only a single linear map to generate a single matrix Z = XW , where W ∈ R C×E is a learnable weight matrix, and using a regular matrix factorization rather than a joint one, that is Z ≈ FG T , then the output is the reconstructed matrix FG T . Note that this is an unsupervised scheme, unlike the joint factorization, which forms a regression model. Depending on the matrix factorization algorithm, a suitable activation function may be applied before the factorization to constrain the input. Particularly, in the case of NMF, ReLU of the matrix must be first taken to make all of the entries nonnegative. As we will see in Section 4, our results suggest that NMF can potentially be an efficient yet effective alternative to attention-based context modeling for medical image segmentation.

Overall Architecture
As shown in Figure 1, a Factorizer model follows a U-Net-style architecture consisting of encoder and decoder parts with skip connections in between at equal resolutions. Given an input image X ∈ R C in ×H×W ×D with C in channels and resolution (H, W, D), the network outputs a logit map of size (C out , H, W, D), where C out is the number of foreground classes. A single 3D convolution with a kernel size of (3, 3, 3) is used as the stem to increase the number of channels to C = 32. However, note that in contrast to ViT, Factorizer does not flatten the spatial dimensions at the initial stage to generate a sequence of tokens. The network has four stages, with the resolution decreasing to 1/16 in the bridge. At each stage of the encoder (decoder), the input tensor is downsampled (upsampled) by a factor of two while the number of channels is doubled (halved). Convolution (transposed convolution) with a kernel size of (2, 2, 2) and stride of 2 is used for downsampling (upsampling). In the bridge, learnable position embeddings are added to the input right after downsampling. We use deep supervision [36] at the three highest resolutions in the decoder, applying pointwise convolutions (i.e., convolution with a kernel size of (1, 1, 1)) to get the output and two auxiliary low-resolution logit tensors.

Factorizer Block
A Factorizer block is constructed by replacing the multi-head self-attention module in a ViT block [8] with a Wrapped NMF module (described in Section 3.4). As shown in Figure 2a, a Factorizer block comprises NMF module and MLP, each of which comes after Layer Normalization and before a residual connection, that is, where MLP has two linear layers with a Gaussian Error Linear Unit (GELU) nonlinearity in between: The number of input and output channels are the same, but the number of inner channels is double that of input channels.

Wrapped NMF Module
The major component of Factorizer is the Wrapped NMF module, which relies on matricization (i.e., an operation that turns a tensor into a matrix) and NMF. As shown in Figure 2b, a Wrapped NMF subblock first applies a pointwise convolution to linearly project each voxel. The output is then reshaped into a batch of matrices using a Matricize operation, which is detailed later on in Section 3.4.1. The resulting matrices are passed through ReLU to clamp all their elements into nonnegative values and then low-rank approximated using NMF. The reconstructed matrices are reshaped back to their original size using the Dematricize operation, i.e., the inverse of Matricize. Finally, another pointwise convolution is applied, yielding the output. More formally, Wrapped NMF can be computed as follows: where NMF(·) is the NMF layer, described in 3.4.2. Intermediate tensors X i s and the output have the same size as the input X .

Matricize
Before applying any matrix factorization method, an input batch of multi-dimensional images, denoted by X ∈ R B×C×H×W ×D , must be turned into a batch of matrices, say Z ∈ R B ×M ×N . This operation is called Matricize. In this work, we propose three Global Matricize. This operation simply flattens the spatial dimensions and divides the channels into multiple groups (analogous to heads of Multi-Head Self-Attention). More specifically, Global Matricize reshapes a batch of C-channel 3D images X ∈ R B×C×H×W ×D into a batch of matrices denoted by a 3D tensor Z ∈ R B(C/E)×E×HW D , where E is called the head dimension, i.e., the number of channels per head (or matrix). This operation is obviously suitable for modeling global context and imposes no locality inductive bias.
Local Matricize. Global Matricize lacks a notion of locality which is typically useful for images, especially in low-data regimes. Local Matricize is proposed to mitigate this shortcoming by splitting an input X ∈ R B×C×H×W ×D into a grid of non-overlapping patches of size (E, P, P, P ). These patches are flattened spatially and then concatenated along the batch dimension, yielding a batch of matrices presented by the tensor Z ∈ R BCHW/(EP 3 )×E×P 3 . The entire procedure can be summarized by Einstein notation: where (g c , g h , g w , g d ) and (e, p h , p w , p d ) indices correspond to grid and patch dimensions, respectively. In PyTorch and TensorFlow, such Einstein operations can be simply implemented using rearrange function from einops library. Once NMF is applied to the resulting batch of matrices, the inverse operation of Local Matricize, called Local Dematricize, must be applied to transform them back to the initial shape. Local Dematricize can be easily obtained by composing the inverses of sub-operations in reverse order. In practice, Local Dematricize can also be formulated and implemented as an Einstein operation. A PyTorch-style pseudocode of Local (De)matricize module is provided in Algorithm 3. Note that while Local Matricize seems to reshape an image in a similar way to the input patchifier of ViT, it concatenates patches along batch dimension rather than channel dimension. In fact, Local Matricize can be used when modeling within-patch interactions is desirable, which is different from ViT-based approaches, where interactions between embedded patches are typically modeled.
Shifted Window Matricize. While Local Matricize introduces locality to a model, voxels close to the boundaries of partitioning windows are not represented effectively. Since patches are low-rank approximated independently later on in an NMF layer, two neighboring boundary voxels from two adjacent patches are very likely to end up having excessively different feature maps in the output of the factorizer block, which in turn degrades the prediction performance for such voxels. To mitigate this problem, we utilize a shifted window approach proposed by Liu et al. [20] and introduce a Shifted Window (SW) Matricize operation, making the output feature maps smoother around boundaries.
An illustration of SW Matricize is provided in Figure 2c. Here, in addition to regular patches similar to those extracted by Local Matricize, shifted window (SW) patches are also included. Let X ∈ R B×C×H×W ×D be the input and (P, P, P ) the patch size. To extract SW patches, the input must be first shifted by the offset of (P/2, P/2, P/2) along spatial dimensions such that voxels shifted beyond the boundaries of the images are reintroduced at the first position, yielding a tensor of the same size (B, C, H, W, D). We call this operation Roll (which can be implemented using roll function in PyTorch and TensorFlow). The resulting rolled tensor is then reshaped by Local Matricize to get SW patches. Finally, both the batches of regular and SW patches are concatenated along the batch dimension. Formally, SW Matricize is computed as follows: where LocalMatricize P denotes Local Matricize with a patch size of (P, P, P ), and Roll P/2 denotes the roll operator with a shift (P/2, P/2, P/2). To build SW Dematricize, we need to reconstruct the image from both regular and SW patches independently then compute their average to achieve smoother and more accurate feature maps. Further details are provided in B, which includes a PyTorch implementation of SW (De)matricize module presented in Algorithm 3.

Nonnegative Matrix Factorization
Once the input is somehow transformed into a batch of matrices, and its negative elements are clipped to zero by ReLU, it is ready to be low-rank approximated by Nonnegative Matrix Factorization (NMF). This is the main component of a Factorizer model that contributes most to modeling local or global context in an image.
NMF [37] seeks to approximate some given nonnegative matrix where F ∈ R M ×R ≥0 and G ∈ R N ×R ≥0 are factor matrices, and the positive integer R ≤ min(M, N ) is the rank. Once the factors F and G T are somehow approximated, the NMF Algorithm 1: Multiplicative update for NMF ( and · · denote element-wise matrix product and division, respectively).
layer in a Factorizer block outputs the reconstructed matrixX = FG T . Note that similar to self-attention, the NMF layer can be viewed as an adaptive filter, meaning that the computation of factors involves the input, as opposed to convolution, where kernel weights are fixed and independent from the input and do not change after training. Various loss functions have been used to form the objective function and measure the quality of an approximation. Depending on the loss function, constraints, and regularization, many variants of NMF have been proposed. In this work, we use a standard NMF with the squared error, which is the most widely used variant for images, to find factors by solving the following problem: minimize In general, problem (10) is nonconvex, and finding global minima is NP-hard. However, numerous iterative algorithms have been proposed to find a "good local minimum". The majority of existing methods are based on a block coordinate descent (BCD) scheme (a.k.a. alternating optimization), where the objective function is iteratively minimized with respect to one factor while the other factor is kept fixed. That is, the convex subproblems are exactly or approximately solved alternately. This ensures that the objective function value does not increase after each update and guarantees convergence to a stationary point under some mild conditions [38]. Among numerous BCD-based algorithms for NMF, Multiplicative Update (MU) [37] is the best-known due to the advantage of being easy-to-implement and scalable. MU enforces a nonnegativity constraint by updating the previous values of a factor matrix by multiplication with a nonnegative scale factor. The pseudocode of MU is outlined in Algorithm 1. However, slow convergence of MU has been pointed out [10,Chapter 8], and hence more effective algorithms with faster convergence such as Hierarchical Alternating Least Squares (HALS) [39] have been introduced. HALS updates a factor, say F = [f 1 | . . . |f R ], with inner iterations, in which the columns f r s are updated by solving the following subproblem f r ← arg min where E r = X − R =r f g T is the residual matrix, which is, in fact, approximated by a rank-one matrix. An encouraging aspect of HALS is that each subproblem (12) can be easily shown to have a closed-form solution: Similarly, the update formula for columns of G can be derived. Note that HALS is a 2Rblock coordinate descent procedure, where at each outermost iteration, first the columns of F and then the columns of G are updated. Algorithm 2 provides pseudocode of HALS (further details on MU and HALS can be found in [10, Chapter 8]). A special case of NMF is R = 1; i.e., X ≈ fg T , where f ∈ R M ≥0 and g ∈ R N ≥0 ; for which one can easily derive that both MU and HALS are simplified to the same update rule: In this paper, all Factorizer models are trained with R = 1, and compression ratios (and indirectly reconstruction errors) are controlled by adjusting the head dimension (i.e., the number of rows in a matrix, as discussed in 3.4.1) to sufficiently small values. This means that a matricize operation transforms an image into a batch of such fat matrices (i.e., the number of columns is much larger than the number of rows) that rank-one approximation would suffice in practice. This greatly simplifies a factorizer model and improves interpretability while yielding better segmentation performance. However, in one of our ablation studies (see Section 4.5), we experimented with both MU and HALS for R > 1 to investigate the impact of rank in the inference phase. The computational complexity of both MU and HALS is O(M N R) per iteration [10,Chapter 8], making the Wrapped NMF layer scale linearly and be much cheaper than self-attention with quadratic complexity (which is computationally intractable on long sequences) and even than Performer [40], as an efficient approximation of attention.
It is worth mentioning that not all NMF algorithms and their settings can be used in the NMF layer of a Factorizer block. The selected algorithm must have some properties so that we can ultimately train the Factorizer model successfully in an end-to-end fashion using a gradient descent-based optimizer on GPU(s) through an existing deep learning framework, such as PyTorch. Firstly, the algorithm should be backpropagation-friendly and amenable to automatic differentiation, that is ∂X ∂X should not only be well-defined and somewhat smooth but also computable by means of an existing deep learning framework, such as PyTorch, so that we can practically train the Factorizer model in an end-to-end fashion using a gradient descent optimizer. Another related aspect is that the gradient ∂X ∂X , as explained in [34], starts to vanish during backpropagation after some iterations. Therefore, the number of outer iterations T in an NMF algorithm should be limited in order to have stable gradients. For MU and HALS, T = 5 is a reasonable choice in practice. Finally, update rules should be also friendly to GPU parallel processing for exploiting GPUs to train Factorizers in a reasonable amount of time. Taking all these factors into account, MU and HALS are appropriate choices. While HALS has better convergence properties, MU is more favorable for GPU training due to unparallelizable inner iterations of HALS.

Datasets
We evaluate the effectiveness of our models on the Brain Tumour Segmentation (BraTS) dataset [11,12] from Medical Segmentation Decathlon [41] and Ischemic Stroke Lesion Segmentation (ISLES) 2022 dataset [13] from a MICCAI 2022 challenge.
BraTS. This dataset consists of 484 multiparametric MRI (mpMRI) scans from patients diagnosed with either low-grade glioma or high-grade glioma (glioblastoma). Each scan comes with four 3D MRI sequences, namely T2 Fluid-Attenuated Inversion Recovery (FLAIR), native T1-weighted (T1), post-Gadolinium contrast T1-weighted (T1Gd), and T2-weighted (T2). Once images are preprocessed (i.e., rigidly co-registered to the same anatomical template, resampled to the same voxel spacing 1mm 3 , and skull-stripped), the ground truths are manually created by experts who label each voxel as enhancing tumor (ET), edema (ED), necrotic and non-enhancing tumor (NCR/NET), or everything else. However, for evaluation, the 3 nested subregions, namely enhancing tumor (ET), tumor core (TC-i.e., the union of ED and NCR/NET), and whole tumor (WT) are used (see the sample ground truths in Figure 5).

ISLES'22.
This dataset is from the ISLES'22 challenge, which aims to evaluate automated methods of acute and sub-acute stroke lesion segmentation in 3D multiparametric MRI data, namely DWI, Apparent Diffusion Coefficient (ADC), and FLAIR sequences. The DWI and ADC images of a patient are aligned while the FLAIR image in its native space has a different voxel size and must be registered to the DWI space. As DWI and ADC are the most informative modalities for stroke lesions, FLAIR is ignored in this paper to avoid the complication of FLIAR-DWI registration and simplify the pipelines. The dataset consists of 250 cases, each is skull-stripped and includes an expert-level annotation of the stroke lesions.

Setup
All the models were implemented using PyTorch [42] and MONAI [43] frameworks and trained on NVIDIA P100 GPUs. We followed the same training workflow in all the experiments. In the following, we first provide the details of this workflow and baseline models, then present the evaluation protocol and the results.
Preprocessing. For each scan in a dataset, a multi-channel 3D image as the input was first constructed by concatenating the modalities-i.e., FLAIR, T1, T1Gd, and T2 for BraTS, and DWI and ADC for ISLES'22. The image and its ground truth were then cropped with a minimal box filtering out zero regions of the image. The image was normalized channel-wise using a z-score to have intensities with zero mean and unit variance. Random patches of size (128, 128, 128) for BraTS and (64, 64, 64) for ISLES'22 were extracted during training. To reduce overfitting, we used data augmentation techniques, including random affine transform, random flip along each spatial dimension, additive Gaussian noise, random Gaussian smoothing, random intensity scaling, random intensity shifting, and random gamma transform. Further details are provided in A.
Training. All models were trained for 100000 steps with a batch size of 2 (one sample per GPU) using AdamW optimizer with a base learning rate of 10 −4 , weight decay of 10 −2 , warmup of 2000 steps, and cosine annealing scheduler. The loss L total is computed by incorporating the three deep supervision outputs and the corresponding downsampled ground truths according to where λ 1 = 1, λ 2 = 0.5, and λ 3 = 0.25; G i and P i correspond to the deep supervision at the ith highest resolution; and the loss function L(·, ·) is a combination of soft Dice loss [44] and cross-entropy loss, defined as L(G, P) = L Dice (G, P) + L CE (G, P), where L Dice (G, P) = 1 − 2 G, P + where G ∈ {0, 1} J×N and P ∈ [0, 1] J×N represent the one-hot encoded ground truth and the predicted probability map for each voxel, respectively, with J denoting the number of foreground classes and N denoting the number of voxels in the patch. The small constant = 10 −5 is commonly used to smooth the soft Dice loss and avoid division by zero.

Inference.
A test image in the inference was first subjected to z-score intensity normalization, then the prediction was made using a sliding window approach with a 50% overlap and a window size of 128×128×128 (same as the patch size used in training). Finally, the resulting probabilities were thresholded by 0.5 to obtain a binary segmentation map.
Evaluation Metrics. The Dice score and Hausdorff Distance 95% (HD95) were used as metrics to assess the performance of models in our experiments. For each segmentation region, the Dice score measures the voxel-wise overlap between the ground truth and the prediction, defined as Dice(g, y) = 2 N n=1 g n y n N n=1 g n + N n=1 y n where g n ∈ {0, 1} and y n ∈ {0, 1} represent the ground truth and the binary prediction for a voxel, respectively, and N is the number of voxels. If both the ground truth and the prediction do not have any nonzero values, that is the denominator of equation (18) is zero, the Dice score is defined as 1. Hausdorff Distance (HD) evaluates the distance between the boundaries of ground truth and prediction. HD is defined as follows: where G and Y denote the set of all voxels on the surface of ground truth and prediction, respectively. HD95 is a more robust version to outliers, which calculates the 95% quantile rather than the maximum of surface distances.
Models. In all the experiments, for Factorizer models with the overall architecture illustrated in Figure 1, the number of output channels of the stem was C = 32. For Local and Swin Factorizer, we used a large window size of (8,8,8) and (4, 4, 4) on the BraTS and ISLES'22 datasets, respectively, to aggregate local information, which is opposed to typical CNNs comprising convolutions mostly with a small receptive field of size (3,3,3). For all Factorizer models, a head dimension of E = 8 on BraTS and E = 4 on ISLES'22 was used. In NMF modules, the factor matrices were initialized with uniform distribution U(0, 1). In training, we used a rank-one approximation (R = 1) with T = 5 outer iterations of HALS (which is equivalent to MU for R = 1), as described in Section 3.4.2.
We compare Factorizers against seven baseline models, among which nnU-Net, Res-U-Net, and Performer follow the same overall architecture and setup as those of Factorizers except that each has a different encoder/decoder block, allowing us to better assess the impact of blocks rather than architectures by eliminating the effect of architectural variability. These three baselines are detailed below: • nnU-Net: This model is based on nnU-Net [45], i.e., a standard 3D U-Net [2] with minor modifications. Each encoder (and decoder) block is composed of two convolutions with a kernel size of (3, 3, 3). Group Normalization (with a group size of 8) is adopted right after each convolution and before LeakyReLU nonlinearity. This model does not have any positional embedding in its bridge since CNNs already have some notion of position.
• Res-U-Net: ResNet block [18] is the cornerstone of Res-U-Net. This block is similar to that of nnU-Net, except that it has a residual connection after the last Group Normalization. This model is similar to SegResNet [16].
• Performer: The encoder and decoder blocks of this model are based on ViT (note that the input is first flattened into a sequence of voxels before feeding it to a ViT block). However, attention scales quadratically with the number of voxels, hence is prohibitively expensive in our case, where the input can have up to 128 3 voxels. Therefore, we replace original attention layers of a ViT block with FAVOR+ (Fast Attention Via positive Orthogonal Random features) used in Performer, recently proposed by Choromanski et al. [40] as a linearly scalable alternative to Transformer. FAVOR+ gives an unbiased estimate of attention using only linear (as opposed to quadratic) time and space complexity. Note that except for the attention layers, the rest of the components of this baseline are the same as those of Factorizer.
We also use four state-of-the-art Transformer-based baselines, namely TransBTS [46], UNETR [30], Swin UNETR [47], and nnFormer [31], each of which has a different overall architecture than that of Figure 1. These models follow U-shaped architectures but apply Transformer blocks only to low-resolution images after somehow downsizing the input using patchifying or convolutional tokenizers, thereby avoiding the computational intractability of self-attention on long sequences. In contrast, Global Factorizer and Performer exploit the global context at all stages of their architectures, from the lowest to highest resolution.

Brain Tumor Segmentation (BraTS)
Quantitative Evaluation. For all the experiments, we performed 5-fold cross-validation to estimate how capable our models are in generalizing to unseen data. The results on the BraTS dataset are reported in Table 1 and illustrated by box plots in Figure 3a and 3b, where pairwise Wilcoxon signed-rank tests were used for comparing the performance of our best model, Swin Factorizer, with that of the baselines. Swin Factorizer is the clear overall winner in the brain tumor segmentation task. With an average Dice score of 84.21% and HD95 of 6.89 mm, Swin Factorizer significantly outperformed Res-U-Net, the best CNN-based and second-best overall baseline, with p-values of < 0.01 for the average Dice score and < 0.05 for the average HD95, while requiring six times fewer computations. Swin Factorizer was the best-performing model on ET and the second best-performing model on TC and WT, particularly yielding the highest Dice score of 79.33% on ET. Despite having over 95% fewer parameters and requiring over 60% fewer FLOPs, Swin Factorizer still outperformed nnFormer, the bestperforming baseline, in terms of average Dice and HD95. With an average Dice score of 83.61%, Local Factorizer yielded comparable results to Res-U-Net, and both Global and Local Factorizer demonstrated improved performance compared to Performer, nnU-Net, TransBTS, UNETR, and Swin UNETR. As shown in Figure 4, while achieving competitive performance over the baselines, Factorizer models have much fewer parameters and are significantly cheaper.
Performer is the Transformer-based counterpart of Global Factorizer in the sense that they both follow the same overall architecture without imposing any locality inductive bias in their blocks. However, while having lower computational complexity, Global Factorizer, with an average Dice score of 83.24% and an average HD95 of 9.71 mm, marginally outperformed Performer, with an average Dice score of 83.16% and an average HD95 of 10.21 mm. Local Factorizer improved the performance of Global Factorizer by exploiting locality, particularly with the average HD95 significantly dropping from 9.71 mm to 7.41 mm (p-value < 0.0001). As expected, Swin Factorizer improved all the scores of Local Factorizer, which is consistent with the fact that Swin Factorizer modifies Local Factorizer by better representing the boundary voxels in Local Matricize.
Qualitative Comparisons. Qualitative comparisons of glioma segmentation models are presented in Figure 5. Swin Factorizer demonstrates superior performance in segmenting TC and ET. This capability is evident in row 1 (row 2), where Swin Factorizer more successfully delineates TC compared to the other models, which do not as accurately distinguish ED (normal tissues) from TC. Particularly, nnU-Net misclassifies a significantly larger part of ED (normal tissues) as NCR/NET (ET). Row 3 exemplifies a successful detection of TC by Swin Factorizer, while the other models miss a fairly large NCR/NET region.
NMF Components Interpretability. One additional advantage of Factorizer over Transformer and CNN models is its higher interpretability resulting from meaningful components of NMF in the sense that each component represents specific image semantics in practice. Note that both the first and last Factorizer blocks (i.e., the high-resolution blocks of the encoder and the decoder, which are just after the stem layer and just before the head layer, respectively) have C = 32 channels, divided into groups of 8-channel heads during matricization, where the NMF of each head has only a single rank-one term, i.e., R = 1. As a result, in total, there are CR/8 = 4 components in both the first and last blocks. Figure 6b and 6c show the components of the first and last NMF layers on a high-grade glioma case for Swin Factorizer and Global Factorizer. More precisely, each component illustrates the factor matrix corresponding to spatial dimensions after dematricization. Interestingly, the components appear highly meaningful and interpretable in the sense that each component gives an interpretation by differentiating one region from another. For instance, as observed in Figure 6b, the first (row 1) and second (row 2) components discriminate roughly WT while the third (row 3) and fourth (row 4) components capture TC. For Swin Factorizer, NCR is distinguished more clearly in the third and fourth components whereas for Global Factorizer NCR is more recognizable in the second component. As we get to deeper layers, components become even more meaningful such that in the last layer, each component clearly detects some regions. For example, as seen in Figure 6c, the first (row 1) and second (row 2) components very clearly discriminate WT while the third (row 3) and fourth (row 4) components also capture TC and NCR.
Another interesting observation is that Global Factorizer yields more discriminative and higher-level components in the first layer, which can be attributed to the fact Global Factorizer models long-range dependencies through all of its blocks, including the first one; however, the receptive field of Swin Factorizer in the first stage of the encoder is small but progressively increases as it passes through downsampling layers. Therefore, Swin Factorizer extracts lower-level features in shallower layers and higher-level ones in deeper layers in a similar way to CNNs. Notice that the footprints of sliding windows, which appear as a grid pattern, are also evident in all the components of Swin Factorizer.

Ischemic Stroke Lesion Segmentation (ISLES)
Quantitative Evaluation. Similarly to Section 4.3.1, 5-fold cross-validation was performed, and pairwise Wilcoxon signed-rank tests were used for comparison purposes. The results on the ISLES'22 dataset are reported in 2 and illustrated by box plots in Figure  Necrosis   3c and 3d. Swin Factorizer, with a Dice score of 76.49% and HD95 of 11.96 mm, demonstrated the best performance, significantly outperforming nnU-Net and all the Transformer-based models with a p-value of < 0.001 for the Dice score and < 0.01 for the HD95 value. Moreover, Swin Factorizer showed improved performance compared to Res-U-Net, which has over three times more parameters and needs over four times more FLOPs.
Despite having over 95% fewer parameters and taking over 50% fewer FLOPs, Local Factorizer yielded a Dice score of 74.28%, which is higher than that of nnFormer, the best-performing Transformer-based model. Local Factorizer also demonstrated a smaller HD95 value with 80% fewer FLOPs than Res-U-Net, the best baseline overall. Finally, Global Factorizer outperformed its Transformer-based model, Performer, by a large margin, whereas it still has the advantage of lower computational cost. As a side note, our Factorizer-based model trained on all the three modalities (i.e., DWI, ADC, and registered FLAIR) and submitted to the ISLES'22 MICCAI challenge ranked among the top three in the final leaderboard, which further verifies the potential of Factorizer as an effective alternative for 3D medical image segmentation (please see https://isles22.grandchallenge.org/isles22/).
Qualitative Comparisons. Qualitative comparisons of stroke lesion segmentation models are presented in Figure 7. Compared to UNETR and nnU-Net, our Swin Factorizer and Global Factorizer models substantially reduce false positives, as observed in row 1. Both UNETR and nnU-Net produce large regions of incorrect lesions circled in green,  but nnU-Net suffers from fewer false positives than UNETR, as evident in row 2.
Although nnFormer seems to yield relatively small false positive regions compared to UNETR and nnU-Net, Swin Factorizer not only produces even slightly fewer false positives but also enjoys more favorable results when it comes to false negatives. This is exemplified by row 3, where Swin Factorizer successfully captures both the stroke lesions, whereas nnFormer fails to detect the lesion circled in orange. Overall, Swin Factorizer displays very competitive segmentation results, superior to those of UNETR, nnFormer, and nnU-Net in most cases. These results verify the potential of Factorizer as an alternative to state-of-the-art models, such as nnFormer and nnU-Net.

Ablation Studies: Training
We conducted ablation studies on Factorizer to further investigate the effects of Factorizer subblocks and positional embedding. In this section, models with an ablated layer were trained from scratch on the BraTS dataset.
Factorizer Subblock. Ablation results of subblocks on brain tumor segmentation are reported in Table 3. When we removed NMF subblocks from a Factorizer model, the performance substantially dropped while Factorizer models without MLP subblocks demonstrated less deterioration in results. Local Factorizer without MLP blocks yielded an  average Dice of 82.67%, still outperforming nnU-Net, and Swin Factorizer without MLP blocks yielded an average Dice of 83.57%, which is significantly greater than that of nnU-Net and comparable with that of Res-U-Net. These results indicate the effectiveness of the NMF layer in improving the models.
Positional Embedding. Table 4 shows the results of the ablation study on positional embedding. In all the cases, adding positional embedding to the bridge of a network improved the performance. Particularly, the average Dice score of Global Factorizer increased from 83.09% to 83.24% significantly with p-value = 0.031, and the average HD95 fell from 10.12 mm to 9.71 mm although this improvement is not statistically significant. Local and Swin Factorizers without positional embedding yielded an average Dice of 83.40% and 83.99%, respectively, slightly underperforming compared to their counterparts with positional embedding. Like Transformers, Factorizers lack any notion of voxel position, and therefore typically benefit from a positional embedding mechanism, which is generally consistent with our results.

Ablation Studies: Inference
Since the output of an NMF layer is a low-rank approximation of the input, it makes sense to perform an ablation study in the inference phase, for instance by short-circuiting an NMF layer. For all the experiments in this section, we ablated some NMF layers or changed their settings in the inference phase after training the model on the BraTS  dataset.
NMF Layer. We investigated the impact of short-circuiting some NMF layers of the pre-trained Swin Factorizer model in the inference phase. In Figure 8a, the average Dice score on BraTS is shown when we kept the first NMF layers and removed (or shortcircuited) the rest. As expected, the more layers we kept, the higher the Dice score was achieved. We observed that most of the performance is achieved via the encoder, which includes the first five blocks. Figure 8b shows the results of investigating the impact of each individual layer, where we kept all the layers except one at a time. Interestingly, we noticed that the NMF layer of the bridge block (layer 5) makes the greatest contribution to the performance. In fact, if an NMF layer except that of the bridge is ablated, the average Dice still stays above 80%. Particularly, if an NMF layer in the decoder (layers 6 to 9) is ablated, the performance is still better than that of nnU-Net. Note that removing an NMF layer, especially those at higher-resolution stages of the network, can significantly reduce the computational complexity and speed up the inference time.
NMF Solver Iterations. We investigate the effect of changing the number of outer iterations (T ) in HALS. Recall that all the Factorizer models were originally trained using HALS with T = 5. Figure 8c shows the results of experimenting with T ∈ {1, . . . , 20} in the inference phase for the Swin Factorizer model pre-trained with T = 5. We noticed that for T > 1, the performance is very close to that of the original model (T = 5). Interestingly, T ∈ {2, 3, 4} yielded even higher average Dice scores compared to the original model although the improvement was not statistically significant. We attribute this to the possible regularization effect of reducing T . Note that we proportionally decreased the computational cost of NMF layers by reducing T while preserving the accuracy. In fact, by training a Factorizer, we also obtain lightweight yet accurate versions of that for free without needing to re-train any model from scratch. All we need is to reduce T or ablate some costly NMF layers in the inference phase. When it comes to model speed-up, this brings a great advantage to Factorizers over CNNs and Transformers, which require much more complex mechanisms to achieve their faster versions.
Rank. We investigated the impact of changing the rank of the pre-trained Swin Factorizer model in the inference phase. Recall that Swin Factorizers were trained with R = 1, in which both HALS and MU lead to the same update rule. Therefore, for R > 1 in inference, we experimented with both HALS and MU in order to make a comparison. In Figure 8d, the average Dice score on BraTS is plotted as a function of rank. As observed, the more the rank deviated from R = 1 (the one used in training), the more significantly the Dice score dropped. In comparison to HALS, MU demonstrated a less dramatic reduction in performance as we increased the rank. This can be due to the fact that within the same number of outer iterations, HALS typically makes a larger decrease in the NMF objective (given in equation (10)), causing the HALS-based model to deviate further from the original model than that of MU.

Conclusion and Future Work
Vision Transformers, particularly those with hierarchical architectures, have recently achieved results comparable with state-of-the-art CNNs on various computer vision tasks. Nevertheless, the lack of locality inductive bias makes them underperform their CNN counterparts in low-data regimes, which is usually the case in medical image segmentation. Moreover, the quadratic complexity of attention makes existing Transformers apply self-attention layers only after somehow reducing the image resolution, and thus, fail to fully capture long-range contexts present at higher resolutions. Hence, this paper introduces a family of models, called Factorizer, which leverages the power of low-rank approximation for developing a scalable interpretable approach to context modeling by formulating NMF as a differentiable layer integrated into an end-to-end U-shaped archi- tecture. Built upon NMF and shifted window idea, Swin Factorizer competed favorably with CNN and Transformer baselines in terms of accuracy and scalability. Swin Factorizer yielded state-of-the-art results on BraTS for brain tumor segmentation; with Dice scores of 79.33%, 83.14%, and 90.16% for enhancing tumor, tumor core, and whole tumor, respectively; and on ISLES'22 for stroke lesion segmentation, with a Dice score of 76.49%. Our experiments indicated that NMF components are highly meaningful, which gives a great advantage to Factorizers over CNNs and Transformers in terms of interpretability. Moreover, our ablation studies revealed a distinctive feature of Factorizers that allows the speed-up of inference for a pre-trained Factorizer with no extra steps and without sacrificing much accuracy. Matrix factorization models are very flexible and versatile. In this work, we used an ordinary NMF model together with some matricization techniques to model local or global contexts. One possible extension of this work is to customize the NMF objective to simultaneously exploit both local and global contexts. Moreover, it would be useful to explore some ideas for automated selection of the rank hyperparameter, for example, by taking a greedy approach to NMF, a.k.a. Nonnegative Matrix Underapproximation, where the components are constructed and added one by one (sequentially) until some criteria are met. This can benefit especially Local Factorizer as different regions happen to have different optimal ranks depending on their contexts. We will explore the effectiveness of other NMF variants, such as Semi-NMF and Convex-NMF, which work on mixed-signed data matrices and relax the need for the ReLU activation function before factorization. Finally, while this paper focuses on the segmentation of 3D medical images, Factorizer may also potentially serve as an effective approach for efficiently processing high-resolution 2D medical or natural images, which can be further investigated in the future.