profile pic
⌘ '
raccourcis clavier

See also: LLMs, embedding, visualisation from Brendan Bycroft

A multi-layer perceptron (MLP) architecture built on top of a multi-head attention mechanism (Vaswani et al., 2023) to signal high entropy tokens to be amplified and less important tokens to be diminished.

ELI5: Mom often creates a food list consists of nn of items to buy. Your job is to guess what the last item on this list would be.

Most implementations are autoregressive. Most major SOTA are decoder-only, as encoder-decoder models has lack behind due to their expensive encoding phase.

state-space models which address transformers’ efficiency issues in attention layers within information-dense data

memory limitations.

excerpt from arxiv

inference.

Either compute-bound (batch inference, saturated usage) or memory-bound (latency)

speculative decoding memory-bound (to saturate FLOPs)

next-token prediction.

Sampling: we essentially look forward K-tokens, and then we sample from the distribution of the next token.

Byte-Latent Transformer

idea: learn from raw-bytes and skip tokenizer/detokenizer protocol.

Feynman-Kac

Let V\mathcal{V} be the vocab of given transformers model, and S=V\mathcal{S} = \mathcal{V}^{*} the set of multi-token strings. Assume V\mathcal{V} contains token EOS and write FS\mathcal{F} \subseteq \mathcal{S} for the set of EOS-terminated strings.

Feynman-Kac Transformer model

is a tuple (s0,{Mt}t1,{Gt}t1)(s_{0}, \{M_t\}_{t\ge 1}, \{G_t\}_{t\ge 1}) where:

  • s0Ss_{0} \in \mathcal{S} is an initial state, which will take as empty string ϵ\epsilon
  • Mt(stst1,fθ)M_t(s_t \mid s_{t-1}, f_\theta) is a Markov kernel from st1Fcs_{t-1} \in \mathcal{F}^c to stSs_t \in \mathcal{S}, parameterised by a transformer network fθ:FcRVf_\theta: \mathcal{F}^c \to \mathbb{R}^{\mid \mathcal{V} \mid} mapping non-EOS-terminated strings to vectors of logits
  • Gt(st1,st,fθ)G_t(s_{t-1}, s_t, f_\theta) is a potential function, mapping a pair (st1,st)Fc×S(s_{t-1}, s_t) \in \mathcal{F}^c \times \mathcal{S} to a real-valued non-negative score.

Goal: generate from distribution P\mathbb{P} that reweights Markove chain M\mathbb{M} by potential functions GtG_t. We define step-t filtering posteriors:

Pt(st)=EM[i=1tTGi(Si1,Si,fθ)[St=st]]EM[i=1tTGi(Si1,Si,fθ)]P_t(s_t) = \frac{\mathbb{E}_\mathbb{M} \left[ \prod_{i=1}^{t \wedge T} G_i(S_{i-1}, S_i, f_\theta) \cdot [S_t = s_t] \right]}{\mathbb{E}_\mathbb{M} \left[ \prod_{i=1}^{t \wedge T} G_i(S_{i-1}, S_i, f_\theta) \right]}

Given that TT is mostly finite we can then define overall posterior P(s)=limtPt(s)\mathbb{P}(s) = \lim_{t \to \infty} \mathbb{P}_t(s) (Lew et al., 2023, p. see 2.2 for examples)

