Emergent Mind

Jamba-1.5: Hybrid Transformer-Mamba Models at Scale

(2408.12570)
Published Aug 22, 2024 in cs.CL and cs.LG

Abstract

We present Jamba-1.5, new instruction-tuned LLMs based on our Jamba architecture. Jamba is a hybrid Transformer-Mamba mixture of experts architecture, providing high throughput and low memory usage across context lengths, while retaining the same or better quality as Transformer models. We release two model sizes: Jamba-1.5-Large, with 94B active parameters, and Jamba-1.5-Mini, with 12B active parameters. Both models are fine-tuned for a variety of conversational and instruction-following capabilties, and have an effective context length of 256K tokens, the largest amongst open-weight models. To support cost-effective inference, we introduce ExpertsInt8, a novel quantization technique that allows fitting Jamba-1.5-Large on a machine with 8 80GB GPUs when processing 256K-token contexts without loss of quality. When evaluated on a battery of academic and chatbot benchmarks, Jamba-1.5 models achieve excellent results while providing high throughput and outperforming other open-weight models on long-context benchmarks. The model weights for both sizes are publicly available under the Jamba Open Model License and we release ExpertsInt8 as open source.

End-to-end latency of Jamba-1.5-Large.

Overview

  • The paper introduces Jamba-1.5, a collection of LLMs that use a hybrid architecture combining Transformer and Mamba layers with a mixture-of-experts (MoE) module, and come in two sizes: Large with 94 billion active parameters and Mini with 12 billion active parameters.

  • It discusses various innovations such as ExpertsInt8, a quantization technique for efficient deployment, and details the training methodology including stages like pre-training, mid-training, and post-training to enhance conversational skills and long-context capabilities.

  • Evaluation metrics show Jamba-1.5 models performing comparably to or better than state-of-the-art models on academic benchmarks and excelling in long-context evaluations and multilingual capabilities, with a strong focus on alignment and safety according to OECD AI principles.

Essay: Overview of Jamba-1.5: Hybrid Transformer-Mamba Models at Scale

The paper presents Jamba-1.5, a collection of LLMs that leverage a hybrid architecture combining Transformer and Mamba layers with a mixture-of-experts (MoE) module. Two model sizes are introduced: Jamba-1.5-Large with 94 billion active parameters and Jamba-1.5-Mini with 12 billion active parameters. Notably, both models offer an effective context length of 256,000 tokens, surpassing other open-weight models. This essay provides a comprehensive analysis of the Jamba-1.5 models, focusing on their architecture, serving considerations, training methodology, evaluation metrics, and alignment approaches.

Model Architecture

The Jamba-1.5 models are based on a hybrid architecture that integrates Transformer layers, Mamba layers, and a MoE module. Specifically, Jamba-1.5-Large features 94 billion active parameters distributed across 9 blocks, each block containing 8 layers with an optimal ratio of 1:7 for attention-to-Mamba layers. The MoE module, applied every two layers, employs 16 experts, selecting the top 2 at each token step. This architectural design not only enhances model efficiency but also significantly reduces memory usage, especially the Key-Value (KV) cache, by an order of magnitude compared to other similar models.

Serving Considerations and Improvements

To facilitate cost-effective inference, the paper introduces ExpertsInt8, a quantization technique that compresses MoE and MLP weights to INT8 while maintaining BF16 activations. This method shows negligible overhead and improves latency across both A100 and H100 GPUs. By quantizing over 85% of the model's weights, ExpertsInt8 enables the deployment of Jamba-1.5-Large on a single machine with 8 80GB GPUs for processing contexts up to 256K tokens. Additional innovations such as Activation Loss, which penalizes large activation values, further optimize inference by reducing potential numerical issues.

Training Methodology

The training process for Jamba-1.5-Large was conducted on NVIDIA H100 GPUs using an in-house framework that includes FSDP, tensor parallelism, sequence parallelism, and expert parallelism adapted from MegaBlocks. Training consisted of three stages: pre-training on a diverse dataset, including multilingual data; mid-training focused on long documents; and post-training to enhance conversational skills while retaining long-context capabilities. Post-training involved supervised fine-tuning on high-quality conversational, skill-specific, and long-context data, with heavy reliance on synthetic data generated through various pipelines.

Evaluation Metrics

The Jamba-1.5 models were evaluated across a range of academic benchmarks, chatbot scenarios, and long-context evaluations. In standard academic benchmarks such as MMLU, MMLU-Pro, GPQA, and HumanEval, the Jamba-1.5 models performed comparably to or better than state-of-the-art models of similar sizes. They excelled particularly in long-context evaluations on the RULER and $\infty$Bench benchmarks, being the only models with a confirmed effective context length of 256K tokens. Furthermore, Jamba-1.5 models demonstrated impressive performance in multilingual capabilities across languages like Spanish, Portuguese, French, and German.

Alignment and Safety Considerations

The paper emphasizes alignment and safety through a structured approach based on OECD AI principles. The alignment strategy includes defining behavioral tenets for models, collaborating with customers to tailor ethical guidelines, and continuously monitoring and improving model behavior. Particularly, the Jamba-1.5 models adhere to principles such as transparency, fairness, and robustness, validated through benchmarks like RealToxicity and TruthfulQA.

Implications and Future Directions

The Jamba-1.5 models represent a significant advancement in large language model architecture, offering high efficiency and scalability. These models not only improve throughput and latency but also extend the practical applicability of long-context language models in diverse domains such as conversational AI, document understanding, and multilingual applications. Future research could explore further enhancements in hybrid architectures and quantization techniques, as well as broader implementation scenarios in AI-driven industries.

In conclusion, the Jamba-1.5 paper articulates a sophisticated approach to modern LLMs, balancing performance, efficiency, and practical deployment considerations. By embracing a hybrid model architecture and innovative quantization techniques, Jamba-1.5 sets a new standard for open-weight models, particularly in handling long-context scenarios, thereby fostering advancements in both theoretical and applied AI research.

Create an account to read this summary for free:

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.

YouTube