SMYRF: Efficient Attention using Asymmetric Clustering

Published in NeurIPS 2020 [Paper] [Code]

Citation: Giannis Daras, Nikita Kitaev, Augustus Odena, Alexandros G. Dimakis, "SMYRF: Efficient Attention using Asymmetric Clustering", NeurIPS 2020

We propose a novel type of balanced clustering algorithm to approximate attention. Attention complexity is reduced from O(N^2) to O(NlogN), where N is the sequence length. Our algorithm, SMYRF, uses Locality Sensitive Hashing (LSH) in a novel way by defining new Asymmetric transformations and an adaptive scheme that produces balanced clusters. The biggest advantage of SMYRF is that it can be used as a drop-in replacement for dense attention layers without any retraining. On the contrary, prior fast attention methods impose constraints (e.g. queries and keys share the same vector representations) and require re-training from scratch. We apply our method to pre-trained state-of-the-art Natural Language Processing and Computer Vision models and we report significant memory and speed benefits. Notably, SMYRF-BERT outperforms (slightly) BERT on GLUE, while using 50% less memory. We also show that SMYRF can be used interchangeably with dense attention before and after training. Finally, we use SMYRF to train GANs with attention in high resolutions. Using a single TPU, we were able to scale attention to 128x128=16k and 256x256=65k tokens on BigGAN on CelebA-HQ.

Results

Memory-quality trade-off

GLUE benchmark

</td></td>
Avg.#CCoLAMNLI-m/mmMRPCQNLIQQPRTESST-2STS-B
BERT12882.691157.8384.43/84.6888.4191.3189.7065.7093.4688.73
SMYRF-BERT2x3282.9823258.7983.76/84.2787.6991.1489.7268.5993.2389.65
SMYRF-BERT2x1681.7421658.9082.86/83.4985.7289.5389.3364.9893.1287.75
BERT6481.5716458.8082.34/82.4787.0290.4889.6961.7393.0088.64
BERT3273.5613256.4064.51/63.4177.8979.8188.5955.2392.6683.53

Interchangeability of SMYRF and dense attention

Results on IMDB dataset. Using dense attention on inference consistently improves results, nearly matching dense attention perf.

 MemorySMYRF InferenceAccuracy
RoBERTa100%94.96%
SMYRF-RoBERTa50%93.72%
SMYRF-RoBERTa50%94.62%
BERT100%94.12%
SMYRF-BERT50%92.64%
SMYRF-BERT50%93.54%

Smyrf-BigGAN training on Celeba-HQ-128

Generated faces by a Smyrf-BigGAN trained on 128x128 resolution with attention at 128x128, using 50% of dense memory.

Results after 120k iterations:

 ResolutionAttention#CFID
BigGAN128x12864x641409626.06
Smyrf-BigGAN128x128128x1284204825.03

where # denotes number of hashes and C number of queries per cluster.