"\\begin{algorithm}\n\\caption{Sequential Monte Carlo Transformer Steering}\n\\begin{algorithmic}\n\\State \\textbf{Input:} $N$ (\\# particles), $K$ (factor), Feynman-Kac Transformer model $\\{s_0, \\{M_t\\}_{t \\geq 1}, \\{G_t\\}_{t \\geq 1}\\}$\n\\State \\textbf{Output:} Weighted particle approximation $\\{(x_i, w_i)\\}_{i=1,\\ldots,N}$ of the posterior $\\mathbb{P}$ \\\\\n\\State \\textbf{Output:} Unbiased estimate $\\hat{Z}$ of the partition function $Z = \\mathbb{E}_\\mathbb{M}[\\prod_{t=1}^T G_t(s_t, s_{t-1}, f_\\theta)]$ \\\\\n\\State Initialize $f_\\theta \\gets \\texttt{CachedTransformer}()$\n\\State Initialize $(x_i, w_i) \\gets (s_0, 1)$ for $i = 1, \\ldots, N$\n\\State Initialize $t \\gets 1$\n\\While{$x_i \\not\\in \\mathcal{F}$ for some $i \\in \\{1, \\ldots, N\\}$}\n \\State $K_i \\gets K (1 - \\mathbb{1}_{\\mathcal{F}}(x_i)) + \\mathbb{1}_{\\mathcal{F}}(x_i)$ for $i = 1, \\ldots, N$\n \\State $N' \\gets \\sum_{i=1}^N K_i$\n \\For{$i \\in \\{1, \\ldots, N\\}$}\n \\If{$x_i \\in \\mathcal{F}$}\n \\State Set $(x_{i,1}, w_{i,1}) \\gets (x_i, w_i \\cdot \\frac{N'}{N})$\n \\Else\n \\State Generate $x_{i,k} \\sim M_t(\\cdot \\mid x_i, f_\\theta)$ for $k = 1, \\ldots, K$\n \\State Set $w_{i,k} \\gets w_i \\cdot G_t(x_i, x_{i,k}, f_\\theta) \\cdot \\frac{N'}{K N}$ for $k = 1, \\ldots, K$\n \\EndIf\n \\EndFor\n \\State Set normalized weights $\\hat{w}_{i,k} \\gets \\frac{w_{(i,k)}}{\\sum_{j=1}^N \\sum_{l=1}^{K_j} w_{(j,l)}}$ for $i = 1, \\ldots, N$ and $k = 1, \\ldots, K_i$\n \\State Set $c^* \\gets \\inf\\{c \\in \\mathbb{R}_{> 0} \\mid \\sum_{i=1}^N \\sum_{k=1}^{K_i} (\\mathbb{1} \\wedge c \\hat{w}_{(i,k)}) > N\\}$\n \\State Set $(I_\\text{det}, I_\\text{stoch}, I_\\text{strat}) \\gets (\\{(i,k) \\mid c^{*} \\hat{w}_{i,k} \\geq 1\\}, \\{(i,k) \\mid c^{*} \\cdot \\hat{w}_{i,k} < 1\\}, \\{\\})$\n \\State Set $\\alpha \\gets \\frac{\\sum_{i \\in I_\\text{stoch}} \\hat{w}_i}{|I_\\text{det}|}$ and generate $U \\sim \\text{Uniform}([0, \\alpha])$\n \\For{$i \\in I_\\text{stoch}$}\n \\State Set $U \\gets U - \\hat{w}_i$\n \\If{$U < 0$}\n \\State Set $I_\\text{strat} \\gets I_\\text{strat} \\cup \\{i\\}$\n \\State Set $U \\gets U + \\alpha$\n \\EndIf\n \\EndFor\n \\State Set particles $\\{(x_i, w_i)\\}_{i=1,\\ldots,|I_\\text{det}|} \\gets \\{(x_j, w_j \\cdot \\frac{N}{N'}) \\mid j \\in I_\\text{det}\\}$\n \\State Set particles $\\{(x_i, w_i)\\}_{i=|I_\\text{det}|+1,\\ldots,N} \\gets \\{(x_j, \\frac{N}{c^* N'} \\sum_{l=1}^{N} \\sum_{k=1}^{K_l} w_{(j,k)}) \\mid j \\in I_\\text{strat}\\}$\n\\EndWhile\n\\State \\Return $\\left((x_i, w_i)_{i=1,\\ldots,N}, \\hat{Z} = \\frac{1}{N} \\sum_{i=1}^N w_i \\right)$\n\\end{algorithmic}\n\\end{algorithm}"

Algorithm 3 Sequential Monte Carlo Transformer Steering

Input: NN (# particles), KK (factor), Feynman-Kac Transformer model {s0,{Mt}t1,{Gt}t1}\{s_0, \{M_t\}_{t \geq 1}, \{G_t\}_{t \geq 1}\}

Output: Weighted particle approximation {(xi,wi)}i=1,,N\{(x_i, w_i)\}_{i=1,\ldots,N} of the posterior P\mathbb{P}

Output: Unbiased estimate Z^\hat{Z} of the partition function Z=EM[t=1TGt(st,st1,fθ)]Z = \mathbb{E}_\mathbb{M}[\prod_{t=1}^T G_t(s_t, s_{t-1}, f_\theta)]

Initialize fθCachedTransformer()f_\theta \gets \texttt{CachedTransformer}()

Initialize (xi,wi)(s0,1)(x_i, w_i) \gets (s_0, 1) for i=1,,Ni = 1, \ldots, N

Initialize t1t \gets 1

while xi∉Fx_i \not\in \mathcal{F} for some i{1,,N}i \in \{1, \ldots, N\} do

KiK(11F(xi))+1F(xi)K_i \gets K (1 - \mathbb{1}_{\mathcal{F}}(x_i)) + \mathbb{1}_{\mathcal{F}}(x_i) for i=1,,Ni = 1, \ldots, N

Ni=1NKiN' \gets \sum_{i=1}^N K_i

for i{1,,N}i \in \{1, \ldots, N\} do

if xiFx_i \in \mathcal{F} then

Set (xi,1,wi,1)(xi,wiNN)(x_{i,1}, w_{i,1}) \gets (x_i, w_i \cdot \frac{N'}{N})

else

Generate xi,kMt(xi,fθ)x_{i,k} \sim M_t(\cdot \mid x_i, f_\theta) for k=1,,Kk = 1, \ldots, K

Set wi,kwiGt(xi,xi,k,fθ)NKNw_{i,k} \gets w_i \cdot G_t(x_i, x_{i,k}, f_\theta) \cdot \frac{N'}{K N} for k=1,,Kk = 1, \ldots, K

end if

end for

Set normalized weights w^i,kw(i,k)j=1Nl=1Kjw(j,l)\hat{w}_{i,k} \gets \frac{w_{(i,k)}}{\sum_{j=1}^N \sum_{l=1}^{K_j} w_{(j,l)}} for i=1,,Ni = 1, \ldots, N and k=1,,Kik = 1, \ldots, K_i

Set cinf{cR>0i=1Nk=1Ki(1cw^(i,k))>N}c^* \gets \inf\{c \in \mathbb{R}_{> 0} \mid \sum_{i=1}^N \sum_{k=1}^{K_i} (\mathbb{1} \wedge c \hat{w}_{(i,k)}) > N\}

Set (Idet,Istoch,Istrat)({(i,k)cw^i,k1},{(i,k)cw^i,k<1},{})(I_\text{det}, I_\text{stoch}, I_\text{strat}) \gets (\{(i,k) \mid c^{*} \hat{w}_{i,k} \geq 1\}, \{(i,k) \mid c^{*} \cdot \hat{w}_{i,k} < 1\}, \{\})

Set αiIstochw^iIdet\alpha \gets \frac{\sum_{i \in I_\text{stoch}} \hat{w}_i}{|I_\text{det}|} and generate UUniform([0,α])U \sim \text{Uniform}([0, \alpha])

for iIstochi \in I_\text{stoch} do

Set UUw^iU \gets U - \hat{w}_i

if U<0U < 0 then

Set IstratIstrat{i}I_\text{strat} \gets I_\text{strat} \cup \{i\}

Set UU+αU \gets U + \alpha

end if

end for

Set particles {(xi,wi)}i=1,,Idet{(xj,wjNN)jIdet}\{(x_i, w_i)\}_{i=1,\ldots,|I_\text{det}|} \gets \{(x_j, w_j \cdot \frac{N}{N'}) \mid j \in I_\text{det}\}

Set particles {(xi,wi)}i=Idet+1,,N{(xj,NcNl=1Nk=1Klw(j,k))jIstrat}\{(x_i, w_i)\}_{i=|I_\text{det}|+1,\ldots,N} \gets \{(x_j, \frac{N}{c^* N'} \sum_{l=1}^{N} \sum_{k=1}^{K_l} w_{(j,k)}) \mid j \in I_\text{strat}\}

end while

return ((xi,wi)i=1,,N,Z^=1Ni=1Nwi)\left((x_i, w_i)_{i=1,\ldots,N}, \hat{Z} = \frac{1}{N} \sum_{i=1}^N w_i \right)

Bibliographie

  • Lew, A. K., Zhi-Xuan, T., Grand, G., & Mansinghka, V. K. (2023). Sequential Monte Carlo Steering of Large Language Models using Probabilistic Programs. arXiv preprint arXiv:2306.03081 [arxiv]
  • Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., & Polosukhin, I. (2023). Attention Is All You Need. arXiv preprint arXiv:1706.03762 [arxiv]