-
Notifications
You must be signed in to change notification settings - Fork 33
Open
Description
Hello, cross posting from ipex #647: torch-ccl does not support torch.distributed.reduce_scatter, despite the claims in the docs.
For instance, in 2.1.300+xpu we have:
torch-ccl/src/ProcessGroupCCL.cpp
Lines 871 to 877 in b9ce713
| c10::intrusive_ptr<C10D_Work> ProcessGroupCCL::reduce_scatter( | |
| std::vector<at::Tensor>& /* unused */, | |
| std::vector<std::vector<at::Tensor>>& /* unused */, | |
| const ReduceScatterOptions& /* unused */) | |
| { | |
| TORCH_CHECK(false, "ProcessGroupCCL does not support reduce_scatter"); | |
| } |
where the
TORCH_CHECK line raises a RuntimeError.
See the ipex ticket for more details and code to reproduce the error.
Metadata
Metadata
Assignees
Labels
No labels