Source: MarkTechPost
In this tutorial, we implement xFormers: a practical toolkit for building fast, memory-efficient Transformer models on GPUs. We begin by validating memory-efficient attention against a standard attention implementation, then compare their speed and memory consumption across different sequence lengths. We then examine causal masking, packed variable-length sequences, grouped-query attention, and custom ALiBi positional biases. Finally, we combine these techniques into a trainable GPT-style model that uses xFormers attention, SwiGLU feed-forward layers, and automatic mixed-precision training.
Setting Up xFormers and Validating Memory-Efficient Attention
import subprocess, sys def _pip(*a): subprocess.run([sys.executable, "-m", "pip", "install", *a], check=False) try: import xformers except Exception: _pip("-q", "-U", "xformers") import math, time import torch, torch.nn as nn, torch.nn.functional as F import xformers, xformers.ops as xops from xformers.ops import fmha ab = fmha.attn_bias assert torch.cuda.is_available(), ( "No GPU detected. In Colab: Runtime → Change runtime type → GPU, then re-run.") device = "cuda" torch.manual_seed(0) print("torch :", torch.__version__) print("xformers :", xformers.__version__) print("GPU :", torch.cuda.get_device_name(0)) print("n--- xformers.info (which kernels are built/available) ---") try: subprocess.run([sys.executable, "-m", "xformers.info"], check=False) except Exception as e: print("xformers.info unavailable:", e) def cuda_time(fn, iters=20, warmup=5): for _ in range(warmup): fn() torch.cuda.synchronize() s, e = (torch.cuda.Event(enable_timing=True) for _ in range(2)) s.record() for _ in range(iters): fn() e.record(); torch.cuda.synchronize() return s.elapsed_time(e) / iters def peak_mem_mb(fn): torch.cuda.empty_cache(); torch.cuda.reset_peak_memory_stats() fn(); torch.cuda.synchronize() return torch.cuda.max_memory_allocated() / 1e6 def vanilla_attention(q, k, v, causal=False): """Reference attention that MATERIALIZES the [B,H,M,M] score matrix. Inputs are xformers-layout [B, M, H, K].""" q, k, v = (t.transpose(1, 2).float() for t in (q, k, v)) scores = (q @ k.transpose(-2, -1)) / math.sqrt(q.shape[-1]) if causal: M = scores.shape[-1] m = torch.triu(torch.ones(M, M, device=q.device, dtype=torch.bool), 1) scores = scores.masked_fill(m, float("-inf")) out = scores.softmax(-1) @ v return out.transpose(1, 2) print("n" + "="*70 + "n1. memory_efficient_attention basics + correctnessn" + "="*70) B, M, H, K = 2, 512, 8, 64 q, k, v = (torch.randn(B, M, H, K, device=device, dtype=torch.float16) for _ in range(3)) out_xf = xops.memory_efficient_attention(q, k, v) out_ref = vanilla_attention(q, k, v).half() print("output shape :", tuple(out_xf.shape), "(layout B, M, H, K)") print("max abs diff vs ref : {:.2e}".format((out_xf - out_ref).abs().max().item())) print("-> it's EXACT attention (fp16 rounding only), just computed without") print(" ever storing the full MxM score matrix.")
We install and import xFormers, verify GPU availability, and inspect the attention kernels supported by the environment. We define helper functions for measuring CUDA execution time and peak memory consumption. We then validate memory-efficient attention against standard attention to confirm that both produce results that closely match each other.
Benchmarking Memory and Speed Against Naive Causal Attention
print("n" + "="*70 + "n2. Memory & speed vs naive attention (fwd+bwd)n" + "="*70) print(f"{'seqlen':>8} | {'naive MB':>10} | {'xformers MB':>12} | {'naive ms':>9} | {'xf ms':>7}") print("-"*60) for M in [512, 1024, 2048, 4096]: q, k, v = (torch.randn(2, M, 8, 64, device=device, dtype=torch.float16, requires_grad=True) for _ in range(3)) def run_xf(): o = xops.memory_efficient_attention(q, k, v); o.sum().backward() def run_naive(): o = vanilla_attention(q, k, v); o.sum().backward() try: nm = peak_mem_mb(run_naive); nt = cuda_time(run_naive, 8, 3) except RuntimeError: nm, nt = float("nan"), float("nan"); torch.cuda.empty_cache() xm = peak_mem_mb(run_xf); xt = cuda_time(run_xf, 8, 3) print(f"{M:>8} | {nm:>10.0f} | {xm:>12.0f} | {nt:>9.2f} | {xt:>7.2f}") print("-> naive memory grows ~4x per doubling of M (it stores BxHxMxM);") print(" xformers grows ~linearly and stays fast.") print("n" + "="*70 + "n3. Causal attention via LowerTriangularMaskn" + "="*70) B, M, H, K = 2, 256, 8, 64 q, k, v = (torch.randn(B, M, H, K, device=device, dtype=torch.float16) for _ in range(3)) out_causal = xops.memory_efficient_attention(q, k, v, attn_bias=ab.LowerTriangularMask()) ref_causal = vanilla_attention(q, k, v, causal=True).half() print("causal max abs diff : {:.2e}".format((out_causal - ref_causal).abs().max().item())) print("-> the mask is implicit; no MxM boolean tensor is allocated.")
We benchmark naive attention and xFormers attention across progressively longer sequences using forward and backward passes. We compare their execution times and peak GPU memory usage to observe how xFormers avoids quadratic memory growth. We also apply an implicit lower-triangular mask and verify causal attention against the reference implementation.
Packing Variable-Length Sequences and Running Grouped-Query Attention
print("n" + "="*70 + "n4. Variable-length packed batch — no padding wasten" + "="*70) seqlens = [37, 120, 8, 200] total = sum(seqlens) H, K = 8, 64 q = torch.randn(1, total, H, K, device=device, dtype=torch.float16) k = torch.randn(1, total, H, K, device=device, dtype=torch.float16) v = torch.randn(1, total, H, K, device=device, dtype=torch.float16) try: bias = ab.BlockDiagonalMask.from_seqlens(seqlens) out_packed = xops.memory_efficient_attention(q, k, v, attn_bias=bias) s0 = seqlens[0] ref0 = vanilla_attention(q[:, :s0], k[:, :s0], v[:, :s0]).half() print("packed shape :", tuple(out_packed.shape), "(all", total, "tokens, no pad)") print("segment-0 max diff : {:.2e}".format((out_packed[:, :s0] - ref0).abs().max().item())) cbias = ab.BlockDiagonalCausalMask.from_seqlens(seqlens) _ = xops.memory_efficient_attention(q, k, v, attn_bias=cbias) print("-> also did a packed CAUSAL pass. This is how vLLM-style engines") print(" batch requests of different lengths with zero padding overhead.") splits = bias.split(out_packed) print("recovered segments :", [tuple(t.shape) for t in splits]) except Exception as e: print("BlockDiagonalMask path skipped on this version/backend:", repr(e)) print("n" + "="*70 + "n5. Grouped-query attention (5-D BMGHK layout)n" + "="*70) B, M, K = 2, 256, 64 n_q_heads, n_kv_heads = 8, 2 G, Hq = n_kv_heads, n_q_heads // n_kv_heads try: qg = torch.randn(B, M, G, Hq, K, device=device, dtype=torch.float16) kg = torch.randn(B, M, G, 1, K, device=device, dtype=torch.float16) vg = torch.randn(B, M, G, 1, K, device=device, dtype=torch.float16) out_gqa = xops.memory_efficient_attention(qg, kg, vg) print("GQA output shape :", tuple(out_gqa.shape), "= [B, M, G, Hq, K]") print(f"-> {n_q_heads} query heads, only {n_kv_heads} KV heads: smaller KV-cache,") print(" which is exactly what Llama-/Mistral-class models use at inference.") except Exception as e: print("GQA 5-D path skipped on this version/backend:", repr(e))
We concatenate variable-length sequences and use BlockDiagonalMask to prevent attention from crossing sequence boundaries without padding. We recover the individual outputs and also perform packed causal attention for decoder-style workloads. We then demonstrate grouped-query attention, where multiple query heads share fewer key-value heads to reduce KV-cache requirements.
Adding a Custom ALiBi Additive Positional Bias
print("n" + "="*70 + "n6. Custom ALiBi additive biasn" + "="*70) B, M, H, K = 1, 128, 8, 64 q, k, v = (torch.randn(B, M, H, K, device=device, dtype=torch.float16) for _ in range(3)) try: slopes = (2.0 ** (-8.0 / H)) ** torch.arange(1, H + 1, device=device) pos = torch.arange(M, device=device) rel = (pos[None, :] - pos[:, None]).clamp(max=0).float() alibi = slopes[:, None, None] * rel[None] alibi = alibi[None].expand(B, H, M, M).to(torch.float16).contiguous() causal = torch.triu(torch.ones(M, M, device=device, dtype=torch.bool), 1) alibi = alibi.masked_fill(causal[None, None], float("-inf")) out_alibi = xops.memory_efficient_attention(q, k, v, attn_bias=alibi) print("ALiBi output shape :", tuple(out_alibi.shape)) print("-> any per-(head,query,key) additive bias works the same way.") except Exception as e: print("Custom-bias path skipped (some backends restrict bias shapes):", repr(e))
We construct a custom ALiBi tensor that applies a different linear positional penalty to each attention head. We combine this additive bias with a causal mask so that tokens attend only to valid previous positions. We pass the resulting bias directly to xFormers attention and verify the shape of its output.
Training a GPT Block with xFormers Attention and SwiGLU
print("n" + "="*70 + "n7. Train a small GPT block (xformers attn + SwiGLU)n" + "="*70) def make_swiglu(d, hidden): """Fused xformers SwiGLU if available, else a clean manual fallback.""" try: m = xops.SwiGLU(in_features=d, hidden_features=hidden, out_features=d, bias=True) return m, "fused xops.SwiGLU" except Exception: class SwiGLU(nn.Module): def __init__(s): super().__init__() s.w12 = nn.Linear(d, 2 * hidden); s.w3 = nn.Linear(hidden, d) def forward(s, x): a, b = s.w12(x).chunk(2, -1) return s.w3(F.silu(a) * b) return SwiGLU(), "manual SwiGLU fallback" class Block(nn.Module): def __init__(self, d, n_heads, mlp_mult=4): super().__init__() self.h, self.k = n_heads, d // n_heads self.n1, self.n2 = nn.LayerNorm(d), nn.LayerNorm(d) self.qkv, self.proj = nn.Linear(d, 3 * d), nn.Linear(d, d) self.ff, self.ff_kind = make_swiglu(d, mlp_mult * d) def forward(self, x): B, M, d = x.shape qkv = self.qkv(self.n1(x)).reshape(B, M, 3, self.h, self.k) q, kk, vv = qkv.unbind(2) a = xops.memory_efficient_attention(q, kk, vv, attn_bias=ab.LowerTriangularMask()) x = x + self.proj(a.reshape(B, M, d)) return x + self.ff(self.n2(x)) class TinyGPT(nn.Module): def __init__(self, vocab, d=128, n_layers=3, n_heads=8, maxlen=64): super().__init__() self.tok = nn.Embedding(vocab, d); self.pos = nn.Embedding(maxlen, d) self.blocks = nn.ModuleList(Block(d, n_heads) for _ in range(n_layers)) self.nf, self.head = nn.LayerNorm(d), nn.Linear(d, vocab) def forward(self, idx): B, M = idx.shape x = self.tok(idx) + self.pos(torch.arange(M, device=idx.device))[None] for b in self.blocks: x = b(x) return self.head(self.nf(x)) VOCAB, SEQ = 64, 64 def make_batch(B): start = torch.randint(0, VOCAB, (B, 1), device=device) return (start + torch.arange(SEQ, device=device)[None]) % VOCAB model = TinyGPT(VOCAB).to(device) print("FFN type :", model.blocks[0].ff_kind) opt = torch.optim.AdamW(model.parameters(), lr=3e-3) scaler = torch.amp.GradScaler("cuda") for step in range(400): seq = make_batch(64); inp, tgt = seq[:, :-1], seq[:, 1:] with torch.autocast("cuda", dtype=torch.float16): logits = model(inp) loss = F.cross_entropy(logits.reshape(-1, VOCAB), tgt.reshape(-1)) opt.zero_grad(); scaler.scale(loss).backward(); scaler.step(opt); scaler.update() if step % 80 == 0 or step == 399: acc = (logits.argmax(-1) == tgt).float().mean().item() print(f"step {step:4d} | loss {loss.item():.4f} | next-token acc {acc*100:5.1f}%") print("-> a full causal transformer running on memory-efficient attention,") print(" trained end-to-end with AMP. Swap in real data/tokenizer to scale up.") print("nDone. Sections 1-3 are core; 4-6 are the advanced bits worth keeping.")
We build a compact GPT-style Transformer using causal xFormers attention, residual connections, normalization, and SwiGLU feed-forward layers. We train the model with automatic mixed precision on a synthetic next-token prediction task that counts upward modulo the vocabulary size. We monitor its loss and accuracy to confirm that the complete memory-efficient Transformer learns successfully end-to-end.
Conclusion
In conclusion, we developed a practical understanding of how xFormers improves Transformer efficiency without changing the fundamental attention calculation. We saw how memory-efficient kernels reduce the cost of long sequences, while causal masks, packed sequences, grouped-query attention, and additive biases support realistic training and inference workflows. We concluded by integrating these capabilities into a compact GPT model and training it end-to-end, giving us a strong foundation for applying xFormers to larger language models and more demanding datasets.
Check out the Full Codes with Notebook. Also, feel free to follow us on Twitter and don’t forget to join our 150k+ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.
Need to partner with us for promoting your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar etc.? Connect with us
Sana Hassan
Sana Hassan, a consulting intern at Marktechpost and dual-degree student at IIT Madras, is passionate about applying technology and AI to address real-world challenges. With a keen interest in solving practical problems, he brings a fresh perspective to the intersection of AI and real-life solutions.

