How Transformers Learn Causal Structure with Gradient Descent

Published on ● Video Link: https://www.youtube.com/watch?v=xlWBsISnaRA



Duration: 0:00
1,990 views
0


Jason Lee (Princeton University)
https://simons.berkeley.edu/talks/jason-lee-princeton-university-2024-11-12
Domain Adaptation and Related Areas

The incredible success of transformers on sequence modeling tasks can be largely attributed to the self-attention mechanism, which allows information to be transferred between different parts of a sequence. Self-attention allows transformers to encode causal structure which makes them particularly suitable for sequence modeling. However, the process by which transformers learn such causal structure via gradient-based training algorithms remains poorly understood. To better understand this process, we introduce an in-context learning task that requires learning latent causal structure. We prove that gradient descent on a simplified two-layer transformer learns to solve this task by encoding the latent causal graph in the first attention layer. The key insight of our proof is that the gradient of the attention matrix encodes the mutual information between tokens. As a consequence of the data processing inequality, the largest entries of this gradient correspond to edges in the latent causal graph. As a special case, when the sequences are generated from in-context Markov chains, we prove that transformers learn an induction head (Olsson et al., 2022). We confirm our theoretical findings by showing that transformers trained on our in-context learning task are able to recover a wide variety of causal structures.




Other Videos By Simons Institute for the Theory of Computing


2024-11-14Open-Source and Science in the Era of Foundation Models
2024-11-13Toward Understanding the Extrapolation of Nonlinear Models to Unseen Domains or the Whole Domain
2024-11-13Language-guided Adaptation
2024-11-13On Spurious Associations and LLM Alignment
2024-11-13Causally motivated robustness to shortcut learning
2024-11-13Talk by Zachary Lipton
2024-11-12Distribution shift in ecological data: generalization vs. specialization,
2024-11-12Transfer learning via local convergence rates of the nonparametric least squares estimator
2024-11-12Transfer learning for weak-to-strong generalization
2024-11-12User-level and federated local differential privacy
2024-11-11How Transformers Learn Causal Structure with Gradient Descent
2024-10-16The Enigma of LLMs: on Creativity, Compositionality, Pluralism, and Paradoxes
2024-10-02Let’s Try and Be More Tolerant: On Tolerant Property Testing and Distance Approximation
2024-10-02A Strong Separation for Adversarially Robust L_0 Estimation for Linear Sketches
2024-10-02Towards Practical Distribution Testing
2024-10-02Toward Optimal Semi-streaming Algorithm for (1+ε)-approximate Maximum Matching
2024-10-02Plenary Talk: Privately Evaluating Untrusted Black-Box Functions
2024-10-02The long path to \sqrt{d} monotonicity testers
2024-10-02O(log log n) Passes is Optimal for Semi-Streaming Maximal Independent Set
2024-10-02Distribution Learning Meets Graph Structure Sampling
2024-10-02On the instance optimality of detecting collisions and subgraphs