How to Speed Up Transformer Training Using NVIDIA Apex (FusedAdam, FusedLayerNorm) and Native torch.amp

how-to-speed-up-transformer-training-using-nvidia-apex-(fusedadam,-fusedlayernorm)-and-native-torch.amp

Source: MarkTechPost

In this tutorial, we work through an implementation of NVIDIA Apex, focusing on the components that still matter in modern GPU training workflows. Instead of treating Apex as a general mixed-precision library, we separate the older parts from the still-useful ones and test them directly. We begin by checking the CUDA runtime, building Apex with the required CUDA and C++ extensions, and detecting which fused kernels are actually available in the environment. This matters because a Python-only Apex installation can appear successful while silently missing the high-performance kernels that make Apex useful. After the setup, we benchmark FusedAdam against PyTorch AdamW, compare FusedLayerNorm and FusedRMSNorm with standard normalization layers, and run both legacy apex.amp and modern torch.amp examples. We then bring everything together in a small Transformer training experiment, where we compare a vanilla FP32 PyTorch path with a fused Apex-plus-AMP path to assess the real effect on throughput.

import os, sys, time, subprocess, importlib import torch assert torch.cuda.is_available(), (    "No CUDA GPU found. In Colab: Runtime > Change runtime type > Hardware accelerator = GPU" ) DEV = torch.device("cuda") print(f"[env] torch {torch.__version__} | CUDA {torch.version.cuda} | GPU {torch.cuda.get_device_name(0)}") def _module_present(name: str) -> bool:    try:        importlib.import_module(name)        return True    except Exception:        return False def _build_apex():    print("[apex] building from source with CUDA + C++ extensions "          "(~10-20 min on first run; grab a coffee)...")    subprocess.run([sys.executable, "-m", "pip", "install", "-q", "ninja", "packaging"], check=True)    if not os.path.isdir("apex"):        subprocess.run(["git", "clone", "--depth", "1",                        "https://github.com/NVIDIA/apex"], check=True)    env = os.environ.copy()    env["APEX_CPP_EXT"]        = "1"    env["APEX_CUDA_EXT"]       = "1"    env["MAX_JOBS"]            = "4"    env["NVCC_APPEND_FLAGS"]   = "--threads 4"    cmd = [sys.executable, "-m", "pip", "install", "-v",           "--no-build-isolation", "--no-cache-dir", "./apex"]    proc = subprocess.run(cmd, env=env)    if proc.returncode != 0:        print("[apex] CUDA build failed -> falling back to PYTHON-ONLY install "              "(fused kernels will be unavailable, tutorial still runs).")        subprocess.run([sys.executable, "-m", "pip", "install", "-v",                        "--no-build-isolation", "--no-cache-dir", "./apex"], check=False) if not _module_present("amp_C"):    _build_apex() HAS_AMP_C  = _module_present("amp_C") HAS_FLN    = _module_present("fused_layer_norm_cuda") try:    import apex    from apex.optimizers import FusedAdam    from apex.normalization import FusedLayerNorm    try:        from apex.normalization import FusedRMSNorm        HAS_RMS = True    except Exception:        HAS_RMS = False    from apex import amp    APEX_OK = True except Exception as e:    print(f"[apex] import failed: {e}")    APEX_OK = False print("n[capabilities]") print(f"  apex importable    : {APEX_OK}") print(f"  FusedAdam kernels  : {HAS_AMP_C}") print(f"  FusedLayerNorm krnl: {HAS_FLN}") print(f"  FusedRMSNorm       : {APEX_OK and HAS_RMS}") print("=" * 78) def bench(fn, iters=50, warmup=10):    for _ in range(warmup):        fn()    torch.cuda.synchronize()    t0 = time.perf_counter()    for _ in range(iters):        fn()    torch.cuda.synchronize()    return (time.perf_counter() - t0) / iters * 1e3

We start by preparing the CUDA environment, checking GPU availability, and printing the active PyTorch, CUDA, and GPU details. We then build NVIDIA Apex from source with CUDA and C++ extensions so that the fused kernels can be used directly rather than relying on a limited Python-only installation. We also detect whether FusedAdam, FusedLayerNorm, FusedRMSNorm, and legacy AMP are available, and define a reusable benchmarking helper for subsequent tests.

print("n### SECTION A: FusedAdam vs AdamW ###") def make_many_param_model(n_layers=60, dim=512):    return torch.nn.Sequential(*[torch.nn.Linear(dim, dim) for _ in range(n_layers)]).to(DEV) def opt_step_factory(optimizer, model, dim=512):    x = torch.randn(64, dim, device=DEV)    def step():        optimizer.zero_grad(set_to_none=True)        out = model(x).pow(2).mean()        out.backward()        optimizer.step()    return step m1 = make_many_param_model() torch_adam = torch.optim.AdamW(m1.parameters(), lr=1e-3) ms_torch = bench(opt_step_factory(torch_adam, m1)) print(f"  torch.optim.AdamW : {ms_torch:6.2f} ms / step") if HAS_AMP_C and APEX_OK:    m2 = make_many_param_model()    m2.load_state_dict(m1.state_dict())    fused_adam = FusedAdam(m2.parameters(), lr=1e-3)    ms_fused = bench(opt_step_factory(fused_adam, m2))    print(f"  apex.FusedAdam    : {ms_fused:6.2f} ms / step   "          f"(~{ms_torch/ms_fused:0.2f}x on optimizer-bound step)") else:    print("  apex.FusedAdam    : SKIPPED (cuda ext not built)")

We benchmark PyTorch AdamW against Apex FusedAdam using a model with many linear layers to make optimizer overhead visible. We run the same optimizer step pattern for both methods, so the comparison focuses on update speed rather than model differences. We then report the step time and speedup to assess whether the fused multi-tensor optimizer provides a practical benefit in the current GPU runtime.

print("n### SECTION B: FusedLayerNorm / FusedRMSNorm ###") B, T, H = 32, 512, 1024 x = torch.randn(B, T, H, device=DEV, requires_grad=True) torch_ln = torch.nn.LayerNorm(H).to(DEV) def ln_torch():    y = torch_ln(x); y.sum().backward() ms_ln_torch = bench(ln_torch) print(f"  nn.LayerNorm       : {ms_ln_torch:6.2f} ms / fwd+bwd") if HAS_FLN and APEX_OK:    fused_ln = FusedLayerNorm(H).to(DEV)    with torch.no_grad():        fused_ln.weight.copy_(torch_ln.weight); fused_ln.bias.copy_(torch_ln.bias)        diff = (fused_ln(x.detach()) - torch_ln(x.detach())).abs().max().item()    print(f"    max|fused - torch| = {diff:.2e}  (should be ~1e-3 or smaller)")    def ln_fused():        y = fused_ln(x); y.sum().backward()    ms_ln_fused = bench(ln_fused)    print(f"  apex.FusedLayerNorm: {ms_ln_fused:6.2f} ms / fwd+bwd   "          f"(~{ms_ln_torch/ms_ln_fused:0.2f}x)")    if HAS_RMS:        fused_rms = FusedRMSNorm(H).to(DEV)        def rms_fused():            y = fused_rms(x); y.sum().backward()        print(f"  apex.FusedRMSNorm  : {bench(rms_fused):6.2f} ms / fwd+bwd "              f"(RMSNorm: no mean-subtraction, used by LLaMA-style models)") else:    print("  apex.FusedLayerNorm: SKIPPED (cuda ext not built)")

We compare the standard PyTorch LayerNorm with Apex FusedLayerNorm on a large tensor resembling transformer hidden states. We first check numerical correctness by copying the same affine parameters and measuring the maximum difference between fused and standard outputs. We then benchmark forward and backward passes and, when available, test FusedRMSNorm to demonstrate how Apex supports normalization layers used in LLaMA-style models.

print("n### SECTION C: mixed precision (apex.amp opt-levels, DEPRECATED) ###") def tiny_net():    return torch.nn.Sequential(        torch.nn.Linear(256, 256), torch.nn.ReLU(),        torch.nn.Linear(256, 256), torch.nn.ReLU(),        torch.nn.Linear(256, 10),    ).to(DEV) if APEX_OK:    for level in ["O0", "O1", "O2"]:        net = tiny_net()        optimizer = (FusedAdam(net.parameters(), lr=1e-3) if HAS_AMP_C                     else torch.optim.AdamW(net.parameters(), lr=1e-3))        net, optimizer = amp.initialize(net, optimizer, opt_level=level, verbosity=0)        xb = torch.randn(128, 256, device=DEV)        yb = torch.randint(0, 10, (128,), device=DEV)        lossfn = torch.nn.CrossEntropyLoss()        for _ in range(20):            optimizer.zero_grad()            loss = lossfn(net(xb), yb)            with amp.scale_loss(loss, optimizer) as scaled_loss:                scaled_loss.backward()            optimizer.step()        print(f"  opt_level={level}: final loss = {loss.item():.4f}") else:    print("  apex.amp: SKIPPED (apex not importable)") print("n  >> Modern recommended equivalent (torch.amp, no Apex needed):") net = tiny_net() optimizer = torch.optim.AdamW(net.parameters(), lr=1e-3) scaler = torch.amp.GradScaler("cuda") xb = torch.randn(128, 256, device=DEV); yb = torch.randint(0, 10, (128,), device=DEV) lossfn = torch.nn.CrossEntropyLoss() for _ in range(20):    optimizer.zero_grad()    with torch.amp.autocast("cuda", dtype=torch.float16):        loss = lossfn(net(xb), yb)    scaler.scale(loss).backward()    scaler.step(optimizer)    scaler.update() print(f"  torch.amp: final loss = {loss.item():.4f}")

We demonstrate the legacy apex.amp mixed-precision workflow by running small training loops across different opt levels, such as O0, O1, and O2. We use amp.initialize and amp.scale_loss to show how Apex handles model wrapping and loss scaling in the older API. We then run the same kind of mixed precision training with modern torch.amp, which is the recommended approach for new PyTorch code.

