1 Introduction

Semantic segmentation is a pivotal computer vision task that aims to assign labels to every pixel on the image. Widely adopted state-of-the-art methods like Fully Convolutional Networks (FCN) (Long et al., 2015) utilize deep convolutional neural networks (ConvNet) as encoders and incorporate segmentation decoders for dense predictions. Prior works (Wang et al., 2020; Yuan et al., 2020; Chen et al., 2018) have aimed to enhance performance by augmenting contextual information or incorporating multi-scale information, leveraging the inherent multi-scale and hierarchical attributes of the ConvNet architectures.

The advent of the Vision Transformer (ViT) (Dosovitskiy et al., 2021) has offered a paradigm shift, serving as a robust backbone for numerous computer vision tasks.ViT, distinct from ConvNet base models, retains a plain and non-hierarchical architecture while preserving the resolution of the feature maps. To conveniently leverage existing segmentation decoders for dense prediction, such as U-Net (Ronneberger et al., 2015) or DeepLab (Chen et al., 2018), recent Transformer-based approaches, including Swin Transformer (Liu et al., 2021) and PVT (Wang et al., 2021), have developed a hierarchical ViT to extract hierarchical feature representations.

Fig. 1
figure 1

Comparison with previous methods in terms of performance and efficiency on ADE20K dataset. The and bubbles in the accompanying graph represent the ViT Base and ViT Large models, respectively, with the size of each bubble corresponding to the FLOPs of the variant segmentation methods. SegViT-BEiT v2 Large achieves state-of-the-art performance with a 58.0% mIoU on the ADE20K validation set. Additionally, our efficient, optimized version, SegViT-Shrunk-BEiT v2 Large, saves half of the GFLOPs compared to UPerNet, significantly reducing computational overhead while maintaining a competitive performance of 55.7%

Fig. 2
figure 2

The overall concept of our Attention-to-Mask decoder. ATM learns the similarity map for each category by capturing the cross-attention between the class tokens and the spatial feature map (Left). \(\texttt{Sigmoid}\) is applied to produce category-specific masks, highlighting the area with high similarity to the corresponding class (Middle). ATM enhances the semantic representations by encouraging the feature to be similar to the target class token and dissimilar to other tokens

However, modifying the original ViT structures requires training the networks from scratch rather than using off-the-shelf plain ViT checkpoints due to the discrepancy between the hierarchical and plain architectures, such as spatial down-sampling (Xu et al., 2022). Altering the plain ViT architecture compromises the use of rich representations from vision-language pre-training methods like CLIP (Radford et al., 2021), BEiT (Bao et al., 2022), BEiT-v2 (Peng et al., 2022), MVP (Wei et al., 2022), and COTS (Lu et al., 2022).

Hence, there is a clear advantage to developing effective decoders for the original ViT structures in order to leverage those powerful representations. Previous works, such as UPerNet (Xiao et al., 2018) and DPT (Ranftl et al., 2021), have primarily focused on hierarchical feature maps and neglected the distinctive characteristics of the plain Vision Transformer. Consequently, these methods introduce computation-intensive operations while offering limited performance gains, as shown in Fig. 1.

A recent trend in several works, such as SETR (Zheng et al., 2021) or Segmenter (Strudel et al., 2021), aims to develop decoders specifically tailored for the Plain ViT architecture. However, these designs are often an extension of per-pixel classification techniques derived from traditional convolution-based decoders. For example, SETR’s decoder (Zheng et al., 2021) uses a sequence of convolutions and bilinear up-sampling to increase the ViT’s extracted feature maps gradually. It then applies a naive MLP to the extracted features to perform pixel-wise classification, which isolates the neighboring contexts surrounding the pixel. Current pixel-wise classification decoder designs overlook the importance of contextual learning when assigning labels to each pixel.

Another prevalent issue in deep networks, including Transformer, is ‘catastrophic forgetting’ (French, 1999; Kirkpatrick et al., 2017), where the model’s performance on previously learned tasks deteriorates as it learns new ones (Shao & Feng, 2022; Wang et al., 2022, 2022; Phan et al., 2022). This limitation poses significant challenges for the application of deep segmentation models in dynamic real-world environments. Recently, the rapid development of the foundation model pre-trained on large-scale data has sparked interest among researchers in studying its transferability across various downstream tasks (Ostapenko et al., 2022). These models are capable of extracting powerful and generalized representations, which has led to a growing interest in exploring their extensibility to new classes and tasks while retaining the previously learned knowledge representations (Ramasesh et al., 2022; Wu et al., 2022).

Inspired by these challenges, this paper aims to develop plain Vision Transformer-based model for effective semantic segmentation without resorting to hierarchical backbone. As self-supervision and multi-modality pre-training continue to evolve, we anticipate that the plain vision transformer will learn enhanced visual representations. Consequently, decoders for dense tasks are expected to adapt more flexibly and efficiently to these representations.

In light of these research gaps, we propose SegViTv2—a novel, efficient segmentation network that features a plain Vision Transformer and exhibits robustness against forgetting. We introduce a novel Attention-to-Mask (ATM) module that operates as a lightweight component for the SegViT decoder. Leveraging the non-linearity of cross-attention learning, our proposed ATM employs learnable class tokens as queries to pinpoint spatial locations that exhibit high compatibility with each class. We advocate for regions affiliated with a particular class to possess substantial similarity values that correspond to the respective class token.

As depicted in Fig. 2, the ATM generates a meaningful similarity map that accentuates regions with a strong affinity towards the ‘Table’ and ‘Chair’ categories. By simply implementing a Sigmoid operation, we can transform these similarity maps into mask-level predictions. The computation of the mask scales linearly with the number of pixels, a negligible cost that can be integrated into any backbone to bolster segmentation accuracy. Building upon this efficient ATM module, we present a novel semantic segmentation paradigm that utilizes the cost-effective structure of plain ViT, referred to as SegViT. Within this paradigm, multiple ATM modules are deployed at various layers to extract segmentation masks at different scales. The final prediction is the summation of the outputs derived from these layers.

To alleviate the computational burdens of plain Vision Transformers (ViTs), we introduce the Shrunk and Shrunk++ structures, which incorporate query-based downsampling (QD) and query-based upsampling (QU). The proposed QD employs a 2x2 nearest neighbor downsampling technique to obtain a sparser token mesh, reducing the number of tokens involved in attention computations. In Shrunk++, we extend QD to edge-aware query-based downsampling (EQD). EQD selectively preserves tokens situated at object edges, as they possess more discriminative information. Consequently, QU recovers the discarded tokens within the object’s homogeneous body, reconstructing high-resolution features crucial for accurately dense prediction. Integrating the Shrunk++ structure with the ATM module as the decoder, our SegViTv2 achieves computational reductions of up to 50% while maintaining competitive performance.

We further adapt our SegViTv2 framework for continual learning. Leveraging the robust, generalized representation of the foundational model, this paper investigates its adaptability to new classes and tasks, ensuring retention of prior knowledge. Recent techniques in continual semantic segmentation (CSS) aim to replay old data (Maracani et al., 2021; Cha et al., 2021) or distill knowledge from the previous model to mitigate model divergence (Cermelli et al., 2020; Phan et al., 2022; Zhang et al., 2022). These methods fine-tune parameters related to old tasks, which can disrupt the previously learned solutions and result in forgetting. In contrast, our proposed SegViT supports learning new classes without interfering with previously acquired knowledge. We strive to establish a forget-free SegViT framework, achieved by incorporating a new ATM module dedicated to new tasks while freezing all old parameters. Consequently, the proposed SegViT architecture has the potential to eliminate the issue of forgetting.

