Skip to content

Conversation

deependujha
Copy link
Collaborator

@deependujha deependujha commented Sep 8, 2025

What does this PR do?

Fixes #<issue_number>

import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.strategies import FSDP2Strategy
from torch.utils.data import DataLoader, TensorDataset


class SimpleModel(L.LightningModule):
    def __init__(self, input_dim=32, hidden_dim=64, output_dim=10):
        super().__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim)
        self.layer2 = nn.Linear(hidden_dim, hidden_dim)
        self.layer3 = nn.Linear(hidden_dim, output_dim)
        self.save_hyperparameters()
        
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log("train_loss", loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log("val_loss", loss, prog_bar=True)
        
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log("test_loss", loss, prog_bar=True)
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


class RandomDataModule(L.LightningDataModule):
    def __init__(self, input_dim=32, output_dim=10, num_samples=1000, batch_size=32):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_samples = num_samples
        self.batch_size = batch_size
        
    def setup(self, stage=None):
        # Generate random data
        if stage == 'fit' or stage is None:
            # Training data
            x_train = torch.randn(self.num_samples, self.input_dim)
            y_train = torch.randint(0, self.output_dim, (self.num_samples,))
            self.train_dataset = TensorDataset(x_train, y_train)
            
            # Validation data
            x_val = torch.randn(self.num_samples // 5, self.input_dim)
            y_val = torch.randint(0, self.output_dim, (self.num_samples // 5,))
            self.val_dataset = TensorDataset(x_val, y_val)
            
        if stage == 'test' or stage is None:
            # Test data
            x_test = torch.randn(self.num_samples // 5, self.input_dim)
            y_test = torch.randint(0, self.output_dim, (self.num_samples // 5,))
            self.test_dataset = TensorDataset(x_test, y_test)
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)


def main():
    # Set up model and data
    input_dim = 32
    hidden_dim = 64
    output_dim = 10
    
    with torch.device("meta"):
        model = SimpleModel(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim)
    data_module = RandomDataModule(input_dim=input_dim, output_dim=output_dim)
    
    # Set up the checkpoint callback to save the best model
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath="checkpoints/",
        filename="model-{epoch:02d}-{val_loss:.2f}",
        save_top_k=1,
        mode="min",
    )

    # # Set up the FSDP strategy
    # fsdp_strategy = FSDP2Strategy(
    # )
    
    # Set up the trainer
    trainer = L.Trainer(
        max_epochs=10,
        callbacks=[checkpoint_callback],
        strategy="fsdp2",
        accelerator="auto",  # Automatically detect available accelerator
        devices="auto",      # Use all available devices
    )
    
    # Train the model
    trainer.fit(model, data_module)
    
    if trainer.is_global_zero:
        print(f"Done training! Best model saved at: {checkpoint_callback.best_model_path}")


if __name__ == "__main__":
    main()
Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

📚 Documentation preview 📚: https://pytorch-lightning--21184.org.readthedocs.build/en/21184/

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Sep 8, 2025
@github-actions github-actions bot added the fabric lightning.fabric.Fabric label Sep 9, 2025
@deependujha deependujha marked this pull request as ready for review September 10, 2025 09:23
Copy link

codecov bot commented Sep 10, 2025

Codecov Report

❌ Patch coverage is 49.04110% with 186 lines in your changes missing coverage. Please review.
✅ Project coverage is 86%. Comparing base (3998b5d) to head (029ebff).
⚠️ Report is 8 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff            @@
##           master   #21184    +/-   ##
========================================
- Coverage      87%      86%    -1%     
========================================
  Files         269      271     +2     
  Lines       23665    24051   +386     
========================================
+ Hits        20642    20710    +68     
- Misses       3023     3341   +318     

Copy link
Collaborator

@SkafteNicki SkafteNicki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably in a follow up PR:

@Borda
Copy link
Member

Borda commented Sep 11, 2025

  • add fsdp2 to fabric?

I would start with FSDP2 in Fabric

@deependujha
Copy link
Collaborator Author

deependujha commented Sep 11, 2025

✅ This PR is functionally complete and introduces initial support for FSDP2 in PyTorch Lightning trainer.

Originally, I planned to follow up with:

  • Gradient accumulation support for FSDP2.
  • Discussion & improvements around best practices for wrapping (e.g., top-level fully_shard(model) vs. selectively wrapping layers like nn.Linear, transformer blocks, etc.).

However, per discussion with @tchaton, we’re ⚠️ pausing this work for now until the PyTorch Lightning Enterprise direction is clear.

I’ll leave this PR open so it can be easily revived later, if we decide to.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants