Emergent Mind

Class-Discriminative Attention Maps for Vision Transformers

(2312.02364)
Published Dec 4, 2023 in cs.CV , cs.AI , cs.LG , and stat.ML

Abstract

Interpretability methods are critical components for examining and exploring deep neural networks (DNN), as well as increasing our understanding of and trust in them. Vision transformers (ViT), which can be trained to state-of-the-art performance with a self-supervised learning (SSL) training method, provide built-in attention maps (AM). While AMs can provide high-quality semantic segmentation of input images, they do not account for any signal coming from a downstream classifier. We introduce class-discriminative attention maps (CDAM), a novel post-hoc explanation method that is highly sensitive to the target class. Our method essentially scales attention scores by how relevant the corresponding tokens are for the predictions of a classifier head. Alternative to classifier outputs, CDAM can also explain a user-defined concept by targeting similarity measures in the latent space of the ViT. This allows for explanations of arbitrary concepts, defined by the user through a few sample images. We investigate the operating characteristics of CDAM in comparison with relevance propagation (RP) and token ablation maps (TAM), an alternative to pixel occlusion methods. CDAM is highly class-discriminative and semantically relevant, while providing implicit regularization of relevance scores. PyTorch implementation: \url{https://github.com/lenbrocki/CDAM} Web live demo: \url{https://cdam.informatism.com/}

Overview

  • Introduces a new interpretability method for Vision Transformers called class-discriminative attention maps (CDAM), enhancing understanding of AI decisions.

  • CDAM refines attention maps by focusing on class-specific or concept-based relevance, leading to more intuitive machine learning models.

  • Through gradients, CDAM separates targeted objects from the background and distinguishes them from other classes.

  • Compares favorably with other methods like relevance propagation and token ablation maps, demonstrating superior class discrimination and semantic consistency.

  • CDAM advances interpretability in AI, showing sparser, focused visualizations, and clearer separations, aiding in human-like intuition for model decisions.

Interpretability in AI is an essential area that helps us understand, trust, and improve machine learning models, particularly deep neural networks (DNNs) like Vision Transformers (ViTs). Vision Transformers, which apply mechanisms initially designed for language processing, have shown impressive results in image recognition tasks. However, while these models can intuitively represent image features, their understanding by humans becomes complex due to the lack of interpretability with respect to specific output classes.

To address this, a new method called class-discriminative attention maps (CDAM) has been introduced. CDAM refines the attention maps used in ViTs by incorporating class-specific signals from a downstream classifier or concept similarity measures. This can reveal which parts of an image are most relevant to the model when making decisions about certain classes or user-defined concepts, which not only enhances the interpretability of ViTs but also provides insights into how different concepts are represented within the model.

CDAM works by computing gradients with respect to token representations in the final layer of the transformer before passing through a classifier. This approach benefits from the existing high-quality object segmentation in attention maps, while also introducing important information about class relevance. For instance, in addition to revealing evidence for a particular class, the method can also show counter-evidence. It is class-discriminative in the sense that it clearly separates targeted objects from the background as well as objects that belong to other classes.

Moreover, the CDAM method offers explanations for broader concepts defined by example images. This concept-based approach does not rely on classifier outputs. Instead, it uses a similarity measure between latent representations of the images and a 'concept vector', allowing assessments of model decisions on concepts it hasn't been explicitly trained to recognize.

In comparison with other methods, such as relevance propagation (RP) and token ablation maps (TAM), CDAM shows strong semantic consistency and class-discrimination. RP is more class-discriminative than regular attention maps (AM), but CDAM and TAM provide clearer distinctions, with CDAM additionally showing implicit regularization and less noise. While RP and TAM are used for comparison, they do not serve as the absolute ground truth for feature relevance, since they provide different perspectives on the decision-making process.

The introduced method is demonstrated to be helpful in providing class-discriminative visualizations that align well with human intuition about the importance of certain image regions for a given class. This is shown by both qualitative visualizations and quantitative correlation assessment. CDAM stands out by not only showing sparser and more focused results than AM but also displaying clearer separations between targeted and non-targeted classes compared to RP.

In conclusion, CDAM represents a significant step forward in the interpretability of Vision Transformers. It can explain both classifier-based predictions and user-defined concepts, offering a versatile tool for understanding the complex representations within self-supervised ViTs. This method enhances the transparency and trust in AI-powered image recognition and offers a promising approach for examining and refining the sophisticated decision-making processes in advanced AI models.

Create an account to read this summary for free:

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.