Our key contributions can be summarized as follows:

  • We introduce the Attention-to-Mask (ATM) decoder module, a potent and efficient tool for semantic segmentation. For the first time, we exploit spatial information present in attention maps to generate mask predictions for each category, proposing a new paradigm for semantic segmentation.

  • We present the Shrunk++ structure, applicable to any plain ViT backbone, which alleviates the intrinsically high computational expense of the non-hierarchical ViT while maintaining competitive performance, as illustrated in Fig. 1. We are the first work capitalizing on edge information to decrease and restore tokens for efficient computation. Our Shrunk++ version of SegViTv2, tested on the ADE20K dataset, achieves a mIoU of 55.7%, with a computational cost of 308.8 GFLOPs, marking a reduction of approximately 50% compared to the original SegViT (637.9 GFLOPs).

  • We propose a new SegViT architecture capable of continual learning with nearly zero forgetting. To our knowledge, we are the first work seeking to completely freeze all parameters for old classes, thereby nearly obliterating the issue of catastrophic forgetting.

2 Related Work

2.1 Semantic Segmentation

Semantic segmentation aims to partition an image into regions with meaningful categories. Fully Convolutional Networks (FCNs) used to be the dominant approach to this task. To enlarge the receptive field, several approaches (Zhao et al., 2017; Chen et al., 2018) propose dilated convolutions or apply spatial pyramid pooling to capture contextual information at multiple scales. Most semantic segmentation methods aim to classify each pixel directly using a classification loss. This paradigm naturally partitions images into different classes.

Various methods have achieved significant advancements by integrating Transformers into the semantic segmentation task. Early works (Liu et al., 2021; Dong et al., 2022) directly adapt the transformer encoder, designed for classification, into semantic segmentation by fine-tuning it together with segmentation decoders such as UPerNet (Xiao et al., 2018). Recent approaches (Xie et al., 2021; Strudel et al., 2021; Cheng et al., 2021) have focused on designing the overall segmentation framework to achieve better adaptation. For instance, SETR (Zheng et al., 2021) views semantic segmentation as a sequence-to-sequence task and proposes a pure Transformer encoder combined with a standard convolution-based decoder. SegFormer (Xie et al., 2021) employs a hierarchical encoder design to extract features from fine-to-coarse levels and a lightweight decoder design for efficient prediction. However, the SegFormer decoder adopts the pyramid structure by fusing multi-scale features, which is specialized for hierarchical ViTs such as Swin Transformer (Liu et al., 2021). The above-mentioned methods aim to design either a naive convolution-based decoder or a pyramid-structure decoder for hierarchical base models. Nonetheless, designing an effective decoder specialized for plain ViTs remains an open research question.

Recently, several segmentation methods propose a universal framework that unifies multiple tasks, including instance segmentation, semantic segmentation, and object detection. For example, Mask DINO (Li et al., 2022) extends DINO with a mask prediction branch, achieving promising results in the instance, panoptic, and semantic segmentation tasks. Mask2Former (Cheng et al., 2022) enhances MaskFormer (Cheng et al., 2021) by introducing deformable multi-scale attention in the decoder and a masked cross-attention mechanism. OneFormer (Jain et al., 2022) represents a universal image segmentation framework with a multi-task train-once design, outperforming specialized models in various tasks.

Recent methods (Cheng et al., 2021; Strudel et al., 2021; Zhang et al., 2021) propose decoupling the per-pixel classification into image partitioning and region classification. For image partitioning, they use learnable tokens as mask embeddings and associate them with the extracted feature map to generate object masks. For region classification, the learnable tokens are fed to a classifier to predict the class corresponding to each mask. This paradigm enables global segmentation and alleviates the burden on the decoder to perform per-pixel classification, resulting in state-of-the-art performance (Cheng et al., 2021). While previous works use generic tokens for mask generation, this work explicitly utilizes class-specific tokens to enhance the semantics of mask embeddings, thereby improving segmentation accuracy.

2.2 Mask-Oriented Segmentation

Compared to previous mask-oriented segmentation techniques such as MaskFormer (Cheng et al., 2021) and Mask2Former (Cheng et al., 2022), our method presents several novel conceptual differences and advantages. Specifically, our approach is tailored to address semantic segmentation problems by assigning each class to a fixed token and generating the corresponding mask directly. In contrast, MaskFormer relies on Hungarian matching, with each learnable query corresponding to spatial information instead of category information. Our Attention-to-Mask (ATM) approach eliminates the need for positional embedding, as we utilize the attention map between the class token and the feature map. Our overarching goal is to adapt Plain Vision Transformers for dense prediction, as recent studies have demonstrated that self-supervised learning (He et al., 2022; Chen et al., 2022; Touvron et al., 2022; Peng et al., 2022) and multimodal learning (Radford et al., 2021) are enhanced by hierarchical ViT structures. Our approach enhances the representation ability of class tokens by applying transformer blocks.

Previous CNN-based decoders, such as OCRNet (Yuan et al., 2019) and K-Net (Zhang et al., 2021), have demonstrated the effectiveness of the attention mechanism in modeling contextual information. For example, K-Net utilizes semantic kernels (one kernel for each class) and performs convolution operations to generate the semantic mask. In contrast, our proposed ATM module integrates cross-attention mechanisms, allowing for more effective contextual learning. While OCRNet (Yuan et al., 2019) applies cross-attention from the class token to the feature map to enhance feature representations, it still employs a standard linear predictor in the decoder to produce the segmentation map. On the other hand, our proposed ATM module is specifically designed for generating segmentation outputs, paving the way for future research on effective decoders for plain ViT. Additionally, existing convolution-based attention networks such as OCRNet (Yuan et al., 2019), K-Net (Zhang et al., 2021), and DANet (Fu et al., 2019) adopt the traditional per-pixel classification framework for segmentation generation. In contrast, our proposed SegViT decouples segmentation into mask prediction and classification, which proves advantageous for establishing connections between the class proxy and language representations (Zhou et al., 2022), as well as facilitating continual learning.

2.3 Transformers for Vision

In the realm of image classification tasks, attention-based transformer models have emerged as powerful alternatives to standard convolution-based networks. The original ViT (Dosovitskiy et al., 2021) represents a plain, non-hierarchical architecture. However, there have been several advancements in the field of hierarchical transformers, such as PVT (Wang et al., 2021), Swin Transformer (Liu et al., 2021), Twins (Chu et al., 2021), SegFormer (Xie et al., 2021), and P2T (Wu et al., 2022). These hierarchical transformer models inherit certain design elements from convolution-based networks, including hierarchical structures, pooling, and downsampling with convolutions. Consequently, they can be seamlessly employed as direct replacements for convolutional-based networks and can be coupled with existing decoder heads for tasks such as semantic segmentation.

2.4 Self-Supervised Vision Transformers

