• ↑↓ pour naviguer
  • pour ouvrir
  • pour sélectionner
  • ⌘ ⌥ ↵ pour ouvrir dans un panneau
  • esc pour rejeter
⌘ '
raccourcis clavier

See also: LLMs, embedding, visualisation from Brendan Bycroft

A multi-layer perceptron (MLP) architecture built on top of a multi-head attention mechanism (Vaswani et al., 2023) to signal high entropy tokens to be amplified and less important tokens to be diminished.

ELI5: Mom often creates a food list consists of nn of items to buy. Your job is to guess what the last item on this list would be.

Most implementations are autoregressive. Most major SOTA are decoder-only, as encoder-decoder models has lack behind due to their expensive encoding phase.

state-space models which address transformers’ efficiency issuesarXiv in attention layers within information-dense data

memory limitations.

see also: arXivarXiv

Arithmetic intensity can be determined with the following:

Arithmetic Intensity=# FLOPs# MOPs\text{Arithmetic Intensity} = \frac{\text{\# FLOPs}}{\text{\# MOPs}}

inference.

Either compute-bound (batch inference, saturated usage) or memory-bound (latency)

speculative decoding memory-bound (to saturate FLOPs)

KV

The core “retrieval” bags that contains all previous stored key-value pair or newly added items.

Prefill disaggregation is pretty interesting in a sense that we can separate prefill stage to a separate nodes (Qin et al., 2024)

KV-centric optimization
figure1: KV-centric optimization

next-token prediction.

Sampling: we essentially look forward K-tokens, and then we sample from the distribution of the next token.

multi-token prediction.

(Gloeckle et al., 2024)

MTP implementation in DeepSeek, where they keep causal chain for prediction of each token at each depth
figure2: MTP implementation in DeepSeek, where they keep causal chain for prediction of each token at each depth

tl/dr: predict nn-tokens at once, via shared trunk and n dedicated attention heads 1

Note that during inference, we only employ one attention head

Byte-Latent Transformer

idea: learn from raw-bytes and skip tokenizer/detokenizer protocol.

Feynman-Kac

Let V\mathcal{V} be the vocab of given transformers model, and S=V\mathcal{S} = \mathcal{V}^{*} the set of multi-token strings. Assume V\mathcal{V} contains token EOS and write FS\mathcal{F} \subseteq \mathcal{S} for the set of EOS-terminated strings.

Feynman-Kac Transformer model

is a tuple (s0,{Mt}t1,{Gt}t1)(s_{0}, \{M_t\}_{t\ge 1}, \{G_t\}_{t\ge 1}) where:

  • s0Ss_{0} \in \mathcal{S} is an initial state, which will take as empty string ϵ\epsilon
  • Mt(stst1,fθ)M_t(s_t \mid s_{t-1}, f_\theta) is a Markov kernel from st1Fcs_{t-1} \in \mathcal{F}^c to stSs_t \in \mathcal{S}, parameterised by a transformer network fθ:FcRVf_\theta: \mathcal{F}^c \to \mathbb{R}^{\mid \mathcal{V} \mid} mapping non-EOS-terminated strings to vectors of logits
  • Gt(st1,st,fθ)G_t(s_{t-1}, s_t, f_\theta) is a potential function, mapping a pair (st1,st)Fc×S(s_{t-1}, s_t) \in \mathcal{F}^c \times \mathcal{S} to a real-valued non-negative score.

Goal: generate from distribution P\mathbb{P} that reweights Markov chain M\mathbb{M} by potential functions GtG_t. We define step-t filtering posteriors:

Pt(st)=EM[i=1tTGi(Si1,Si,fθ)[St=st]]EM[i=1tTGi(Si1,Si,fθ)]P_t(s_t) = \frac{\mathbb{E}_\mathbb{M} \left[ \prod_{i=1}^{t \wedge T} G_i(S_{i-1}, S_i, f_\theta) \cdot [S_t = s_t] \right]}{\mathbb{E}_\mathbb{M} \left[ \prod_{i=1}^{t \wedge T} G_i(S_{i-1}, S_i, f_\theta) \right]}

Given that TT is mostly finite we can then define overall posterior P(s)=limtPt(s)\mathbb{P}(s) = \lim_{t \to \infty} \mathbb{P}_t(s) (Lew et al., 2023, p. see 2.2 for examples)

"\\begin{algorithm}\n\\caption{Sequential Monte Carlo Transformer Steering}\n\\begin{algorithmic}\n\\State \\textbf{Input:} $N$ (\\# particles), $K$ (factor), Feynman-Kac Transformer model $\\{s_0, \\{M_t\\}_{t \\geq 1}, \\{G_t\\}_{t \\geq 1}\\}$\n\\State \\textbf{Output:} Weighted particle approximation $\\{(x_i, w_i)\\}_{i=1,\\ldots,N}$ of the posterior $\\mathbb{P}$ \\\\\n\\State \\textbf{Output:} Unbiased estimate $\\hat{Z}$ of the partition function $Z = \\mathbb{E}_\\mathbb{M}[\\prod_{t=1}^T G_t(s_t, s_{t-1}, f_\\theta)]$ \\\\\n\\State Initialize $f_\\theta \\gets \\texttt{CachedTransformer}()$\n\\State Initialize $(x_i, w_i) \\gets (s_0, 1)$ for $i = 1, \\ldots, N$\n\\State Initialize $t \\gets 1$\n\\While{$x_i \\not\\in \\mathcal{F}$ for some $i \\in \\{1, \\ldots, N\\}$}\n \\State $K_i \\gets K (1 - \\mathbb{1}_{\\mathcal{F}}(x_i)) + \\mathbb{1}_{\\mathcal{F}}(x_i)$ for $i = 1, \\ldots, N$\n \\State $N' \\gets \\sum_{i=1}^N K_i$\n \\For{$i \\in \\{1, \\ldots, N\\}$}\n \\If{$x_i \\in \\mathcal{F}$}\n \\State Set $(x_{i,1}, w_{i,1}) \\gets (x_i, w_i \\cdot \\frac{N'}{N})$\n \\Else\n \\State Generate $x_{i,k} \\sim M_t(\\cdot \\mid x_i, f_\\theta)$ for $k = 1, \\ldots, K$\n \\State Set $w_{i,k} \\gets w_i \\cdot G_t(x_i, x_{i,k}, f_\\theta) \\cdot \\frac{N'}{K N}$ for $k = 1, \\ldots, K$\n \\EndIf\n \\EndFor\n \\State Set normalized weights $\\hat{w}_{i,k} \\gets \\frac{w_{(i,k)}}{\\sum_{j=1}^N \\sum_{l=1}^{K_j} w_{(j,l)}}$ for $i = 1, \\ldots, N$ and $k = 1, \\ldots, K_i$\n \\State Set $c^* \\gets \\inf\\{c \\in \\mathbb{R}_{> 0} \\mid \\sum_{i=1}^N \\sum_{k=1}^{K_i} (\\mathbb{1} \\wedge c \\hat{w}_{(i,k)}) > N\\}$\n \\State Set $(I_\\text{det}, I_\\text{stoch}, I_\\text{strat}) \\gets (\\{(i,k) \\mid c^{*} \\hat{w}_{i,k} \\geq 1\\}, \\{(i,k) \\mid c^{*} \\cdot \\hat{w}_{i,k} < 1\\}, \\{\\})$\n \\State Set $\\alpha \\gets \\frac{\\sum_{i \\in I_\\text{stoch}} \\hat{w}_i}{|I_\\text{det}|}$ and generate $U \\sim \\text{Uniform}([0, \\alpha])$\n \\For{$i \\in I_\\text{stoch}$}\n \\State Set $U \\gets U - \\hat{w}_i$\n \\If{$U < 0$}\n \\State Set $I_\\text{strat} \\gets I_\\text{strat} \\cup \\{i\\}$\n \\State Set $U \\gets U + \\alpha$\n \\EndIf\n \\EndFor\n \\State Set particles $\\{(x_i, w_i)\\}_{i=1,\\ldots,|I_\\text{det}|} \\gets \\{(x_j, w_j \\cdot \\frac{N}{N'}) \\mid j \\in I_\\text{det}\\}$\n \\State Set particles $\\{(x_i, w_i)\\}_{i=|I_\\text{det}|+1,\\ldots,N} \\gets \\{(x_j, \\frac{N}{c^* N'} \\sum_{l=1}^{N} \\sum_{k=1}^{K_l} w_{(j,k)}) \\mid j \\in I_\\text{strat}\\}$\n\\EndWhile\n\\State \\Return $\\left((x_i, w_i)_{i=1,\\ldots,N}, \\hat{Z} = \\frac{1}{N} \\sum_{i=1}^N w_i \\right)$\n\\end{algorithmic}\n\\end{algorithm}"

Algorithm 4 Sequential Monte Carlo Transformer Steering

