Skip to content

UserWarning: cuDNN SDPA backward got grad_output.strides() != output.strides() #144913

@vecorro

Description

@vecorro

🐛 Describe the bug

I'm getting this warning when using TRainer and FSDP to pre-train Llama3.1-8b.

UserWarning: cuDNN SDPA backward got grad_output.strides() != output.strides()

This might introduce overhead in the training processes.

I have tried to disable the backend with:

import os
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0"

from torch.nn.attention import SDPBackend
torch.backends.cuda.sdp_kernel = SDPBackend.FLASH_ATTENTION

However, the HF Trainer ignores these settings and continues using SDPA.

Here is the full script:

import datasets
import torch
import time

from torch.utils.data import DataLoader, Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    TrainerCallback,
    Trainer,
    set_seed,
    DataCollatorWithPadding,
)

from transformers.integrations import TensorBoardCallback
import GPUtil, psutil

from torch.utils.tensorboard import SummaryWriter

# Explicitly disable cuDNN SDPA to avoid stride mismatch warnings
import os
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0"

# Set Flash Attention as the preferred backend
from torch.nn.attention import SDPBackend
torch.backends.cuda.sdp_kernel = SDPBackend.FLASH_ATTENTION

# Model and dataset configuration
LLM_MODEL = "meta-llama/Meta-Llama-3.1-8B"
DATASET_PATH = "../data-prep/data_files/llama31_tokenized_docs_full_dataset.parquet"
OUTPUT_DIR = "./llama3_8b_ddp_pretraining"
set_seed(42)

# Load model and tokenizer
model_name = LLM_MODEL
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
)
model.config.use_cache = False
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Ensure pad token is set for the tokenizer
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Custom dataset class with contiguous tensors
class CustomDataset(Dataset):
    def __init__(self, dataset_name, tokenizer, split="train", max_tokens=None, max_length=512):
        self.dataset = datasets.load_dataset(
            "parquet",
            data_files=dataset_name,
            split=split
        )
        if max_tokens is not None:
            self.dataset = self.dataset.filter(lambda x: x["num_tokens"] <= max_tokens)
        
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        input_ids = self.dataset[idx]["input_ids"]
        if len(input_ids) > self.max_length:
            input_ids = input_ids[:self.max_length]
        
        attention_mask = [1] * len(input_ids)
        padding_length = self.max_length - len(input_ids)
        
        if padding_length > 0:
            input_ids += [self.tokenizer.pad_token_id] * padding_length
            attention_mask += [0] * padding_length
        
        # Ensure tensors are contiguous
        input_ids = torch.tensor(input_ids, dtype=torch.long).contiguous()
        attention_mask = torch.tensor(attention_mask, dtype=torch.long).contiguous()
        labels = input_ids.clone().contiguous()
        
        return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

# Initialize dataset and data collator
train_dataset = CustomDataset(
    dataset_name=DATASET_PATH,
    tokenizer=tokenizer,
    split="train",
    max_tokens=512,
    max_length=512,
)
print(f"Training dataset size is: {len(train_dataset.dataset)} samples")

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,  
    optim="adamw_torch",  
    num_train_epochs=1,
    per_device_train_batch_size=64, 
    gradient_accumulation_steps=8,
    learning_rate=3e-5,
    weight_decay=0.01,
    warmup_steps=10,
    lr_scheduler_type="cosine",  
    gradient_checkpointing=True,
    dataloader_num_workers=8,
    bf16=True,  
    logging_steps=10,  
    report_to="tensorboard", 
    save_strategy="epoch",
    save_total_limit=2,  
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=None,
    data_collator=data_collator,
 
)

trainer.train()


Versions

NVIDIA (PyTorch container) Release 24.12 (build 126674149)
Using CUDA 12.6 driver version 560.35.05 with kernel driver version 550.127.08
pytorch-triton             3.0.0+72734f086
torch                      2.6.0a0+df5bbc09d1.nv24.12
torch-tb-profiler          0.4.3
torch_tensorrt             2.6.0a0
torchprofile               0.0.4
torchvision                0.20.0a0
transformers               4.48.0
accelerate                 1.2.1

cc @csarofeen @ptrblck @xwang233 @eqy

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudnnRelated to torch.backends.cudnn, and CuDNN supportmodule: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions