-
Notifications
You must be signed in to change notification settings - Fork 25.3k
[Pytorch] General broadcast for arithmetic operators #104718
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104718
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit 05f3818: UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D46874508 |
This pull request was exported from Phabricator. Differential Revision: D46874508 |
This pull request was exported from Phabricator. Differential Revision: D46874508 |
This pull request was exported from Phabricator. Differential Revision: D46874508 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D46874508 |
This pull request was exported from Phabricator. Differential Revision: D46874508 |
This pull request was exported from Phabricator. Differential Revision: D46874508 |
This pull request was exported from Phabricator. Differential Revision: D46874508 |
This pull request was exported from Phabricator. Differential Revision: D46874508 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D46874508 |
This pull request was exported from Phabricator. Differential Revision: D46874508 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D46874508 |
This pull request was exported from Phabricator. Differential Revision: D46874508 |
…ize arithmetic operators (pytorch#104718) Summary: Pull Request resolved: pytorch#104718 ## This diff 1. Templatizes the arithmetic operators 2. Adds general broadcasting to arithmetic operators Follow on diff: * Templatize the remaining arithmetic operators (scalar, in-place) * Rename Arithmetic.cpp --> BinaryOps.cpp ## Templatizing arithmetic ops Create template so that `add`, `sub`, `mul`, `div` can be generated from one shader. See Stephen's comment on v2. Note that there is a special case for div, where we account for the divide by 0. ## Adding general broadcasting to arithmetic ops Currently, broadcast is supported for 4D tensors where, if the batch or channel dimensions are not equal, then the batch and channel of one tensor must both be 1, ie: ``` tensorA NCHW: 5, 2, 3, 3 tensorB NCHW: 1, 1, 3, 3 --> batch=1, channel=1 ``` This diff adds broadcast support for 4D tensors where the batch and channel of a tensor are different, ie: ``` tensorA NCHW: 5, 1, 3, 3 tensorB NCHW: 1, 5, 3, 3 ``` Broadcast rules: ``` - tensorA.dim()[x] = tensorB.dim()[x] - tensorA.dim()[x] == 1 || tensorB.dim()[x] == 1 - tensorA.dim()[x] does not exist || tensorB.dim()[x] does not exist ``` Broadcast method: 1. Pass `output`, `input` and `other` tensors to the shader 2. Iterate through the output texture to calculate the value of each texel (no repeating) 3. Mapping NHW positions: use modulo 4. Mapping C position: divide pos.z by ceil(C/4) to map to original tensor range --- Also some test refactoring to reduce repeated setup code. Test Plan: New tests: Add ``` [ RUN ] VulkanAPITest.add_broadcast5 [ OK ] VulkanAPITest.add_broadcast5 (0 ms) [ RUN ] VulkanAPITest.add_broadcast6 [ OK ] VulkanAPITest.add_broadcast6 (0 ms) ``` Sub ``` [ RUN ] VulkanAPITest.sub_broadcast5 [ OK ] VulkanAPITest.sub_broadcast5 (0 ms) [ RUN ] VulkanAPITest.sub_broadcast6 [ OK ] VulkanAPITest.sub_broadcast6 (0 ms) ``` Mul ``` [ RUN ] VulkanAPITest.mul_broadcast5 [ OK ] VulkanAPITest.mul_broadcast5 (1 ms) [ RUN ] VulkanAPITest.mul_broadcast6 [ OK ] VulkanAPITest.mul_broadcast6 (1 ms) ``` Div ``` [ RUN ] VulkanAPITest.div_broadcast5 [ OK ] VulkanAPITest.div_broadcast5 (1 ms) [ RUN ] VulkanAPITest.div_broadcast6 [ OK ] VulkanAPITest.div_broadcast6 (2 ms) ``` All tests: https://www.internalfb.com/phabricator/paste/view/P781794761 ``` xplat/caffe2/aten/src/ATen/test/vulkan_api_test.cpp:6377: Skipped QueryPool is not available [ SKIPPED ] VulkanAPITest.querypool_flushed_shader_log (0 ms) [----------] 307 tests from VulkanAPITest (5576 ms total) [----------] Global test environment tear-down [==========] 307 tests from 1 test suite ran. (5576 ms total) [ PASSED ] 306 tests. [ SKIPPED ] 1 test, listed below: [ SKIPPED ] VulkanAPITest.querypool_flushed_shader_log YOU HAVE 5 DISABLED TESTS ``` Test Vulkan Delegate on OD: ``` buck2 test 'fbcode//mode/dev' fbcode//executorch/backends/vulkan/test:test_vulkan_delegate -- --exact 'executorch/backends/vulkan/test:test_vulkan_delegate - test_vulkan_backend_add (executorch.backends.vulkan.test.test_vulkan_delegate.TestBackends)' Tests finished: Pass 1. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` Run clang-format on glsl files and Arithmetic.cpp Reviewed By: SS-JIA Differential Revision: D46874508 fbshipit-source-id: e2c0f4c4525c5d567c75c2e0a50065e00de24066
This pull request was exported from Phabricator. Differential Revision: D46874508 |
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary:
Currently, broadcast is supported for 4D tensors where, if the batch or channel dimensions are not equal, then the batch and channel of one tensor must both be 1, ie:
This diff adds broadcast support for 4D tensors where the batch and channel of a tensor are different, ie:
Broadcast rules:
Broadcast method:
output
,input
andother
tensors to the shaderAlso some test refactoring to reduce repeated setup code.
Test Plan:
New tests:
Add
Sub
Mul
Div
All tests:
https://www.internalfb.com/phabricator/paste/view/P781794761
Run clang-format on glsl files and Arithmetic.cpp
Differential Revision: D46874508
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @anijain2305