Emergent Mind

Learning to Decode Collaboratively with Multiple Language Models

(2403.03870)
Published Mar 6, 2024 in cs.CL and cs.LG

Abstract

We propose a method to teach multiple LLMs (LLM) to collaborate by interleaving their generations at the token level. We model the decision of which LLM generates the next token as a latent variable. By optimizing the marginal likelihood of a training set under our latent variable model, the base LLM automatically learns when to generate itself and when to call on one of the ``assistant'' language models to generate, all without direct supervision. Token-level collaboration during decoding allows for a fusion of each model's expertise in a manner tailored to the specific task at hand. Our collaborative decoding is especially useful in cross-domain settings where a generalist base LLM learns to invoke domain expert models. On instruction-following, domain-specific QA, and reasoning tasks, we show that the performance of the joint system exceeds that of the individual models. Through qualitative analysis of the learned latent decisions, we show models trained with our method exhibit several interesting collaboration patterns, e.g., template-filling. Our code is available at https://github.com/clinicalml/co-llm.

Co. method generates answer templates and leverages a larger Llama base model for detailed answers.

Overview

  • This paper introduces a novel method called \ourmethod for collaborative decoding among multiple LLMs (LMs), which learns when and how to utilize the strengths of each model in an ensemble without direct supervision.

  • The method operates under a latent-variable framework, allowing automatic decision-making for token generation by combining a base LM with one or more assistant LMs, optimizing the marginal likelihood of correct sequence generation.

  • Experiments show that \ourmethod outperforms individual models and demonstrates significant advantages in cross-domain scenarios by combining generalist and specialist models for tasks like question-answering and reasoning.

  • The approach suggests a path forward for more versatile and capable LMs by learning collaboration patterns directly from data, and highlights the importance of further research to mitigate cascading errors and explore more complex collaboration strategies.

Exploring Collaborative Decoding with Multiple LLMs

Introduction to Collaborative Decoding

Recent developments in language models (LMs) have introduced sophisticated methods that combine the capabilities of multiple large LMs to improve performance on a wide array of tasks. These methods range from enhancing decoding speed to increasing the fidelity and accuracy of generated text. However, most existing approaches require predefined rules for model combination or direct supervision indicating when to utilize auxiliary tools or models. We propose a new method, \ourmethod, which introduces a latent variable framework for collaborative decoding among multiple LMs, permitting an automatic learning of when and how to best leverage the strengths of each model in the ensemble.

Latent-Variable Framework

\ourmethod operates under a latent-variable framework where the decision of which LM to use for generating the next token is treated as a latent variable. This configuration allows for the collaboration of multiple LMs at a token level without direct supervision. The joint generation mechanism optimizes the marginal likelihood of generating a correct sequence by judically combining the outputs of a "base LM" with one or more "assistant LMs”. This setup is particularly advantageous in cross-domain scenarios, enabling a generalist base LM to solicit input from domain-specialized models effortlessly.

Through experiments, we demonstrate that \ourmethod efficiently fuses the expertise of different models, tailoring their collaborative efforts to the task at hand. For instance, in a cross-domain question-answering setup, a generalist base LM can invoke a domain-specialist model to generate answers that require specific knowledge or reasoning capabilities outside the base model's training data.

Key Findings

Our experiments across various datasets, including instruction-following, domain-specific question-answering, and reasoning tasks, reveal several key findings:

  • The performance of the jointly operating models often surpasses that of any individual model involved, suggesting that \ourmethod successfully harmonizes their distinct strengths.
  • The method shows pronounced benefits in cross-domain settings, allowing a generalist model to leverage the specialized knowledge of domain expert models.
  • Through qualitative analysis, we observe emergent collaboration patterns, such as template-filling, where the base model generates a scaffold for the response, and the assistant models fill in the requisite details.

Implications and Future Directions

The latent-variable framework of \ourmethod offers a flexible and powerful mechanism to dynamically combine the capabilities of multiple LMs. By learning optimal collaboration patterns directly from data, this approach sidesteps the need for hard coding rules or collecting labeled data to signal the participation of auxiliary models. This research opens up new avenues for constructing more versatile and capable LMs that can adaptively marshal the strengths of various models to address a broad spectrum of tasks more effectively.

Looking ahead, further work could extend this framework to encompass a larger number of models, possibly leading to even more sophisticated and nuanced collaboration strategies. Additionally, exploring methods to safeguard against potential cascading errors when relying on unsupervised model selection will be crucial in enhancing the reliability and robustness of collaborative decoding approaches like \ourmethod.

Create an account to read this summary for free:

Newsletter

Get summaries of trending comp sci papers delivered straight to your inbox:

Unsubscribe anytime.

GitHub