1. Introduction
The softmax function is a fundamental component in machine learning, particularly in classification tasks and attention mechanisms. However, its conventional implementation requires processing all input elements simultaneously, which becomes problematic for large inputs or streaming scenarios. The online softmax addresses this by computing the function incrementally.
2. Standard softmax
Given an input vector , the softmax function is define as:
This can be computed in three step:
1) compute all exponentials: for each
2) compute the sum:
3) compute each output:
Numerical stability
In practice, we use a numerically stable version by subtracting the maximum value:
This prevents overflow while producing identical results due to the shift invariance of softmax.
3. Online softmax derivation
The online version computes softmax incrementally, processing one element at a time while maintaining necessary statistics. We derive it by decomposing the computation into running updates.
3.1 Key variables
Define two running quantities:
- - running maximum
- - running sum of exponentials
3.2 Recursive updates
When a new element arrives, we update:
For the sum, we need to adjust previous terms for the new maximum:
3.3 Final computation
After processing all elements, the softmax for any is:
3.4 Algorithm

class OnlineSoftmax: def __init__(self): self.running_max = float('-inf') # m_k self.running_sum = 0.0 # S_k self.processed_values = [] # Store processed x_i - m_k def process(self, x_new): # Update running maximum old_max = self.running_max self.running_max = max(self.running_max, x_new) # Adjust running sum if self.running_max != old_max: # Rescale previous sum self.running_sum *= math.exp(old_max - self.running_max) # Rescale stored values self.processed_values = [ val - (self.running_max - old_max) for val in self.processed_values ] # Add new value processed_x = x_new - self.running_max self.processed_values.append(processed_x) self.running_sum += math.exp(processed_x) def get_softmax(self): # Return softmax for all processed values return [ math.exp(val) / self.running_sum for val in self.processed_values ]
# create an instance online_softmax = OnlineSoftmax() # process input one by one x = [1.0, 2.0, 3.0, 4.0] for xi in x: online_softmax.process(xi) # get all the softmax result = online_softmax.get_softmax()
4. Complexity analysis
Both standard and online softmax have time complexity, but online softmax:
- process elements sequentially
- requires only additional memory (for and )
- enables streaming applications
5. Online softmax for chunked attention in LLMs
In large language models (LLMs), processing long sequences with attention mechanisms requires computing softmax over prohibitively large matrices. We present a chunked computation approach using online softmax.
5.1 Problem formulation
For input sequence , divided into chunks where each chunk has size , the attention computation requires:
The quadratic complexity makes full computation infeasible for large .
5.2 Chunked online softmax
1) local statistics: for each chunk :
2) global aggregation:
3) normalization:
5.3 Pseudocode implementation

def chunked_online_softmax(Q, K, V, chunk_size): # Initialize N = K.shape[0] # sequence length num_chunks = ceil(N / chunk_size) global_max = float('-inf') global_sum = 0 chunk_stats = [] # Step 1: Compute local statistics for each chunk for i in range(num_chunks): start_idx = i * chunk_size end_idx = min((i + 1) * chunk_size, N) # Get current chunk of K K_chunk = K[start_idx:end_idx] # Compute attention scores for chunk scores = matmul(Q, K_chunk.transpose()) / sqrt(K.shape[1]) # Local statistics local_max = max(scores) local_sum = sum(exp(scores - local_max)) # Store chunk information chunk_stats.append({ 'scores': scores, 'local_max': local_max, 'local_sum': local_sum }) # Update global maximum global_max = max(global_max, local_max) # Step 2: Global aggregation for stats in chunk_stats: # Adjust sums based on global maximum global_sum += exp(stats['local_max'] - global_max) * stats['local_sum'] # Step 3: Compute final attention and output output = zeros_like(Q) for i, stats in enumerate(chunk_stats): start_idx = i * chunk_size end_idx = min((i + 1) * chunk_size, N) # Compute normalized attention weights adjusted_scores = exp(stats['scores'] - global_max) attention_weights = adjusted_scores / global_sum # Compute chunk output V_chunk = V[start_idx:end_idx] output += matmul(attention_weights, V_chunk) return output
5.4 Key properties
- memory: reduces peak memory from to
- numerical stability: proper handling of exponentials via online max
- implementation: used in FlashAttention and Efficient Memory Attention (EMA)