print("n### SECTION D: end-to-end Transformer (vanilla fp32 vs Apex fused + AMP) ###") VOCAB, D, NHEAD, LAYERS, SEQ, BATCH, STEPS = 2000, 256, 4, 4, 128, 32, 60 class Block(torch.nn.Module):    def __init__(self, d, nhead, norm_cls):        super().__init__()        self.attn = torch.nn.MultiheadAttention(d, nhead, batch_first=True)        self.ff = torch.nn.Sequential(torch.nn.Linear(d, 4 * d), torch.nn.GELU(),                                      torch.nn.Linear(4 * d, d))        self.n1, self.n2 = norm_cls(d), norm_cls(d)    def forward(self, x):        h = self.n1(x); x = x + self.attn(h, h, h, need_weights=False)[0]        return x + self.ff(self.n2(x)) class TinyTransformer(torch.nn.Module):    def __init__(self, norm_cls):        super().__init__()        self.emb = torch.nn.Embedding(VOCAB, D)        self.blocks = torch.nn.ModuleList([Block(D, NHEAD, norm_cls) for _ in range(LAYERS)])        self.norm = norm_cls(D)        self.head = torch.nn.Linear(D, VOCAB)    def forward(self, idx):        x = self.emb(idx)        for b in self.blocks:            x = b(x)        return self.head(self.norm(x)) g = torch.Generator(device="cpu").manual_seed(0) data = torch.randint(0, VOCAB, (BATCH, SEQ + 1), generator=g).to(DEV) inp, tgt = data[:, :-1], data[:, 1:] lossfn = torch.nn.CrossEntropyLoss() def run_training(use_apex):    torch.manual_seed(0)    norm_cls = (FusedLayerNorm if (use_apex and HAS_FLN and APEX_OK) else torch.nn.LayerNorm)    model = TinyTransformer(norm_cls).to(DEV)    if use_apex and HAS_AMP_C and APEX_OK:        optimizer = FusedAdam(model.parameters(), lr=3e-4)    else:        optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)    scaler = torch.amp.GradScaler("cuda", enabled=use_apex)    def one_step():        optimizer.zero_grad(set_to_none=True)        with torch.amp.autocast("cuda", dtype=torch.float16, enabled=use_apex):            logits = model(inp)            loss = lossfn(logits.reshape(-1, VOCAB), tgt.reshape(-1))        scaler.scale(loss).backward()        scaler.step(optimizer)        scaler.update()        return loss    for _ in range(5):        one_step()    torch.cuda.synchronize()    t0 = time.perf_counter()    for _ in range(STEPS):        loss = one_step()    torch.cuda.synchronize()    dt = time.perf_counter() - t0    return loss.item(), (STEPS * BATCH * SEQ) / dt, dt loss_v, tps_v, dt_v = run_training(use_apex=False) print(f"  vanilla (fp32, nn.LayerNorm, AdamW)        : "      f"{dt_v:5.2f}s  | {tps_v:9.0f} tok/s | final loss {loss_v:.3f}") if APEX_OK and (HAS_AMP_C or HAS_FLN):    loss_a, tps_a, dt_a = run_training(use_apex=True)    print(f"  apex   (fp16, FusedLayerNorm, FusedAdam)   : "          f"{dt_a:5.2f}s  | {tps_a:9.0f} tok/s | final loss {loss_a:.3f}")    print(f"  ----> speedup: {tps_a / tps_v:0.2f}x throughput") else:    print("  apex path SKIPPED (no fused kernels built)") print("n" + "=" * 78) print("DONE. Key takeaways:") print("  - FusedAdam/FusedLayerNorm/FusedRMSNorm are the still-relevant Apex pieces;") print("    speedups grow with model size & parameter count (tiny demo understates it).") print("  - apex.amp is deprecated -> prefer torch.amp.autocast + torch.amp.GradScaler.") print("  - FusedAdam composes cleanly with native torch.amp (Section D).") print("  - On real workloads, also try a larger model and bf16 autocast (no scaler needed).") print("=" * 78)

We build a small Transformer with attention blocks, feed-forward layers, embeddings, and normalization to test Apex in an end-to-end training workload. We train it once with vanilla FP32 PyTorch using AdamW and standard LayerNorm, then train it again with fused Apex components and native PyTorch AMP when the kernels are available. We finally compare runtime, token throughput, final loss, and speedup to understand how fused kernels affect real training performance.

In conclusion, we have a clear and practical understanding of where NVIDIA Apex still fits in a 2026 deep learning workflow. We saw that Apex is no longer primarily about mixed precision, since native PyTorch AMP now handles that aspect more cleanly. However, its fused optimizer and fused normalization kernels can still be useful when the environment supports a proper CUDA extension build. We also learned how to write Apex-aware code that does not break when fused kernels are unavailable, making the tutorial more reliable across Colab runtimes. The final Transformer benchmark gives us a complete view of how FusedAdam, FusedLayerNorm, and torch.amp can work together in an end-to-end training loop. Also, we used this tutorial to move beyond installation and API usage, and we evaluated Apex as it should be evaluated: by checking kernel availability, comparing against PyTorch baselines, and measuring performance in an actual training workload.


Check out the Full Codes with NotebookAlso, feel free to follow us on Twitter and don’t forget to join our 150k+ ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.

Need to partner with us for promoting your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar etc.? Connect with us

Sana Hassan, a consulting intern at Marktechpost and dual-degree student at IIT Madras, is passionate about applying technology and AI to address real-world challenges. With a keen interest in solving practical problems, he brings a fresh perspective to the intersection of AI and real-life solutions.