-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Open
Labels
dynamo-symbolic-analysismodule: higher order operatorstorch.cond and similartorch.cond and similarmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_while_loop_tensor_constant_idx(self):
def while_loop_decomp(x, y0):
out = torch.zeros_like(x)
def cond_fn(idx, out, y0):
return idx < out.size(0)
def body_fn(idx, out, y0):
i = idx.item()
y0 = x[i] + y0
out = out.clone()
out[i] = y0
return idx + 1, out, y0
cnt = torch.tensor(0)
_, out, _ = while_loop(cond_fn, body_fn, [cnt, out, y0])
return out
class TestModel(torch.nn.Module):
def forward(self, x, y0):
return while_loop_decomp(x, y0)
x, y0 = torch.randn(16, 8), torch.randn(8)
exp_out = TestModel()(x, y0)
ep = export(TestModel(), (x, y0))
out = ep.module()(x, y0)
self.assertEqual(exp_out, out)
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 1237, in call_function
return self._call_function(tx, args, kwargs)
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 106, in graph_break_as_hard_error
return fn(*args, **kwargs)
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 1521, in _call_function
return _call_while_loop(self, tx, args, kwargs, stack_output=False)
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 475, in _call_while_loop
) = speculate_subgraph(
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 997, in speculate_subgraph
output = f.call_function(tx, args, sub_kwargs)
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/variables/functions.py", line 598, in call_function
return super().call_function(tx, args, kwargs)
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/variables/functions.py", line 342, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/symbolic_convert.py", line 1283, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/symbolic_convert.py", line 4197, in inline_call
return tracer.inline_call_()
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/symbolic_convert.py", line 4410, in inline_call_
self.run()
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/symbolic_convert.py", line 1482, in run
while self.step():
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/symbolic_convert.py", line 1343, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/symbolic_convert.py", line 899, in wrapper
return inner_fn(self, inst)
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/symbolic_convert.py", line 2377, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/symbolic_convert.py", line 1261, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/variables/functions.py", line 598, in call_function
return super().call_function(tx, args, kwargs)
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/variables/functions.py", line 342, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/symbolic_convert.py", line 1283, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/symbolic_convert.py", line 4197, in inline_call
return tracer.inline_call_()
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/symbolic_convert.py", line 4410, in inline_call_
self.run()
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/symbolic_convert.py", line 1482, in run
while self.step():
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/symbolic_convert.py", line 1343, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/symbolic_convert.py", line 464, in impl
self.push(fn_var.call_function(self, self.popn(nargs), {}))
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/variables/builtin.py", line 1346, in call_function
return handler(tx, args, kwargs)
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/variables/builtin.py", line 1306, in _handle_insert_op_in_graph
return wrap_fx_proxy(tx, proxy)
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/variables/builder.py", line 2708, in wrap_fx_proxy
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/variables/builder.py", line 2774, in wrap_fx_proxy_cls
return _wrap_fx_proxy(
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/variables/builder.py", line 2874, in _wrap_fx_proxy
return handle_traced_output(
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/variables/builder.py", line 2886, in handle_traced_output
var = construct_tensor_variable(
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/variables/builder.py", line 3140, in construct_tensor_variable
set_example_value(proxy.node, example_value)
File "/home/lsakka/pytorch10/pytorch/torch/_dynamo/utils.py", line 2713, in set_example_value
:= torch.fx.experimental.symbolic_shapes.compute_unbacked_bindings(
File "/home/lsakka/pytorch10/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1310, in compute_unbacked_bindings
raise PendingUnbackedSymbolNotFound(
torch._dynamo.exc.InternalTorchDynamoError: PendingUnbackedSymbolNotFound: Pending unbacked symbols {u2} not in returned outputs FakeTensor(..., size=(), dtype=torch.int64) ((), 0).
Did you accidentally call new_dynamic_size() or item() more times than you needed to in your fake implementation?
For more help, see https://docs.google.com/document/d/1RWrH-3wLEpzR9kCS6gGBNen_-Fs-8PVbWWFE5AcgeWE/edit
could be related to #164341
Versions
NA
Metadata
Metadata
Assignees
Labels
dynamo-symbolic-analysismodule: higher order operatorstorch.cond and similartorch.cond and similarmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module