• ↑↓ pour naviguer
  • pour ouvrir
  • pour sélectionner
  • ⌘ ⌥ ↵ pour ouvrir dans un panneau
  • esc pour rejeter
⌘ '
raccourcis clavier

Attention operates on a sequence of query QQ, key KK and value VV vector. Attention matrix of a sequence then computed as (Vaswani et al., 2023):

A(Q,K,V)=softmax(QKTd)V   for QL×d,KL×d,VL×dA(Q, K, V) = \text{softmax}(\frac{Q \cdot K^{T}}{\sqrt{d}})V \space \space \text{ for } Q_{L \times d}, K_{L \times d}, V_{L \times d}

equivalent

We can probably arrange the attention function (composed of multiple attention-heads) according to Elhage et al. (2021):

Attnl,h(Xil1)=jiai,jl,hxjl1WVl,hWOl,h\text{Attn}^{\vec{l,h}}(X_{\leq i}^{l-1}) = \sum_{j \leq i}a^{l,h}_{i,j} x^{l-1}_j W^{l,h}_{V} W_{O}^{l,h}

where the learn-able weight matrices WVl,hRd×dhW_{V}^{l,h} \in \mathbb{R}^{d \times d_h} and WOl,hRdh×dW_{O}^{l,h} \in \mathbb{R}^{d_h \times d}, dhd_h is the dimension per head, are combined OV matrix

Muti-head Attention

Allows the model to jointly attend to information from different representation subspaces at different positions:

MHA(Q,K,V)=concat(head1,,headn)WOwhere  headi=A(QWiO,KWiO,VWiO)WORhdv×dmodel\begin{aligned} \text{MHA}(Q,K,V) &= \text{concat}(\text{head}_1, \cdots, \text{head}_n) W^O \\ &\text{where } \space \text{head}_i = \text{A}(QW_i^O, KW_i^O, VW_i^O) \\ W^O & \in \mathbb{R}^{hd_v \times d_{\text{model}}} \end{aligned}

Group-Query Attention

(Ainslie et al., 2023)

idea: reduce number of KV heads nkn_k to a fraction nk=nqkn_k^{'} = \frac{n_q}{k} of number of query heads nqn_q (evenly dividing the query heads into nkn_k groups with rr heads)

RadixAttention

Zheng et al. (2024) proposes this to maintain a LRU eviction policy to maintain relevant KV cache for all requests within a radix tree, Implemented in sgl-project/sglang

radix tree setup:

  • key: sequence of tokens
  • value: KV cache tensor (stored in GPU in a paged layout)

dynamic evolution of the radix tree in response to various requests.

cache-aware scheduling

We define the hit rate as

hit rate=rRnumber of cached prefill tokens in rrRnumber of prefill tokens in r=1CrRnumber of prefill tokens\begin{aligned} \text{hit rate} &= \frac{\sum_{r \in R} \text{number of cached prefill tokens in } r}{\sum_{r \in R} \text{number of prefill tokens in } r} \\[8pt] &=1 - \frac{C}{\sum_{r \in R} \text{number of prefill tokens}} \end{aligned}

in batch settings: sort requests by matching prefix length and prioritise one with longer matched prefixes instead of FIFO schedule.

"\\begin{algorithm}\n\\caption{Cache-Aware Scheduling}\n\\begin{algorithmic}\n\\State \\textbf{Input:} Radix tree $T$, Memory pool $P$.\n\\State \\textbf{Input:} current running batch $B$, waiting queue $Q$.\n\\State \\textbf{Output:} Finished requests and updated system state.\n\\State // Get all requests from the waiting queue\n\\State requests $\\gets Q.\\text{get\\_all\\_requests}()$\n\\State // Search for prefix matching for all waiting request\n\\For{req $\\in$ requests}\n \\State req.prefix\\_node, req.prefix\\_len $\\gets$ T.match\\_prefix(req.input\\_tokens)\n\\EndFor\n\\State // Sort the request according to matched prefix lengths\n\\State requests.sort()\n\\State // Select requests for the next batch\n\\State available\\_size $\\gets$ T.evictable\\_size() + P.available\\_size()\n\\State current\\_size $\\gets$ 0\n\\State new\\_batch $\\gets$ []\n\\For{req $\\in$ requests}\n \\If{req.size() + current\\_size $\\le$ available\\_size}\n \\State new\\_batch.append(req)\n \\State $\\delta \\gets T.\\text{increase\\_ref\\_counter}(req.\\text{prefix\\_node})$\n \\State available\\_size $\\gets$ available\\_size + $\\delta$\n \\EndIf\n\\EndFor\n\\State Q.remove\\_requests(new\\_batch)\n\\State // Insert requests into the current running batch\n\\State B.merge(new\\_batch)\n\\State // Allocate new memory and do eviction if necessary\n\\State needed\\_size $\\gets$ B.needed\\_size()\n\\State success, buffer $\\gets$ P.alloc(needed\\_size)\n\\If{$\\neg \\text{success}$}\n \\State T.evict(needed\\_size)\n \\State success, buffer $\\gets$ P.alloc(needed\\_size)\n\\EndIf\n\\State B.run(buffer)\n\\State // Process finished requests\n\\State finished\\_requests $\\gets$ B.drop\\_finished\\_requests()\n\\For{req $\\in$ finished\\_requests}\n \\State T.decrease\\_ref\\_counter(req.prefix\\_node)\n \\State T.insert(req)\n\\EndFor\n\\State \\Return finished\\_requests\n\\end{algorithmic}\n\\end{algorithm}"

Algorithm 1 Cache-Aware Scheduling

Input: Radix tree TT, Memory pool PP.

Input: current running batch BB, waiting queue QQ.

Output: Finished requests and updated system state.

// Get all requests from the waiting queue

requests Q.get_all_requests()\gets Q.\text{get\_all\_requests}()

// Search for prefix matching for all waiting request

for req \in requests do

req.prefix_node, req.prefix_len \gets T.match_prefix(req.input_tokens)

end for

// Sort the request according to matched prefix lengths

requests.sort()

// Select requests for the next batch

available_size \gets T.evictable_size() + P.available_size()

current_size \gets 0

new_batch \gets []

for req \in requests do

if req.size() + current_size \le available_size then

new_batch.append(req)

δT.increase_ref_counter(req.prefix_node)\delta \gets T.\text{increase\_ref\_counter}(req.\text{prefix\_node})

available_size \gets available_size + δ\delta

end if

end for

Q.remove_requests(new_batch)

// Insert requests into the current running batch

B.merge(new_batch)

// Allocate new memory and do eviction if necessary

needed_size \gets B.needed_size()

success, buffer \gets P.alloc(needed_size)

if ¬success\neg \text{success} then

T.evict(needed_size)

success, buffer \gets P.alloc(needed_size)

end if

B.run(buffer)

// Process finished requests

finished_requests \gets B.drop_finished_requests()

for req \in finished_requests do

T.decrease_ref_counter(req.prefix_node)

T.insert(req)

end for

return finished_requests

We got lower bound:

Ceedges(T)eC \ge \sum_{e \in \text{edges}(T)} \mid e \mid

Consider we visit radix tree TT in DFS order. For each edge ee of TT, the first time we compute KV cache associated with ee, then we will compute the whole subtree of ee.

During computation of ee subtree, then edge ee will be continuously hit, thus no additional computation will happen.

cache hit

with cache size \ge maximum request length (which will equals to longest path in radix tree), edge ee WILL NOT be evicted during computation of its subtree since the common prefix including ee of the subtree will be continuously hit.

We can show that longest-shared-prefix-first order is equivalent to DFS order by induction 1

compressed FSM for jump-ahead tokens.

Implemented in (Zheng et al., 2024)

Method 1: FSM-based decoding

  • intuition: Using FSM (Willard & Louf, 2023) to guide generations by increasing logit bias for tokens that conform to given JSON schema. This allows us to track the current state during decoding and filter out invalid tokens by applying logit bias to the output.

    Decoding with FSM
    figure3: Decoding with FSM
  • limitation: we can see that given construction of FSM requires token-level access, it can only transition the state by only one token at a time, resulting in slow decoding.

Method 2: Interleaved-based

  • intuition: breaks down JSON schemas, each containing either a chunk prefill part or constrained decoding part. They are then executed interleaved by inference system. Faster than per-token decoding given that chunked prefill components can process multiple tokens per forward pass

    See also guidance-ai/guidance using llama.cpp as backend.

  • limitation:

    • interleaved-based require custom syntax, making it less expressive compared to regex.
    • struggles to deal with tokenization boundaries due to conflicts between decode and chunked prefill segments.
    • frequent communications between interpreter and back-end adds additional overhead.

Method 3: Jump-Forward Decoding with compressed FSM

Jump-forward decoding via compressed FSM
figure4: Jump-forward decoding via compressed FSM

tokenization boundary handling

During decoding, it is preferred to combine multiple characters into a single tokens.

For example, when decoding "Hello" in context of JSON decoding, LLM might output the following token ", He, llo, ",

This may cause some strange behaviour if we combine the last " with , (this regex "[\w\d\s]*" with the last , will lead to endless decoding because this token ", is not valid even if the LM wants to stop.)

Fix:

  • implement re-tokenization mechanism during jump-forward phase (append string instead of the tokens, followed with re-tokenization of the entire text) \to add approximately 4% of overhead
  • use a comprehensive regex to guide the decoding phase, instead of employing multiple concatenated regex 2
Lien vers l'original

Multi-head Latent Attention (MLA)

low-rank joint compression for attention keys and values to reduce KV cache during inference (DeepSeek-AI et al., 2025, p. see 2.1.1)

  • dd denote the embedding dimension
  • nhn_h denotes number of attention heads
  • dhd_h denotes dimension per heads
  • htRdh_t \in \mathbb{R}^d denotes the attention input for the tt-th token at a given attention layer
ctKV=WDKVht,[kt,1C;kt,2C;;kt,nhC]=ktC=WUKctKV,ktR=RoPE(WKRht),ki,t=[kt,iC;ktR],[vt,1C;vt,2C;;vt,nhC]=vtC=WUVctKV.\begin{align} \boxed{\textcolor{blue}{\mathbf{c}_t^{KV}}} &= W^{DKV} \mathbf{h}_t, \tag{1} \\ [\mathbf{k}_{t,1}^{C}; \mathbf{k}_{t,2}^{C}; \dots; \mathbf{k}_{t, n_h}^{C}] &= \mathbf{k}_t^C = W^{UK} \mathbf{c}_t^{KV}, \tag{2} \\ \boxed{\textcolor{blue}{\mathbf{k}_t^{R}}} &= \mathrm{RoPE}(W^{KR} \mathbf{h}_t), \tag{3} \\ \mathbf{k}_{i,t} &= [\mathbf{k}_{t,i}^{C}; \mathbf{k}_t^{R}], \tag{4} \\ [\mathbf{v}_{t,1}^{C}; \mathbf{v}_{t,2}^{C}; \dots; \mathbf{v}_{t,n_h}^{C}] &= \mathbf{v}_t^{C} = W^{UV} \mathbf{c}_t^{KV}. \tag{5} \end{align}
  • where ctKVRdcc_{t}^{KV} \in \mathbb{R}^{d_{c}} is the compression latent for keys and values
  • dcdhnhd_c \ll d_h n_h indicates KV compression dimension
  • WDKVRdc×dW^{DKV} \in \mathbb{R}^{d_c \times d} denotes down-projection matrix
  • WUK,WUVRdhnh×dcW^{UK}, W^{UV} \in \mathbb{R}^{d_h n_h \times d_c} are the up-projection matrices to keys and values, respectively
  • WKRRdhR×dW^{KR} \in \mathbb{R}^{d^R_h \times d} is the matrix used to produced the duplicate key that carries RoPE
  • RoPE(.)\mathrm{RoPE}(.) denotes operations for RoPE matrices, and RoPE[;]\mathrm{RoPE}[;] denotes concatenation

cached generations

Both ctKV\textcolor{blue}{\mathbf{c}_t^{KV}} and ktR\textcolor{blue}{\mathbf{k}_t^{R}} should be cached to reduce KV cache while maintaining performance with MHA

For attention queries, we can perform the same operation:

