-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Description
🐛 Describe the bug
I am getting the following error when training LlamaForCausalLM with torch 2.1 and FSDP (mixed precision) and torch.compile. Same exact code works when torch.compile disabled or when torch 2.0.1 is used. I also tried enabling and disabling amp autocast, it doesn't matter and the same error happens.
RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'.
Please ensure that the gradient and the tensor have the same dtype
I am using a docker image, error happens in Environment 2 which is provided in the Versions section.
Error logs
0%| | 0/5 [00:00<?, ?it/s]Traceback (most recent call last):
File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 511, in <module>
Traceback (most recent call last):
File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 511, in <module>
main()
File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 387, in main
outputs = model(**batch)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
main()
File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 387, in main
outputs = model(**batch)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
Traceback (most recent call last):
File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 511, in <module>
Traceback (most recent call last):
File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 511, in <module>
Traceback (most recent call last):
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 511, in <module>
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
Traceback (most recent call last):
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 511, in <module>
main()
File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 387, in main
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return forward_call(*args, **kwargs)outputs = model(**batch)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
main()
main()
File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 387, in main
File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 387, in main
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
outputs = model(**batch)
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
outputs = model(**batch)
return fn(*args, **kwargs)
main() File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 511, in <module>
File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 387, in main
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
outputs = model(**batch)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return self._call_impl(*args, **kwargs)return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
main()
File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 387, in main
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
outputs = model(**batch)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
return forward_call(*args, **kwargs)
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return fn(*args, **kwargs)
return fn(*args, **kwargs)return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return fn(*args, **kwargs)
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return self._call_impl(*args, **kwargs)
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
return forward_call(*args, **kwargs)
File "/workspace/accelerate/src/accelerate/utils/operations.py", line 659, in forward
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return self._call_impl(*args, **kwargs)return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return model_forward(*args, **kwargs)
File "/workspace/accelerate/src/accelerate/utils/operations.py", line 647, in __call__
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
return self._call_impl(*args, **kwargs)
return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return forward_call(*args, **kwargs)
File "/workspace/accelerate/src/accelerate/utils/operations.py", line 659, in forward
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return func(*args, **kwargs)
File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 1034, in forward
return model_forward(*args, **kwargs)
File "/workspace/accelerate/src/accelerate/utils/operations.py", line 647, in __call__
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
outputs = self.model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return convert_to_fp32(self.model_forward(*args, **kwargs))
return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return func(*args, **kwargs)
File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 1034, in forward
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/accelerate/src/accelerate/utils/operations.py", line 659, in forward
outputs = self.model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return forward_call(*args, **kwargs)output = self._fsdp_wrapped_module(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
File "/workspace/accelerate/src/accelerate/utils/operations.py", line 659, in forward
return forward_call(*args, **kwargs)
File "/workspace/accelerate/src/accelerate/utils/operations.py", line 659, in forward
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return model_forward(*args, **kwargs)
File "/workspace/accelerate/src/accelerate/utils/operations.py", line 647, in __call__
return forward_call(*args, **kwargs)
File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 921, in forward
return self._call_impl(*args, **kwargs) return model_forward(*args, **kwargs)
return model_forward(*args, **kwargs)
File "/workspace/accelerate/src/accelerate/utils/operations.py", line 647, in __call__
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
File "/workspace/accelerate/src/accelerate/utils/operations.py", line 647, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return func(*args, **kwargs)
File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 1034, in forward
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/accelerate/src/accelerate/utils/operations.py", line 659, in forward
layer_outputs = decoder_layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return convert_to_fp32(self.model_forward(*args, **kwargs))
return convert_to_fp32(self.model_forward(*args, **kwargs)) File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return func(*args, **kwargs)
File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 1034, in forward
return func(*args, **kwargs)
File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 1034, in forward
return forward_call(*args, **kwargs)
File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 921, in forward
return model_forward(*args, **kwargs)
File "/workspace/accelerate/src/accelerate/utils/operations.py", line 647, in __call__
outputs = self.model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return forward_call(*args, **kwargs)
File "/workspace/accelerate/src/accelerate/utils/operations.py", line 659, in forward
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
outputs = self.model(
outputs = self.model( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return convert_to_fp32(self.model_forward(*args, **kwargs))layer_outputs = decoder_layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return model_forward(*args, **kwargs)
File "/workspace/accelerate/src/accelerate/utils/operations.py", line 647, in __call__
return func(*args, **kwargs)
File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 1034, in forward
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return func(*args, **kwargs)
File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 1034, in forward
outputs = self.model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
args, kwargs = _pre_forward(
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward
return forward_call(*args, **kwargs)
File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 921, in forward
unshard_fn(state, handle)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
return forward_call(*args, **kwargs)
return forward_call(*args, **kwargs)
File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 921, in forward
File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 921, in forward
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
outputs = self.model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
layer_outputs = decoder_layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard
args, kwargs = _pre_forward(
layer_outputs = decoder_layer(layer_outputs = decoder_layer( File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
ran_pre_unshard = handle.pre_unshard()
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
unshard_fn(state, handle)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 921, in forward
_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
ret = self._writeback_orig_params() File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard
return self._call_impl(*args, **kwargs)return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return func(*args, **kwargs)ran_pre_unshard = handle.pre_unshard()
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 2202, in _writeback_orig_params
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard
return forward_call(*args, **kwargs)
File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 921, in forward
layer_outputs = decoder_layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
layer_outputs = decoder_layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
ret = self._writeback_orig_params()
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
args, kwargs = _pre_forward(
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward
return self._call_impl(*args, **kwargs)return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 2202, in _writeback_orig_params
args, kwargs = _pre_forward(
args, kwargs = _pre_forward( File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward
unshard_fn(state, handle)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
flat_param.grad = flat_param_grad
RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'. Please ensure that the gradient and the tensor have the same dtype
Traceback (most recent call last):
unshard_fn(state, handle)unshard_fn(state, handle)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
ran_pre_unshard = handle.pre_unshard() File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard
_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream) File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
ran_pre_unshard = handle.pre_unshard()
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard
ran_pre_unshard = handle.pre_unshard()
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard
flat_param.grad = flat_param_grad
RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'. Please ensure that the gradient and the tensor have the same dtype
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
args, kwargs = _pre_forward(
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward
ret = self._writeback_orig_params()
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 2202, in _writeback_orig_params
unshard_fn(state, handle)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
ret = self._writeback_orig_params()ret = self._writeback_orig_params()
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
args, kwargs = _pre_forward(
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 2202, in _writeback_orig_params
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 2202, in _writeback_orig_params
_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard
unshard_fn(state, handle)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
ran_pre_unshard = handle.pre_unshard()
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard
_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard
ran_pre_unshard = handle.pre_unshard()
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard
flat_param.grad = flat_param_grad
RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'. Please ensure that the gradient and the tensor have the same dtype
flat_param.grad = flat_param_grad
flat_param.grad = flat_param_grad
ret = self._writeback_orig_params()
RuntimeError File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'. Please ensure that the gradient and the tensor have the same dtype: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'. Please ensure that the gradient and the tensor have the same dtype
============================================================
pretrain_fsdp_torch2.1_minimal.py FAILED
------------------------------------------------------------
Failures:
[1]:
time : 2023-10-14_20:44:01
host : ec3b2a9a542c
rank : 2 (local_rank: 2)
exitcode : 1 (pid: 348906)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[2]:
time : 2023-10-14_20:44:01
host : ec3b2a9a542c
rank : 3 (local_rank: 3)
exitcode : 1 (pid: 348907)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[3]:
time : 2023-10-14_20:44:01
host : ec3b2a9a542c
rank : 4 (local_rank: 4)
exitcode : 1 (pid: 348908)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[4]:
time : 2023-10-14_20:44:01
host : ec3b2a9a542c
rank : 5 (local_rank: 5)
exitcode : 1 (pid: 348909)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[5]:
time : 2023-10-14_20:44:01
host : ec3b2a9a542c
rank : 6 (local_rank: 6)
exitcode : 1 (pid: 348910)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[6]:
time : 2023-10-14_20:44:01
host : ec3b2a9a542c
rank : 7 (local_rank: 7)
exitcode : 1 (pid: 348911)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2023-10-14_20:44:01
host : ec3b2a9a542c
rank : 1 (local_rank: 1)
exitcode : 1 (pid: 348905)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
Minified repro
No response
Versions
Environment 1
PyTorch version: 2.0.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.27.4
Libc version: glibc-2.35
Python version: 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-86-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.2.128
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB
Nvidia driver version: 535.104.12
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 224
On-line CPU(s) list: 0-223
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7713 64-Core Processor
CPU family: 25
Model: 1
Thread(s) per core: 1
Core(s) per socket: 224
Socket(s): 1
Stepping: 1
BogoMIPS: 3999.99
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves clzero xsaveerptr wbnoinvd arat npt nrip_save umip vaes vpclmulqdq rdpid fsrm arch_capabilities
Virtualization: AMD-V
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 14 MiB (224 instances)
L1i cache: 14 MiB (224 instances)
L2 cache: 112 MiB (224 instances)
L3 cache: 3.5 GiB (224 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-223
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==1.23.1
[pip3] onnx==1.14.0
[pip3] pytorch-quantization==2.1.2
[pip3] torch==2.0.1
[pip3] torch-tensorrt==2.0.0.dev0
[pip3] torchdata==0.7.0a0
[pip3] torchtext==0.16.0a0
[pip3] torchvision==0.16.0
[pip3] triton==2.0.0
[conda] Could not collect
Environment 2
Collecting environment information...
PyTorch version: 2.1.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.27.4
Libc version: glibc-2.35
Python version: 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-86-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.2.128
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB
Nvidia driver version: 535.104.12
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 224
On-line CPU(s) list: 0-223
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7713 64-Core Processor
CPU family: 25
Model: 1
Thread(s) per core: 1
Core(s) per socket: 224
Socket(s): 1
Stepping: 1
BogoMIPS: 3999.99
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves clzero xsaveerptr wbnoinvd arat npt nrip_save umip vaes vpclmulqdq rdpid fsrm arch_capabilities
Virtualization: AMD-V
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 14 MiB (224 instances)
L1i cache: 14 MiB (224 instances)
L2 cache: 112 MiB (224 instances)
L3 cache: 3.5 GiB (224 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-223
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==1.23.1
[pip3] onnx==1.14.0
[pip3] pytorch-quantization==2.1.2
[pip3] torch==2.1.0
[pip3] torch-tensorrt==2.0.0.dev0
[pip3] torchdata==0.7.0a0
[pip3] torchtext==0.16.0a0
[pip3] torchvision==0.16.0
[pip3] triton==2.1.0+e621604
[conda] Could not collect
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @zhaojuanmao @mrshenli @rohan-varma @chauhang @penguinwu @pritamdamania87 @satgera @gqchen @aazzolini @osalpekar @jiayisuse @XilunWu @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @kiukchung @LucasLLC @tianyu-l @gchanan @kadeng