Papers
Topics
Authors
Recent
2000 character limit reached

Quiet-STaR: Language Models Can Teach Themselves to Think Before Speaking (2403.09629v2)

Published 14 Mar 2024 in cs.CL, cs.AI, and cs.LG

Abstract: When writing and talking, people sometimes pause to think. Although reasoning-focused works have often framed reasoning as a method of answering questions or completing agentic tasks, reasoning is implicit in almost all written text. For example, this applies to the steps not stated between the lines of a proof or to the theory of mind underlying a conversation. In the Self-Taught Reasoner (STaR, Zelikman et al. 2022), useful thinking is learned by inferring rationales from few-shot examples in question-answering and learning from those that lead to a correct answer. This is a highly constrained setting -- ideally, a LLM could instead learn to infer unstated rationales in arbitrary text. We present Quiet-STaR, a generalization of STaR in which LMs learn to generate rationales at each token to explain future text, improving their predictions. We address key challenges, including 1) the computational cost of generating continuations, 2) the fact that the LM does not initially know how to generate or use internal thoughts, and 3) the need to predict beyond individual next tokens. To resolve these, we propose a tokenwise parallel sampling algorithm, using learnable tokens indicating a thought's start and end, and an extended teacher-forcing technique. Encouragingly, generated rationales disproportionately help model difficult-to-predict tokens and improve the LM's ability to directly answer difficult questions. In particular, after continued pretraining of an LM on a corpus of internet text with Quiet-STaR, we find zero-shot improvements on GSM8K (5.9%$\rightarrow$10.9%) and CommonsenseQA (36.3%$\rightarrow$47.2%) and observe a perplexity improvement of difficult tokens in natural text. Crucially, these improvements require no fine-tuning on these tasks. Quiet-STaR marks a step towards LMs that can learn to reason in a more general and scalable way.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (75)
  1. Thinking fast and slow with deep learning and tree search. Advances in neural information processing systems, 30, 2017.
  2. Fireact: Toward language agent fine-tuning. arXiv preprint arXiv:2310.05915, 2023.
  3. Scaling instruction-finetuned language models. arXiv preprint arXiv:2210.11416, 2022.
  4. Training Verifiers to Solve Math Word Problems. arXiv, 2021. _eprint: 2110.14168.
  5. Strategic reasoning with language models. arXiv preprint arXiv:2305.19165, 2023.
  6. Are we modeling the task or the annotator? an investigation of annotator bias in natural language understanding datasets. arXiv preprint arXiv:1908.07898, 2019.
  7. Think before you speak: Training language models with pause tokens. arXiv preprint arXiv:2310.02226, 2023.
  8. Reinforced self-training (rest) for language modeling. arXiv preprint arXiv:2308.08998, 2023.
  9. Textbooks are all you need. arXiv preprint arXiv:2306.11644, 2023.
  10. Language models can teach themselves to program better. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id=SaRj2ka1XZ3.
  11. Backpack language models. arXiv preprint arXiv:2305.16765, 2023.
  12. Large language models are reasoning teachers. arXiv preprint arXiv:2212.10071, 2022.
  13. Training chain-of-thought via latent-variable inference. Advances in Neural Information Processing Systems, 36, 2024.
  14. V-star: Training verifiers for self-taught reasoners. arXiv preprint arXiv:2402.06457, 2024.
  15. Distilling step-by-step! outperforming larger language models with less training data and smaller model sizes. arXiv preprint arXiv:2305.02301, 2023.
  16. Large language models can self-improve. arXiv preprint arXiv:2210.11610, 2022.
  17. Categorical reparameterization with gumbel-softmax. arXiv preprint arXiv:1611.01144, 2016.
  18. Mistral 7b. arXiv preprint arXiv:2310.06825, 2023.
  19. Discrete prompt compression with reinforcement learning. arXiv preprint arXiv:2308.08758, 2023.
  20. Demonstrate-search-predict: Composing retrieval and language models for knowledge-intensive nlp. arXiv preprint arXiv:2212.14024, 2022.
  21. Dspy: Compiling declarative language model calls into self-improving pipelines. arXiv preprint arXiv:2310.03714, 2023.
  22. Large Language Models are Zero-Shot Reasoners, 2022. URL https://arxiv.org/abs/2205.11916.
  23. Can language models learn from explanations in context? arXiv preprint arXiv:2204.02329, 2022.
  24. Learning to reason and memorize with self-notes. Advances in Neural Information Processing Systems, 36, 2024.
  25. The power of scale for parameter-efficient prompt tuning. arXiv preprint arXiv:2104.08691, 2021.
  26. Solving quantitative reasoning problems with language models. Advances in Neural Information Processing Systems, 35:3843–3857, 2022.
  27. Automated statistical model discovery with language models. arXiv preprint arXiv:2402.17879, 2024.
  28. Explanations from large language models make small reasoners better. arXiv preprint arXiv:2210.06726, 2022.
  29. Prefix-tuning: Optimizing continuous prompts for generation. arXiv preprint arXiv:2101.00190, 2021.
  30. Compressing context to enhance inference efficiency of large language models. arXiv preprint arXiv:2310.06201, 2023.
  31. Crystal: Introspective reasoners reinforced with self-feedback. arXiv preprint arXiv:2310.04921, 2023.
  32. Wizardmath: Empowering mathematical reasoning for large language models via reinforced evol-instruct. arXiv preprint arXiv:2308.09583, 2023.
  33. Self-refine: Iterative refinement with self. Feedback, 2023.
  34. Playing atari with deep reinforcement learning. arXiv preprint arXiv:1312.5602, 2013.
  35. Asynchronous methods for deep reinforcement learning. In International conference on machine learning, pp.  1928–1937. PMLR, 2016.
  36. Learning to compress prompts with gist tokens. Advances in Neural Information Processing Systems, 36, 2024.
  37. Show your work: Scratchpads for intermediate computation with language models. arXiv preprint arXiv:2112.00114, 2021.
  38. Feedback loops with language models drive in-context reward hacking. arXiv preprint arXiv:2402.06627, 2024.
  39. Openwebmath: An open dataset of high-quality mathematical web text. arXiv preprint arXiv:2310.06786, 2023.
  40. Training chain-of-thought via latent-variable inference. In Thirty-seventh Conference on Neural Information Processing Systems, 2023.
  41. Certified reasoning with language models. arXiv preprint arXiv:2306.04031, 2023.
  42. Generative Language Modeling for Automated Theorem Proving. CoRR, abs/2009.03393, 2020. URL https://arxiv.org/abs/2009.03393. _eprint: 2009.03393.
  43. Why think step by step? reasoning emerges from the locality of experience. Advances in Neural Information Processing Systems, 36, 2024.
  44. Autoact: Automatic agent learning from scratch via self-planning. arXiv preprint arXiv:2401.05268, 2024.
  45. Phenomenal yet puzzling: Testing inductive reasoning capabilities of language models with hypothesis refinement. arXiv preprint arXiv:2310.08559, 2023.
  46. Language models are unsupervised multitask learners. OpenAI blog, 1(8):9, 2019.
  47. Exploring the limits of transfer learning with a unified text-to-text transformer. Journal of machine learning research, 21(140):1–67, 2020.
  48. Explain yourself! leveraging language models for commonsense reasoning. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pp.  4932–4942, 2019.
  49. Toolformer: Language models can teach themselves to use tools. Advances in Neural Information Processing Systems, 36, 2024.
  50. Proximal policy optimization algorithms. arXiv preprint arXiv:1707.06347, 2017.
  51. Programming Puzzles. In Thirty-fifth Conference on Neural Information Processing Systems, 2021. URL https://openreview.net/forum?id=fe_hCc4RBrg.
  52. Reflexion: Language agents with verbal reinforcement learning. arXiv preprint arXiv:2303.11366, 2023.
  53. Unsupervised commonsense question answering with self-talk. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP), pp.  4615–4629, 2020.
  54. Mastering chess and shogi by self-play with a general reinforcement learning algorithm. arXiv preprint arXiv:1712.01815, 2017.
  55. Commonsenseqa: A question answering challenge targeting commonsense knowledge. arXiv preprint arXiv:1811.00937, 2018.
  56. Function vectors in large language models. arXiv preprint arXiv:2310.15213, 2023.
  57. Solving math word problems with process-and outcome-based feedback. Neural Information Processing Systems (NeurIPS 2022) Workshop on MATH-AI, 2022.
  58. Hypothesis search: Inductive reasoning with language models. arXiv preprint arXiv:2309.05660, 2023.
  59. Chain-of-thought reasoning without prompting. arXiv preprint arXiv:2402.10200, 2024.
  60. Language modelling as a multi-task problem. arXiv preprint arXiv:2101.11287, 2021.
  61. Finetuned language models are zero-shot learners. In International Conference on Learning Representations, 2021a.
  62. Finetuned language models are zero-shot learners. arXiv preprint arXiv:2109.01652, 2021b.
  63. Emergent Abilities of Large Language Models, October 2022a. URL http://arxiv.org/abs/2206.07682. arXiv:2206.07682 [cs].
  64. Chain of Thought Prompting Elicits Reasoning in Large Language Models, 2022b. URL https://arxiv.org/abs/2201.11903.
  65. Ronald J Williams. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8:229–256, 1992.
  66. React: Synergizing reasoning and acting in language models. International Conference on Learning Representations (ICLR 2023), 2022.
  67. Star: Bootstrapping reasoning with reasoning. Advances in Neural Information Processing Systems, 35:15476–15488, 2022.
  68. Parsel: Algorithmic reasoning with language models by composing decompositions, 2023a.
  69. Self-taught optimizer (stop): Recursively self-improving code generation. arXiv preprint arXiv:2310.02304, 2023b.
  70. Chain-of-thought reasoning is a policy improvement operator. arXiv preprint arXiv:2309.08589, 2023.
  71. In-context principle learning from mistakes. arXiv preprint arXiv:2402.05403, 2024.
  72. Automatic chain of thought prompting in large language models. arXiv preprint arXiv:2210.03493, 2022.
  73. Hop, union, generate: Explainable multi-hop reasoning without rationale supervision. arXiv preprint arXiv:2305.14237, 2023.
  74. Teaching algorithmic reasoning via in-context learning. arXiv preprint arXiv:2211.09066, 2022.
  75. Large language models can learn rules. arXiv preprint arXiv:2310.07064, 2023.
