profile pic
⌘ '
raccourcis clavier

(Vaswani et al., 2023)

Attention operates on a sequence of query QQ, key KK and value VV vector. Attention matrix of a sequence then computed as:

A(Q,K,V)=softmax(QKTd)V   for QL×d,KL×d,VL×dA(Q, K, V) = \text{softmax}(\frac{Q \cdot K^{T}}{\sqrt{d}})V \space \space \text{ for } Q_{L \times d}, K_{L \times d}, V_{L \times d}

We can probably arrange the attention function (composed of multiple attention-heads) according to (Elhage et al., 2021):

Attnl,h(Xil1)=jiai,jl,hxjl1WVl,hWOl,h\text{Attn}^{\vec{l,h}}(X_{\leq i}^{l-1}) = \sum_{j \leq i}a^{l,h}_{i,j} x^{l-1}_j W^{l,h}_{V} W_{O}^{l,h}

where the learn-able weight matrices WVl,hRd×dhW_{V}^{l,h} \in \mathbb{R}^{d \times d_h} and WOl,hRdh×dW_{O}^{l,h} \in \mathbb{R}^{d_h \times d}, dhd_h is the dimension per head, are combined OV matrix

Muti-head Attention

Allows the model to jointly attend to information from different representation subspaces at different positions:

MHA(Q,K,V)=concat(head1,,headn)WOwhere  headi=A(QWiO,KWiO,VWiO)WORhdv×dmodel\begin{aligned} \text{MHA}(Q,K,V) &= \text{concat}(\text{head}_1, \cdots, \text{head}_n) W^O \\ &\text{where } \space \text{head}_i = \text{A}(QW_i^O, KW_i^O, VW_i^O) \\ W^O & \in \mathbb{R}^{hd_v \times d_{\text{model}}} \end{aligned}

Group-Query Attention

by (Ainslie et al., 2023)

