Flash Attention is a method to improve the efficiency of transformer models, in particular large language models (LLMs), helping reduce both model training time and inference latency. Inference latency is, in particular, a challenge for LLMs, and flash attention has become a key technique that enables your LLM applications to respond faster.
Transformer models are built on the attention mechanism, which helps the model focus on relevant parts of the text input when making predictions. However, as transformer-based models become larger and larger to handle more complex tasks or larger datasets, a major limitation arises from the self-attention mechanism. This mechanism becomes increasingly slow and memory intensive as the model size grows. This is because it keeps loading and unloading data from memory. Flash Attention is introduced as a solution to mitigate this memory bottleneck problem associated with attention mechanisms in transformer models.
By improving the efficiency of attention operations, Flash Attention allows for faster training and inference of transformer-based models. Rather than loading queries, keys, and values, or intermediate computation results multiple times for each computation iteration, Flash Attention loads all the data (queries, keys, and values) just once. It then computes the attention score (conducts a series of operations) on this loaded data before writing back the final results. Additionally, it divides the loaded data into smaller blocks, aiding parallel processing.
By strategically minimizing the back-and-forth data transfers between memory types, Flash Attention optimizes resource utilization. Key strategies include "kernel fusion," which combines multiple computation steps into a single operation, reducing the need for repetitive data transfers thus reducing overhead. This streamlined approach not only enhances computational efficiency but also simplifies the implementation process, making it accessible to a broader audience of practitioners. Another key strategy is "tiling", which involves partitioning the input data into smaller blocks to facilitate parallel processing. This strategy optimizes memory usage, enabling scalable solutions for models with larger input sizes.
High Bandwidth Memory (HBM) offers large memory capacity but suffers from slower processing speeds. On the other hand, SRAM (Static Random-Access Memory) is a type of memory that provides fast access to data but is typically limited in capacity compared to HBM. On-chip SRAM, as the name suggests, is located directly on the chip, enabling even faster access times compared to off-chip memory.
In standard attention mechanisms, such as those used in standard transformer models, HBM is used to store, read, and write the keys, queries, and values used in the attention computation. However, the operations involved in attention calculations often lead to frequent data transfers between HBM and on-chip SRAM. For example, during computation, keys, queries, and values are loaded from HBM into on-chip SRAM for processing, and intermediate results and final outputs are written back to HBM after each step of the attention mechanism. The frequent movement of data between HBM and SRAM results in high overhead due to the time spent on data transfer and processing.
Instead, Flash Attention optimizes the movement of data between HBM and on-chip SRAM by reducing redundant reads and writes. Instead of performing these operations for each individual attention step, Flash Attention loads the keys, queries, and values only once, combines or "fuses" the operations of the attention mechanism, and then writes the results back to memory. This reduces the overall computational overhead and improves efficiency.
In summary, while standard attention mechanisms rely heavily on data movement between HBM and SRAM, Flash Attention introduces optimizations such as optimized data movement, kernel fusion, and efficient memory usage to minimize overhead and improve efficiency in memory access and computation. The impact of Flash Attention offers tangible benefits in terms of both training speed and inference latency.
Axolotl supports flash-attention for open-source models like Llama-2 and Mistral. You can enable flash-attention by installing its profile along with axolotl:
pip install axolotl[flash-attn]
Axolotl can be used for fine-tuning models on Hopsworks by simply installing it as a Python dependency in your project. Your fine-tuning training data can be loaded from Hopsworks by Axolotl using the built-in FUSE support that makes your training data, stored on HopsFS-S3, available as local files to Axolotl.
Several model serving servers now support flash attention, including vLLM, and HF’s one. It is anticipated that many more model serving servers will support flash attention to supercharge LLMs.
For an Enterprise model serving solution with flash attention, Hopsworks comes with KServe support, which includes support for both vLLM and HF model serving servers. This gives you the benefits of scale, low latency, logging, monitoring, and access control for serving LLMs at high performance.