-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Open
Labels
module: cudnnRelated to torch.backends.cudnn, and CuDNN supportRelated to torch.backends.cudnn, and CuDNN supportmodule: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiionAll things related to torch.nn.functional.scaled_dot_product_attentiiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
I'm getting this warning when using TRainer and FSDP to pre-train Llama3.1-8b.
UserWarning: cuDNN SDPA backward got grad_output.strides() != output.strides()
This might introduce overhead in the training processes.
I have tried to disable the backend with:
import os
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0"
from torch.nn.attention import SDPBackend
torch.backends.cuda.sdp_kernel = SDPBackend.FLASH_ATTENTION
However, the HF Trainer ignores these settings and continues using SDPA.
Here is the full script:
import datasets
import torch
import time
from torch.utils.data import DataLoader, Dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
TrainerCallback,
Trainer,
set_seed,
DataCollatorWithPadding,
)
from transformers.integrations import TensorBoardCallback
import GPUtil, psutil
from torch.utils.tensorboard import SummaryWriter
# Explicitly disable cuDNN SDPA to avoid stride mismatch warnings
import os
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0"
# Set Flash Attention as the preferred backend
from torch.nn.attention import SDPBackend
torch.backends.cuda.sdp_kernel = SDPBackend.FLASH_ATTENTION
# Model and dataset configuration
LLM_MODEL = "meta-llama/Meta-Llama-3.1-8B"
DATASET_PATH = "../data-prep/data_files/llama31_tokenized_docs_full_dataset.parquet"
OUTPUT_DIR = "./llama3_8b_ddp_pretraining"
set_seed(42)
# Load model and tokenizer
model_name = LLM_MODEL
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
model.config.use_cache = False
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Ensure pad token is set for the tokenizer
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Custom dataset class with contiguous tensors
class CustomDataset(Dataset):
def __init__(self, dataset_name, tokenizer, split="train", max_tokens=None, max_length=512):
self.dataset = datasets.load_dataset(
"parquet",
data_files=dataset_name,
split=split
)
if max_tokens is not None:
self.dataset = self.dataset.filter(lambda x: x["num_tokens"] <= max_tokens)
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
input_ids = self.dataset[idx]["input_ids"]
if len(input_ids) > self.max_length:
input_ids = input_ids[:self.max_length]
attention_mask = [1] * len(input_ids)
padding_length = self.max_length - len(input_ids)
if padding_length > 0:
input_ids += [self.tokenizer.pad_token_id] * padding_length
attention_mask += [0] * padding_length
# Ensure tensors are contiguous
input_ids = torch.tensor(input_ids, dtype=torch.long).contiguous()
attention_mask = torch.tensor(attention_mask, dtype=torch.long).contiguous()
labels = input_ids.clone().contiguous()
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
# Initialize dataset and data collator
train_dataset = CustomDataset(
dataset_name=DATASET_PATH,
tokenizer=tokenizer,
split="train",
max_tokens=512,
max_length=512,
)
print(f"Training dataset size is: {len(train_dataset.dataset)} samples")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# Training arguments
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
optim="adamw_torch",
num_train_epochs=1,
per_device_train_batch_size=64,
gradient_accumulation_steps=8,
learning_rate=3e-5,
weight_decay=0.01,
warmup_steps=10,
lr_scheduler_type="cosine",
gradient_checkpointing=True,
dataloader_num_workers=8,
bf16=True,
logging_steps=10,
report_to="tensorboard",
save_strategy="epoch",
save_total_limit=2,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=None,
data_collator=data_collator,
)
trainer.train()
Versions
NVIDIA (PyTorch container) Release 24.12 (build 126674149)
Using CUDA 12.6 driver version 560.35.05 with kernel driver version 550.127.08
pytorch-triton 3.0.0+72734f086
torch 2.6.0a0+df5bbc09d1.nv24.12
torch-tb-profiler 0.4.3
torch_tensorrt 2.6.0a0
torchprofile 0.0.4
torchvision 0.20.0a0
transformers 4.48.0
accelerate 1.2.1
Metadata
Metadata
Assignees
Labels
module: cudnnRelated to torch.backends.cudnn, and CuDNN supportRelated to torch.backends.cudnn, and CuDNN supportmodule: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiionAll things related to torch.nn.functional.scaled_dot_product_attentiiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module