Citations (57)

Summary

  • The paper introduces a self-teaching framework (Quiet-STaR) that enables language models to generate token-level rationales using teacher-forcing and a REINFORCE-style reward mechanism.
  • The methodology achieves significant zero-shot improvements, with GSM8K accuracy rising from 5.9% to 10.9% and CommonsenseQA from 36.3% to 47.2%, highlighting practical gains in difficult predictions.
  • The approach opens new avenues for unsupervised reasoning model development, suggesting future applications in self-enhancing dialogue systems and multimodal AI through dynamic thought allocation.

"Quiet-STaR: LLMs Can Teach Themselves to Think Before Speaking" Analysis

Introduction to Quiet-STaR

The paper "Quiet-STaR: LLMs Can Teach Themselves to Think Before Speaking" (2403.09629) introduces a significant advancement in LLM reasoning. It builds upon the Self-Taught Reasoner (STaR) architecture to enable LMs to infer rationales independently at each token in a text. Unlike traditional approaches that restrict reasoning to structured question-answering datasets, Quiet-STaR leverages diverse unstructured internet text to facilitate reasoning, thereby improving prediction accuracy through tokenwise parallel sampling and a teacher-forcing framework.

Algorithm and Methodology

Quiet-STaR Mechanism: The Quiet-STaR framework involves generating thoughts or rationales corresponding to each token within an input sequence. These rationales, marked by meta-tokens that signify thought initiation and completion, allow LLMs to infer intermediate reasoning required for future text predictions. Figure 1

Figure 1: Quiet-STaR. We visualize the algorithm as applied during training to a single thought. We generate thoughts, in parallel, following all tokens in the text.

Parallel Generation: A cornerstone of Quiet-STaR is its parallel generation capability, where rationales are generated simultaneously for all tokens. This is achieved by a diagonal attention mask construction enabling self-attention among generated thoughts and preceding tokens, resulting in efficient reasoning across multiple contexts. Figure 2

Figure 2: Parallel Generation. By constructing an attention mask that allows all thought tokens to pay attention to themselves, all preceding thought tokens within the same thought, and the preceding text, we can generate continuations of all of the thoughts in parallel.

Mixing and Reinforcement: The system utilizes a mixing head to determine the interpolation between the rationale-enhanced predictions and base LLM outputs. This approach alleviates distribution shift issues and stabilizes training curves. Furthermore, Quiet-STaR employs a REINFORCE-style reward mechanism where rewards are derived from the effectiveness of rationales in improving token predictions relative to average performance, facilitating iterative reasoning optimization. Figure 3

Figure 3: Forward Pass and Teacher Forcing. We visualize a single forward pass of our algorithm. Solid lines denote LLM computation, while dashed lines indicate tokens are inserted via teacher forcing, and the mixer represents the mixing head.

Experimental Results

The empirical evaluation establishes that Quiet-STaR enhances zero-shot problem-solving capabilities across challenging datasets like GSM8K and CommonsenseQA without dataset-specific fine-tuning. The zero-shot accuracy improvements observed include a rise from 5.9% to 10.9% on GSM8K and from 36.3% to 47.2% on CommonsenseQA, showcasing the effectiveness of internal rationales. Figure 4

Figure 4

Figure 4: Generalization Results. We evaluate the extent to which the model trained with Quiet-STaR generalizes to directly answering problems that require reasoning.

The distribution of success indicates that Quiet-STaR disproportionately aids in the prediction of challenging tokens, aligning with the hypothesis that LMs benefit from reasoning through difficult-to-predict contexts. Figure 5

Figure 5: Distribution of changes in log probability. We visualize the distribution of changes in log probability resulting from the generated thoughts across the evaluation dataset.

Implications and Future Directions

Quiet-STaR offers substantial implications for AI frameworks, marking a paradigm shift towards unsupervised reasoning model development. The scalable approach implies potential in self-enhancing dialogue systems, interactive AI learners, and multi-modal reasoning systems. Future pathways may investigate dynamic thought token allocation, improved reward mechanisms, and meta-learning strategies to bolster LM coherence and reasoning fidelity.

Conclusion

Quiet-STaR represents substantial progress in LLM reasoning capabilities. By enabling unsupervised rationale generation across diverse text datasets, it achieves a notable increase in predictive accuracy and reasoning capacity. This methodology not only promises improvements in current NLP applications but also sets the stage for future advancements in autonomous LM reasoning development, thereby contributing meaningfully to the AI research landscape.

Whiteboard

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 182 tweets with 7601 likes about this paper.