idea: reduce number of KV heads nkn_k to a fraction nk=nqkn_k^{'} = \frac{n_q}{k} of number of query heads nqn_q (evenly dividing the query heads into nkn_k groups with rr heads)

RadixAttention

Implemented in (Zheng et al., 2024) where they maintain a LRU eviction policy to maintain relevant KV cache for all requests within a radix tree

radix tree setup:

  • key: sequence of tokens
  • value: KV cache tensor (stored in GPU in a paged layout)

dynamic evolution of the radix tree in response to various requests.

cache-aware scheduling

We define the hit rate as

hit rate=rRnumber of cached prefill tokens in rrRnumber of prefill tokens in r=1CrRnumber of prefill tokens\begin{aligned} \text{hit rate} &= \frac{\sum_{r \in R} \text{number of cached prefill tokens in } r}{\sum_{r \in R} \text{number of prefill tokens in } r} \\[8pt] &=1 - \frac{C}{\sum_{r \in R} \text{number of prefill tokens}} \end{aligned}

in batch settings: sort requests by matching prefix length and prioritise one with longer matched prefixes instead of FIFO schedule.

"\\begin{algorithm}\n\\caption{Cache-Aware Scheduling}\n\\begin{algorithmic}\n\\State \\textbf{Input:} Radix tree $T$, Memory pool $P$.\n\\State \\textbf{Input:} current running batch $B$, waiting queue $Q$.\n\\State \\textbf{Output:} Finished requests and updated system state.\n\\State // Get all requests from the waiting queue\n\\State requests $\\gets Q.\\text{get\\_all\\_requests}()$\n\\State // Search for prefix matching for all waiting request\n\\For{req $\\in$ requests}\n \\State req.prefix\\_node, req.prefix\\_len $\\gets$ T.match\\_prefix(req.input\\_tokens)\n\\EndFor\n\\State // Sort the request according to matched prefix lengths\n\\State requests.sort()\n\\State // Select requests for the next batch\n\\State available\\_size $\\gets$ T.evictable\\_size() + P.available\\_size()\n\\State current\\_size $\\gets$ 0\n\\State new\\_batch $\\gets$ []\n\\For{req $\\in$ requests}\n \\If{req.size() + current\\_size $\\le$ available\\_size}\n \\State new\\_batch.append(req)\n \\State $\\delta \\gets T.\\text{increase\\_ref\\_counter}(req.\\text{prefix\\_node})$\n \\State available\\_size $\\gets$ available\\_size + $\\delta$\n \\EndIf\n\\EndFor\n\\State Q.remove\\_requests(new\\_batch)\n\\State // Insert requests into the current running batch\n\\State B.merge(new\\_batch)\n\\State // Allocate new memory and do eviction if necessary\n\\State needed\\_size $\\gets$ B.needed\\_size()\n\\State success, buffer $\\gets$ P.alloc(needed\\_size)\n\\If{$\\neg \\text{success}$}\n \\State T.evict(needed\\_size)\n \\State success, buffer $\\gets$ P.alloc(needed\\_size)\n\\EndIf\n\\State B.run(buffer)\n\\State // Process finished requests\n\\State finished\\_requests $\\gets$ B.drop\\_finished\\_requests()\n\\For{req $\\in$ finished\\_requests}\n \\State T.decrease\\_ref\\_counter(req.prefix\\_node)\n \\State T.insert(req)\n\\EndFor\n\\State \\Return finished\\_requests\n\\end{algorithmic}\n\\end{algorithm}"

Algorithm 1 Cache-Aware Scheduling

Input: Radix tree TT, Memory pool PP.

Input: current running batch BB, waiting queue QQ.

Output: Finished requests and updated system state.

// Get all requests from the waiting queue

requests Q.get_all_requests()\gets Q.\text{get\_all\_requests}()

// Search for prefix matching for all waiting request

for req \in requests do

req.prefix_node, req.prefix_len \gets T.match_prefix(req.input_tokens)

end for

// Sort the request according to matched prefix lengths

requests.sort()

// Select requests for the next batch

available_size \gets T.evictable_size() + P.available_size()

current_size \gets 0

new_batch \gets []

for req \in requests do

if req.size() + current_size \le available_size then

new_batch.append(req)

δT.increase_ref_counter(req.prefix_node)\delta \gets T.\text{increase\_ref\_counter}(req.\text{prefix\_node})

available_size \gets available_size + δ\delta

end if

end for

Q.remove_requests(new_batch)

// Insert requests into the current running batch

B.merge(new_batch)

// Allocate new memory and do eviction if necessary

needed_size \gets B.needed_size()

success, buffer \gets P.alloc(needed_size)

if ¬success\neg \text{success} then

T.evict(needed_size)

success, buffer \gets P.alloc(needed_size)

end if

B.run(buffer)

// Process finished requests

finished_requests \gets B.drop_finished_requests()

for req \in finished_requests do

T.decrease_ref_counter(req.prefix_node)

T.insert(req)

end for

return finished_requests

We got lower bound:

Ceedges(T)eC \ge \sum_{e \in \text{edges}(T)} \mid e \mid

Consider we visit radix tree TT in DFS order. For each edge ee of TT, the first time we compute KV cache associated with ee, then we will compute the whole subtree of ee.

During computation of ee subtree, then edge ee will be continuously hit, thus no additional computation will happen.

cache hit

with cache size \ge maximum request length (which will equals to longest path in radix tree), edge ee WILL NOT be evicted during computation of its subtree since the common prefix including ee of the subtree will be continuously hit.

We can show that longest-shared-prefix-first order is equivalent to DFS order by induction 1

compressed FSM for jump-ahead tokens.

Implemented in (Zheng et al., 2024)

Method 1: FSM-based decoding

  • intuition: Using FSM (Willard & Louf, 2023) to guide generations by increasing logit bias for tokens that conform to given JSON schema. This allows us to track the current state during decoding and filter out invalid tokens by applying logit bias to the output.

  • limitation: we can see that given construction of FSM requires token-level access, it can only transition the state by only one token at a time, resulting in slow decoding.

Method 2: Interleaved-based

  • intuition: breaks down JSON schemas, each containing either a chunk prefill part or constrained decoding part. They are then executed interleaved by inference system. Faster than per-token decoding given that chunked prefill components can process multiple tokens per forward pass

    See also https://github.com/guidance-ai/guidance#guidance-acceleration using llama.cpp as backend.

  • limitation:

    • interleaved-based require custom syntax, making it less expressive compared to regex.
    • struggles to deal with tokenization boundaries due to conflicts between decode and chunked prefill segments.
    • frequent communications between interpreter and back-end adds additional overhead.

Method 3: Jump-Forward Decoding with compressed FSM

tokenization boundary handling

During decoding, it is preferred to combine multiple characters into a single tokens.

For example, when decoding "Hello" in context of JSON decoding, LLM might output the following token ", He, llo, ",

This may cause some strange behaviour if we combine the last " with , (this regex "[\w\d\s]*" with the last , will lead to endless decoding because this token ", is not valid even if the LM wants to stop.)

