-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[CUDA12] use MaybeSetDevice in cuda device guard setDevice #132398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/132398
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 5488512 with merge base d6a82ce ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
The committers listed above are authorized under a signed CLA. |
This might be correct but it's actually kind of hard to tell |
I second with Edward, this change looks both correct and incorrect at the same time. |
Thanks @ezyang @Aidyn-A. So is there a better solution for this memory issue with torch_xla? Can we add a check in set_device? Or should we not set cuda device when using torch_xla? |
What kind of user code would cause the device in autograd to differ from the previously set device? |
I think the proper way to go about this would be to find out where XLA is accidentally setting device to zero, and to make it stop doing that. |
When using XLA:GPU, there are two registered device guard implementations: one for CUDA and one for XLA. In the above example, the CUDA device count is 4(https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/engine.cpp#L1508), which leads to set_device being called for devices numbered 0 through 3. I believe that as long as there is a registered CUDA device guard implementation and we are running XLA:GPU, the CUDA device count will be greater than 1, which will trigger set_device(0). Therefore, we either should avoid registering the CUDA device guard when using XLA:GPU (I'm not sure if this is feasible), or we should manually check in https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/engine.cpp#L844 to exclude the set_device(0) operation when the implementation type is CUDA when using XLA:GPU ? |
I think it's not suppose to actually call set_device on the backward thread unless you actually queue some work up on that device. |
Maybe I didn't fully understand your meaning. When initializing the device thread https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/engine.cpp#L1476, due to cuda device_count > 1, it leads to the call of cuda set_device(0). This should occur when initializing the device thread. |
If you are able to recompile PyTorch, here is something instructive. Make set_device raise an error if device index is zero. Then, initialize XLA:GPU on only gpu 4 with TORCH_SHOW_CPP_STACKTRACES=1. Report the backtrace here. |
I added
after https://github.com/pytorch/pytorch/blob/main/torch/csrc/autograd/engine.cpp#L842,
test.py :
Here is the gdb stack:
It's important to note that no context will be initialized on GPU 0, as device 0 in engine.cc corresponds to GPU 3. In a distributed setting with 4 GPUs, like in the test mentioned above, I suspect the issue arises because CUDA_VISIBLE_DEVICES is not set for each individual rank. This can lead to the invocation of set_device(0), which may cause excessive memory allocation on GPU 0. However, it may not be feasible to set CUDA_VISIBLE_DEVICES=LOCAL_RANK for each rank. |
Err, you didn't run the right test? Like, if you want to find out why it's allocating in the distributed setting, you should have this assert and try to run the distributed thing, and THAT will tell you who is calling set_device(0). What you're seeing here is 100% expected for the reasons you described |
Actually, the error stack in the distributed case is completely the same as the one I posted above. Currently, we can see that because the CUDA device count is 4, the thread_init will execute on GPUs 0 to 3, ultimately triggering set_device(0). This logic seems correct when using PyTorch without XLA, so the issue likely arises from XLA. I need to investigate further. |
This PR leverages the
MaybeSetDevice
function (#94864) in cuda device guardsetDevice
function to address the issue of unnecessary context memory allocation on cuda:0 when using other devices, such as XLA:GPU. This PR resolves the issue described in pytorch/xla#6208 (comment).Before this PR, when testing pytorch/xla with the following command:
the GPU memory usage for each process was as follows:
There are unnecessary memory allocation on GPU#0.
After this PR, the CUDA Device Guard's
setDevice
will no longer callcudaSetDevice
if a primary context does not exist on cuda:0. As a result, the GPU memory usage is now:Test Environment: CUDA12 + pytorch main branch + pytorch/xla v2.3.0.
cc @Aidyn-A , @vanbasten23 , @ezyang