ctQ=WDQht,[qt,1C;qt,2C;;qt,nhC]=qtC=WUQctQ,[qt,1R;qt,2R;;qt,nhR]=RoPE(WQRctQ),qi,t=[qt,iC;qtR],\begin{align} \mathbf{c}_t^{Q} &= W^{DQ} \mathbf{h}_t, \tag{6} \\ [\mathbf{q}_{t,1}^{C}; \mathbf{q}_{t,2}^{C}; \dots; \mathbf{q}_{t, n_h}^{C}] &= \mathbf{q}_t^C = W^{UQ} \mathbf{c}_t^{Q}, \tag{7} \\ [\mathbf{q}_{t,1}^{R}; \mathbf{q}_{t,2}^{R}; \dots; \mathbf{q}_{t, n_h}^{R}] &= \mathrm{RoPE}(W^{QR} \mathbf{c}_t^Q), \tag{8} \\ \mathbf{q}_{i,t} &= [\mathbf{q}_{t,i}^{C}; \mathbf{q}_t^{R}], \tag{9} \end{align}
  • ctQc_t^Q is the compressed latent of queries
  • dcdhnhd_c \ll d_h n_h indicates queries compression dimension
  • WDQRdc×d,WUQRdhnh×dcW^{DQ} \in \mathbb{R}^{d^{'}_c \times d}, W^{UQ} \in \mathbb{R}^{d_h n_h \times d^{'}_c} are the up and down projections matrices
  • WQRRdhRnh×dcW^{QR} \in \mathbb{R}^{d_{h}^R n_{h} \times d_{c}^{'}} is the matrix that produce decompiled queries that carry RoPE

Attention output

The attention output ut\mathbf{u}_{t} can be calculated with the following:

ot,i=j=1tSoftmaxj(qt,iTkj,idh+dhR)vjiC,ut=WO[ot,1;ot,2;;ot,nh]\begin{align} \mathbf{o}_{t,i} &= \sum_{j=1}^{t} \mathrm{Softmax}_j (\frac{q_{t,i}^T k_{j,i}}{\sqrt{d_h + d_h^R}}) v_{j_i}^C, \tag{10} \\ \mathbf{u}_t &= \mathbf{W}^O [o_{t,1}; o_{t,2}; \dots; o_{t, n_h}] \tag{11} \end{align}

RingAttention

RazorAttention

Paged Attention

(Kwon et al., 2023)

In conjunction with continuous batching, implemented in vLLM

Reduce memory usage of attention mechanism by swapping kv-cache in and out of memory. A block manager is similar to those of virtual memory in OS.

Essentially, it’s a form of paging, such that attention can be stored in contiguous memory. Partitions the KV cache of each sequence into KV blocks.

Another optimization is to use KV compression to reduce the size of the KV cache for longer context.

Given:

  • each block contains KV vectors for fixed number of tokens, denoted as block size BB.
  • Key block Kj=(k(j1)B+1,,kjB)K_j= (k_{(j-1)B+1}, \ldots, k_{jB})
  • Value block Vj=(v(j1)B+1,,vjB)V_j= (v_{(j-1)B+1}, \ldots, v_{jB})
Aij=exp(qiTKj/d)t=1i//Bexp(qiTKt/d),oi=j=1i//BVjAijTA_{ij} = \frac{\exp(q_i^T K_j / \sqrt{d})}{\sum_{t=1}^{i//B} \exp(q_i^T K_t / \sqrt{d})}, \quad o_i = \sum_{j=1}^{i//B} V_j A_{ij}^T

where Aij=(ai,(j1)B+1,ai,jB)A_{ij}=(a_{i,(j-1)B+1}, \ldots a_{i,jB}) is row vector of attention score on j-th KV block.

Remarque

  1. base: a random request correspond to node xTx \in T will be processed.

    • All requests correspond to nodes {v1,,vn}\{v_{1}, \ldots, v_{n}\} on path xrootx \gets \text{root} doesn’t need recomputation.
    • Thus, computation complexity for requests of nodes {v1,,vn,x}\{v_{1}, \ldots, v_{n}, x\} is aligned with DFS

    induction: assume we visit node yTy \in T, and the visited node align with DFS order. Let PP denote path of yrooty \gets \text{root}.

    • Each node that has not been visited has the lowest common ancestor with visited nodes on PP.
    • Since nodes on PP are cached, a node zz that has yet to be visited with lowest common accestor on PP will have the longest shared prefix
    • longest-shared-prefix-first order will select zz, which is a valid DFS q.e.d
  2. this phenomena is also known as coalescence in structured generations, where it exploit deterministic structures in desired outputs to skip expensive forward pass