Self-supervised learning has emerged as a powerful technique for pretraining visual models, eliminating the need for labeled data. One notable self-supervised method is MAE (He et al., 2022) (Masked Autoencoder), which trains a vision transformer to reconstruct masked regions of input images. This approach results in a high generalization capacity. Another significant method is CLIP (Radford et al., 2021) (Contrastive Language-Image Pre-Training), which involves joint training of a vision transformer and a language model on a large corpus of text and images, leading to the creation of a comprehensive knowledge store. CAE (Chen et al., 2022) aims to learn image representations that are invariant to context changes and effectively capture underlying semantic content. Furthermore, iBot (Zhou et al., 2022) performs masked visual learning using an online tokenizer and self-distillation mechanism, facilitating semantic representation learning. In our approach, we leverage attention to masks to optimize the extraction of dense hidden representations, thereby enhancing the segmentation capability of our model.

Fig. 3
figure 3

The overall SegViT structure with the ATM module. The Attention-to-Mask (ATM) module inherits the typical transformer decoder structure. It takes in randomly initialized class embeddings as queries and the feature maps from the ViT backbone to generate keys and values. The outputs of the ATM module are used as the input queries for the next layer. The ATM module is carried out sequentially with inputs from different layers of the backbone as keys and values in a cascade manner. A linear transform is then applied to the output of the ATM module to produce the class predictions for each token. The mask for the corresponding class is transferred from the similarities between queries and keys in the ATM module. We have removed the self-attention mechanism in ATM decoder layers further improve the efficiency while maintaining the performance

2.5 Plain-Backbone Decoders

For dense prediction tasks, such as semantic segmentation, the high-resolution feature maps produced by the backbone are vital for preserving spatial details. In typical hierarchical transformer models, techniques such as FPN (Lin et al., 2017) or dilated backbone are employed to generate high-resolution feature maps by merging features from different levels. However, when it comes to a plain, non-hierarchical transformer backbone, the resolution remains the same across all layers. SETR (Zheng et al., 2021) proposed a straightforward approach to address segmentation tasks by treating transformer outputs from the base model in a sequence-to-sequence perspective. Segmenter (Strudel et al., 2021) combines class embeddings and transformer patch embeddings and applies several self-attention layers on the combined tokens to learn discriminative embeddings. In their approach, the class tokens are used as input to the ViT backbone, resulting in increased computational complexity. In contrast, our SegViT introduces the class tokens as input to the ATM, the Attention-to-Mask module, thereby reducing computational costs while still benefiting from the integration of class tokens.

2.6 Continual Learning

Continual learning (CL) aims to address the issue of forgetting, ensuring consistent performance on previously learned classes while adapting to new ones (Chen and Liu, 2016). Most CL methods propose regularization techniques for convolution-based networks (Li & Hoiem, 2018; Douillard et al., 2020; Kang et al., 2022; Peng et al., 2021) or expand the network architectures to accommodate new tasks (Yan et al., 2021), thereby avoiding the need to store and replay old data. In recent years, efforts have also emerged to prevent forgetting in Transformer models. Dytox (Douillard et al., 2022) dynamically learns new task tokens, which are then utilized to make the learned embeddings more relevant to the specific task. Lifelong ViT (Wang et al., 2022) and contrastive ViT (Wang et al., 2022) introduce cross-attention mechanisms between tasks through external key vectors, and they slow down the changes to these keys to mitigate forgetting. Despite the use of complex mechanisms to prevent forgetting, these methods still require fine-tuning of the network for new classes, which can result in interference with previously learned knowledge.

In the field of semantic segmentation, recent research has been devoted to addressing the forgetting issue in continual learning. However, in addition to forgetting, continual semantic segmentation (CSS) also encounters the problem of "background shift." This refers to the situation where foreground object classes from previous tasks are mistakenly classified as background in the current task (Cermelli et al., 2020). REMINDER (Phan et al., 2022) tackles forgetting in CSS by utilizing class similarity to identify the classes that are more likely to be forgotten. It then focuses on revising those specific classes to mitigate the forgetting problem. RCIL (Zhang et al., 2022) introduces a two-branch convolutional network, with one branch frozen and the other trained to prevent forgetting. At the end of each learning step, the trainable branch is merged with the frozen branch, which can introduce model interference. However, it is worth noting that existing CSS and CL techniques typically involve fine-tuning certain parts of the network dedicated to the old tasks. Unfortunately, this fine-tuning process can lead to forgetting as the model diverges from the previously learned solution.

3 Method

In this section, we first introduce the overall architecture of our proposed SegViT model for semantic segmentation. Then, we discuss the Shrunk and Shrunk++ architectures designed to reduce the model’s computational cost. Lastly, we explore the adaptation of our SegViT model for the context of continual semantic segmentation to minimize forgetting.

3.1 Overall SegViT Architecture

SegViT comprises a ViT-based encoder responsible for feature extraction and a decoder used to learn the segmentation map. For the encoder, we designed the ‘Shrunk’ structure to reduce the computational overhead associated with the plain ViT. Regarding the decoder, we introduce a novel lightweight module named Attention-to-Mask (ATM). This module generates class-specific masks denoted as M and class predictions denoted as P, which determine the presence of a particular class in the image. The mask outputs from a stack of ATM modules are combined and then multiplied by the class predictions to obtain the final segmentation output. Figure 3 illustrates the overall architecture of our proposed SegViT.

3.1.1 Encoder

Given an input image \(I \in {\mathbb {R}}^{H \times W \times 3}\), the plain vision transformer backbone reshapes it into a sequence of tokens \(\mathcal {F}_0 \in {\mathbb {R}}^{L \times C}\), where \(L = \frac{HW}{P^2}\), P is the patch size, and C is the number of channels. To capture positional information, learnable position embeddings of the same size as \(\mathcal {F}_0\) are added. Subsequently, the token sequence \(\mathcal {F}_0\) is processed by m transformer layers to produce the output. The output tokens for each layer are defined as \([\mathcal {F}_1, \mathcal {F}_2, \dots , \mathcal {F}_m] \in {\mathbb {R}}^{L \times C}\). For a plain vision transformer like ViT, the number of tokens are high and remains constant for each layer. Processing a substantial number of tokens for every layer results in elevated computational costs for plain ViT. We denote a plain ViT-based encoder as the ’Single’ structure. To mitigate computational costs, we introduce the Shrunk and Shrunk++ structures, tailored to create a more computationally efficient ViT-based encoder. Further details regarding the Shrunk structure can be found in Sect. 3.2.

3.1.2 Decoder

Attention-to-Mask (ATM) Cross-attention can be described as the mapping between two sequences of tokens, denoted as \(\{\mathbf {v_1}, \mathbf {v_2}\}\). In our case, we define two token sequences: \(\mathcal {G} \in {\mathbb {R}}^{N \times C}\) with a length N equal to the number of classes, and \(\mathcal {F}_i \in {\mathbb {R}}^{L \times C}\). To enable cross-attention, linear transformations are applied to each token sequence, resulting in the query (Q), key (K), and value (V) representations. This process is described by Eq. (1).

$$\begin{aligned} \begin{aligned} Q&= \phi _{q} (\mathcal {G}) \in {{\mathbb {R}}}^{N \times C},\\ K&= \phi _{k} ({\mathcal {F}}_{i}) \in {{\mathbb {R}}}^{L \times C},\\ V&= \phi _{v} ({\mathcal {F}}_{i}) \in {{\mathbb {R}}}^{L \times C}. \end{aligned} \end{aligned}$$
(1)

