Emergent Mind

JORA: JAX Tensor-Parallel LoRA Library for Retrieval Augmented Fine-Tuning

(2403.11366)
Published Mar 17, 2024 in cs.LG , cs.CL , and cs.DC

Abstract

The scaling of LLMs for retrieval-based tasks, particularly in Retrieval Augmented Generation (RAG), faces significant memory constraints, especially when fine-tuning extensive prompt sequences. Current open-source libraries support full-model inference and fine-tuning across multiple GPUs but fall short of accommodating the efficient parameter distribution required for retrieved context. Addressing this gap, we introduce a novel framework for PEFT-compatible fine-tuning of Llama-2 models, leveraging distributed training. Our framework uniquely utilizes JAX's just-in-time (JIT) compilation and tensor-sharding for efficient resource management, thereby enabling accelerated fine-tuning with reduced memory requirements. This advancement significantly improves the scalability and feasibility of fine-tuning LLMs for complex RAG applications, even on systems with limited GPU resources. Our experiments show more than 12x improvement in runtime compared to Hugging Face/DeepSpeed implementation with four GPUs while consuming less than half the VRAM per GPU.

JORA library simplifies fine-tuning with efficient memory use and tensor-parallelism for retrieval-augmented tasks.

Overview

  • JORA, a JAX-based library, enhances the fine-tuning of Llama-2 models with tensor-parallelism and Low-Rank Adaptation (LoRA) for improved memory efficiency and computational performance.

  • It utilizes just-in-time (JIT) compilation and tensor-sharding to expedite fine-tuning while reducing GPU memory demands, particularly effective for Retrieval Augmented Generation (RAG) tasks.

  • Experimental results demonstrate JORA's superior memory utilization and computational performance over existing Hugging Face/DeepSpeed implementations, with a noteworthy improvement in runtime across multi-GPU setups.

  • A practical application case study on social media content analysis showcases JORA's ability to handle large sequence lengths and complex retrieval tasks, confirming its utility in real-world scenarios.

JAX Tensor-Parallel LoRA Library for Retrieval Augmented Fine-Tuning

Introduction

Recent advancements in language model fine-tuning have spotlighted the efficiency and scalability issues encountered during the process, especially for Retrieval Augmented Generation (RAG) tasks. The paper introduces JORA, a JAX-based library designed to address these challenges. It facilitates the fine-tuning of Llama-2 models, leveraging tensor-parallelism and Low-Rank Adaptation (LoRA) for enhanced memory efficiency and computational performance. JORA's innovative use of JAX's just-in-time (JIT) compilation and tensor-sharding techniques allows for accelerated fine-tuning while significantly reducing GPU memory requirements.

Background and Motivation

  • Retrieval Augmented Generation (RAG): RAG techniques integrate retrieved external knowledge into language models, enhancing their output with relevant context. This approach, while effective, presents sizable memory and computational challenges, especially when processing extensive prompt sequences.
  • Existing Training Libraries: Libraries like Hugging Face and DeepSpeed offer capabilities for distributed training but fall short in supporting parameter-efficient tuning, particularly in tensor-parallel contexts. JORA emerges as a solution targeting these specific gaps.

JORA Framework

JORA employs JAX for JIT compilation, optimizing training performance for Llama-2 models. By integrating LoRA into the training process, JORA allows for the efficient fine-tuning of models on retrieval-based tasks. This section discusses the technical foundation of JORA, emphasizing its:

  • Tensor-parallel Training: Distributes the training workload across multiple GPUs, reducing the memory footprint of each individual GPU.
  • Dataset and Training API: Provides helper functions for loading training data and simplifying the fine-tuning process, including a custom data format and pre-defined dataset loading mechanics.
  • Model Transfer API: Facilitates the conversion of JORA-trained models into the Hugging Face model format, ensuring compatibility with a wide range of downstream applications.

Experimental Results

The paper shares compelling experimental evidence, demonstrating JORA's superiority over the Hugging Face/DeepSpeed implementation. Specifically, it highlights:

  • Memory Utilization: JORA significantly outperforms the baseline in memory efficiency, especially evident in multi-GPU setups.
  • Computational Performance: The experiments show JORA achieving a more than 12x improvement in runtime compared to the baseline across various GPU configurations. This performance gain is attributed to JORA's optimized use of JAX's JIT compilation and tailored tensor-parallelism.

Practical Application

A case study is presented where JORA is applied in fine-tuning models for social media content analysis. The study outlines how JORA aids in understanding the structural relationships within social media posts, underscoring the library's practical utility in handling large sequence lengths and complex retrieval tasks. The results from this application scenario further attest to JORA's effectiveness in improving the model's performance on real-world tasks.

Conclusion

JORA addresses the critical hurdles of fine-tuning LLMs for retrieval-augmented tasks, presenting a robust and efficient solution. It significantly reduces the memory footprint and computational time required for fine-tuning, making it a valuable tool for researchers and practitioners working with complex natural language processing applications. The open-source availability of JORA underscores its potential to facilitate further advancements in the field, promising enhancements in the scalability and efficiency of language model fine-tuning.

The authors' contribution with JORA not only paves the way for more efficient use of LLMs in retrieval-based applications but also sets a precedent for future research in the domain, advocating for a shift towards more resource-efficient methodologies in AI.

Create an account to read this summary for free:

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.