Shortening the Loss Plateau
Transformer models waste enormous compute stuck in two training stalls โ loss plateaus and grokking. We identify what causes them and show targeted interventions that cut grokking delay by up to 316ร through sparse initialization.
Core Question
Transformer models (a type of machine learning model that powers modern AI systems like ChatGPT and Claude) often spend enormous amounts of compute stuck in inefficient training regimes. These stalls appear in two well-known forms: training-loss plateaus and grokking (generalization plateau).
In this project we investigate whether these phenomena arise from similar optimization dynamics, and whether the same interventions can shorten both forms of stalled learning. We focus exclusively on small-scale modular arithmetic tasks; generalization to large language models is left for future work.
Two Forms of Training Stall
Training-Loss Plateau
Early in training, loss stalls due to representation collapse and repetition bias in token embeddings. Attention structure forms slowly, delaying useful learning.
Grokking (Generalization Plateau)
Models memorize training data almost instantly but take hundreds of thousands of steps before generalizing. We investigate what drives this delay and how to eliminate it.
Our Approach
We test whether shared interventions can accelerate both phenomena.
Task Diversity
Training on multiple arithmetic tasks simultaneously to break memorization dynamics.
Optimizer Noise
Comparing SGD and AdamW to study how gradient noise influences escape from memorization basins.
Initialization Constraints
Testing sparse and small-weight initialization to restrict early representational capacity.
Dataset
All experiments use synthetically generated modular arithmetic data in an online manner. The tasks are chosen such that the model can solve the task to perfect accuracy. There are no external dataset, no train/test leakage and no sensitive or private information involved.
Key Findings
Our experiments reveal that training stalls arise from inefficient early representation learning. Several interventions significantly accelerate training dynamics across both loss plateaus and grokking.
Loss Plateau Insights
Slow Representation and Attention Formation
The loss plateau is driven by slow development of meaningful internal representations. Early in training, attention maps and token embeddings change very slowly, limiting gradient signal quality and delaying useful learning.
Task Diversity Accelerates Learning
Training on multiple arithmetic tasks simultaneously shortens the loss plateau. By distributing training across tasks, the model requires fewer samples per task while learning more generalizable representations, allowing loss to converge faster.
Grokking Acceleration
4ร Faster via Task Diversity
Training on 4 arithmetic tasks simultaneously reduced Division grokking from 334,000 โ ~80,000 steps.
8ร Faster via SGD
Replacing AdamW with SGD introduced gradient noise that escaped the memorization basin, achieving stable grokking at 44,900 steps.
316ร Faster via Initialization
Sparse or small weight initialization reduced the delay from ~332,000 steps to just 1,050 steps.
Discussion & Implications
Our results suggest that both training-loss plateaus and grokking share a common root cause: the model's early representations are too unconstrained to learn efficiently. Interventions that restrict or guide early representation learning, whether through initialization, optimizer choice, or task diversity, consistently accelerate convergence.
These findings have practical implications for anyone training transformers at scale: careful initialization and optimizer selection can dramatically reduce the compute needed to reach generalization, potentially saving significant GPU hours in real workloads.
Limitations
All experiments are conducted on small toy arithmetic tasks (mod 97). While these are standard benchmarks in grokking research, we cannot guarantee the same speedups will transfer to large-scale language model training or other domains. The interaction between our interventions (e.g., combining SGD with sparse initialization) was not fully explored. Future work should validate these findings on larger models and real NLP tasks.
What Changed Along the Way
Early experiments with AdamW showed that grokking delay was extremely sensitive to weight decay but still required hundreds of thousands of steps. Switching to SGD introduced gradient noise that helped escape memorization basins, but the biggest breakthrough came from initialization: restricting early representational capacity via sparse or small-weight initialization collapsed the grokking delay from ~332,000 steps to just 1,050: a result we did not anticipate from the optimizer experiments alone.
References
Power, A., Burda, Y., Edwards, H., Babuschkin, I., & Misra, V. (2022).
Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets
. arXiv:2201.02177.
Lyu, K., Jin, T., Li, Y., Du, S., Lee, J. D., & Hu, W. (2024).
Dichotomy of Early and Late Phase Implicit Biases Can Provably Induce Grokking
. ICLR 2024.
Kim, J., Kwon, S., Choi, J. Y., Park, J., Cho, J., Lee, J. D., & Ryu, E. K. (2025).
Task Diversity Shortens the ICL Plateau
. arXiv:2410.05448.
Lee, S., et al. (2024).
Grokfast: Accelerated Grokking by Amplifying Slow Gradients
. arXiv:2405.20233.
Gopalani, P., & Hu, W. (2025).
What Happens During the Loss Plateau? Understanding Abrupt Learning in Transformers
. arXiv:2506.13688.
He, J., Pan, X., Chen, S., & Yang, Z. (2025).
In-Context Linear Regression Demystified: Training Dynamics and Mechanistic Interpretability of Multi-Head Softmax Attention
. arXiv:2503.12734.