This repository provides PAT (Prefix-Aware Attention), a high-performance CUTLASS-based implementation designed to optimize the decoding attention phase in transformer models.
PAT: Accelerating LLM Decoding via Prefix-Aware Attention with Resource Efficient Multi-Tile Kernel
Jinjun Yi†, Zhixin Zhao†, Yitao Hu*, Ke Yan, Weiwei Sun, Hao Wang, Laiping Zhao, Yuhao Zhang, Wenxin Li, Keqiu Li.
ACM International Conference on Architectural Support for Programming Languages and Operating Systems (ASPLOS), 2026
Features: PAT identifies complex shared prefix patterns within batched sequences and schedules shared prefixes into separate CTA computations. This approach significantly reduces KV cache reads, which is the primary bottleneck in attention computation during LLM decoding.
Usage: PAT serves as a plugin for LLM serving systems. To enble PAT in vLLM, only an environment variable VLLM_ATTENTION_BACKEND="PREFIX_ATTN" is required. Please refer to Install PAT from Source for detailed instructions on installing PAT for NVIDIA A100 and H100 GPUs.
Below are the instructions to reproduce the experimental results presented in our paper "PAT: Accelerating LLM Decoding via Prefix-Aware Attention with Resource Efficient Multi-Tile Kernel". The repository is organized as follows:
benchmark/: Contains scripts for kernel performance experiments and end-to-end serving performance experiments.csrc/: Contains the core implementation of the PAT.plot/: Contains scripts for generating plots from experimental results.plugin/: Contains vLLM plugins to integrate PAT with vLLM.prefix_attn/: Contains the main Python package for PAT.test/: Contains unit tests for PAT.
To run these experiments, you will need:
- An x86-64 Linux host with at least 64GB RAM.
- 200GB of free disk space.
- An NVIDIA A100 GPU with 80GB of memory.
- NVIDIA driver >= 550 and CUDA >= 12.4.
We have tested the experiments on Google Cloud a2-ultragpu-1g instance (200GB disk) with the Deep Learning VM with CUDA 12.4 M129 system image. We recommend using a similar setup to ensure convenience and consistent performance.
Hint: You can use multiple GPUs (e.g.,
a2-ultragpu-8ginstance) to speed up the end-to-end performance experiments. The scripts will automatically detect the available GPUs and distribute the experiments across them.
If Docker has not been installed, run the following commands:
curl -fsSL https://get.docker.com -o get-docker.sh
sh get-docker.sh
dockerd &git clone https://github.com/flashserve/PAT.git
docker pull flashserve/pat:ae # (~50 GB, including model weights)docker run -it --gpus all -v ${PWD}/PAT:/workspace/PAT -w /workspace \
--shm-size=64g flashserve/pat:ae /bin/bashcd /workspace/PAT/test
python test.pyIf the tests pass successfully, you should see the output: [INFO] successfully pass the test!
This script evaluates the attention kernel execution latency under synthetic workloads for different methods, as shown in Section 8.3 and Figure 10 of the paper.
cd /workspace/PAT/benchmark
# This experiment takes about 1.5 hours to complete
bash ./run_kernel_bench.shThe results will be saved in kernel_perf.json.
This script evaluates end-to-end inference latency across different methods under real-world workloads, corresponding to Section 8.4 and Figure 11 in the paper. Note that completing all experiments requires over 60 GPU-hours, so we provide two scripts: (1) run_e2e_bench_part.sh: runs a subset of experiments (QPS=7, all workloads, all baselines) for quick verification of results; (2) run_e2e_bench_full.sh: runs all experiments to reproduce the results in the paper.
Hint: To run the full experiments (
run_e2e_bench_full.sh), you can use multiple GPUs (e.g.,a2-ultragpu-8ginstance) to speed up the experiments. The scripts will automatically detect the available GPUs and distribute the experiments across them.
cd /workspace/PAT/benchmark
# Quick verification (takes 4-5 GPU-hours to complete)
bash ./run_e2e_bench_part.sh
# Full experiments (takes over 60 GPU-hours to complete)
# bash ./run_e2e_bench_full.shThe results will be saved in e2e_perf.jsonl.
cd /workspace/PAT/plot
python eval_kernel_perf.py --log-file ../benchmark/kernel_perf.jsonThis will generate a plot fig/kernel_performance_overall.pdf, showing the kernel performance comparison among different methods, corresponding to Figure 10 in the paper as follows.
cd /workspace/PAT/plot
python eval_e2e_from_jsonl.py --log-file ../benchmark/e2e_perf.jsonlThis will generate a plot fig/eval_e2e_overall_p99.pdf, showing the end-to-end serving performance comparison among different methods, corresponding to Figure 11 in the paper as follows.
Alternatively, PAT can be installed directly from source without using Docker. This setup supports both NVIDIA A100 and H100 GPUs, and we have validated PAT on both. Depending on hardware and network conditions, the full installation typically takes about 1–3 hours.
-
Requirements: A100 / H100 GPU, CUDA>=12.4
-
Clone PAT, vLLM, and CUTLASS repositories
mkdir ~/workspace && cd ~/workspace
git clone https://github.com/flashserve/PAT.git
git clone https://github.com/NVIDIA/cutlass.git
git clone https://github.com/vllm-project/vllm.git- Build vLLM with PAT plugin (1-2 hours)
cd ~/workspace/vllm
git checkout v0.9.0
# Add PAT plugin to vLLM
rsync -av --progress ../PAT/plugin/vllm/ ./vllm/
TORCH_CUDA_ARCH_LIST="8.0" pip install .- Install FlashInfer and other dependencies
pip install flashinfer-python==0.2.5 transformers==4.53.0 numpy==1.24.0- Build PAT from source (~10 minutes)
# PyTorch (2.7.0) is required if vLLM is not installed
cd ~/workspace/PAT
# Replace <abs_path_to_cutlass> with the absolute path to CUTLASS repo
CUTLASS_ROOT=<abs_path_to_cutlass> pip install . --no-build-isolation- Launch vLLM with PAT
VLLM_ATTENTION_BACKEND="PREFIX_ATTN" VLLM_USE_V1=0 \
vllm serve Qwen/Qwen3-8B --enable-prefix-caching --enforce-eagerIf you use this codebase, or otherwise found our work valuable, please cite:
@inproceedings{yi2026pat,
title={PAT: Accelerating LLM Decoding via Prefix-Aware Attention with Resource Efficient Multi-Tile Kernel},
author={Yi, Jinjun and Zhao, Zhixin and Hu, Yitao and Yan, Ke and Sun, Weiwei and Wang, Hao and Zhao, Laiping and Zhang, Yuhao and Li, Wenxin and Li, Keqiu},
booktitle={Proceedings of the 31st ACM International Conference on Architectural Support for Programming Languages and Operating Systems},
year={2026}
}

