Source: MachineLearningMastery.com
Some language models are too large to train on a single GPU. In addition to creating the model as a pipeline of stages, as in Pipeline Parallelism, you can split the model across multiple GPUs using Fully Sharded Data Parallelism (FSDP). In this article, you will learn how to use FSDP to split models for training. In particular, you will learn about:
- The idea of sharding and how FSDP works
- How to use FSDP in PyTorch
Let’s get started!

Train Your Large Model on Multiple GPUs with Fully Sharded Data Parallelism.
Photo by Ferenc Horvath. Some rights reserved.
Overview
This article is divided into five parts; they are:
- Introduction to Fully Sharded Data Parallel
- Preparing Model for FSDP Training
- Training Loop with FSDP
- Fine-Tuning FSDP Behavior
- Checkpointing FSDP Models
Introduction to Fully Sharded Data Parallel
Sharding is a term originally used in database management systems, where it refers to dividing a database into smaller units, called shards, to improve performance. In machine learning, sharding refers to dividing model parameters across multiple devices. Unlike pipeline parallelism, the shards contain only a portion of any complete operation. For example, the nn.Linear module is essentially a matrix multiplication. A sharded version of it contains only a portion of the matrix. When a sharded module needs to process data, you must gather the shards to create a complete matrix temporarily and perform the operation. Afterwards, this matrix is discarded to reclaim memory.
When you use FSDP, all model parameters are sharded, and each process holds exactly one shard. Unlike data parallelism, where each GPU has a full copy of the model and only data and gradient updates are synchronized across GPUs, FSDP does not keep a full copy of the model on each GPU; instead, both the model and the data are synchronized on every step. Therefore, FSDP incurs higher communication overhead in exchange for lower memory usage.

