Skip to content

Conversation

davidberard98
Copy link
Contributor

@davidberard98 davidberard98 commented Mar 10, 2023

Stack from ghstack:

Summary: profiler.record_function inserts an event into the chrome trace generated by the pytorch profiler. This PR adds record_function everywhere that @dynamo_timed is annotated.

dynamo_timed and the CLI viewer torch._dynamo.utils.compile_times() are already useful on their own; but for identifying when these get called, it's nice to be able to view in the profiler chrome trace.

Why not just turn on python stack traces in the profiler to get this information? Dynamo compilation is implemented in python and therefore produces a huge amount of events when it records compilation steps. The resulting trace files are often too large to load in chrome://tracing, and they take a long time to generate. Additionally, the stack traces are deep enough that they are often hard to read. This approach produces much more readable traces with lower overhead.

Tests:

Example:

Run this:

import torch

def gn(x):
    return x.sin().cos()

def fn(x, y):
    return x.sin() * y.cos()

x, y = [torch.rand((2, 2), device='cuda') for _ in range(2)]

# just to clear out any lazy initialization
with torch.profiler.profile() as prof:
    torch.compile(gn)(x)

with torch.profiler.profile() as prof:
    torch.compile(fn)(x, y)

prof.export_chrome_trace("./dynamo_timed_profile.json")

and we can see that the resulting trace shows important dynamo steps, even when python tracing is turned off.

Screenshot 2023-03-29 at 7 26 15 PM

cc @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

profiler.record_function inserts an event into the chrome trace generated by the pytorch profiler. This PR adds record_function everywhere that @dynamo_timed is annotated.

dynamo_timed and the CLI viewer torch._dynamo.utils.compile_times() already useful on their own; but for identifying _when_ these get called, it's nice to be able to view in the profiler chrome trace.

TODO:
* add tests
* add screenshots
* run benchmarks to make sure this doesn't slow anything down.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 10, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/96495

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 93c101c:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

davidberard98 added a commit that referenced this pull request Mar 10, 2023
profiler.record_function inserts an event into the chrome trace generated by the pytorch profiler. This PR adds record_function everywhere that dynamo_timed is annotated.

dynamo_timed and the CLI viewer torch._dynamo.utils.compile_times() already useful on their own; but for identifying _when_ these get called, it's nice to be able to view in the profiler chrome trace.

TODO:
* add tests
* add screenshots
* run benchmarks to make sure this doesn't slow anything down.

ghstack-source-id: 793ab8a
Pull Request resolved: #96495
@davidberard98
Copy link
Contributor Author

@pytorchbot rebase -s

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

… functions"

profiler.record_function inserts an event into the chrome trace generated by the pytorch profiler. This PR adds record_function everywhere that dynamo_timed is annotated.

dynamo_timed and the CLI viewer torch._dynamo.utils.compile_times() already useful on their own; but for identifying _when_ these get called, it's nice to be able to view in the profiler chrome trace.

TODO:
* add tests
* add screenshots
* run benchmarks to make sure this doesn't slow anything down.

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/davidberard98/175/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/96495)

… functions"

profiler.record_function inserts an event into the chrome trace generated by the pytorch profiler. This PR adds record_function everywhere that dynamo_timed is annotated.

dynamo_timed and the CLI viewer torch._dynamo.utils.compile_times() already useful on their own; but for identifying _when_ these get called, it's nice to be able to view in the profiler chrome trace.

TODO:
* add tests
* add screenshots
* run benchmarks to make sure this doesn't slow anything down.

cc soumith voznesenskym penguinwu anijain2305 @EikanWang jgong5 @Guobing-Chen @XiaobingSuper zhuhaozhe blzheng @Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Mar 28, 2023
profiler.record_function inserts an event into the chrome trace generated by the pytorch profiler. This PR adds record_function everywhere that dynamo_timed is annotated.

dynamo_timed and the CLI viewer torch._dynamo.utils.compile_times() already useful on their own; but for identifying _when_ these get called, it's nice to be able to view in the profiler chrome trace.

TODO:
* add tests
* add screenshots
* run benchmarks to make sure this doesn't slow anything down.

ghstack-source-id: 7174f15
Pull Request resolved: #96495
…d_function on all dynamo_timed functions"

profiler.record_function inserts an event into the chrome trace generated by the pytorch profiler. This PR adds record_function everywhere that dynamo_timed is annotated.

dynamo_timed and the CLI viewer torch._dynamo.utils.compile_times() already useful on their own; but for identifying _when_ these get called, it's nice to be able to view in the profiler chrome trace.

TODO:
* add tests
* add screenshots
* run benchmarks to make sure this doesn't slow anything down.

cc soumith voznesenskym penguinwu anijain2305 @EikanWang jgong5 @Guobing-Chen @XiaobingSuper zhuhaozhe blzheng @Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
…nction on all dynamo_timed functions"


**Summary**: profiler.record_function inserts an event into the chrome trace generated by the pytorch profiler. This PR adds record_function everywhere that dynamo_timed is annotated.

dynamo_timed and the CLI viewer torch._dynamo.utils.compile_times() already useful on their own; but for identifying _when_ these get called, it's nice to be able to view in the profiler chrome trace.

**Tests**:
- Added in test/dynamo/test_profiler.py
- Performance run with `ciflow/inductor-perf-compare` shows no noticeable change in compilation time or speedup numbers. Geomean speedup changes from 1.275 -> 1.277. Geomean compilation times change from 54.2s -> 53.8s. That's likely just due to noise. All individual benchmark numbers regressed by no more than 5% between the two runs; and we see improvements of around the same magnitude, suggesting this is, again, just noise. For meta employees, you can see the results in a google sheets here: https://docs.google.com/spreadsheets/d/1Ki69XvcgxcA3ZnqC5n_jav5KiD4u7Wojlad3VTnIdlk/edit?usp=sharing