The similarity map is calculated by computing the dot product between the query and key representations. Following the scaled dot-product attention mechanism, the similarity map and attention map are calculated as follows:

$$\begin{aligned} \begin{aligned} S(Q, K)&= \frac{QK^T}{\sqrt{d_{k}}}\in {{\mathbb {R}}}^{N \times L}, \\ Attention(\mathcal {G}, \mathcal {F}_{i})&= \texttt{Softmax}\big (S(Q, K)\big )V \in {{\mathbb {R}}}^{N \times C}, \end{aligned} \end{aligned}$$
(2)

where \(\sqrt{d_{k}}\) is a scaling factor with \(d_{k}\) equals to the dimension of the keys.

The shape of the similarity map S(QK) is determined by the lengths of the two token sequences, N and L. The attention mechanism updates \(\mathcal {G}\) by performing a weighted sum of V, where the weights are derived from the similarity map after applying the softmax function along the L dimension.

In dot-product attention, the softmax function is used to concentrate attention exclusively on the token with the highest similarity. However, we believe that tokens other than those with maximum similarity also carry meaningful information. Based on this intuition, we have designed a lightweight module that generates semantic predictions more directly. To this end, we assign \(\mathcal {G}\) as the class embeddings for the segmentation task, and \(\mathcal {F}_i\) as the output of layer i of the ViT backbone. A semantic mask is paired with each token in \(\mathcal {G}\) to represent the semantic prediction for each class. The binary mask M is defined as follows:

$$\begin{aligned} Mask(\mathcal {G}, \mathcal {F}_{i}) = \texttt{Sigmoid}(S(Q, K)) \in {{\mathbb {R}}}^{N \times L}. \end{aligned}$$
(3)

The masks have a shape of \(N \times L\), which can be reshaped to \(N \times \frac{H}{P} \times \frac{W}{P}\) and bilinearly upsampled to the original image size \(N \times H \times W\). As depicted in the right section of Fig. 3, the ATM mechanism produces masks as an intermediate output during cross-attention.

The final output tokens \(Z \in {\mathbb {R}}^{L \times C}\) from the ATM module are utilized for classification. A fully connected layer (FC) parameterized by \(W \in {\mathbb {R}}^{C \times 2}\) followed by the Softmax function is used to predict whether the object class is present in the image or not. The class predictions \(\mathcal {P} \in {\mathbb {R}}^{N \times 2}\) are formally defined as:

$$\begin{aligned} \begin{aligned} \mathcal {P} = \text {Softmax}(WZ). \end{aligned} \end{aligned}$$
(4)

Here, \(P_{c,1}\) indicates the likelihood of class c appearing in the image. For simplicity, we refer to \(P_c\) as the probability score for class c.

The output segmentation map for class \(O_s \in {\mathbb {R}}^{H \times W}\) is obtained by element-wise multiplication of the reshaped class-specific mask \(M_c\) and its corresponding prediction score \(P_c\): \(O_c = P_c \odot M_c\). During inference, the label is assigned to each pixel i by selecting the class with the highest score using \(\text {argmax}_{c} O_{i,c}\).

Indeed, plain base models like ViT do not inherently possess multiple stages with features of different scales. Consequently, structures such as Feature Pyramid Networks (FPN) that merge features from multiple scales are not applicable to them.

Nevertheless, features from layers other than the last one in ViT contain valuable low-level semantic information, which can contribute to improving performance. In SegViT, we have developed a structure that leverages feature maps from different layers of ViT to enrich the feature representations. This allows us to incorporate and benefit from the rich low-level semantic information present in those feature maps.

SegViT is trained via the classification loss and the binary mask loss. The classification loss (\(\mathcal {L}_{\text {cls}}\)) minimizes cross-entropy between the class prediction and the actual target. The mask loss (\(\mathcal {L}_{\text {mask}}\)) consists of a focal loss (Lin et al., 2017) and a dice loss (Milletari et al., 2016) for optimizing the segmentation accuracy and addressing sample imbalance issues in mask prediction. The dice loss and focal loss respectively minimize the dice and focal scores between the predicted masks and the ground-truth segmentation. The final loss is the combination of each loss, formally defined as:

$$\begin{aligned} \mathcal {L}=\mathcal {L}_{\text {cls}} +\lambda _{\text {focal}}\mathcal {L}_{\text {focal}} + \lambda _{\text {dice}}\mathcal {L}_{\text {dice}} \end{aligned}$$
(5)

where \(\lambda _{\text {focal}}\) and \(\lambda _{\text {dice}}\) are hyperparameters that control the strength of each loss function. Previous mask transformer methods such as MaskFormer (Cheng et al., 2021) and DETR (Carion et al., 2020) have adopted the binary mask loss and fine-tuned their hyperparameters through empirical experiments. Hence, for consistency, we directly use the same values as MaskFormer and DETR for the loss hyperparameters: \(\lambda _{\text {focal}}=20.0\) and \(\lambda _{\text {dice}}=1.0\).

3.2 Shrunk Structure for Efficient Plain ViT Encoder

Recent efforts, such as DynamicViT (Rao et al., 2021), TokenLearner (Ryoo et al., 2021), and SPViT (Kong et al., 2022), propose token pruning techniques to accelerate vision transformers. However, most of these approaches are specifically designed for image classification tasks and, as a result, discard valuable information. However, when adapting these techniques to semantic segmentation tasks, they may fail to preserve high-resolution features that are necessary for accurate dense prediction tasks.

Fig. 4
figure 4

Architecture of the proposed query-downsapling (QD) layer (blue block) and the query-upsampling (QU) layer (block). The QD layer uses an efficient down-sampling technique (green block) and removes less informative input tokens used for the query. The QU layer takes a set of trainable query tokens and learns to recover the discarded tokens using multi-head attention (Color figure online)

Fig. 5
figure 5

Illustrations of the Shrunk and Shrunk++. In the diagram, the and boxes respectively refer to the transformer encoder block and the patch embedding block. In SegVit (Zhang et al., 2022), the proposed Shrunk structure employs query downsampling (QD) on the middle-level features to preserve the information. In the new Shrunk++ architecture, we introduce the Edged Query Downsampling (EQD) technique which consolidates every four adjacent tokens into one token and additionally includes the tokens that contain edges. This enhancement enables downsampling operations to take place before the first layer without significant performance degradation, offering computational savings for the initial layers of the Shrunk model. The edge information is extracted using a lightweight parallel edge detection head

In this paper, we introduce the Shrunk structure. This method employs query-based down-sampling (QD) to prune the input token sequence \(\mathcal {F}_i\) and uses query up-sampling (QU) to retrieve the discarded tokens, ensuring preservation of fine-detail features vital for semantic segmentation. The overall architecture of QD and QU is illustrated in Fig. 4.

For QD, we have re-designed the Transformer encoder block (Vaswani et al., 2017) and incorporated efficient down-sampling operations to specifically reduce the number of query tokens. In a Transformer encoder layer, the computational cost is directly influenced by the number of query tokens, and the output size is determined by the query token size. To mitigate the computational burden while maintaining information integrity, a viable strategy is to selectively reduce the number of query tokens while preserving the key and value tokens. This approach allows for an effective reduction in the output size of the current layer, leading to reduced computational costs for subsequent layers.

For QU, we perform up-sampling using a token sequence - either predefined or inherited - that has a higher resolution than the query tokens. The key and value tokens are taken from the token sequence obtained from the backbone, which typically has a lower resolution. The output size is dictated by the query tokens with higher resolution. Through the cross-attention mechanism, information from the key and value tokens is integrated into the output. This process facilitates a non-linear merging of information and demonstrates an upsampling behavior, effectively increasing the resolution of the output.

As illustrated in Fig. 5, our proposed Shrunk structure incorporates the QD and QU modules. Specifically, we integrate a QD operation at the middle depth of the ViT backbone, precisely at the \(8^\text {th}\) layer of a 24-layer backbone. The QD operation downsamples the query tokens using a \(2\times 2\) nearest neighbor downsampling operation, resulting in a feature map size reduction to 1/32. However, such downsampling can potentially cause information loss and performance degradation. To mitigate this issue, prior to applying the QD operation, we employ a QU operation to the feature map. This involves initializing a set of query tokens with a resolution of 1/16 to store the information. Subsequently, as the downsampled feature map progresses through the remaining backbone layers, it is merged and upsampled using another QU operation alongside the previously stored 1/16 high-resolution feature map. This iterative process ultimately generates a 1/16 high-resolution feature map enriched with semantic information processed by the backbone.

Despite the effectiveness of the proposed Shrunk approach in maintaining performance, it requires the integration of the QD operation within the intermediate layers of the backbone. This necessity arises due to the fact that shallow layers primarily capture low-level features, and applying downsampling to these layers would result in significant information loss. Consequently, these low-level layers continue to be computed at a higher resolution, limiting the potential reduction in computational cost.

To address this limitation and further optimize the backbone, we introduce SegViTv2 using a novel architecture called Shrunk++. In this architecture, we incorporate an edge detection module in the QD section and introduce an Edged Query Downsampling (EQD) technique to update the QD process. In addition to the \(2\times 2\) nearest downsampling operation that eliminates every 4 consecutive tokens, our approach aims to retain tokens that contain multiple categories, specifically tokens that contain an edge. By preserving the \(2\times 2\) sparse tokens, we retain important semantic information, while also preserving the edge tokens to retain detailed spatial information. By retaining both types of information, we minimize the loss of valuable information and overcome the limitations associated with low-level layers. To extract edges, we add a separate branch using a lightweight multilayer perceptron (MLP) termed as the edge detection head that learns to detect edges from the input image. The edge detection head operates as an auxiliary branch, trained simultaneously with the main ATM decoder. This head processes the input image, which has the same dimensions as the backbone. Let the input image have C channels, aligned with the backbone. The Multi-Layer Perceptron (MLP) in this head consists of three layers, with dimensions C, C/2, and 2, respectively. Let I represent the input image, and the output of the MLP can be defined as \(E = \text {MLP}(I; W_1, W_2, W_3)\), where \(W_1, W_2, W_3\) are the weights for the three layers. The output E is then passed through a softmax activation function, resulting in \(S = \text {Softmax}(E)\). To determine the confidence level of a token belonging to an edge, we apply a threshold \(\tau \). In our implementation, we set \(\tau \) to 0.7. To obtain the ground-truth (GT) edge, we perform post-processing on the GT segmentation map Y. Since the input has been tokenized with a patch size of P, we tokenize the GT and reshape it into a sequence of tokens denoted as \(Y \in R^{(HW/P^2) \times P \times P}\), where the last two dimensions correspond to the patch dimensions. We consider a patch to contain an edge if there exists any edge pixel within the patch. We define the edge mask \(Mask_{i}\) as follows:

$$\begin{aligned} Mask_{i}= & {} {\left\{ \begin{array}{ll} 1 &{} \text {if } \sum _{j,k} Y_{i,j,k} > 0, \\ 0 &{} \text {otherwise}. \end{array}\right. } \end{aligned}$$
(6)

For each element \(s_i\) in S, we create a binary edge mask \(M_i\): \(M_i = 1, \text {if } s_i \ge \tau \). The cross-entropy loss is computed between the generated edge mask \(M_i\) and the ground-truth edge mask \(Y_i\): \(\mathcal {L}_{\text {edge}} = - \sum {i} Y_i \log (M_i) + (1 - Y_i) \log (1 - M_i)\). By incorporating the Edge Detection head as an auxiliary branch, the Shrunk++ architecture effectively retains detailed spatial contexts throughout the query downsampling process, forming an Edge Query Downsampling (EQD) structure. This EQD structure effectively captures and retains edge information during sparse downsampling, significantly reducing computational overhead while maintaining performance. The integration of EQD enables the Shrunk++ architecture to strike a remarkable balance between computational efficiency and maintaining high-performance levels.

3.3 Exploration on Continual Semantic Segmentation

Continual semantic segmentation aims to train a segmentation model in T steps without forgetting. At step t, we are given a dataset \( {{\mathcal {D}} } ^t\) which comprises a set of pairs \((X^t, Y^t)\), where \(X^t\) is an image of size \(H\times W\) and \(Y^t\) is the ground-truth segmentation map. Here, \(Y^t\) only consists of labels in current classes \(\mathcal {C}^{t}\), while all other classes (i.e., old classes \(\mathcal {C}^{1:t-1}\) or future classes \(\mathcal {C}^{t+1:T}\)) are assigned to the background. In continual learning, the model at step t should be able to predict all classes \(\mathcal {C}^{1:t}\) in history.

3.3.1 SegViT for Continual Learning

Existing continual semantic segmentation methods (Zhang et al., 2022; Phan et al., 2022) propose regularization algorithms to preserve the past knowledge of a specific architecture, DeepLabV3. These methods focus on continual semantic segmentation for DeepLabV3 with a ResNet backbone, which has a less robust visual representation for distinguishing between different categories. Consequently, these methods require fine-tuning model parameters to learn new classes while attempting to retain knowledge of old classes. Unfortunately, adapting the old parameters dedicated to the previous task inevitably interferes with past knowledge, leading to catastrophic forgetting. In contrast, our proposed SegViT decouples class prediction from mask segmentation, making it inherently suitable for a continual learning setting. By leveraging the powerful representation capability of the plain vision transformer, we can learn new classes by solely fine-tuning the class proxy (i.e., the class token) while keeping the old parameters frozen. This approach eliminates the need for fine-tuning old parameters when learning new tasks, effectively addressing the issue of catastrophic forgetting.

Fig. 6
figure 6

Overview of SegViT adapted for continual semantic segmentation. When learning a new task t, we grow and train a separate ATM and fully-connected layer to produce mask and class prediction. All the parameters dedicated to the old task \(t-1\), including ATM, FC layers, and the ViT encoder, are frozen. This prevents interfering with the old knowledge, which guarantees no forgetting

During training on the current task t, we add a new sequence of learnable tokens \(\mathcal {G}^{t} \in {{\mathbb {R}}}^{|\mathcal {C}^{t}| \times C}\), where \(|\mathcal {C}^{t}|\) is the number of classes in the current task. To learn new classes, we grow and train new ATM modules and a fully-connected layer for mask prediction and mask classification. For simplicity, we ignore the parallel structure of ATM modules. A single ATM module refers to multiple ATM modules. Let \(A^{t}\) and \(W^{t}\) denote the ATM module and the weights of the fully connected (FC) layer for task t. All parameters for prior tasks, including the ViT encoder, the ATM module, and the FC layer, are completely frozen. Figure 6 illustrates the overview of our SegViT architecture adapted for continual semantic segmentation.

Given the encoder extracted features \(\mathcal {F}_{T}\) and the class tokens \(\mathcal {G}^{t}\), the ATM produces the mask predictions \(M^{t}\) and the output tokens \(Z^{t}\) corresponding to the mask:

$$\begin{aligned} M^{t}, Z^{t} = \text {ATM}(\mathcal {G}^{t}, \mathcal {F}^{T}). \end{aligned}$$
(7)

Based on Eq. (4), the class prediction \(\mathcal {P}\) is obtained by applying FC on the class token \(Z^t\).

The prediction score \(S^t_c\) for each class c is multiplied by the corresponding mask \(M^t_c\) to get the segmentation map \(O^t_c\) for class c:

$$\begin{aligned} O^t_c = S^t_c \odot M^t_c, \end{aligned}$$
(8)

where \(\odot \) denotes the element-wise multiplication. The segmentation \(\hat{O}^t\) is obtained by taking the class c having the highest score in every pixel, defined as

$$\begin{aligned} \hat{O}^t = \underset{c \in \mathcal {C}^t}{\text {argmax }} O^t_{i,c} \end{aligned}$$
(9)

Based on the ground truth \(Y^t\) for task t, SegViT is trained using the loss function defined in Eq. (5). To produce the final segmentation across all tasks, we concatenate the individual outputs \(O^t\) from each task.

4 Experiments

4.1 Datasets

ADE20K (Zhou et al., 2017) is a challenging scene parsing dataset which contains 20, 210 images as the training set and 2, 000 images as the validation set with 150 semantic classes.

COCO-Stuff-10K (Caesar et al., 2018) is a scene parsing benchmark with 9, 000 training images and 1, 000 test images. Even though the dataset contains 182 categories, not all categories exist in the test split. We follow the implementation of mmsegmentation (MMSegmentation, 2020) with 171 categories to conduct the experiments.

PASCAL-Context (Mottaghi et al., 2014) is a dataset with 4, 996 images in the training set and 5, 104 images in the validation set. There are 60 semantic classes in total, including a class representing ‘background’.

Table 1 Experiment results on the ADE20K val. split
Fig. 7
figure 7

Visuals results of different segmentation networks and plain ViT backbones on the ADE20K validation set (Zhou et al., 2017). It includes the following models: a Segmenter (Strudel et al., 2021) with ViT large, b StructToken (Lin et al., 2022) with ViT large, c UPerNet (Xiao et al., 2018) with BEiT large, and d SegViT V2 with BEiTv2 large. The results demonstrate that our methods effectively generate accurate segmentation masks and unlock the potential of plain ViT. Zoom in for a better view

4.2 Implementation Details

4.2.1 Transformer Backbone

We employ the naive ViT (Dosovitskiy et al., 2021) as the backbone for our method. For our ablation studies, we primarily utilize the ‘Base’ variation, while also presenting results based on the ‘Large’ variant. Notably, variations in performance can arise due to different pre-trained weights, as indicated by Segmenter (Strudel et al., 2021). To ensure equitable comparisons, we adopt the pre-trained weights provided by Augreg (Steiner et al., 2021), aligning with practices employed in Strudel (Strudel et al., 2021) and StructToken (Lin et al., 2022). These weights stem from training on ImageNet-21k with strong data augmentation and regularization techniques (Steiner et al., 2021). To explore the maximum capacity and assess the upper bound of our method, we also conduct experiments using stronger base models such as DEiT v3 (Touvron et al., 2022) and BEiT v2 (Peng et al., 2022).

4.2.2 Training Settings

We use MMSegmentation (MMSegmentation, 2020) and follow the commonly used training settings. During training, we apply sequential data augmentation techniques, including random horizontal flipping, random resizing within a ratio of 0.5 to 2.0, and random cropping. For most settings, the cropping dimensions are set to \(512\times 512\), except for PASCAL-Context where we use \(480 \times 480\), and for ViT-large backbone on ADE20K where we use \(640\times 640\). The batch size is set to 16 for all datasets with a total iteration of 160k, 80k, and 80k for ADE20k, COCO-Stuff-10k, and PASCAL-Context respectively.

Table 2 Experiment results on the COCO-Stuff-10K test. split
Table 3 Experimental results on the PASCAL-Context val. split

4.2.3 Evaluation Metric

We use the mean Intersection over Union (mIoU) as the metric to evaluate the performance. ‘ss’ means single-scale testing and ‘ms’ test time augmentation with multi-scaled (0.5, 0.75, 1.0, 1.25, 1.5, 1.75) inputs. All reported mIoU scores are in a percentage format. All reported computational costs in GFLOPs are measured using the fvcoreFootnote 1 library.

4.3 Comparisons with the State-of-the-Art Methods

4.3.1 Results on ADE20K

Table 1 reports the comparison with the state-of-the-art methods on ADE20K validation set using ViT backbone. The SegViT uses the ATM module with multi-layer inputs from the original ViT backbone, while the Shrunk is the one that conducts QD to the ViT backbone and saves \(50\%\) of the computational cost without sacrificing too much performance. Our approach achieves a state-of-the-art mIoU of \(58.2\%\) (MS) with the BEiTv2 Large backbone. To ensure a fair comparison, we evaluate our SegViT module with the BEiT-v2 large backbone on a crop size of \(512\times 512\), which consumes 374.0 GFlOPs. Our approach achieves a slightly better performance of \(56.5\%\) mIoU compared to Mask2former-Swin-L, which achieves \(56.1\%\) with 402.7 GFlops on a crop size of \(640\times 640\). Additionally, our Shrunk version offers around a 50% reduction in computational cost (308.8 GFLOPs), while delivering competitive performance with a mIoU of \(57.0\%\) (MS). Optimizing SegViT with ViT-Large using the proposed Shrunk++ reduces the computational cost of Shrunk by 3.05 times, while preserving the mIoU. Figure 7 shows the visual results of different segmentation methods. In contrast to other methods that often confuse similar classes and misclassify related concepts, our SegViT stands out by more precise object boundary delineation and achieving accurate segmentation of complete objects, even in cluttered scenes.

4.3.2 Results on COCO-Stuff-10K

Table 2 shows the result on the COCO-Stuff-10K dataset. Our method achieves \(50.3\%\) which is higher than the previous state-to-the-art StrucToken by \(1.2\%\) with less computational cost. Our Shrunk version achieves \(49.4\%\) mIoU with 224.8 GFLOPs, which is similar to the computational cost of a dilated ResNet-101 backbone but with much higher performance. By extending SegViT with the effective Shrunk++, we significantly decrease its GFLOPs by 1.82 times, while retaining a competitive mIoU.

4.3.3 Results on PASCAL-Context

Table 3 shows the results on the PASCAL-Context dataset. We follow HRNet (Sun et al., 2019) to evaluate our method and report the results under 59 classes (without background) and 60 classes (with background). Using full SegViT structure without adopting Shrunk or Shrunk++, we reach mIoU of \(67.14\%\) and \(61.63\%\) respectively for those two metrics, outperforming the state-of-the-art methods using the ViT backbones with less computational cost. By applying Shrunk and Shrunk++ architecture, the computational cost in terms of GLOPs is reduced by \(42\%\) and \(45\%\), respectively. Among all approaches evaluated on the PASCAL-Context dataset, SegViTv2 with Shrunk++ achieves the best trade-off between accuracy and efficiency.

4.4 Ablation Study

In this section, we conduct extensive ablation studies to show the effectiveness of our proposed methods.

4.4.1 Effect of the ATM Module

We conducted an analysis to evaluate the impact of using the proposed ATM module as an encoder. The results are summarized in Table 4. To establish a baseline for comparison, we introduced SETR-naive, which utilizes two \(1\times 1\) convolutions to directly derive per-pixel classifications from the final layer of the ViT-Base transformer output. From the results, it is evident that applying the ATM module under the supervision of a conventional cross-entropy loss leads to a performance improvement of 0.5%. However, the performance gains become much more substantial when we decouple the classification and mask prediction processes, supervising each separately. This approach results in a significant performance boost of 3.1%, highlighting the efficacy of the ATM module in enhancing semantic segmentation performance.

Table 4 Comparisons between our proposed ATM module with SETR (Zheng et al., 2021)

4.4.2 Ablation of the Feature Levels

The effects of using multiple-layer inputs from the backbone to the ATM modules are presented in Table 5. The incorporation of feature maps from lower layers leads to a notable performance improvement of 1.3%. We further investigated the impact of including more layers of features and observed additional gains in performance. After empirical testing, we determined that utilizing three layers yielded optimal results, resulting in an overall mIoU boost of 1.7%. These ablation studies confirm the effectiveness of our proposed ATM decoder and highlight the advantage of incorporating multi-layer features into the segmentation structure. This integration significantly enhances the performance of semantic segmentation tasks.

Table 5 Results of using different layer inputs to the SegViT structure on ADE20K dataset using ViT-Base as the backbone

4.4.3 SegViT on Hierarchical Base Models

We conducted an analysis to evaluate the performance of SegViT on hierarchical base models. For comparison, we selected two competitive methods, Maskformer (Cheng et al., 2021) and Mask2former (Cheng et al., 2022). The results presented in Table 6 indicate that, even though our method was not specifically designed for hierarchical base models, we are still able to achieve competitive performance while maintaining computational efficiency. This demonstrates the applicability of our SegViT approach to various types of ViT-Base models.

Table 6 The experiments use the Swin-Tiny (Liu et al., 2021) backbone and are carried out on the ADE20K dataset
Table 7 Ablation results of Shrunk and Shrunk++ version on the ADE20K dataset
Table 8 Ablation results of different decoder methods with their corresponding feature merge types and loss types

4.4.4 Ablation of Shrunk and Shrunk++ Strategies

In this section, we analyze the effectiveness of the different SegViT structures. Table 7 presents the effects of various techniques employed in each SegViT structure, including query upsampling (QU), query downsampling (QD), token-squeezing (TS) techniques, and segmentation heads. Applying the ATM head to the ’Single’ structure yields a notable performance improvement of 6.67% compared with using the SETR head. This demonstrates the effectiveness of the ATM head in enhancing the performance of the baseline structure. However, applying QD to the ’Single’ structure with the ATM head leads to a performance drop of 2.7%, suggesting the occurrence of information loss during the downsampling phase. Importantly, incorporating QU restores the performance. QU helps recover the discarded information from QD and reconstructs the high-resolution feature map, which is crucial for dense prediction tasks. Jointly leveraging QU and QD, the Shrunk architecture achieves optimal performance while reducing computational costs by 16.15% in comparison to the ‘Single’ structure.

In the proposed Shrunk++ structure, we analyze the performance of two main token-squeezing techniques: nearest downsampling and edge-aware downsampling. It is important to note that token squeezing is directly applied to the first layer of the network for optimal computational efficiency. Applying naive nearest downsampling with a 3x3 kernel reduces the GFLOPs of the Shrunk structure without token-squeezing by a factor of 2.97. However, reducing the computational cost with 3x3 and 2x2 nearest downsampling leads to a performance drop of 13%. In contrast, by incorporating an additional edge extractor into our Shrunk++ architecture, we significantly improve the mIoU, achieving performance on par with Shrunk, i.e., 49.9% mIoU, with a minor increase in computational cost to 74.6 GFLOPs. The edge-aware downsampling technique preserves the edge details, thereby preserving discriminative features for dense predictions. Among the different settings, the 2x2 + Naive MLP Edge setting achieves an optimal balance between performance and efficiency.

4.4.5 Ablation Studies on Decoder Variances

Different decoder methods are associated with specific feature merge types and loss types. In Table 8, we compare the designs of various decoders on a plain ViT backbone. For hierarchical base models like Swin, the resolution of the feature maps in each stage is reduced. Consequently, the adoption of a Feature Pyramid Network (FPN) is necessary to obtain feature maps with larger resolutions and rich semantic information. However, in Table 8, we observe that the FPN structure does not perform well with plain vision transformers. With plain ViT base models, the resolution remains constant, and the feature map of the final layer encapsulates the most comprehensive semantic information. Hence, our proposed method, which utilizes tokens to merge features from different levels, achieves superior performance. By simply replacing the FPN structure with the ATM-based token merge, we improve the performance from 46.7% to 50.6%. Regarding the loss type, the pixel-level loss refers to the conventional cross-entropy loss applied to the feature map. The dot product loss corresponds to the loss utilized in Carion et al. (2020) and Cheng et al. (2021). Attention mask loss indicates the direct application of mask supervision to the similarity map generated by the ATM during attention calculation. Incorporating loss supervision on the attention mask, as in our method, leads to a performance improvement of 0.6%.

Table 9 Ablation of the QD module in terms of the targets and methods to down-sample
Table 10 Comparisons for various ViT pre-training schedules on the validation set of ADE20K

4.4.6 Ablation for the QD Module

The motivation behind using QD is to leverage the pre-trained weights of the backbone. As shown in Table 9, using a stride-2 convolution with learnable parameters to downsample the query will disturb the pre-trained weights, leading to a notable decline in performance. Applying down-sampling to both the query and the key-value pairs would inevitably lead to information loss during the down-sampling process, which is evident in the lower performance. Our results show that applying \(2 \times 2\) nearest down-sampling exclusively to the query in the QD module yields better results. This approach allows us to preserve the pre-trained weights of the backbone while achieving the desired down-sampling effect.

4.5 Application 1: A Better Indicator for Feature Representation Learning

4.5.1 Background

Semantic segmentation serves as a fundamental vision task that has been extensively employed in previous research to assess the representation learning capabilities of weakly, fully, and self-supervised base models (He et al., 2022; Chen et al., 2022; Touvron et al., 2022; Peng et al., 2022). In prior work, the UPerNet decoder structure has been commonly used for semantic segmentation. However, the UPerNet decoder may not be a suitable indicator for evaluating the feature representation ability of the base model. This is primarily due to its heavier computational requirements and slower convergence rate. Additionally, variations in feature representation acquired by the base model can be substantial due to diverse training strategies during the fine-tuning process on semantic segmentation datasets Consequently, the task of semantic segmentation may not adequately evaluate the feature representation ability of pre-trained models.

4.5.2 Experiment Settings

In this section, we extensively evaluate our proposed SegVit across diverse weakly, fully, and self-supervised vision transformers, including those proposed by He et al. He et al. (2022), Chen et al. Chen et al. (2022), Touvron et al. Touvron et al. (2022), and the BEiT model Peng et al. (2022). We demonstrate that our method outperforms UPerNet Xiao et al. (2018) in both self-supervised and multi-modality base models, achieving state-of-the-art performance. Notably, our approach achieves superior performance to UPerNet while utilizing only 5% of the computational cost in terms of the decoder head. Table 10 illustrates that our proposed SegViT head consistently outperforms UPerNet across all base models. For the ViT-Base, our method improves the performance of UPerNet on the CLIP model by 1.16% while significantly reducing the computational cost. Similar findings are evident for ViT-Large base models. Furthermore, compared to UPerNet, our proposed SegViT’s decoder head exhibits a better alignment between the growth trend of segmentation accuracy and the classification accuracy on ImageNet. This clearly demonstrates the superior efficiency of our SegViT head compared to UPerNet, making it a more suitable indicator for representation learning in base models.

4.6 Application 2: Continual Semantic Segmentation

The decoupling of class prediction and mask segmentation in our proposed SegVit decoder makes it inherently well-suited for continual learning settings. This characteristic allows us to learn new classes by solely fine-tuning the class proxy (the class token), leveraging the powerful representation ability of the plain vision transformer while keeping the old parameters frozen. To validate the effectiveness of this new approach to continual learning, we conducted experiments following standard settings adopted by prior studies.

4.6.1 Experiment Settings

Continual Semantic Segmentation (CSS) has two settings (Cermelli et al., 2020; Douillard et al., 2021): disjoint and overlapped. In the disjoint setup, all pixels in the images at each step belong to the previous classes or the current class. In the overlapped setting, the dataset of each step contains all the images that have pixels of at least one current class, and all pixels from previous and future tasks are labeled as background. The overlapped setting is more realistic and challenging, thus we evaluate the performance of the overlapped setup on the ADE20k dataset.

Following prior studies (Phan et al., 2022; Cermelli et al., 2020; Douillard et al., 2021), we perform three experiments: adding 50 classes after training with 100 classes (100–50 setting with 2 steps), adding 50 classes each time after training with 50 classes (50–50 setting with 3 steps), adding 10 classes each time sequentially after training with 100 classes (100–10 setting with 6 steps).

4.6.2 Baselines

We conducted a comprehensive comparison of our proposed method against state-of-the-art Continual Semantic Segmentation (CSS) techniques, including RCIL (Zhang et al., 2022), PLOP (Douillard et al., 2021), REMINDER (Phan et al., 2022), SDR (Michieli & Zanuttigh, 2021), and MiB (Cermelli et al., 2020). To ensure fair comparisons, existing methods were evaluated using DeepLabV3 (Chen et al., 2017) with ResNet101 and ViT-Base backbones that were pre-trained on ImageNet-21k. The reported results for PLOP, RCIL, and REMINDER were obtained based on the codebases provided by the respective authors. Furthermore, we included the performance of the Oracle model, which represents the upper bound achieved by jointly training on all available data, serving as a benchmark for each method.

4.6.3 Metrics

We evaluate the model performance by five mIoU metrics. First, we compute mIoU for the base classes \(\mathcal {C}^0\), which reflects model rigidity: the model’s resilience to catastrophic forgetting. Second, we compute mIoU for all incremented classes \(\mathcal {C}^{1:T}\), which measures plasticity: the model capacity in learning new tasks. Third, we compute the mIoU of all classes in \(\mathcal {C}^{0:T}\) (all), which shows the overall performance of models. Fourth, we report the average of mIoU (avg) measured step after step as proposed by Douillard et al. (2021), which evaluates performance over the entire continual learning process. To ensure fair comparisons, we evaluate the relative performance of each CSS method in terms of relative mIoU reduction compared with its Oracle model, jointly trained on all data.

Table 11 CSS results on ADE20k in mIoU (%) on 100–50 and 100–10 settings
Table 12 Performance drop (degree of forgetting) of all classes grouped by tasks on the 100–10 setting

5 Results and Discussion

Table 11 shows the results of different CSS methods on ADE20k. Our SegViT-CL consistently outperforms existing methods in all mIoU for both settings. In terms of mIoU reduction, the proposed SegViT-CL only decreases the mIoU of the Oracle model by \(2.2\%\) on the 100–50 setting, which is two times better than the second-best method, RCIL with ResNet backbone with \(4.6\%\) reduction. This substantial enhancement over existing methods underlines the effectiveness of our proposed method in the continual semantic segmentation paradigm. On a long CL setting 100–10 with 6 tasks, ours is almost forgetting-free with a marginal mIoU reduction of \(0.3\%\), while recent CSS methods significantly suffer from forgetting with at least \(5.4\%\) mIoU reduction. Using the ViT backbone, existing methods including MiB, REMINDER, and PLOP still suffer from high mIoU reductions. Compared with the Oracle, MiB (Cermelli et al., 2020), PLOP (Douillard et al., 2021), and REMINDER (Phan et al., 2022) decrease the mIoU by 8.6%, 6.5% and 5.6% respectively on the 100–10 setting, demonstrating the sub-optimal performance of current CSS methods for ViT architecture. This highlights the need for developing a specialized ViT architecture that is robust to forgetting.

To evaluate the forgetting of every task on the 100–10 setting, we compute the performance drop at the last step compared with its initial mIoU when the model first learns the task. For example, the initial mIoU of task 2 is the mIoU of class 101–110 evaluated at step 2. Similarly, that of task 3 is the mIoU of class 111–120 reported at step 3. Table 12 shows the performance drop at the last step compared with the initial mIoU of each task. Averaged across 5 tasks, the mIoU only drops by 0.45%, which shows that SegViT is robust to forgetting across all tasks on the 100–10 setting. Figure 8 shows the mIoU on the base classes after incrementally training on many tasks in 100–5, which is a long continual learning setting with 11 tasks. Overall, our SegViT achieves nearly zero forgetting for almost all tasks at the last step. In contrast to previous CSS methods which require partial fine-tuning, the proposed SegViT supports completely freezing old parameters, effectively eliminating any interference with previously acquired knowledge.

Fig. 8
figure 8

mIoU of recent CSS methods on the first 100 base classes after incrementally learning new tasks on 100–5 settings with 11 tasks

6 Conclusion

This paper presents SegViTv2, a novel approach for semantic segmentation using plain ViT transformer base models. The proposed method introduces a lightweight decoder head that incorporates the Attention-to-mask (ATM) module. Additionally, a Shrunk++ structure is proposed to reduce the computational cost of the ViT encoder by 50% while maintaining competitive segmentation accuracy. Moreover, this work extends the SegViT framework to address the challenge of continual semantic segmentation, aiming to achieve nearly zero forgetting. By protecting the parameters of old tasks, SegViT effectively mitigates the impact of catastrophic forgetting. Extensive experimental evaluations conducted on various benchmarks demonstrate the superiority of SegViT over UPerNet, while significantly reducing computational costs. The introduced decoder head provides a robust and cost-effective avenue for future research in the field of ViT-based semantic segmentation.