Emergent Mind

FastSAM3D: An Efficient Segment Anything Model for 3D Volumetric Medical Images

(2403.09827)
Published Mar 14, 2024 in eess.IV and cs.CV

Abstract

Segment anything models (SAMs) are gaining attention for their zero-shot generalization capability in segmenting objects of unseen classes and in unseen domains when properly prompted. Interactivity is a key strength of SAMs, allowing users to iteratively provide prompts that specify objects of interest to refine outputs. However, to realize the interactive use of SAMs for 3D medical imaging tasks, rapid inference times are necessary. High memory requirements and long processing delays remain constraints that hinder the adoption of SAMs for this purpose. Specifically, while 2D SAMs applied to 3D volumes contend with repetitive computation to process all slices independently, 3D SAMs suffer from an exponential increase in model parameters and FLOPS. To address these challenges, we present FastSAM3D which accelerates SAM inference to 8 milliseconds per 128128128 3D volumetric image on an NVIDIA A100 GPU. This speedup is accomplished through 1) a novel layer-wise progressive distillation scheme that enables knowledge transfer from a complex 12-layer ViT-B to a lightweight 6-layer ViT-Tiny variant encoder without training from scratch; and 2) a novel 3D sparse flash attention to replace vanilla attention operators, substantially reducing memory needs and improving parallelization. Experiments on three diverse datasets reveal that FastSAM3D achieves a remarkable speedup of 527.38x compared to 2D SAMs and 8.75x compared to 3D SAMs on the same volumes without significant performance decline. Thus, FastSAM3D opens the door for low-cost truly interactive SAM-based 3D medical imaging segmentation with commonly used GPU hardware. Code is available at https://github.com/arcadelab/FastSAM3D.

FastSAM3D framework featuring a distilled 6-layer ViT-Tiny, lightweight prompt encoder, mask decoder, with 3D sparse attention.

Overview

  • Introduces FastSAM3D, an efficient 3D Segment Anything Model for volumetric medical image segmentation with rapid inference capabilities.

  • Proposes layer-wise progressive distillation and 3D sparse flash attention to enhance efficiency without sacrificing performance.

  • Validated on diverse datasets, demonstrating significant acceleration and reduced computational requirements while maintaining competitive performance.

  • Highlights implications for clinical applications requiring real-time interaction and future research directions, including mixed reality integration.

Efficient 3D Segmentation in Medical Imaging with FastSAM3D

Introduction

Recent advancements in Segment Anything Models (SAMs) hold significant potential for the field of medical image segmentation. Traditional deep learning models, despite their efficacy in specific tasks, often fall short in terms of generalizability. SAMs, with their zero-shot learning capabilities through prompt-based interactions, offer a promising solution. However, their application in 3D medical imaging has been hampered by extensive computational requirements. Addressing these limitations, this paper introduces FastSAM3D, a highly efficient 3D SAM designed for rapid inference in volumetric medical images, achieving substantial acceleration without compromising performance.

FastSAM3D Architecture and Innovations

FastSAM3D proposes two major innovations to enhance the efficiency of SAMs for 3D medical image segmentation:

  1. Layer-wise Progressive Distillation: This novel approach accelerates the image encoder by distilling a complex 12-layer Vision Transformer (ViT-B) to a more efficient 6-layer ViT-Tiny variant. Unlike traditional distillation methods, this layer-wise approach aligns the intermediate representations of the student and teacher models progressively, facilitating a smoother and more effective knowledge transfer.
  2. 3D Sparse Flash Attention: To further reduce computational requirements, FastSAM3D replaces the conventional self-attention mechanism with a combination of 3D sparse attention and flash attention. This hybrid approach allows for processing volumetric data in segments, leveraging the efficiency of flash attention for parallel processing, thus significantly reducing both the time and memory footprint of SAM operations.

Empirical Evaluation

The efficacy of FastSAM3D was validated across three diverse datasets spanning CT and MRI modalities, involving tasks such as abdominal organ segmentation and lesion localization. Experimental results demonstrate that:

  • Performance: FastSAM3D maintains competitive performance metrics, closely matching those of its 3D SAM predecessors, while substantially outperforming 2D SAM implementations, especially as the number of prompts increases.
  • Efficiency Gains: With a remarkable $527.38\times$ acceleration compared to 2D SAMs and an $8.75\times$ speed-up over existing 3D SAMs, the model showcases unprecedented efficiency. Specifically, FastSAM3D reduced the inference time for the encoder to 3 ms and decoder to 5 ms per $128\times128\times128$ 3D volumetric image on an NVIDIA A100 GPU.
  • Resource Optimization: The introduction of 3D sparse flash attention significantly decreases the model's memory requirements, enabling effective segmentation with commonly available GPU hardware.

Implications and Future Directions

The FastSAM3D model illustrates a significant leap towards the practical deployment of SAMs in medical imaging, especially in tasks requiring real-time interaction, such as surgical planning and intraoperative guidance. Its ability to perform efficient 3D segmentation without sacrificing accuracy paves the way for new clinical applications that leverage AI for enhanced decision-making in diagnostics and treatment planning.

Further research may explore the integration of FastSAM3D within mixed reality applications for an immersive surgical planning experience. Additionally, investigating the model's adaptability to other volumetric imaging tasks beyond the medical domain could further broaden the applicability of SAMs in addressing complex 3D segmentation challenges.

Conclusion

FastSAM3D overcomes the computational barriers associated with SAMs in 3D medical imaging, enabling efficient, real-time interactive segmentation. By marrying the concepts of distilled model architectures and optimized attention mechanisms, it sets a new standard for prompt-based segmentation models, promising substantial benefits for clinical applications and research in medical image analysis.

Create an account to read this summary for free:

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.