Structured State Space Models for In-Context Reinforcement Learning (2303.03982v3)
Abstract: Structured state space sequence (S4) models have recently achieved state-of-the-art performance on long-range sequence modeling tasks. These models also have fast inference speeds and parallelisable training, making them potentially useful in many reinforcement learning settings. We propose a modification to a variant of S4 that enables us to initialise and reset the hidden state in parallel, allowing us to tackle reinforcement learning tasks. We show that our modified architecture runs asymptotically faster than Transformers in sequence length and performs better than RNN's on a simple memory-based task. We evaluate our modified architecture on a set of partially-observable environments and find that, in practice, our model outperforms RNN's while also running over five times faster. Then, by leveraging the model's ability to handle long-range sequences, we achieve strong performance on a challenging meta-learning task in which the agent is given a randomly-sampled continuous control environment, combined with a randomly-sampled linear projection of the environment's observations and actions. Furthermore, we show the resulting model can adapt to out-of-distribution held-out tasks. Overall, the results presented in this paper show that structured state space models are fast and performant for in-context reinforcement learning tasks. We provide code at https://github.com/luchris429/popjaxrl.
- Human-timescale adaptation in an open-ended task space. arXiv e-prints, 2023.
- A survey of meta-reinforcement learning. arXiv preprint arXiv:2301.08028, 2023.
- JAX: composable transformations of Python+NumPy programs, 2018. URL http://github.com/google/jax.
- Decision transformer: Reinforcement learning via sequence modeling. Advances in neural information processing systems, 34:15084–15097, 2021.
- Learning phrase representations using rnn encoder-decoder for statistical machine translation. arXiv preprint arXiv:1406.1078, 2014.
- Decision s4: Efficient sequence-based rl via state spaces layers. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id=kqHkCVS7wbj.
- Kenji Doya. Reinforcement learning in continuous time and space. Neural computation, 12(1):219–245, 2000.
- Rl2: Fast reinforcement learning via slow reinforcement learning. arXiv preprint arXiv:1611.02779, 2016.
- Implementation matters in deep policy gradients: A case study on ppo and trpo. In International Conference on Learning Representations, 2020.
- It’s raw! audio generation with state-space models. arXiv preprint arXiv:2202.09729, 2022.
- Hippo: Recurrent memory with optimal polynomial projections. Advances in neural information processing systems, 33:1474–1487, 2020.
- Efficiently modeling long sequences with structured state spaces. arXiv preprint arXiv:2111.00396, 2021a.
- Combining recurrent, convolutional, and continuous-time models with linear state space layers. Advances in neural information processing systems, 34:572–585, 2021b.
- James D Hamilton. State-space models. Handbook of econometrics, 4:3039–3080, 1994.
- Muesli: Combining improvements in policy optimization. In International Conference on Machine Learning, pages 4214–4226. PMLR, 2021.
- Long short-term memory. Neural computation, 9(8):1735–1780, 1997.
- Evolved policy gradients. Advances in Neural Information Processing Systems, 31, 2018.
- Cleanrl: High-quality single-file implementations of deep reinforcement learning algorithms. Journal of Machine Learning Research, 23(274):1–18, 2022. URL http://jmlr.org/papers/v23/21-1342.html.
- Recurrent experience replay in distributed reinforcement learning. In International conference on learning representations, 2018.
- Introducing symmetries to black box meta reinforcement learning. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 36, pages 7202–7210, 2022a.
- General-purpose in-context learning by meta-learning transformers. arXiv preprint arXiv:2212.04458, 2022b.
- Robert Tjarko Lange. gymnax: A JAX-based reinforcement learning environment library, 2022. URL http://github.com/RobertTLange/gymnax.
- In-context reinforcement learning with algorithm distillation. arXiv preprint arXiv:2210.14215, 2022.
- Rllib: Abstractions for distributed reinforcement learning. In International Conference on Machine Learning, pages 3053–3062. PMLR, 2018.
- Discovered policy optimisation. arXiv preprint arXiv:2210.05639, 2022.
- Gradients are not all you need. arXiv preprint arXiv:2111.05803, 2021.
- POPGym: Benchmarking partially observable reinforcement learning. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id=chDrutUTs0K.
- S4nd: Modeling images and videos as multidimensional signals with state spaces. In Advances in Neural Information Processing Systems, 2022.
- Recurrent model-free rl can be a strong baseline for many pomdps. In International Conference on Machine Learning, pages 16691–16723. PMLR, 2022.
- Discovering reinforcement learning algorithms. Advances in Neural Information Processing Systems, 33:1060–1070, 2020.
- The hippocampus as a spatial map: Preliminary evidence from unit activity in the freely-moving rat. Brain research, 1971.
- Behaviour suite for reinforcement learning. arXiv preprint arXiv:1908.03568, 2019.
- Efficient transformers in reinforcement learning using actor-learner distillation. arXiv preprint arXiv:2104.01655, 2021.
- Stabilizing transformers for reinforcement learning. In International conference on machine learning, pages 7487–7498. PMLR, 2020.
- Stable-baselines3: Reliable reinforcement learning implementations. Journal of Machine Learning Research, 22(268):1–8, 2021. URL http://jmlr.org/papers/v22/20-1364.html.
- A generalist agent. arXiv preprint arXiv:2205.06175, 2022.
- Proximal policy optimization algorithms. ArXiv, abs/1707.06347, 2017.
- Simplified state space layers for sequence modeling. arXiv preprint arXiv:2208.04933, 2022.
- Reinforcement learning: An introduction. 2018.
- Deepmind control suite. arXiv preprint arXiv:1801.00690, 2018.
- Long range arena: A benchmark for efficient transformers. arXiv preprint arXiv:2011.04006, 2020.
- Attention is all you need. Advances in neural information processing systems, 30, 2017.
- Learning to reinforcement learn. arXiv preprint arXiv:1611.05763, 2016.
- Gradient-based learning algorithms for recurrent. Backpropagation: Theory, architectures, and applications, 433:17, 1995.
- Chris Lu (33 papers)
- Yannick Schroecker (11 papers)
- Albert Gu (40 papers)
- Emilio Parisotto (24 papers)
- Jakob Foerster (101 papers)
- Satinder Singh (80 papers)
- Feryal Behbahani (18 papers)