TY - JOUR
T1 - Dissecting the Interplay of Attention Paths in a Statistical Mechanics Theory of Transformers
AU - Tiberi, Lorenzo
AU - Mignacco, Francesca
AU - Irie, Kazuki
AU - Sompolinsky, Haim
N1 - Publisher Copyright:
© 2024 Neural information processing systems foundation. All rights reserved.
PY - 2024
Y1 - 2024
N2 - Despite the remarkable empirical performance of transformers, their theoretical understanding remains elusive. Here, we consider a deep multi-head self-attention network, that is closely related to transformers yet analytically tractable. We develop a statistical mechanics theory of Bayesian learning in this model, deriving exact equations for the network's predictor statistics under the finite-width thermodynamic limit, i.e., N, P → ∞, P/N = O(1), where N is the network width and P is the number of training examples. Our theory shows that the predictor statistics are expressed as a sum of independent kernels, each one pairing different attention paths, defined as information pathways through different attention heads across layers. The kernels are weighted according to a task-relevant kernel combination mechanism that aligns the total kernel with the task labels. As a consequence, this interplay between attention paths enhances generalization performance. Experiments confirm our findings on both synthetic and real-world sequence classification tasks. Finally, our theory explicitly relates the kernel combination mechanism to properties of the learned weights, allowing for a qualitative transfer of its insights to models trained via gradient descent. As an illustration, we demonstrate an efficient size reduction of the network, by pruning those attention heads that are deemed less relevant by our theory.
AB - Despite the remarkable empirical performance of transformers, their theoretical understanding remains elusive. Here, we consider a deep multi-head self-attention network, that is closely related to transformers yet analytically tractable. We develop a statistical mechanics theory of Bayesian learning in this model, deriving exact equations for the network's predictor statistics under the finite-width thermodynamic limit, i.e., N, P → ∞, P/N = O(1), where N is the network width and P is the number of training examples. Our theory shows that the predictor statistics are expressed as a sum of independent kernels, each one pairing different attention paths, defined as information pathways through different attention heads across layers. The kernels are weighted according to a task-relevant kernel combination mechanism that aligns the total kernel with the task labels. As a consequence, this interplay between attention paths enhances generalization performance. Experiments confirm our findings on both synthetic and real-world sequence classification tasks. Finally, our theory explicitly relates the kernel combination mechanism to properties of the learned weights, allowing for a qualitative transfer of its insights to models trained via gradient descent. As an illustration, we demonstrate an efficient size reduction of the network, by pruning those attention heads that are deemed less relevant by our theory.
UR - http://www.scopus.com/inward/record.url?scp=105000504411&partnerID=8YFLogxK
M3 - ???researchoutput.researchoutputtypes.contributiontojournal.conferencearticle???
AN - SCOPUS:105000504411
SN - 1049-5258
VL - 37
JO - Advances in Neural Information Processing Systems
JF - Advances in Neural Information Processing Systems
T2 - 38th Conference on Neural Information Processing Systems, NeurIPS 2024
Y2 - 9 December 2024 through 15 December 2024
ER -