Fix:

  • implement re-tokenization mechanism during jump-forward phase (append string instead of the tokens, followed with re-tokenization of the entire text) \to add approximately 4% of overhead
  • use a comprehensive regex to guide the decoding phase, instead of employing multiple concatenated regex 2
Lien vers l'original

RingAttention

RazorAttention

Paged Attention

by (Kwon et al., 2023)

Used in conjunction with continuous batching, implemented through vLLM

Reduce memory usage of attention mechanism by swapping kv-cache in and out of memory. A block manager is similar to those of virtual memory in OS.

Essentially, it’s a form of paging, such that attention can be stored in contiguous memory. Partitions the KV cache of each sequence into KV blocks.

Another optimization is to use KV compression to reduce the size of the KV cache for longer context.

Given:

  • each block contains KV vectors for fixed number of tokens, denoted as block size BB.
  • Key block Kj=(k(j1)B+1,,kjB)K_j= (k_{(j-1)B+1}, \ldots, k_{jB})
  • Value block Vj=(v(j1)B+1,,vjB)V_j= (v_{(j-1)B+1}, \ldots, v_{jB})
Aij=exp(qiTKj/d)t=1i//Bexp(qiTKt/d),oi=j=1i//BVjAijTA_{ij} = \frac{\exp(q_i^T K_j / \sqrt{d})}{\sum_{t=1}^{i//B} \exp(q_i^T K_t / \sqrt{d})}, \quad o_i = \sum_{j=1}^{i//B} V_j A_{ij}^T

where Aij=(ai,(j1)B+1,ai,jB)A_{ij}=(a_{i,(j-1)B+1}, \ldots a_{i,jB}) is row vector of attention score on j-th KV block.

Bibliographie

  • Ainslie, J., Lee-Thorp, J., de Jong, M., Zemlyanskiy, Y., Lebrón, F., & Sanghai, S. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. arXiv preprint arXiv:2305.13245 [arxiv]
  • Elhage, N., Nanda, N., Olsson, C., Henighan, T., Joseph, N., Mann, B., Askell, A., Bai, Y., Chen, A., Conerly, T., DasSarma, N., Drain, D., Ganguli, D., Hatfield-Dodds, Z., Hernandez, D., Jones, A., Kernion, J., Lovitt, L., Ndousse, K., … Olah, C. (2021). A Mathematical Framework for Transformer Circuits. Transformer Circuits Thread.
  • Kwon, W., Li, Z., Zhuang, S., Sheng, Y., Zheng, L., Yu, C. H., Gonzalez, J. E., Zhang, H., & Stoica, I. (2023). Efficient Memory Management for Large Language Model Serving with PagedAttention. Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles.
  • Liu, H., Zaharia, M., & Abbeel, P. (2023). Ring Attention with Blockwise Transformers for Near-Infinite Context. arXiv preprint arXiv:2310.01889 [arxiv]
  • Tang, H., Lin, Y., Lin, J., Han, Q., Hong, S., Yao, Y., & Wang, G. (2024). RazorAttention: Efficient KV Cache Compression Through Retrieval Heads. arXiv preprint arXiv:2407.15891 [arxiv]
  • Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., & Polosukhin, I. (2023). Attention Is All You Need. arXiv preprint arXiv:1706.03762 [arxiv]
  • Zheng, L., Yin, L., Xie, Z., Sun, C., Huang, J., Yu, C. H., Cao, S., Kozyrakis, C., Stoica, I., Gonzalez, J. E., Barrett, C., & Sheng, Y. (2024). SGLang: Efficient Execution of Structured Language Model Programs. arXiv preprint arXiv:2312.07104 [arxiv]

Remarque

  1. base: a random request correspond to node xTx \in T will be processed.

    • All requests correspond to nodes {v1,,vn}\{v_{1}, \ldots, v_{n}\} on path xrootx \gets \text{root} doesn’t need recomputation.
    • Thus, computation complexity for requests of nodes {v1,,vn,x}\{v_{1}, \ldots, v_{n}, x\} is aligned with DFS

    induction: assume we visit node yTy \in T, and the visited node align with DFS order. Let PP denote path of yrooty \gets \text{root}.

    • Each node that has not been visited has the lowest common ancestor with visited nodes on PP.
    • Since nodes on PP are cached, a node zz that has yet to be visited with lowest common accestor on PP will have the longest shared prefix
    • longest-shared-prefix-first order will select zz, which is a valid DFS q.e.d