**Example**:


cc soumith voznesenskym penguinwu anijain2305 @EikanWang jgong5 @Guobing-Chen @XiaobingSuper zhuhaozhe blzheng @Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
…nction on all dynamo_timed functions"


**Summary**: profiler.record_function inserts an event into the chrome trace generated by the pytorch profiler. This PR adds record_function everywhere that dynamo_timed is annotated.

dynamo_timed and the CLI viewer torch._dynamo.utils.compile_times() already useful on their own; but for identifying _when_ these get called, it's nice to be able to view in the profiler chrome trace.

**Tests**:
- Added in test/dynamo/test_profiler.py
- Performance run with `ciflow/inductor-perf-compare` shows no noticeable change in compilation time or speedup numbers. Geomean speedup changes from 1.275 -> 1.277. Geomean compilation times change from 54.2s -> 53.8s. That's likely just due to noise. All individual benchmark numbers regressed by no more than 5% between the two runs; and we see improvements of around the same magnitude, suggesting this is, again, just noise. For meta employees, you can see the results in a google sheets here: https://docs.google.com/spreadsheets/d/1Ki69XvcgxcA3ZnqC5n_jav5KiD4u7Wojlad3VTnIdlk/edit?usp=sharing

**Example**:


cc soumith voznesenskym penguinwu anijain2305 @EikanWang jgong5 @Guobing-Chen @XiaobingSuper zhuhaozhe blzheng @Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Mar 30, 2023
profiler.record_function inserts an event into the chrome trace generated by the pytorch profiler. This PR adds record_function everywhere that dynamo_timed is annotated.

dynamo_timed and the CLI viewer torch._dynamo.utils.compile_times() already useful on their own; but for identifying _when_ these get called, it's nice to be able to view in the profiler chrome trace.

TODO:
* add tests
* add screenshots
* run benchmarks to make sure this doesn't slow anything down.

ghstack-source-id: 77881f5
Pull Request resolved: #96495
…unctions"


**Summary**: profiler.record_function inserts an event into the chrome trace generated by the pytorch profiler. This PR adds record_function everywhere that dynamo_timed is annotated.

dynamo_timed and the CLI viewer torch._dynamo.utils.compile_times() already useful on their own; but for identifying _when_ these get called, it's nice to be able to view in the profiler chrome trace.

Why not just turn on python stack traces in the profiler to get this information? Dynamo compilation is implemented in python and therefore produces a huge amount of events when it records compilation steps. The resulting trace files are often too large to load in chrome://tracing, and take a long time to generate. Additionally, the stack traces are deep enough that they are often hard to read.

**Tests**:
- Added in test/dynamo/test_profiler.py. Verified in https://github.com/pytorch/pytorch/actions/runs/4559322864/jobs/8043307798?pr=96495 that the tests are actually running.
- Performance run with `ciflow/inductor-perf-compare` shows no noticeable change in compilation time or speedup numbers. Geomean speedup changes from 1.275 -> 1.277. Geomean compilation times change from 54.2s -> 53.8s. That's likely just due to noise. All individual benchmark numbers regressed by no more than 5% between the two runs; and we see improvements of around the same magnitude, suggesting this is, again, just noise. For meta employees, you can see the results in a google sheets here: https://docs.google.com/spreadsheets/d/1Ki69XvcgxcA3ZnqC5n_jav5KiD4u7Wojlad3VTnIdlk/edit?usp=sharing

**Example**:

Run this:

```python
import torch

def gn(x):
    return x.sin().cos()

def fn(x, y):
    return x.sin() * y.cos()

x, y = [torch.rand((2, 2), device='cuda') for _ in range(2)]

# just to clear out any lazy initialization
with torch.profiler.profile() as prof:
    torch.compile(gn)(x)

with torch.profiler.profile() as prof:
    torch.compile(fn)(x, y)

prof.export_chrome_trace("./dynamo_timed_profile.json")
```

and we can see that the resulting trace shows important dynamo steps, even when python tracing is turned off.

<img width="867" alt="Screenshot 2023-03-29 at 7 26 15 PM" src="https://user-images.githubusercontent.com/5067123/228712263-8ae67ab9-1a52-4765-a9c2-7c5cf0abe2f5.png">

cc soumith voznesenskym penguinwu anijain2305 @EikanWang jgong5 @Guobing-Chen @XiaobingSuper zhuhaozhe blzheng @Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Mar 30, 2023
profiler.record_function inserts an event into the chrome trace generated by the pytorch profiler. This PR adds record_function everywhere that dynamo_timed is annotated.

dynamo_timed and the CLI viewer torch._dynamo.utils.compile_times() already useful on their own; but for identifying _when_ these get called, it's nice to be able to view in the profiler chrome trace.

TODO:
* add tests
* add screenshots
* run benchmarks to make sure this doesn't slow anything down.

ghstack-source-id: 7743051
Pull Request resolved: #96495
@davidberard98 davidberard98 added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 30, 2023
@davidberard98 davidberard98 marked this pull request as ready for review March 30, 2023 17:27
@davidberard98 davidberard98 requested a review from ngimel March 30, 2023 17:30
@davidberard98 davidberard98 changed the title [WIP][dynamo] profiler.record_function on all dynamo_timed functions [dynamo] profiler.record_function on all dynamo_timed functions Mar 30, 2023
@davidberard98
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@davidberard98 davidberard98 added the release notes: profiler release notes category label Mar 30, 2023
@davidberard98
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/davidberard98/175/head branch June 8, 2023 16:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo release notes: profiler release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants