Skip to content

Deadlock attempting to do concurrent send, receive #72

@pspillai

Description

@pspillai

I am trying to implement a concurrent asynchronous send and receive between multiple processes. This results in deadlock. Minimum code to reproduce this is as follows:

import torch.nn.parallel
import torch.distributed as dist
import intel_extension_for_pytorch as ipex
import oneccl_bindings_for_pytorch
import os

os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
os.environ['RANK'] = str(os.environ.get('PMI_RANK', 0))
os.environ['WORLD_SIZE'] = str(os.environ.get('PMI_SIZE', 1))

print (os.environ['RANK'], os.environ['WORLD_SIZE'])
backend = 'ccl'
dist.init_process_group(backend)
my_rank = dist.get_rank()
my_size = dist.get_world_size()
print("my rank = %d  my size = %d" % (my_rank, my_size))

dev = f"xpu:{my_rank}"
torch.xpu.set_device(my_rank)
A = torch.ones(1,2, dtype=torch.float32).to(dev)
_ = A[0,0].item()
B = torch.zeros(1,2, dtype=torch.float32).to(dev)
_ = B[0,0].item()

dist.barrier()

dist.all_reduce(A)

print ("START")
o1 = dist.isend(A,1-my_rank)
o2 = dist.irecv(B,1-my_rank)
o1.wait()
o2.wait()

print ("DONE")

Run with

mpirun -n 2 python -u test.py

This sounds like the isend and irecv on each process is serialized. This particular example can complete if one process does send first and the other recv first, but I think they are still being serialized, so the two transfers are not concurrent.

I tried to use batch_isend_irecv to define a list of transfers, but this resulted in the same deadlock.
Without concurrent transfers, it is almost impossible to implement efficient distributed compute and shift algorithms or Cannon's algorithms, etc.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions