Find the code for this project here : https://github.com/Information-Theory-Group/Adaptive-Sampling-Networks.git

Authors : Navneel Singhal, Saurav Panigrahi

In the deployment of Large Language Models (LLMs), the decoding strategy, how we select the next token given a probability distribution is often treated as a hyperparameter search rather than a learnable component of the system. We rely on heuristics like Temperature scaling (T), Top-k, Top-p (Nucleus), and Min-p to balance exploration and exploitation.

However, these heuristics are rigid. They apply the same logic regardless of the model's uncertainty or the semantic context. The Adaptive Sampling Network (ASN) project introduces a paradigm shift: replacing static heuristics with a lightweight, permutation-equivariant neural network that transforms raw logits into a task-optimised distribution.

This post explores the mathematical foundations of ASNs, focusing on Permutation Equivariance, the Linear Attention mechanism over probability mass, and the optimisation objectives used to distill complex sampling behaviours.

1. Problem Formulation

Let M be a base LLM that outputs a set of logits $\mathbf{z} \in \mathbb{R}^V$ over a vocabulary \mathcal{V} of size $V$, given a context $\mathbf{x}$. The standard softmax probability distribution is $P(\cdot|\mathbf{x}) = \text{softmax}(\mathbf{z})$.

Traditional sampling applies a function $H: \mathbb{R}^V \to \mathbb{R}^V$ (the heuristic) followed by sampling. For example, Top-$k$ sets all but the $k$ largest logits to $-\infty$.

An Adaptive Sampling Network, denoted $S_\theta$, is a parameterized function that maps the raw logits to a modified set of logits:

$$ \mathbf{z}' = S_\theta(\mathbf{z}) $$

The goal is to learn parameters $\theta$ such that sampling from $\text{softmax}(\mathbf{z}')$ maximizes a specific utility function (quality, diversity, or correctness) or mimics a complex oracle heuristic.

2. Permutation Equivariance

A critical theoretical constraint imposed on $S_\theta$ is Permutation Equivariance. The sampler should operate on the distribution of probabilities, not the specific identity of the tokens. If the base model assigns high probability to "dog" or "cat", the sampler's decision to truncate the tail should depend only on the probability values, not the string "dog".

Formally, let $\Pi$ be any $V \times V$ permutation matrix. $S_\theta$ is permutation equivariant if:

$$ S_\theta(\Pi \mathbf{z}) = \Pi S_\theta(\mathbf{z}) $$

This constraint prevents the sampler from memorising token-specific semantics (which is the job of the base LLM) and forces it to learn universal distribution transformations.

image.png

Architectural Implementation