Input: NN (# particles), KK (factor), Feynman-Kac Transformer model {s0,{Mt}t1,{Gt}t1}\{s_0, \{M_t\}_{t \geq 1}, \{G_t\}_{t \geq 1}\}

Output: Weighted particle approximation {(xi,wi)}i=1,,N\{(x_i, w_i)\}_{i=1,\ldots,N} of the posterior P\mathbb{P}

Output: Unbiased estimate Z^\hat{Z} of the partition function Z=EM[t=1TGt(st,st1,fθ)]Z = \mathbb{E}_\mathbb{M}[\prod_{t=1}^T G_t(s_t, s_{t-1}, f_\theta)]

Initialize fθCachedTransformer()f_\theta \gets \texttt{CachedTransformer}()

Initialize (xi,wi)(s0,1)(x_i, w_i) \gets (s_0, 1) for i=1,,Ni = 1, \ldots, N

Initialize t1t \gets 1

while xi∉Fx_i \not\in \mathcal{F} for some i{1,,N}i \in \{1, \ldots, N\} do

KiK(11F(xi))+1F(xi)K_i \gets K (1 - \mathbb{1}_{\mathcal{F}}(x_i)) + \mathbb{1}_{\mathcal{F}}(x_i) for i=1,,Ni = 1, \ldots, N

Ni=1NKiN' \gets \sum_{i=1}^N K_i

for i{1,,N}i \in \{1, \ldots, N\} do

if xiFx_i \in \mathcal{F} then

Set (xi,1,wi,1)(xi,wiNN)(x_{i,1}, w_{i,1}) \gets (x_i, w_i \cdot \frac{N'}{N})

else

Generate xi,kMt(xi,fθ)x_{i,k} \sim M_t(\cdot \mid x_i, f_\theta) for k=1,,Kk = 1, \ldots, K

Set wi,kwiGt(xi,xi,k,fθ)NKNw_{i,k} \gets w_i \cdot G_t(x_i, x_{i,k}, f_\theta) \cdot \frac{N'}{K N} for k=1,,Kk = 1, \ldots, K

end if

end for

Set normalized weights w^i,kw(i,k)j=1Nl=1Kjw(j,l)\hat{w}_{i,k} \gets \frac{w_{(i,k)}}{\sum_{j=1}^N \sum_{l=1}^{K_j} w_{(j,l)}} for i=1,,Ni = 1, \ldots, N and k=1,,Kik = 1, \ldots, K_i

Set cinf{cR>0i=1Nk=1Ki(1cw^(i,k))>N}c^* \gets \inf\{c \in \mathbb{R}_{> 0} \mid \sum_{i=1}^N \sum_{k=1}^{K_i} (\mathbb{1} \wedge c \hat{w}_{(i,k)}) > N\}

Set (Idet,Istoch,Istrat)({(i,k)cw^i,k1},{(i,k)cw^i,k<1},{})(I_\text{det}, I_\text{stoch}, I_\text{strat}) \gets (\{(i,k) \mid c^{*} \hat{w}_{i,k} \geq 1\}, \{(i,k) \mid c^{*} \cdot \hat{w}_{i,k} < 1\}, \{\})

Set αiIstochw^iIdet\alpha \gets \frac{\sum_{i \in I_\text{stoch}} \hat{w}_i}{|I_\text{det}|} and generate UUniform([0,α])U \sim \text{Uniform}([0, \alpha])

for iIstochi \in I_\text{stoch} do

Set UUw^iU \gets U - \hat{w}_i

if U<0U < 0 then

Set IstratIstrat{i}I_\text{strat} \gets I_\text{strat} \cup \{i\}

Set UU+αU \gets U + \alpha

end if

end for

Set particles {(xi,wi)}i=1,,Idet{(xj,wjNN)jIdet}\{(x_i, w_i)\}_{i=1,\ldots,|I_\text{det}|} \gets \{(x_j, w_j \cdot \frac{N}{N'}) \mid j \in I_\text{det}\}

Set particles {(xi,wi)}i=Idet+1,,N{(xj,NcNl=1Nk=1Klw(j,k))jIstrat}\{(x_i, w_i)\}_{i=|I_\text{det}|+1,\ldots,N} \gets \{(x_j, \frac{N}{c^* N'} \sum_{l=1}^{N} \sum_{k=1}^{K_l} w_{(j,k)}) \mid j \in I_\text{strat}\}

end while

return ((xi,wi)i=1,,N,Z^=1Ni=1Nwi)\left((x_i, w_i)_{i=1,\ldots,N}, \hat{Z} = \frac{1}{N} \sum_{i=1}^N w_i \right)

Remarque

  1. Gloeckle et al. (2024) employs n=4n=4. The order of the forward and backward in a n-token prediction model with n=4n=4 heads of the shared trunk works as follow:

    z = model.shared(x)
    d = z.detach()
    d.requires_grad = False
     
    for i in range(n):
      p = model.heads[i](d)
      loss(p, y[i]).backward()
    z.backward()