FSDP requires processes to exchange data to unshard the model.
The workflow of FSDP is as follows:
There will be multiple processes running together, possibly on multiple machines across a network. Each process (equivalently, each GPU) holds only one shard of the model. When the model is sharded, each module’s weights are stored as a DTensor (distributed tensor, sharded across multiple GPUs) rather than a plain Tensor. Therefore, no process can run any module independently. Before each operation, FSDP issues an all-gather request to enable all processes to exchange a module’s shards with one another. This creates a temporary unsharded module, and each process runs the forward pass on this module with its micro-batch of data. Afterward, the unsharded module is discarded as the processes move on to the next module in the model.
Similar operations happen in the backward pass. Each module must be unsharded when FSDP issues an all-gather request to it. Then the backward pass computes gradients from the forward pass results. Note that each process operates on a different micro-batch of data, so the gradients computed by each process are different. Therefore, FSDP issues a reduce-scatter request, causing all processes to exchange gradients so that the final batch-wide gradient is averaged. This final gradient is then used to update the model parameters on every shard.
As shown in the figure above, FSDP requires more communication and has a more complex workflow than plain data parallelism. Since the model is distributed across multiple GPUs, you do not need as much VRAM to host a very large model. This is the motivation for using FSDP for training.
To improve FSDP’s efficiency, PyTorch uses prefetching to overlap communication and computation. While your GPU is computing the first module, the processes exchange shards from the second module, so the second module is available once the first is complete. This keeps both the GPU and the network busy, reducing the latency per training step. Some tuning in FSDP can help you maximize such overlap and improve training throughput, often at the cost of higher memory usage.
Preparing Model for FSDP Training
When you need FSDP, usually it means your model is too large to fit on a single GPU. One way to enable such a large model is to train it on a fake device meta, then shard it and distribute the shards across multiple GPUs.
In PyTorch, you need to use the torchrun command to launch an FSDP training script with multiple processes. Under torchrun, each process will see the world size (total number of processes), its rank (the index of the current process), and its local rank (the index of the GPU device on the current machine). In the script, you need to initialize this as a process group:
|
import torch.distributed as dist # Initialize the distributed environment dist.init_process_group(backend=“nccl”) local_rank = int(os.environ[“LOCAL_RANK”]) device = torch.device(f“cuda:{local_rank}”) rank = dist.get_rank() world_size = dist.get_world_size() print(f“World size {world_size}, rank {rank}, local rank {local_rank}. Using {device}”) |
Next, you should create the model and then shard it. The code below is based on the model architecture described in the previous post:
|
... from torch.distributed.fsdp import FSDPModule, fully_shard with torch.device(“meta”): model_config = LlamaConfig() model = LlamaForPretraining(model_config) for layer in model.base_model.layers: fully_shard(layer) fully_shard(model.base_model) fully_shard(model) model.to_empty(device=device) model.reset_parameters() assert isinstance(model, FSDPModule), f“Expected FSDPModule, got {type(model)}” |
In PyTorch, you use the fully_shard() function to create a sharded model. This function replaces parameters of type Tensor with DTensor in-place. It also modifies the model to perform the all-gather operation before the actual computation.
You should notice that in the above, fully_shard() is not only called on model, but also on model.base_model as well as each transformer block in the base model. This needs careful consideration.
Usually, you do not want to shard only the top-level model, but also some submodules beneath it. When you do so, you must apply fully_shard() from bottom up, with the top-level model being sharded last. Each sharded module will be one unit of all-gather. In the design shown above, when you pass a tensor to model, the top-level model components will be unsharded, except for those that were sharded separately. Since it is a decoder-only transformer model, the input should be processed by the base model first, then the prediction head in the top model. FSDP will unshard the base model, except for each repeating transformer block. This includes the input embedding layer, which is the first operation applied to the input tensor.
After the embedding layer, the input tensor should be processed by a sequence of transformer blocks. Each block is sharded separately, so all-gather is triggered for each block. The block transforms the input and passes it on to the next transformer block. After the last transformer block, the RMS norm layer in the base model, which is already unsharded, processes the output before returning to the top model for the prediction.
This is why you do not want to shard the top-level model: if you do, the all-gather operation will create a full model on each GPU, violating the assumption that each GPU has insufficient memory to support the full model. In that case, you should use plain data parallelism rather than FSDP.
In this design, each GPU requires one complete transformer block plus the other modules in the top and base models, such as the embedding layer, the final RMS norm layer in the base model, and the prediction head in the top model. You can revise this design (for example, by further sharding model.base_model.embed_tokens and breaking down each transformer block into attention and feed-forward sublayers) to further reduce the memory requirement.
After you have the sharded model, you can transfer it from a meta device to your local GPU with model.to_empty(device=device). You also need to reset the weights of the newly created model (unless you want to initialize them from a checkpoint). You can borrow the function reset_all_weights() from the previous post to reset the weights. Here is another way that uses model.reset_parameters(). This requires you to implement the corresponding member function in each module:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
lass LlamaAttention(nn.Module): “”“Grouped-query attention with rotary embeddings.”“” def __init__(self, config: LlamaConfig) -> None: super().__init__() self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_kv_heads = config.num_key_value_heads # GQA: H_kv < H_q # Linear layers for Q, K, V projections self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) def reset_parameters(self): self.q_proj.reset_parameters() self.k_proj.reset_parameters() self.v_proj.reset_parameters() self.o_proj.reset_parameters() def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor: ... class LlamaMLP(nn.Module): “”“Feed-forward network with SwiGLU activation.”“” def __init__(self, config: LlamaConfig) -> None: super().__init__() # Two parallel projections for SwiGLU self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.act_fn = F.silu # SwiGLU activation function # Project back to hidden size self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) def reset_parameters(self): self.gate_proj.reset_parameters() self.up_proj.reset_parameters() self.down_proj.reset_parameters() def forward(self, x: Tensor) -> Tensor: ... class LlamaDecoderLayer(nn.Module): “”“Single transformer layer for a Llama model.”“” def __init__(self, config: LlamaConfig) -> None: super().__init__() self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=1e–5) self.self_attn = LlamaAttention(config) self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=1e–5) self.mlp = LlamaMLP(config) def reset_parameters(self): self.input_layernorm.reset_parameters() self.self_attn.reset_parameters() self.post_attention_layernorm.reset_parameters() self.mlp.reset_parameters() def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor: ... class LlamaModel(nn.Module): “”“The full Llama model without any pretraining heads.”“” def __init__(self, config: LlamaConfig) -> None: super().__init__() self.rotary_emb = RotaryPositionEncoding( config.hidden_size // config.num_attention_heads, config.max_position_embeddings, ) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers) ]) self.norm = nn.RMSNorm(config.hidden_size, eps=1e–5) def reset_parameters(self): self.embed_tokens.reset_parameters() for layer in self.layers: layer.reset_parameters() self.norm.reset_parameters() def forward(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor: ... class LlamaForPretraining(nn.Module): def __init__(self, config: LlamaConfig) -> None: super().__init__() self.base_model = LlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def reset_parameters(self): self.base_model.reset_parameters() self.lm_head.reset_parameters() def forward(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor: ... |
You know the model is sharded if it is an instance of FSDPModule. Subsequently, you can create the optimizer and other training components as usual. The PyTorch optimizer supports updating DTensor objects the same way as plain Tensor objects.
Training Loop with FSDP
Using FSDP is straightforward. Virtually nothing needs to be changed in the training loop:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
... optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate, betas=(0.9, 0.99), eps=1e–8, weight_decay=0.1, ) warmup_scheduler = lr_scheduler.LinearLR( optimizer, start_factor=0.1, end_factor=1.0, total_iters=num_warmup_steps, ) cosine_scheduler = lr_scheduler.CosineAnnealingLR( optimizer, T_max=num_training_steps – num_warmup_steps, eta_min=0, ) scheduler = lr_scheduler.SequentialLR( optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[num_warmup_steps], ) loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN_ID) # Start training loop for epoch in range(epochs): pbar = tqdm.tqdm(dataloader, desc=f“Epoch {epoch+1}/{epochs}”) for batch_id, batch in enumerate(pbar): # Explicit prefetching before sending any data to model model.unshard() # Get batched data, move from CPU to GPU input_ids, target_ids = batch input_ids = input_ids.to(device) target_ids = target_ids.to(device) # create attention mask: causal mask + padding mask attn_mask = create_causal_mask(input_ids) + create_padding_mask(input_ids, PAD_TOKEN_ID) # Extract output from model logits = model(input_ids, attn_mask) # Compute loss: cross-entropy between logits and target, ignoring padding tokens loss = loss_fn(logits.view(–1, logits.size(–1)), target_ids.view(–1)) # Backward with loss and gradient clipping by L2 norm to 1.0 # Optimizer and gradient clipping works on DTensor optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() pbar.set_postfix(loss=loss.item()) pbar.update(1) pbar.close() |
The only change you can observe is the use of model.unshard() to trigger the all-gather before the forward pass, but this is optional. Even if you do not call it, model(input_ids, attn_mask) will still trigger the all-gather operation internally. This line simply starts the all-gather before the input tensor is prepared for the forward pass.
However, FSDP is partially a data parallelism technique. As with distributed data parallelism, you should use a sampler with your data loader so that each rank in the process group processes a different micro-batch. This works because each process sees a complete module of the model at a time via all-gather, so each process can use that module to process a different micro-batch of data. In essence, FSDP exchanges both the model and the training data, going one step further than data parallelism. Below is how you should set up your data loader:
|
... dataset = PretrainingDataset(dataset, tokenizer, seq_length) sampler = DistributedSampler(dataset, shuffle=False) dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, sampler=sampler, pin_memory=True, # optional shuffle=False, num_workers=2, prefetch_factor=2, ) |
This is the same as how you set up the data loader for distributed data parallel in the previous article.
Fine-Tuning FSDP Behavior
The above is all you need to run FSDP training. However, you can introduce variations to fine-tune FSDP’s behavior.
Using torch.compile()
If your model can be compiled, you can also compile an FSDP model. However, you need to compile it after sharding the model, so the compiled model can reference the distributed tensors rather than plain tensors.
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
... # create sharded model first fully_shard(model) model.to_empty(device=device) model.reset_parameters() ... # create data loader, optimizer, scheduler, loss function for the training loop # then create compiled version model = torch.compile(model) loss_fn = torch.compile(loss_fn) # Start training loop for epoch in range(epochs): for batch in dataloader: ... logits = model(input_ids, attn_mask) loss = loss_fn(logits.view(–1, logits.size(–1)), target_ids.view(–1)) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() |
Arguments to fully_shard()
Recall that you can use torch.autocast() to run mixed precision training. You can enable mixed-precision training in FSDP as well, but you must apply it when you shard the model. The change needed is particularly simple:
|
... from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard with torch.device(“meta”): model_config = LlamaConfig() model = LlamaForPretraining(model_config) mp_policy = MixedPrecisionPolicy( param_dtype=torch.bfloat16, reduce_dtype=torch.float32, ) for layer in model.base_model.layers: fully_shard(layer, mp_policy=mp_policy) fully_shard(model.base_model, mp_policy=mp_policy) fully_shard(model, mp_policy=mp_policy) model.to_empty(device=device) |
When you shard the model, you can specify the argument mp_policy to describe exactly how the mixed precision training should be performed. In the example above, you keep the model parameters in bfloat16, but use float32 for gradients (during scatter-reduce). You can also specify output_dtype and cast_forward_inputs to define the data types of the forward pass inputs and outputs. Note that since fully_shard() is applied to each module, you are free to use different mixed precision policies for different modules.
Of course, PyTorch still allows you to use torch.set_default_dtype(torch.bfloat16) to change the default data type for the entire model. This changes the default data type for all DTensor objects created.
In FSDP, you need an all-gather step before the actual forward or backward computation. Before all-gather, you do not have a complete parameter for the operation. Since inter-process communication is slow and a lot of data needs to be moved to the GPU anyway, you can apply CPU offloading to keep your sharded model in CPU memory when it is not in use. This means:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
... from torch.distributed.fsdp import MixedPrecisionPolicy, CPUOffloadPolicy, fully_shard with torch.device(“meta”): model_config = LlamaConfig() model = LlamaForPretraining(model_config) mp_policy = MixedPrecisionPolicy( param_dtype=torch.bfloat16, reduce_dtype=torch.float32, ) offload_policy = CPUOffloadPolicy(pin_memory=True) for layer in model.base_model.layers: fully_shard(layer, mp_policy=mp_policy, offload_policy=offload_policy) fully_shard(model.base_model, mp_policy=mp_policy, offload_policy=offload_policy) fully_shard(model, mp_policy=mp_policy, offload_policy=offload_policy) model.to_empty(device=“cpu”) |
Typically, using CPU offloading makes the training loop noticeably slower. If you use CPU offloading, you should consider changing the training loop such that the optimizer zeros out gradient tensors instead of setting the gradients to None:
|
... for batch_id, batch in enumerate(pbar): ... logits = model(input_ids, attn_mask) loss = loss_fn(logits.view(–1, logits.size(–1)), target_ids.view(–1)) optimizer.zero_grad(set_to_none=False) # retain allocated grad tensor loss.backward() optimizer.step() ... |
This is because CPU memory is usually more abundant than GPU memory, and you can afford to keep the allocated gradient tensors in memory to avoid the overhead of re-allocating them.
The third argument you can add to fully_shard() is reshard_after_forward=True. By default (reshard_after_forward=None), FSDP will keep the unsharded model in the memory of the root module after the forward pass, so the backward pass does not need to call all-gather again. Non-root modules will always discard the unsharded tensors, unless you set reshard_after_forward=False.
Usually, you do not want to change this setting, since this likely means you need to run all-gather immediately after discarding the unsharded tensors. But understanding how this parameter works lets you reconsider your model design: In the implementation of LlamaForPretraining above, the root module contains only the prediction head. But if you move the embedding layer from the base model LlamaModel to the root model, you will keep the embedding layer (which is usually large) in memory for a long time. This is the model engineering you can consider when applying FSDP.
Gradient Checkpointing
FSDP has a lower memory requirement than plain data parallelism. If you want to further reduce memory usage, you can use gradient checkpointing with FSDP. Unlike the plain model, you do not use torch.utils.checkpoint.checkpoint() to wrap the part that requires gradient checkpointing. Instead, you set a policy and apply it to the sharded model:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
... import functools from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing, checkpoint_wrapper, ) from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls={LlamaDecoderLayer, nn.Embedding}, ) apply_activation_checkpointing( model, checkpoint_wrapper_fn=checkpoint_wrapper, auto_wrap_policy=wrap_policy, ) |
The wrap_policy is a helper function that checks whether the module belongs to one of the listed classes. If so, gradient checkpointing will be applied to it, so its internal activations are discarded after the forward pass and recomputed during the backward pass. The function apply_activation_checkpointing() recursively scans the module and applies gradient checkpointing to its submodules.
As a reminder, gradient checkpointing is a technique that trades time for memory during training. You save memory by discarding intermediate activations, but the backward pass is slower due to recomputation.
All-Gather Prefetching
FSDP implements a similar efficiency optimization to pipeline parallelism: it issues an all-gather request to the next module while the current module is processing data. This is called prefetching and it deliberately overlaps communication and computation to reduce the latency of each training step.
You can indeed control how the prefetching is performed. Below is an example:
|
... num_prefetch = 2 modules = list(model.base_model.layers) for i, module in enumerate(modules): if i == len(modules) – 1: break module.set_modules_to_forward_prefetch(modules[i+1:i+num_prefetch+1]) for i, module in enumerate(modules): if i == 0: continue module.set_modules_to_backward_prefetch(modules[max(0, i–num_prefetch):i]) |
By default, FSDP determines the next module and prefetches it. The code above causes FSDP to prefetch not the next item but two items ahead. The modules list enumerates the sharded modules in the model in their execution order. Then, for each module, you set the forward prefetch to two subsequent modules and the backward prefetch to two preceding modules.
Note that FSDP will not check if you specify them in the correct execution order. If you prefetch the wrong module, your training performance will deteriorate. But you also must not specify a module that is not sharded (such as model.lm_head in the example above) as FSDP will not be able to issue all-gather requests for it.
Checkpointing FSDP Models
FSDP model is still a PyTorch model, but with the model weights replaced by DTensor objects. If you want to, you can still manipulate the DTensor objects like a Tensor object, as the optimizer would do in your training loop. You can also check the DTensor objects to see what is in each shard:
|
... from torch.distributed.fsdp import FSDPModule from torch.distributed.tensor import Shard # Expect the model object to be both FSDPModule type and the original model type assert isinstance(model, FSDPModule) assert isinstance(model, LlamaForPretraining) rank = torch.distributed.get_rank() for param in model.parameters(): # DTensors should have a placement assert param.placements == (Shard(rank),) # DTensors has the same dtype as the original tensor assert param.dtype == torch.float32 # You can see what is in this shard print(param.get_local_tensor()) |
You can use this property to save and load a sharded model. However, you must ensure that only one process is saving the model so that you do not overwrite the file on disk:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
# Save model on rank 0 (the master process) if torch.distributed.get_rank() == 0: sharded_state_dict = model.state_dict() # map to DTensor full_state_dict = {} # map to plain Tensor on CPU for param_name, sharded_param in sharded_state_dict.items(): full_param = sharded_param.full_tensor() full_state_dict[param_name] = full_param.cpu() torch.save(full_state_dict, “model.pth”) # Load model on all processes together from torch.distributed.tensor import distribute_tensor dist.barrier() full_state_dict = torch.load(“model.pth”, map_location=“cpu”, mmap=True) meta_sharded_state_dict = model.state_dict() # FSDPModule on meta device sharded_state_dict = {} for param_name, full_tensor in full_state_dict.items(): # create new DTensor, reusing the device mesh and placement as fully_shard() assigned sharded_meta_param = meta_sharded_state_dict.get(param_name) dtensor = distribute_tensor( full_tensor, sharded_meta_param.device_mesh, sharded_meta_param.placements, ) sharded_state_dict[param_name] = nn.Parameter(dtensor) # must use `assign=True` to replace tensor on meta device with actual DTensor model.load_state_dict(sharded_state_dict, strict=False, assign=True) dist.barrier() |
Indeed, there is an easier method: The distributed checkpointing API, as you have already seen in the previous article:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
... from torch.distributed.checkpoint import load, save from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict, StateDictOptions def save_checkpoint(model, optimizer): dist.barrier() model_state, optimizer_state = get_state_dict( model, optimizer, options=StateDictOptions(full_state_dict=True, cpu_offload=True) ) save( {“model”: model_state, “optimizer”: optimizer_state}, checkpoint_id=“checkpoint-dist”, # each rank will save its own file ) dist.barrier() def load_checkpoint(model, optimizer): dist.barrier() model_state, optimizer_state = get_state_dict( model, optimizer, options=StateDictOptions(full_state_dict=True, cpu_offload=True) ) load( {“model”: model_state, “optimizer”: optimizer_state}, checkpoint_id=“checkpoint-dist” # each rank will load its own file ) # necessary if model.load_state_dict() should be called set_state_dict( model, optimizer, model_state_dict=model_state, optim_state_dict=optimizer_state, options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True, cpu_offload=True) ) dist.barrier() |
The cpu_offload option must be removed if you do not use CPU offloading.
These two functions are supposed to be called by all processes together. Each process will save its own sharded model and optimizer state to a different file, all under the same directory as the checkpoint_id you specified. Do not attempt to read them with torch.load() since these files are in a different format. However, you can still use the same load_checkpoint() function above on an unsharded model in a plain Python script. Usually, after training is completed, you can recreate the model file from sharded checkpoints:
|
... model = LlamaForPretraining(model_config) # unsharded model optimizer = torch.optim.AdamW(model.parameters(), lr=1e–4) load_checkpoint(model, optimizer) torch.save(model.state_dict(), “model.pth”) |
For completeness, below is the full script that you can run FSDP training:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 |
import dataclasses import functools import os import datasets import tokenizers import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F import torch.optim.lr_scheduler as lr_scheduler import tqdm from torch import Tensor from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing, checkpoint_wrapper, ) from torch.distributed.checkpoint import load, save from torch.distributed.checkpoint.state_dict import ( StateDictOptions, get_state_dict, set_state_dict, ) from torch.distributed.fsdp import ( CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard, ) from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.utils.data.distributed import DistributedSampler # Build the model @dataclasses.dataclass class LlamaConfig: “”“Define Llama model hyperparameters.”“” vocab_size: int = 50000 # Size of the tokenizer vocabulary max_position_embeddings: int = 2048 # Maximum sequence length hidden_size: int = 768 # Dimension of hidden layers intermediate_size: int = 4*768 # Dimension of MLP’s hidden layer num_hidden_layers: int = 12 # Number of transformer layers num_attention_heads: int = 12 # Number of attention heads num_key_value_heads: int = 3 # Number of key-value heads for GQA class RotaryPositionEncoding(nn.Module): “”“Rotary position encoding.”“” def __init__(self, dim: int, max_position_embeddings: int) -> None: “”“Initialize the RotaryPositionEncoding module. Args: dim: The hidden dimension of the input tensor to which RoPE is applied max_position_embeddings: The maximum sequence length of the input tensor ““” super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings # compute a matrix of ntheta_i N = 10_000.0 inv_freq = 1.0 / (N ** (torch.arange(0, dim, 2) / dim)) inv_freq = torch.cat((inv_freq, inv_freq), dim=–1) position = torch.arange(max_position_embeddings) sinusoid_inp = torch.outer(position, inv_freq) # save cosine and sine matrices as buffers, not parameters self.register_buffer(“cos”, sinusoid_inp.cos()) self.register_buffer(“sin”, sinusoid_inp.sin()) def forward(self, x: Tensor) -> Tensor: “”“Apply RoPE to tensor x. Args: x: Input tensor of shape (batch_size, seq_length, num_heads, head_dim) Returns: Output tensor of shape (batch_size, seq_length, num_heads, head_dim) ““” batch_size, seq_len, num_heads, head_dim = x.shape device = x.device dtype = x.dtype # transform the cosine and sine matrices to 4D tensor and the same dtype as x cos = self.cos.to(device, dtype)[:seq_len].view(1, seq_len, 1, –1) sin = self.sin.to(device, dtype)[:seq_len].view(1, seq_len, 1, –1) # apply RoPE to x x1, x2 = x.chunk(2, dim=–1) rotated = torch.cat((–x2, x1), dim=–1) output = (x * cos) + (rotated * sin) return output class LlamaAttention(nn.Module): “”“Grouped-query attention with rotary embeddings.”“” def __init__(self, config: LlamaConfig) -> None: super().__init__() self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_kv_heads = config.num_key_value_heads # GQA: H_kv < H_q # hidden_size must be divisible by num_heads assert (self.head_dim * self.num_heads) == self.hidden_size # Linear layers for Q, K, V projections self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) def reset_parameters(self): self.q_proj.reset_parameters() self.k_proj.reset_parameters() self.v_proj.reset_parameters() self.o_proj.reset_parameters() def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor: bs, seq_len, dim = hidden_states.size() # Project inputs to Q, K, V query_states = self.q_proj(hidden_states).view(bs, seq_len, self.num_heads, self.head_dim) key_states = self.k_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim) value_states = self.v_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim) # Apply rotary position embeddings query_states = rope(query_states) key_states = rope(key_states) # Transpose tensors from BSHD to BHSD dimension for scaled_dot_product_attention query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) # Use PyTorch’s optimized attention implementation # setting is_causal=True is incompatible with setting explicit attention mask attn_output = F.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attn_mask, dropout_p=0.0, enable_gqa=True, ) # Transpose output tensor from BHSD to BSHD dimension, reshape to 3D, and then project output attn_output = attn_output.transpose(1, 2).reshape(bs, seq_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output class LlamaMLP(nn.Module): “”“Feed-forward network with SwiGLU activation.”“” def __init__(self, config: LlamaConfig) -> None: super().__init__() # Two parallel projections for SwiGLU self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.act_fn = F.silu # SwiGLU activation function # Project back to hidden size self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) def reset_parameters(self): self.gate_proj.reset_parameters() self.up_proj.reset_parameters() self.down_proj.reset_parameters() def forward(self, x: Tensor) -> Tensor: # SwiGLU activation: multiply gate and up-projected inputs gate = self.act_fn(self.gate_proj(x)) up = self.up_proj(x) return self.down_proj(gate * up) class LlamaDecoderLayer(nn.Module): “”“Single transformer layer for a Llama model.”“” def __init__(self, config: LlamaConfig) -> None: super().__init__() self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=1e–5) self.self_attn = LlamaAttention(config) self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=1e–5) self.mlp = LlamaMLP(config) def reset_parameters(self): self.input_layernorm.reset_parameters() self.self_attn.reset_parameters() self.post_attention_layernorm.reset_parameters() self.mlp.reset_parameters() def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor: # First residual block: Self-attention residual = hidden_states hidden_states = self.input_layernorm(hidden_states) attn_outputs = self.self_attn(hidden_states, rope=rope, attn_mask=attn_mask) hidden_states = attn_outputs + residual # Second residual block: MLP residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) + residual return hidden_states class LlamaModel(nn.Module): “”“The full Llama model without any pretraining heads.”“” def __init__(self, config: LlamaConfig) -> None: super().__init__() self.rotary_emb = RotaryPositionEncoding( config.hidden_size // config.num_attention_heads, config.max_position_embeddings, ) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers) ]) self.norm = nn.RMSNorm(config.hidden_size, eps=1e–5) def reset_parameters(self): self.embed_tokens.reset_parameters() for layer in self.layers: layer.reset_parameters() self.norm.reset_parameters() def forward(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor: # Convert input token IDs to embeddings hidden_states = self.embed_tokens(input_ids) # Process through all transformer layers, then the final norm layer for layer in self.layers: hidden_states = layer(hidden_states, rope=self.rotary_emb, attn_mask=attn_mask) hidden_states = self.norm(hidden_states) # Return the final hidden states return hidden_states class LlamaForPretraining(nn.Module): def __init__(self, config: LlamaConfig) -> None: super().__init__() self.base_model = LlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def reset_parameters(self): self.base_model.reset_parameters() self.lm_head.reset_parameters() def forward(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor: hidden_states = self.base_model(input_ids, attn_mask) return self.lm_head(hidden_states) def create_causal_mask(batch: Tensor, dtype: torch.dtype = torch.float32) -> Tensor: “”“Create a causal mask for self-attention. Args: batch: Batch of sequences, shape (batch_size, seq_len) dtype: Data type of the mask Returns: Causal mask of shape (seq_len, seq_len) ““” batch_size, seq_len = batch.shape mask = torch.full((seq_len, seq_len), float(“-inf”), device=batch.device, dtype=dtype) .triu(diagonal=1) return mask def create_padding_mask(batch: Tensor, padding_token_id: int, dtype: torch.dtype = torch.float32) -> Tensor: “”“Create a padding mask for a batch of sequences for self-attention. Args: batch: Batch of sequences, shape (batch_size, seq_len) padding_token_id: ID of the padding token dtype: Data type of the mask Returns: Padding mask of shape (batch_size, 1, seq_len, seq_len) ““” padded = torch.zeros_like(batch, device=batch.device, dtype=dtype) .masked_fill(batch == padding_token_id, float(“-inf”)) mask = padded[:,:,None] + padded[:,None,:] return mask[:, None, :, :] # Generator function to create padded sequences of fixed length class PretrainingDataset(torch.utils.data.Dataset): def __init__(self, dataset: datasets.Dataset, tokenizer: tokenizers.Tokenizer, seq_length: int): self.dataset = dataset self.tokenizer = tokenizer self.seq_length = seq_length self.bot = tokenizer.token_to_id(“[BOT]”) self.eot = tokenizer.token_to_id(“[EOT]”) self.pad = tokenizer.token_to_id(“[PAD]”) def __len__(self): return len(self.dataset) def __getitem__(self, index: int) -> tuple[Tensor, Tensor]: “”“Get a sequence of token ids from the dataset. [BOT] and [EOT] tokens are added. Clipped and padded to the sequence length. ““” seq = self.dataset[index][“text”] tokens: list[int] = [self.bot] + self.tokenizer.encode(seq).ids + [self.eot] # pad to target sequence length toklen = len(tokens) if toklen < self.seq_length+1: pad_length = self.seq_length+1 – toklen tokens += [self.pad] * pad_length # return the sequence x = torch.tensor(tokens[:self.seq_length], dtype=torch.int64) y = torch.tensor(tokens[1:self.seq_length+1], dtype=torch.int64) return x, y def load_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer, scheduler: lr_scheduler.SequentialLR) -> None: dist.barrier() model_state, optimizer_state = get_state_dict( model, optimizer, options=StateDictOptions(full_state_dict=True, cpu_offload=cpu_offload), ) load( {“model”: model_state, “optimizer”: optimizer_state}, checkpoint_id=“checkpoint-dist”, ) set_state_dict( model, optimizer, model_state_dict=model_state, optim_state_dict=optimizer_state, options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True, cpu_offload=cpu_offload), ) scheduler.load_state_dict( torch.load(“checkpoint-dist/lrscheduler.pt”, map_location=device), ) dist.barrier() def save_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer, scheduler: lr_scheduler.SequentialLR) -> None: dist.barrier() model_state, optimizer_state = get_state_dict( model, optimizer, options=StateDictOptions(full_state_dict=True, cpu_offload=cpu_offload), ) save( {“model”: model_state, “optimizer”: optimizer_state}, checkpoint_id=“checkpoint-dist”, ) if dist.get_rank() == 0: torch.save(scheduler.state_dict(), “checkpoint-dist/lrscheduler.pt”) dist.barrier() # Load the tokenizer and dataset tokenizer = tokenizers.Tokenizer.from_file(“bpe_50K.json”) dataset = datasets.load_dataset(“HuggingFaceFW/fineweb”, “sample-10BT”, split=“train”) # Initialize the distributed environment dist.init_process_group(backend=“nccl”) local_rank = int(os.environ[“LOCAL_RANK”]) device = torch.device(f“cuda:{local_rank}”) rank = dist.get_rank() world_size = dist.get_world_size() print(f“World size {world_size}, rank {rank}, local rank {local_rank}. Using {device}”) # Create pretraining model on meta device, on all ranks with torch.device(“meta”): model_config = LlamaConfig() model = LlamaForPretraining(model_config) # Convert model from meta device to FSDP2, must shard every component cpu_offload = False fsdp_kwargs = { # optional: use mixed precision training “mp_policy”: MixedPrecisionPolicy( param_dtype=torch.bfloat16, reduce_dtype=torch.float32, ), # optional: CPU offloading “offload_policy”: CPUOffloadPolicy() if cpu_offload else None, # optional: discard all-gathered parameters after forward pass even on root modules # “reshard_after_forward”: True, } for layer in model.base_model.layers: fully_shard(layer, **fsdp_kwargs) fully_shard(model.base_model, **fsdp_kwargs) fully_shard(model, **fsdp_kwargs) model.to_empty(device=“cpu” if cpu_offload else device) model.reset_parameters() assert isinstance(model, FSDPModule), f“Expected FSDPModule, got {type(model)}” # Set explicit prefetching on models # more prefetching uses more memory, but allow more overlap of computation and communication num_prefetch = 1 if num_prefetch > 1: modules = list(model.base_model.layers) for i, module in enumerate(modules): if i == len(modules) – 1: break module.set_modules_to_forward_prefetch(modules[i+1:i+num_prefetch+1]) for i, module in enumerate(modules): if i == 0: continue module.set_modules_to_backward_prefetch(modules[max(0, i–num_prefetch):i]) # Optional: Apply gradient checkpointing on a distributed model (all ranks) #wrap_policy = functools.partial( # transformer_auto_wrap_policy, # transformer_layer_cls={LlamaDecoderLayer, nn.Embedding}, #) #apply_activation_checkpointing( # model, # checkpoint_wrapper_fn=checkpoint_wrapper, # auto_wrap_policy=wrap_policy, #) # Training parameters epochs = 3 learning_rate = 1e–3 batch_size = 64 // world_size seq_length = 512 num_warmup_steps = 1000 PAD_TOKEN_ID = tokenizer.token_to_id(“[PAD]”) model.train() # DataLoader, optimizer, scheduler, and loss function # Sampler is needed to shard the dataset across world size dataset = PretrainingDataset(dataset, tokenizer, seq_length) sampler = DistributedSampler(dataset, shuffle=False, drop_last=True) dataloader = torch.utils.data.DataLoader( dataset, sampler=sampler, batch_size=batch_size, pin_memory=True, # optional shuffle=False, num_workers=2, prefetch_factor=2, ) num_training_steps = len(dataloader) * epochs optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate, betas=(0.9, 0.99), eps=1e–8, weight_decay=0.1, ) warmup_scheduler = lr_scheduler.LinearLR( optimizer, start_factor=0.1, end_factor=1.0, total_iters=num_warmup_steps, ) cosine_scheduler = lr_scheduler.CosineAnnealingLR( optimizer, T_max=num_training_steps – num_warmup_steps, eta_min=0, ) scheduler = lr_scheduler.SequentialLR( optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[num_warmup_steps], ) loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN_ID) # Optional: Compile the model and loss function #model = torch.compile(model) #loss_fn = torch.compile(loss_fn) # if checkpoint-dist dir exists, load the checkpoint to model and optimizer if os.path.exists(“checkpoint-dist”): load_checkpoint(model, optimizer, scheduler) # start training for epoch in range(epochs): pbar = tqdm.tqdm(dataloader, desc=f“Epoch {epoch+1}/{epochs}”) for batch_id, batch in enumerate(pbar): if batch_id % 1000 == 0: save_checkpoint(model, optimizer, scheduler) # Explicit prefetching before sending any data to model model.unshard() # Get batched data, move from CPU to GPU input_ids, target_ids = batch input_ids = input_ids.to(device) target_ids = target_ids.to(device) # create attention mask: causal mask + padding mask attn_mask = create_causal_mask(input_ids) + create_padding_mask(input_ids, PAD_TOKEN_ID) # Extract output from model logits = model(input_ids, attn_mask) # Compute loss: cross-entropy between logits and target, ignoring padding tokens loss = loss_fn(logits.view(–1, logits.size(–1)), target_ids.view(–1)) # Backward with loss and gradient clipping by L2 norm to 1.0 # Optimizer and gradient clipping works on DTensor optimizer.zero_grad(set_to_none=False if cpu_offload else True) loss.backward() # All-reduce fail if using CPU offloading if not cpu_offload: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() pbar.set_postfix(loss=loss.item()) pbar.update(1) pbar.close() # Save the model save_checkpoint(model, optimizer, scheduler) # Clean up the distributed environment dist.destroy_process_group() |
To run this code, you need to run it with the torchrun command, such as torchrun --standalone --nproc_per_node=4 fsdp_training.py.
This code incorporates all elements discussed in this article. It may not be the most efficient implementation. You should read and modify it to suit your needs.
Further Readings
Below are some resources that you may find useful:
- Ott et al, Fully Sharded Data Parallel: Faster AI training with fewer GPUs, Engineering at Meta, 2021.
- Zhao et al (2023) PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel, Proc VLDB Endowment, Vol 16, No 12, pp.3848-3860.
- Getting Started with Fully Sharded Data Parallelism (FSDP2), from PyTorch tutorials
- Advanced Model Training with Fully Sharded Data Parallelism (FSDP), from PyTorch tutorials
- Large Scale Transformer model training with Tensor Parallel (TP), from PyTorch tutorials
- torch.distributed.fsdp.fully_shard (FSDP2) API, from PyTorch documentation
Summary
In this article, you learned about Fully Sharded Data Parallelism (FSDP) and how to use it in PyTorch. Specifically, you learned:
- FSDP is a data parallelism technique that shards the model across multiple GPUs.
- FSDP requires more communication and has a more complex workflow than plain data parallelism.
- FSDP can be used to train very large models with fewer GPUs. You can also apply mixed-precision training and other techniques to trade off memory and compute performance.
