import math
import time
from typing import Optional, Tuple
import jax
import jax.numpy as jnp
from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention
@jax.jit
def scaled_dot_product_attention(Q, K, V):
"""
Calculates the attention weights and returns the output after applying these weights to V.
Args:
- Q, K, V are the query, key, and value tensors respectively.
"""
# Compute the dot product, scaled by the square root of the depth of K
matmul_qk = jnp.einsum('tih,tjh->tij', Q, K) # [seqlen, n_head, seqlen]
dim_k = K.shape[-1]
scaled_attention_logits = matmul_qk / jnp.sqrt(dim_k)
# Apply softmax to get the weights on the values
weights = jax.nn.softmax(scaled_attention_logits, axis=-1) # [seqlen, n_head, seqlen]
# Apply the weights to the values
output = jnp.einsum('tij,tjh->tih', weights, V) # [seqlen, n_head, embed_dim]
return output
@jax.jit
def vanilla_attn(Q, K, V):
"""
Perform multi-head self-attention on the inputs Q, K, V.
Args:
- Q, K, V: Input tensors with dimensions [seqlen, n_head, embed_dim].
"""
return jax.vmap(scaled_dot_product_attention)(Q, K, V)
def generate_data(key: int, batch_size: int, seq_len: int, embed_dim: int, num_heads: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray]]:
"""Utility function to generate dummy data for attention mechanism."""
scale = 1.0 / math.sqrt(embed_dim)
q = jnp.array(jax.random.normal(jax.random.PRNGKey(key), (batch_size, num_heads, seq_len, embed_dim)) * scale)
k = jnp.array(jax.random.normal(jax.random.PRNGKey(key + 1), (batch_size, num_heads, seq_len, embed_dim)) * scale)
v = jnp.array(jax.random.normal(jax.random.PRNGKey(key + 2), (batch_size, num_heads, seq_len, embed_dim)) * scale)
return q, k, v
@jax.jit
def flash_attn(Q, K, V):
return flash_attention(Q, K, V)
def test_performance():
n: int = 50
batch_size, seq_len, embed_dim, num_heads = 256, 512, 128, 8
q, k, v = generate_data(1, batch_size, seq_len, embed_dim, num_heads)
# warmup
_ = (
vanilla_attn(q, k, v).block_until_ready(),
flash_attn(q, k, v).block_until_ready(),
)
# Testing Equinox's MHSA
start_time = time.time()
for i in range(n):
q, k, v = generate_data(i, batch_size, seq_len, embed_dim, num_heads)
q, k, v = q.block_until_ready(), k.block_until_ready(), v.block_until_ready()
_ = vanilla_attn(q, k, v).block_until_ready()
equinox_time = time.time() - start_time
# Testing Pallas's flash_attention
start_time = time.time()
for i in range(n):
q, k, v = generate_data(i, batch_size, seq_len, embed_dim, num_heads)
q, k, v = q.block_until_ready(), k.block_until_ready(), v.block_until_ready()
_ = flash_attn(q, k, v).block_until_ready()
pallas_time = time.time() - start_time
print(f"Average Naive MHSA time: {equinox_time / n:.5f}s")
print(f"Average Pallas flash_attention time: {pallas_time / n:.5f}s")
if __name__ == "__main__":
test_performance()