Skip to content

Conversation

baoleai
Copy link

@baoleai baoleai commented Aug 1, 2024

This PR leverages the MaybeSetDevice function (#94864) in cuda device guard setDevice function to address the issue of unnecessary context memory allocation on cuda:0 when using other devices, such as XLA:GPU. This PR resolves the issue described in pytorch/xla#6208 (comment).

Before this PR, when testing pytorch/xla with the following command:

PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc_per_node 4 xla/test/test_train_mp_imagenet.py  --fake_data --pjrt_distributed --batch_size=32 --num_epochs=1

the GPU memory usage for each process was as follows:

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A    106448      C   /usr/bin/python                           62336MiB |
|    0   N/A  N/A    106449      C   /usr/bin/python                             414MiB |
|    0   N/A  N/A    106450      C   /usr/bin/python                             414MiB |
|    0   N/A  N/A    106451      C   /usr/bin/python                             414MiB |
|    1   N/A  N/A    106449      C   /usr/bin/python                           62480MiB |
|    2   N/A  N/A    106450      C   /usr/bin/python                           62480MiB |
|    3   N/A  N/A    106451      C   /usr/bin/python                           62336MiB |
+---------------------------------------------------------------------------------------+

There are unnecessary memory allocation on GPU#0.

After this PR, the CUDA Device Guard's setDevice will no longer call cudaSetDevice if a primary context does not exist on cuda:0. As a result, the GPU memory usage is now:

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A     66928      C   /usr/bin/python                           62336MiB |
|    1   N/A  N/A     66929      C   /usr/bin/python                           62480MiB |
|    2   N/A  N/A     66930      C   /usr/bin/python                           62480MiB |
|    3   N/A  N/A     66931      C   /usr/bin/python                           62336MiB |
+---------------------------------------------------------------------------------------+

Test Environment: CUDA12 + pytorch main branch + pytorch/xla v2.3.0.

cc @Aidyn-A , @vanbasten23 , @ezyang

@baoleai baoleai requested review from eqy and syed-ahmed as code owners August 1, 2024 11:55
Copy link

pytorch-bot bot commented Aug 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/132398

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5488512 with merge base d6a82ce (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented Aug 1, 2024

CLA Signed


The committers listed above are authorized under a signed CLA.

@ezyang ezyang requested a review from ngimel August 2, 2024 12:15
@ezyang
Copy link
Contributor

ezyang commented Aug 2, 2024

This might be correct but it's actually kind of hard to tell

@Aidyn-A
Copy link
Collaborator

Aidyn-A commented Aug 2, 2024

I second with Edward, this change looks both correct and incorrect at the same time.
A corner case would be: if a user set device 1 first, but then decides to use a device guard targeting device 0. In this case, device 0 will never be set, as a consequence it will continue doing ops on device 1.

@baoleai
Copy link
Author

baoleai commented Aug 5, 2024

Thanks @ezyang @Aidyn-A. So is there a better solution for this memory issue with torch_xla? Can we add a check in set_device? Or should we not set cuda device when using torch_xla?

@baoleai
Copy link
Author

baoleai commented Aug 5, 2024

I second with Edward, this change looks both correct and incorrect at the same time. A corner case would be: if a user set device 1 first, but then decides to use a device guard targeting device 0. In this case, device 0 will never be set, as a consequence it will continue doing ops on device 1.

What kind of user code would cause the device in autograd to differ from the previously set device?

@ezyang ezyang self-requested a review August 5, 2024 20:42
@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 5, 2024
@ezyang
Copy link
Contributor

ezyang commented Aug 8, 2024

I think the proper way to go about this would be to find out where XLA is accidentally setting device to zero, and to make it stop doing that.

@baoleai
Copy link
Author

baoleai commented Aug 8, 2024

I think the proper way to go about this would be to find out where XLA is accidentally setting device to zero, and to make it stop doing that.

When using XLA:GPU, there are two registered device guard implementations: one for CUDA and one for XLA. In the above example, the CUDA device count is 4(https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/engine.cpp#L1508), which leads to set_device being called for devices numbered 0 through 3. I believe that as long as there is a registered CUDA device guard implementation and we are running XLA:GPU, the CUDA device count will be greater than 1, which will trigger set_device(0). Therefore, we either should avoid registering the CUDA device guard when using XLA:GPU (I'm not sure if this is feasible), or we should manually check in https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/engine.cpp#L844 to exclude the set_device(0) operation when the implementation type is CUDA when using XLA:GPU ?

@ezyang
Copy link
Contributor

ezyang commented Aug 9, 2024

I think it's not suppose to actually call set_device on the backward thread unless you actually queue some work up on that device.

@baoleai
Copy link
Author

baoleai commented Aug 9, 2024

I think it's not suppose to actually call set_device on the backward thread unless you actually queue some work up on that device.

Maybe I didn't fully understand your meaning. When initializing the device thread https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/engine.cpp#L1476, due to cuda device_count > 1, it leads to the call of cuda set_device(0). This should occur when initializing the device thread.

@ezyang
Copy link
Contributor

ezyang commented Aug 9, 2024

If you are able to recompile PyTorch, here is something instructive. Make set_device raise an error if device index is zero. Then, initialize XLA:GPU on only gpu 4 with TORCH_SHOW_CPP_STACKTRACES=1. Report the backtrace here.

@baoleai
Copy link
Author

baoleai commented Aug 16, 2024

If you are able to recompile PyTorch, here is something instructive. Make set_device raise an error if device index is zero. Then, initialize XLA:GPU on only gpu 4 with TORCH_SHOW_CPP_STACKTRACES=1. Report the backtrace here.

I added

if (i == 1 && device == 0) {
  throw std::runtime_error("Wrong device 0!");
}

after https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/engine.cpp#L842,
Then, I tested test.py using the following command:

TORCH_SHOW_CPP_STACKTRACES=1 CUDA_VISIBLE_DEVICES=3 PJRT_DEVICE=CUDA python test.py

test.py :

import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = torch_xla.device()
a = torch.ones(512, 512)
b = torch.nn.Linear(512, 512)
a = a.to(device)
b = b.to(device)
c = b(a).sum()
c.backward()
xm.mark_step()

Here is the gdb stack:

#0  __GI_raise (sig=sig@entry=6) at ../sysdeps/unix/sysv/linux/raise.c:50
#1  0x00007f2465948859 in __GI_abort () at abort.c:79
#2  0x00007f24649c4ee6 in ?? () from /lib/x86_64-linux-gnu/libstdc++.so.6
#3  0x00007f24649d6f8c in ?? () from /lib/x86_64-linux-gnu/libstdc++.so.6
#4  0x00007f24649d6ff7 in std::terminate() () from /lib/x86_64-linux-gnu/libstdc++.so.6
#5  0x00007f24649d7258 in __cxa_throw () from /lib/x86_64-linux-gnu/libstdc++.so.6
#6  0x00007f2457eb275c in torch::autograd::set_device(int) [clone .cold] () from /home/baole.abl/baole/github/pytorch/torch/lib/libtorch_cpu.so
#7  0x00007f245bbe9165 in torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) () from /home/baole.abl/baole/github/pytorch/torch/lib/libtorch_cpu.so
#8  0x00007f245bbe02b6 in torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) ()
   from /home/baole.abl/baole/github/pytorch/torch/lib/libtorch_cpu.so
#9  0x00007f2463e00975 in torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) ()
   from /home/baole.abl/baole/github/pytorch/torch/lib/libtorch_python.so
#10 0x00007f2464a06793 in ?? () from /lib/x86_64-linux-gnu/libstdc++.so.6
#11 0x00007f246590b609 in start_thread (arg=<optimized out>) at pthread_create.c:477
#12 0x00007f2465a45353 in clone () at ../sysdeps/unix/sysv/linux/x86_64/clone.S:95

It's important to note that no context will be initialized on GPU 0, as device 0 in engine.cc corresponds to GPU 3.

In a distributed setting with 4 GPUs, like in the test mentioned above, I suspect the issue arises because CUDA_VISIBLE_DEVICES is not set for each individual rank. This can lead to the invocation of set_device(0), which may cause excessive memory allocation on GPU 0. However, it may not be feasible to set CUDA_VISIBLE_DEVICES=LOCAL_RANK for each rank.

@ezyang
Copy link
Contributor

ezyang commented Aug 19, 2024

Err, you didn't run the right test? Like, if you want to find out why it's allocating in the distributed setting, you should have this assert and try to run the distributed thing, and THAT will tell you who is calling set_device(0). What you're seeing here is 100% expected for the reasons you described

@baoleai
Copy link
Author

baoleai commented Aug 20, 2024

Actually, the error stack in the distributed case is completely the same as the one I posted above. Currently, we can see that because the CUDA device count is 4, the thread_init will execute on GPUs 0 to 3, ultimately triggering set_device(0).

This logic seems correct when using PyTorch without XLA, so the issue likely arises from XLA. I need to investigate further.

@baoleai baoleai closed this Aug 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants