Skip to content

Torch 2.1 compile + FSDP (mixed precision) + LlamaForCausalLM: RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'. #111317

@KeremTurgutlu

Description

@KeremTurgutlu

🐛 Describe the bug

I am getting the following error when training LlamaForCausalLM with torch 2.1 and FSDP (mixed precision) and torch.compile. Same exact code works when torch.compile disabled or when torch 2.0.1 is used. I also tried enabling and disabling amp autocast, it doesn't matter and the same error happens.

RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'.
Please ensure that the gradient and the tensor have the same dtype

I am using a docker image, error happens in Environment 2 which is provided in the Versions section.

Error logs

  0%|          | 0/5 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 511, in <module>
Traceback (most recent call last):
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 511, in <module>
    main()
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 387, in main
    outputs = model(**batch)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    main()
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 387, in main
    outputs = model(**batch)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
Traceback (most recent call last):
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 511, in <module>
Traceback (most recent call last):
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 511, in <module>
    Traceback (most recent call last):
return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 511, in <module>
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
Traceback (most recent call last):
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 511, in <module>
    main()
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 387, in main
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
        return forward_call(*args, **kwargs)outputs = model(**batch)

  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    main()    
main()
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 387, in main
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 387, in main
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    outputs = model(**batch)
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
        outputs = model(**batch)
return fn(*args, **kwargs)    
main()  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl

  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 511, in <module>
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 387, in main
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    outputs = model(**batch)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
        return self._call_impl(*args, **kwargs)return self._call_impl(*args, **kwargs)

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    main()
  File "/workspace/workdir/models/pretraining/huggingface/llama/pretrain_fsdp_torch2.1_minimal.py", line 387, in main
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
      File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
outputs = model(**batch)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
        return forward_call(*args, **kwargs)
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return forward_call(*args, **kwargs)  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward

  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
        return fn(*args, **kwargs)    
return fn(*args, **kwargs)return self._call_impl(*args, **kwargs)

  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
        return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return fn(*args, **kwargs)    
output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return self._call_impl(*args, **kwargs)    
return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
        return forward_call(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 659, in forward
return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return self._call_impl(*args, **kwargs)return self._call_impl(*args, **kwargs)

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return model_forward(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 647, in __call__
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
        return self._call_impl(*args, **kwargs)
return fn(*args, **kwargs)  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return forward_call(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 659, in forward
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
      File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return func(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 1034, in forward
    return model_forward(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 647, in __call__
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return convert_to_fp32(self.model_forward(*args, **kwargs))
    return self._call_impl(*args, **kwargs)  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return func(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 1034, in forward
        return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 659, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
        return forward_call(*args, **kwargs)output = self._fsdp_wrapped_module(*args, **kwargs)

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 659, in forward
    return forward_call(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 659, in forward
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return model_forward(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 647, in __call__
    return forward_call(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 921, in forward
        return self._call_impl(*args, **kwargs)    return model_forward(*args, **kwargs)
return model_forward(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 647, in __call__

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 647, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 1034, in forward
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 659, in forward
    layer_outputs = decoder_layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return convert_to_fp32(self.model_forward(*args, **kwargs))
    return convert_to_fp32(self.model_forward(*args, **kwargs))  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast

  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
        return func(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 1034, in forward
return func(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 1034, in forward
    return forward_call(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 921, in forward
    return model_forward(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 647, in __call__
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 659, in forward
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
        outputs = self.model(
outputs = self.model(  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
        return convert_to_fp32(self.model_forward(*args, **kwargs))layer_outputs = decoder_layer(

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
        return model_forward(*args, **kwargs)
  File "/workspace/accelerate/src/accelerate/utils/operations.py", line 647, in __call__
return func(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 1034, in forward
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
        return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return func(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 1034, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    args, kwargs = _pre_forward(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward
    return forward_call(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 921, in forward
    unshard_fn(state, handle)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
        return forward_call(*args, **kwargs)
return forward_call(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 921, in forward
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 921, in forward
    return forward_call(*args, **kwargs)
      File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
        layer_outputs = decoder_layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
      File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard
    args, kwargs = _pre_forward(
        layer_outputs = decoder_layer(layer_outputs = decoder_layer(  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward


  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    ran_pre_unshard = handle.pre_unshard()
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    unshard_fn(state, handle)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 921, in forward
    _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
    ret = self._writeback_orig_params()  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard

        return self._call_impl(*args, **kwargs)return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
        return func(*args, **kwargs)ran_pre_unshard = handle.pre_unshard()

  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 2202, in _writeback_orig_params
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard
    return forward_call(*args, **kwargs)
  File "/workspace/transformers/src/transformers/models/llama/modeling_llama.py", line 921, in forward
    layer_outputs = decoder_layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
        return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
    layer_outputs = decoder_layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    ret = self._writeback_orig_params()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    args, kwargs = _pre_forward(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward
        return self._call_impl(*args, **kwargs)return func(*args, **kwargs)

  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 2202, in _writeback_orig_params
    args, kwargs = _pre_forward(
    args, kwargs = _pre_forward(  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward

  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward
    unshard_fn(state, handle)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
    flat_param.grad = flat_param_grad
RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'. Please ensure that the gradient and the tensor have the same dtype
Traceback (most recent call last):
        unshard_fn(state, handle)unshard_fn(state, handle)

  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
    _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
        _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
    ran_pre_unshard = handle.pre_unshard()  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard

_unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard

  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
    ran_pre_unshard = handle.pre_unshard()
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard
    ran_pre_unshard = handle.pre_unshard()
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard
    flat_param.grad = flat_param_grad
RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'. Please ensure that the gradient and the tensor have the same dtype
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 825, in forward
    args, kwargs = _pre_forward(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward
    ret = self._writeback_orig_params()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 2202, in _writeback_orig_params
    unshard_fn(state, handle)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
        ret = self._writeback_orig_params()ret = self._writeback_orig_params()

  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    args, kwargs = _pre_forward(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 429, in _pre_forward
        return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 2202, in _writeback_orig_params
return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 2202, in _writeback_orig_params
    _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard
    unshard_fn(state, handle)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 464, in _pre_forward_unshard
    ran_pre_unshard = handle.pre_unshard()
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard
    _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_runtime_utils.py", line 336, in _unshard
    ran_pre_unshard = handle.pre_unshard()
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/flat_param.py", line 1194, in pre_unshard
    flat_param.grad = flat_param_grad
RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'. Please ensure that the gradient and the tensor have the same dtype
        flat_param.grad = flat_param_grad
flat_param.grad = flat_param_grad
    ret = self._writeback_orig_params()
RuntimeError  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'. Please ensure that the gradient and the tensor have the same dtype: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'float'. Please ensure that the gradient and the tensor have the same dtype
============================================================
pretrain_fsdp_torch2.1_minimal.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2023-10-14_20:44:01
  host      : ec3b2a9a542c
  rank      : 2 (local_rank: 2)
  exitcode  : 1 (pid: 348906)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[2]:
  time      : 2023-10-14_20:44:01
  host      : ec3b2a9a542c
  rank      : 3 (local_rank: 3)
  exitcode  : 1 (pid: 348907)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[3]:
  time      : 2023-10-14_20:44:01
  host      : ec3b2a9a542c
  rank      : 4 (local_rank: 4)
  exitcode  : 1 (pid: 348908)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[4]:
  time      : 2023-10-14_20:44:01
  host      : ec3b2a9a542c
  rank      : 5 (local_rank: 5)
  exitcode  : 1 (pid: 348909)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[5]:
  time      : 2023-10-14_20:44:01
  host      : ec3b2a9a542c
  rank      : 6 (local_rank: 6)
  exitcode  : 1 (pid: 348910)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[6]:
  time      : 2023-10-14_20:44:01
  host      : ec3b2a9a542c
  rank      : 7 (local_rank: 7)
  exitcode  : 1 (pid: 348911)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-10-14_20:44:01
  host      : ec3b2a9a542c
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 348905)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

Minified repro

No response

Versions

Environment 1

PyTorch version: 2.0.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.27.4
Libc version: glibc-2.35

Python version: 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-86-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.2.128
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 535.104.12
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      48 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             224
On-line CPU(s) list:                0-223
Vendor ID:                          AuthenticAMD
Model name:                         AMD EPYC 7713 64-Core Processor
CPU family:                         25
Model:                              1
Thread(s) per core:                 1
Core(s) per socket:                 224
Socket(s):                          1
Stepping:                           1
BogoMIPS:                           3999.99
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves clzero xsaveerptr wbnoinvd arat npt nrip_save umip vaes vpclmulqdq rdpid fsrm arch_capabilities
Virtualization:                     AMD-V
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          14 MiB (224 instances)
L1i cache:                          14 MiB (224 instances)
L2 cache:                           112 MiB (224 instances)
L3 cache:                           3.5 GiB (224 instances)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-223
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.23.1
[pip3] onnx==1.14.0
[pip3] pytorch-quantization==2.1.2
[pip3] torch==2.0.1
[pip3] torch-tensorrt==2.0.0.dev0
[pip3] torchdata==0.7.0a0
[pip3] torchtext==0.16.0a0
[pip3] torchvision==0.16.0
[pip3] triton==2.0.0
[conda] Could not collect

Environment 2

Collecting environment information...
PyTorch version: 2.1.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.27.4
Libc version: glibc-2.35

Python version: 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-86-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.2.128
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 535.104.12
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      48 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             224
On-line CPU(s) list:                0-223
Vendor ID:                          AuthenticAMD
Model name:                         AMD EPYC 7713 64-Core Processor
CPU family:                         25
Model:                              1
Thread(s) per core:                 1
Core(s) per socket:                 224
Socket(s):                          1
Stepping:                           1
BogoMIPS:                           3999.99
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves clzero xsaveerptr wbnoinvd arat npt nrip_save umip vaes vpclmulqdq rdpid fsrm arch_capabilities
Virtualization:                     AMD-V
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          14 MiB (224 instances)
L1i cache:                          14 MiB (224 instances)
L2 cache:                           112 MiB (224 instances)
L3 cache:                           3.5 GiB (224 instances)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-223
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.23.1
[pip3] onnx==1.14.0
[pip3] pytorch-quantization==2.1.2
[pip3] torch==2.1.0
[pip3] torch-tensorrt==2.0.0.dev0
[pip3] torchdata==0.7.0a0
[pip3] torchtext==0.16.0a0
[pip3] torchvision==0.16.0
[pip3] triton==2.1.0+e621604
[conda] Could not collect

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @zhaojuanmao @mrshenli @rohan-varma @chauhang @penguinwu @pritamdamania87 @satgera @gqchen @aazzolini @osalpekar @jiayisuse @XilunWu @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @kiukchung @LucasLLC @tianyu-l @gchanan @kadeng

Metadata

Metadata

Assignees

Labels

module: fsdponcall: distributedAdd this issue/PR to distributed oncall triage queueoncall: pt2pt2d-triage-nov2024triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions