Skip to content

Conversation

XiaobingSuper
Copy link
Collaborator

@XiaobingSuper XiaobingSuper commented May 11, 2023

Stack from ghstack (oldest at bottom):

For the dynamic shape of TIMM eca_halonext26ts model(python -m torch.backends.xeon.run_cpu --node_id 0 benchmarks/dynamo/timm_models.py --performance --float32 -dcpu --inference -n5 --inductor --dynamic-shapes --only eca_halonext26ts), there meets an error when calling aten.expand when src and the expanded size are all symbolic size:

Traceback (most recent call last):
  File "/home/xiaobing/pytorch-offical/torch/_dynamo/utils.py", line 1309, in run_node
    return node.target(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_decomp/decompositions.py", line 3270, in matmul
    tensor2_expanded = tensor2.expand(tensor2_expand_size).reshape(
  File "/home/xiaobing/pytorch-offical/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1105, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_subclasses/fake_tensor.py", line 1269, in dispatch
    return decomposition_table[func](*args, **kwargs)
  File "/home/xiaobing/pytorch-offical/torch/_refs/__init__.py", line 2799, in expand
    check(
  File "/home/xiaobing/pytorch-offical/torch/_prims_common/__init__.py", line 1648, in check
    raise exc_type(s())
RuntimeError: expand: attempting to expand a dimension of length 8*s0!

the src size is 8*s0, the expanded size is ((s0*(((s2 - 1)//16))**2 + 2*s0*(((s2 - 1)//16)) + s0)//(8*(((((s2 - 1)//16) + 1)//8))**2)).

This PR will try to fix it.

cc @ezyang

@pytorch-bot
Copy link

pytorch-bot bot commented May 11, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/101173

Note: Links to docs will display an error until the docs builds have been completed.

❌ 14 New Failures

As of commit 7b95b3f:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

XiaobingSuper added a commit that referenced this pull request May 11, 2023
…ze are all symbolic size

ghstack-source-id: 848eb52
Pull Request resolved: #101173
@XiaobingSuper
Copy link
Collaborator Author

Note: a mini reproduced code didn't create even following the steps at https://github.com/pytorch/pytorch/blob/main/docs/source/compile/troubleshooting.rst.

@XiaobingSuper XiaobingSuper changed the title inductor: fix the issue of aten.expand when the source and expaned size are all symbolic size dynamo: fix the issue of aten.expand when the source and expaned size are all symbolic size May 11, 2023
@XiaobingSuper XiaobingSuper requested review from ngimel and lezcano May 11, 2023 09:46
@@ -2791,11 +2793,12 @@ def expand(a: Tensor, *shape) -> Tensor:
for idx, x in enumerate(a.shape):
offset_idx = idx + offset
requested_length = shape[offset_idx]
requested_length = guard_int(requested_length)
x = guard_int(x)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if it is ok to add such a guard, for timm model, the root error is caused by https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py#L3266 which matmul calling expand.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it's not ok unfortunately

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, the error message looks legit: you can't expand a non-one dimension. But somehow, the dimension is 8*s0. You must figure out what caused the dimension to be this size; that is where the bug truly is.

@lezcano lezcano requested review from ezyang and removed request for lezcano May 11, 2023 10:20
@XiaobingSuper XiaobingSuper marked this pull request as draft May 11, 2023 13:26
@XiaobingSuper
Copy link
Collaborator Author

XiaobingSuper commented Jun 8, 2023

@ezyang , there has a smaller case to reproduce this issue:

import torch
import torch._dynamo
import torch.nn.functional as F

try:
    from torch import _assert
except ImportError:
    def _assert(condition: bool, message: str):
        assert condition, message

torch._dynamo.config.dynamic_shapes = True
torch._dynamo.config.assume_static_by_default = False

class HaloAttn(torch.nn.Module):
    def __init__(
            self, dim=256, qkv_bias=False):
        super().__init__()
        self.num_heads = 8
        self.dim_head_qk = 16
        self.dim_head_v = 32
        self.dim_out_qk = 128
        self.dim_out_v = 256
        self.scale = 0.25
        self.scale_pos_embed = False
        self.block_size = self.block_size_ds = 8
        self.halo_size = 2
        self.win_size = 12  # neighbourhood window size
        self.block_stride = 1

        self.q = torch.nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.block_stride, bias=qkv_bias)
        self.kv = torch.nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias)

    def forward(self, x):
        B, C, H, W = x.shape
        _assert(H % self.block_size == 0, '')
        _assert(W % self.block_size == 0, '')

        num_h_blocks = H // self.block_size
        num_w_blocks = W // self.block_size
        num_blocks = num_h_blocks * num_w_blocks

        q = self.q(x)
        # unfold
        q = q.reshape(
            -1, self.dim_head_qk,
            num_h_blocks, self.block_size_ds, num_w_blocks, self.block_size_ds).permute(0, 1, 3, 5, 2, 4)
        # B, num_heads * dim_head * block_size ** 2, num_blocks
        q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3)
        # B * num_heads, num_blocks, block_size ** 2, dim_head

        kv = self.kv(x)
        # Generate overlapping windows for kv. This approach is good for GPU and CPU. However, unfold() is not
        # lowered for PyTorch XLA so it will be very slow. See code at bottom of file for XLA friendly approach.
        # FIXME figure out how to switch impl between this and conv2d if XLA being used.
        kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size])
        kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape(
            B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1)
        k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1)
        # B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v

        return q @ k.transpose(-1, -2)

x = torch.rand(128, 256, 16, 16)

model = HaloAttn().eval()
model = model.eval()
print(model)
opt_model = torch.compile(model, dynamic=True)
with torch.no_grad():
    for i in range(3):
        out = opt_model(x)

For this case, if I comment _assert(H % self.block_size == 0, '') and _assert(W % self.block_size == 0, ''), it can works.

@XiaobingSuper
Copy link
Collaborator Author

I checked that expand input sizes are : src[8*s0, ((s2//8))**2, 16, 144], dst[((s0*s2**2)//(8*((s2//8))**2)), ((s2//8))**2, 16, 144], the expand check it is true for 8*s0 == ((s0*s2**2)//(8*((s2//8))**2)) when removing assert code, but it is false when has assert code, it seems the guard has some issue.

@ngimel
Copy link
Collaborator

ngimel commented Jun 8, 2023

Likely we'd have to create a guard that s2 % 8 == 0 and, if we don't have the necessary optimization for sympy expr, an optimization pass that'll simplify ((s0*s2**2)//(8*((s2//8))**2)) to 8 * s0. The latter part probably works if there's a guard on s2

@XiaobingSuper
Copy link
Collaborator Author

Likely we'd have to create a guard that s2 % 8 == 0 and, if we don't have the necessary optimization for sympy expr, an optimization pass that'll simplify ((s0*s2**2)//(8*((s2//8))**2)) to 8 * s0. The latter part probably works if there's a guard on s2

Where should I add a guard for s2 % 8 == 0?

@ezyang
Copy link
Contributor

ezyang commented Jun 10, 2023

@ysiraichi wondering if you could run TV on this and see if it's a sympy simplification bug

@ysiraichi
Copy link
Collaborator

I tried running it with TV, but it failed before actually the validation check.
Investigating it a bit further, I think the problem is somewhat related to a SymPy issue sympy/sympy#25146. Here's what I'm seeing:

With the asserts:

symbolic_shapes: [DEBUG] 0.0: eval Eq(((s0*s2**2)//(8*((s2//8))**2)), 8*s0) == False [statically known]

Without the asserts:

symbolic_shapes: [INFO] 0.0: eval Eq(8*s0, ((s0*s2**2)//(8*((s2//8))**2))) [guard added]

The Actual Problem

That expression goes through ShapeEnv.simplify and ends up being transformed into, since it was marked as "divisible" due to the assertion:

Eq(8*s0, ((s0*s2**2)//(8*((s2/8))**2)))

From there, SymPy kicks in and apply the following transformations:

>>> Eq(8*s0, ((s0*s2**2)//(8*((s2/8))**2)))
>>> Eq(8*s0, s0//Rational(1, 8))  # Correct.
>>> False  # Incorrect.

A Possible Solution

The easiest way I see in order to solve this problem is to special case it (Rational divisor) in FloorDiv.eval.
That said, maybe we should also check if one of the sides of the equality is real. If so, we multiply the other side by 1.0.

@ezyang
Copy link
Contributor

ezyang commented Jun 13, 2023

Hmm, failing before the validation check sucks. Maybe we should have a way to trigger validation early... we ought to be able to attempt to produce_guards at any point during compilation.

@ngimel
Copy link
Collaborator

ngimel commented Jun 13, 2023

Special-casing Rational in FloorDiv (especially when Rational numerator divides FloorDiv numerator) sgtm, I'd be wary of actually multiplying by 1.0, there are very few cases we encounter non-integer expressions, and imo we should first try staying in the integer land.

@ysiraichi
Copy link
Collaborator

@XiaobingSuper may I take over this one?

ysiraichi added a commit that referenced this pull request Aug 4, 2023
This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Aug 4, 2023
…ng error."

This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Aug 4, 2023
This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

ghstack-source-id: 5aa7711
Pull Request resolved: #106645
ysiraichi added a commit that referenced this pull request Aug 5, 2023
…tal.py. on "Handle `Rational` divisors in `FloorDiv`."

Follow-up: #101173

This PR fixes the bug presented in #101173 by creating a special case for `sympy.Rational`
divisors, inside `FloorDiv` evaluation. In summary:

```python
FloorDiv(a, Rational(1, b))
a * b
```

Besides that, this PR also does 2 other things:

- Replaces the use of the old `sympy.Mod` by the internal `Mod` (there were a few places
that were still looking for the SymPy one)

- Introduces debugging logs to the translation validator. These can be seen by setting the
environment variable: `TORCH_LOGS=+torch.fx.experimental.validator`

[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Aug 8, 2023
This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

ghstack-source-id: f0175c7
Pull Request resolved: #106645
ysiraichi added a commit that referenced this pull request Aug 9, 2023
…on on tracing error."

This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Aug 9, 2023
…ror."

This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Aug 9, 2023
…acing error."

This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Aug 9, 2023
This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Aug 9, 2023
This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

ghstack-source-id: b307a71
Pull Request resolved: #106645
@ysiraichi
Copy link
Collaborator

Solved at #106644

@ysiraichi ysiraichi closed this Aug 9, 2023
ysiraichi added a commit that referenced this pull request Aug 10, 2023
…validation on tracing error."

This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Aug 10, 2023
…acing error."

This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Aug 10, 2023
…ror."

This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Aug 10, 2023
This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

ghstack-source-id: 76f312c
Pull Request resolved: #106645
ysiraichi added a commit that referenced this pull request Aug 10, 2023
This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Aug 10, 2023
…cing error."

This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Aug 10, 2023
This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Aug 10, 2023
This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

ghstack-source-id: 2cf7caa
Pull Request resolved: #106645
ysiraichi added a commit that referenced this pull request Aug 11, 2023
…on "Run translation validation on tracing error."

This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Aug 11, 2023
…ion validation on tracing error."

This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Aug 11, 2023
… translation validation on tracing error."

This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Aug 11, 2023
…idation on tracing error."

This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Aug 12, 2023
…acing error."

This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Aug 12, 2023
This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov

[ghstack-poisoned]
ysiraichi added a commit that referenced this pull request Aug 12, 2023
This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

ghstack-source-id: 5d0116b
Pull Request resolved: #106645
pytorchmergebot pushed a commit that referenced this pull request Aug 14, 2023
This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

Pull Request resolved: #106645
Approved by: https://github.com/ezyang
Cyril-Anto pushed a commit to Cyril-Anto/pytorch that referenced this pull request Aug 17, 2023
Follow-up: pytorch#101173

This PR fixes the bug presented in pytorch#101173 by creating a special case for `sympy.Rational`
divisors, inside `FloorDiv` evaluation. In summary:

```python
FloorDiv(a, Rational(1, b))
a * b
```

Besides that, this PR also does 2 other things:

- Replaces the use of the old `sympy.Mod` by the internal `Mod` (there were a few places
that were still looking for the SymPy one)

- Introduces debugging logs to the translation validator. These can be seen by setting the
environment variable: `TORCH_LOGS=+torch.fx.experimental.validator`
Pull Request resolved: pytorch#106644
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#106643
@facebook-github-bot facebook-github-bot deleted the gh/XiaobingSuper/114/head branch September 9, 2023 14:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants