Skip to content

Conversation

abock
Copy link
Contributor

@abock abock commented Aug 25, 2023

This reworks the DORT backend factory function to support the options kwarg of torch.compile, and defines a concrete OrtBackendOptions type that can be used to influence the backend.

Caching is also implemented in order to reuse backends with equal options.

Wrapping the backend in auto_autograd also becomes an option, which allows OrtBackend to always be returned as the callable for torch.compile; wrapping happens internally if opted into (True by default).

Lastly, expose options for configuring preferred execution providers (will be attempted first), whether or not to attempt to infer an ORT EP from a torch found device in the graph or inputs, and finally the default/fallback EPs.

Demo

The following demo runs Gelu through torch.compile(backend="onnxrt") using various backend options through a dictionary form and a strongly typed form. It additionally exports the model through both the ONNX TorchScript exporter and the new TorchDynamo exporter.

import math

import onnx.inliner
import onnxruntime
import torch
import torch.onnx

torch.manual_seed(0)


class Gelu(torch.nn.Module):
    def forward(self, x):
        return x * (0.5 * torch.erf(math.sqrt(0.5) * x) + 1.0)


@torch.compile(
    backend="onnxrt",
    options={
        "preferred_execution_providers": [
            "NotARealEP",
            "CPUExecutionProvider",
        ],
        "export_options": torch.onnx.ExportOptions(dynamic_shapes=True),
    },
)
def dort_gelu(x):
    return Gelu()(x)

ort_session_options = onnxruntime.SessionOptions()
ort_session_options.log_severity_level = 0

dort_gelu2 = torch.compile(
    Gelu(),
    backend="onnxrt",
    options=torch.onnx._OrtBackendOptions(
        preferred_execution_providers=[
            "NotARealEP",
            "CPUExecutionProvider",
        ],
        export_options=torch.onnx.ExportOptions(dynamic_shapes=True),
        ort_session_options=ort_session_options,
    ),
)

x = torch.randn(10)

torch.onnx.export(Gelu(), (x,), "gelu_ts.onnx")

export_output = torch.onnx.dynamo_export(Gelu(), x)
export_output.save("gelu_dynamo.onnx")
inlined_model = onnx.inliner.inline_local_functions(export_output.model_proto)
onnx.save_model(inlined_model, "gelu_dynamo_inlined.onnx")

print("Torch Eager:")
print(Gelu()(x))

print("DORT:")
print(dort_gelu(x))
print(dort_gelu2(x))

cc @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 25, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/107973

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 23fc2a8 with merge base d4a9963 (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Aug 25, 2023
@abock abock added this to the 2.1.0 milestone Aug 25, 2023
@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 25, 2023
@abock abock added the module: onnx Related to torch.onnx label Aug 25, 2023
@abock abock force-pushed the abock/dort-torch-compile-options branch 2 times, most recently from 6e7bc42 to 555c074 Compare August 25, 2023 20:53
Copy link
Collaborator

@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG w/ comments

@justinchuby justinchuby changed the title [ONNX] Support torch.compile(backend="onnxrt", options=OrtBackendOptions(...)) [ONNX] Support torch.compile(backend="onnxrt", options=OrtBackendOptions(...)) Aug 25, 2023
@BowenBao
Copy link
Collaborator

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 26, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@BowenBao
Copy link
Collaborator

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-12-py3-arm64 / test (default, 3, 3, macos-m1-12)

Details for Dev Infra team Raised by workflow job

@abock abock force-pushed the abock/dort-torch-compile-options branch 2 times, most recently from d9afa40 to 026e677 Compare August 26, 2023 14:29
@abock abock force-pushed the abock/dort-torch-compile-options branch from 026e677 to cfeb9ee Compare August 26, 2023 14:38
@abock abock assigned abock and unassigned wschin Aug 26, 2023
…ons(...))

This reworks the DORT backend factory function to support the
options kwarg of torch.compile, and defines a concrete OrtBackendOptions
type that can be used to influence the backend.

Caching is also implemented in order to reuse backends with equal options.

Wrapping the backend in auto_autograd also becomes an option, which allows
`OrtBackend` to always be returned as the callable for torch.compile; wrapping
happens internally if opted into (True by default).
@abock abock force-pushed the abock/dort-torch-compile-options branch from cfeb9ee to 23fc2a8 Compare August 26, 2023 15:12
@abock
Copy link
Contributor Author

abock commented Aug 26, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

voznesenskym pushed a commit that referenced this pull request Aug 27, 2023
…ions(...))` (#107973)

This reworks the DORT backend factory function to support the options kwarg of torch.compile, and defines a concrete OrtBackendOptions type that can be used to influence the backend.

Caching is also implemented in order to reuse backends with equal options.

Wrapping the backend in auto_autograd also becomes an option, which allows `OrtBackend` to always be returned as the callable for torch.compile; wrapping happens internally if opted into (True by default).

Lastly, expose options for configuring preferred execution providers (will be attempted first), whether or not to attempt to infer an ORT EP from a torch found device in the graph or inputs, and finally the default/fallback EPs.

### Demo

The following demo runs `Gelu` through `torch.compile(backend="onnxrt")` using various backend options through a dictionary form and a strongly typed form. It additionally exports the model through both the ONNX TorchScript exporter and the new TorchDynamo exporter.

```python
import math

import onnx.inliner
import onnxruntime
import torch
import torch.onnx

torch.manual_seed(0)

class Gelu(torch.nn.Module):
    def forward(self, x):
        return x * (0.5 * torch.erf(math.sqrt(0.5) * x) + 1.0)

@torch.compile(
    backend="onnxrt",
    options={
        "preferred_execution_providers": [
            "NotARealEP",
            "CPUExecutionProvider",
        ],
        "export_options": torch.onnx.ExportOptions(dynamic_shapes=True),
    },
)
def dort_gelu(x):
    return Gelu()(x)

ort_session_options = onnxruntime.SessionOptions()
ort_session_options.log_severity_level = 0

dort_gelu2 = torch.compile(
    Gelu(),
    backend="onnxrt",
    options=torch.onnx._OrtBackendOptions(
        preferred_execution_providers=[
            "NotARealEP",
            "CPUExecutionProvider",
        ],
        export_options=torch.onnx.ExportOptions(dynamic_shapes=True),
        ort_session_options=ort_session_options,
    ),
)

x = torch.randn(10)

torch.onnx.export(Gelu(), (x,), "gelu_ts.onnx")

export_output = torch.onnx.dynamo_export(Gelu(), x)
export_output.save("gelu_dynamo.onnx")
inlined_model = onnx.inliner.inline_local_functions(export_output.model_proto)
onnx.save_model(inlined_model, "gelu_dynamo_inlined.onnx")

print("Torch Eager:")
print(Gelu()(x))

print("DORT:")
print(dort_gelu(x))
print(dort_gelu2(x))
```

Pull Request resolved: #107973
Approved by: https://github.com/BowenBao
@github-actions github-actions bot deleted the abock/dort-torch-compile-options branch February 25, 2025 02:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo module: onnx Related to torch.onnx open source release notes: onnx torch.onnx related changes that should show up in the release notes triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

6 participants