Skip to content

[Fuzzer][Eager/Compile Divergence] The size of tensor a (u0) must match the size of tensor b (18) at non-singleton dimension #164876

@bobrenjc93

Description

@bobrenjc93
import torch
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True

torch.manual_seed(1012969)

def fuzzed_program(arg_0, arg_1, sentinel):
    var_node_3 = arg_0 # size=(2, 10), stride=(10, 1), dtype=float64, device=cuda
    var_node_4 = arg_1 # size=(10, 3), stride=(3, 1), dtype=float64, device=cuda
    var_node_2 = torch.matmul(var_node_3.to(torch.float64), var_node_4.to(torch.float64)) # size=(2, 3), stride=(3, 1), dtype=float64, device=cuda
    _inp_unique_wide = torch.arange(1, device=var_node_2.device, dtype=torch.int64)
    _uniq_wide = torch.unique(_inp_unique_wide)
    var_node_1 = _uniq_wide.to(var_node_2.dtype) # size=(1,), stride=(1,), dtype=float64, device=cuda
    var_node_5 = torch.full((1, 18), 0.40330381448978797, dtype=torch.float64) # size=(1, 18), stride=(18, 1), dtype=float64, device=cuda
    var_node_0 = torch.matmul(var_node_1.to(torch.float64), var_node_5.to(torch.float64)) # size=(18,), stride=(1,), dtype=float64, device=cuda
    # Ensure gradient computation by multiplying with sentinel and taking real part
    result = var_node_0 * sentinel
    if result.is_complex():
        result = result.real
    return result

# Sentinel tensor to ensure gradient computation
sentinel = torch.tensor(1.0, requires_grad=True)

arg_0 = torch.as_strided(torch.randn(20).to(torch.float64), (2, 10), (10, 1))
arg_1 = torch.as_strided(torch.randn(30).to(torch.float64), (10, 3), (3, 1))

args = (arg_0, arg_1) + (sentinel,)
result_original = fuzzed_program(*args)
print('✅ eager success')
compiled_program = torch.compile(fuzzed_program, fullgraph=True, dynamic=True)
result_compiled = compiled_program(*args)
print('✅ compile success')
   ✅ eager success                                                                                                                                                                                                                             
   Traceback (most recent call last):                                                                                                                                                                                                           
     File "/tmp/torchfuzz/fuzz_35dd026b.py", line 32, in <module>                                                                                                                                                                               
       result_compiled = compiled_program(*args)                                                                                                                                                                                                
     File "/home/bobren/local/a/pytorch/torch/_dynamo/eval_frame.py", line 899, in compile_wrapper                                                                                                                                              
       raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1                                                                                                                                                                    
     File "/home/bobren/local/a/pytorch/torch/_dynamo/output_graph.py", line 2309, in _call_user_compiler                                                                                                                                       
       raise BackendCompilerFailed(                                                                                                                                                                                                             
     File "/home/bobren/local/a/pytorch/torch/_dynamo/output_graph.py", line 2284, in _call_user_compiler                                                                                                                                       
       compiled_fn = compiler_fn(gm, example_inputs)                                                                                                                                                                                            
     File "/home/bobren/local/a/pytorch/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__                                                                                                                                             
       compiled_gm = compiler_fn(gm, example_inputs)                                                                                                                                                                                            
     File "/home/bobren/local/a/pytorch/torch/__init__.py", line 2399, in __call__                                                                                                                                                              
       return compile_fx(model_, inputs_, config_patches=self.config)                                                                                                                                                                           
     File "/home/bobren/local/a/pytorch/torch/_inductor/compile_fx.py", line 2465, in compile_fx                                                                                                                                                
       return _maybe_wrap_and_compile_fx_main(                                                                                                                                                                                                  
     File "/home/bobren/local/a/pytorch/torch/_inductor/compile_fx.py", line 2539, in _maybe_wrap_and_compile_fx_main                                                                                                                           
       return _compile_fx_main(                                                                                                                                                                                                                 
     File "/home/bobren/local/a/pytorch/torch/_inductor/compile_fx.py", line 2726, in _compile_fx_main                                                                                                                                          
       return aot_autograd(                                                                                                                                                                                                                     
     File "/home/bobren/local/a/pytorch/torch/_dynamo/backends/common.py", line 117, in __call__                                                                                                                                                
       cg = aot_module_simplified(gm, example_inputs, **self.kwargs)                                                                                                                                                                            
     File "/home/bobren/local/a/pytorch/torch/_functorch/aot_autograd.py", line 1108, in aot_module_simplified                                                                                                                                  
       aot_graph_capture = aot_stage1_graph_capture(aot_state, functional_call)                                                                                                                                                                 
     File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 164, in aot_stage1_graph_capture                                                                                                                 
       aot_dispatch_autograd_graph(                                                                                                                                                                                                             
     File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/graph_capture.py", line 436, in aot_dispatch_autograd_graph                                                                                                              
       fx_g = _create_graph(                                                                                                                                                                                                                    
     File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/graph_capture.py", line 82, in _create_graph                                                                                                                             
       fx_g = make_fx(                                                                                                                                                                                                                          
     File "/home/bobren/local/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 2433, in wrapped                                                                                                                                           
       return make_fx_tracer.trace(f, *args)                                                                                                                                                                                                    
     File "/home/bobren/local/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 2360, in trace                                                                                                                                             
       return self._trace_inner(f, *args)                                                                                                                                                                                                       
     File "/home/bobren/local/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 2322, in _trace_inner                                                                                                                                      
       t = dispatch_trace(                                                                                                                                                                                                                      
     File "/home/bobren/local/a/pytorch/torch/_compile.py", line 54, in inner                                                                                                                                                                   
       return disable_fn(*args, **kwargs)                                                                                                                                                                                                       
     File "/home/bobren/local/a/pytorch/torch/_dynamo/eval_frame.py", line 1098, in _fn                                                                                                                                                         
       return fn(*args, **kwargs)                                                                                                                                                                                                               
     File "/home/bobren/local/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1309, in dispatch_trace                                                                                                                                    
       graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]                                                                                                                                                                      
     File "/home/bobren/local/a/pytorch/torch/_dynamo/eval_frame.py", line 1098, in _fn                                                                                                                                                         
       return fn(*args, **kwargs)                                                                                                                                                                                                               
     File "/home/bobren/local/a/pytorch/torch/fx/_symbolic_trace.py", line 868, in trace                                                                                                                                                        
       (self.create_arg(fn(*args)),),                                                                                                                                                                                                           
     File "/home/bobren/local/a/pytorch/torch/fx/_symbolic_trace.py", line 721, in flatten_fn                                                                                                                                                   
       tree_out = root_fn(*tree_args)                                                                                                                                                                                                           
     File "/home/bobren/local/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1367, in wrapped                                                                                                                                           
       out = f(*tensors)  # type:ignore[call-arg]                                                                                                                                                                                               
     File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/graph_capture.py", line 70, in inner_f                                                                                                                                   
       out, out_descs = call_and_expect_output_descs(f, args)                                                                                                                                                                                   
     File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/utils.py", line 551, in call_and_expect_output_descs                                                                                                                     
       outs_pair = fn(*args)                                                                                                                                                                                                                    
     File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1096, in inner_fn                                                                                                                       
       outs, outs_descs = call_and_expect_output_descs(fn, args)                                                                                                                                                                                
     File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/utils.py", line 551, in call_and_expect_output_descs                                                                                                                     
       outs_pair = fn(*args)                                                                                                                                                                                                                    
     File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1044, in joint_helper                                                                                                                   
       return _functionalized_f_helper(primals, tangents)                                                                                                                                                                                       
     File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 802, in _functionalized_f_helper                                                                                                        
       f_outs, f_outs_descs = call_and_expect_output_descs(fn, f_args)                                                                                                                                                                          
     File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/utils.py", line 551, in call_and_expect_output_descs                                                                                                                     
       outs_pair = fn(*args)                                                                                                                                                                                                                    
     File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 431, in inner_fn_with_anomaly                                                                                                           
       return inner_fn(primals, tangents)                                                                                                                                                                                                       
     File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 293, in inner_fn                                                                                                                        
       (outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs(                                                                                                                                                                    
     File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/utils.py", line 551, in call_and_expect_output_descs                                                                                                                     
       outs_pair = fn(*args)                                                                                                                                                                                                                    
     File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 170, in inner_fn                                                                                                                        
       outs, outs_descs = call_and_expect_output_descs(fn, args_maybe_cloned)                                                                                                                                                                   
     File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/utils.py", line 551, in call_and_expect_output_descs                                                                                                                     
       outs_pair = fn(*args)                                                                                                                                                                                                                    
     File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 135, in orig_flat_fn2                                                                                                                            
       out = orig_flat_fn(*args)                                                                                                                                                                                                                
     File "/home/bobren/local/a/pytorch/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1351, in functional_call                                                                                                                
       out = PropagateUnbackedSymInts(mod).run(                                                                                                                                                                                                 
     File "/home/bobren/local/a/pytorch/torch/fx/interpreter.py", line 174, in run                                                                                                                                                              
       self.env[node] = self.run_node(node)                                                                                                                                                                                                     
     File "/home/bobren/local/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 7896, in run_node                                                                                                                                       
       result = super().run_node(n)                                                                                                                                                                                                             
     File "/home/bobren/local/a/pytorch/torch/fx/interpreter.py", line 256, in run_node                                                                                                                                                         
       return getattr(self, n.op)(n.target, args, kwargs)                                                                                                                                                                                       
     File "/home/bobren/local/a/pytorch/torch/fx/interpreter.py", line 336, in call_function                                                                                                                                                    
       return target(*args, **kwargs)                                                                                                                                                                                                           
     File "/home/bobren/local/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1415, in __torch_function__                                                                                                                                
       return func(*args, **kwargs)                                                                                                                                                                                                             
     File "/home/bobren/local/a/pytorch/torch/_prims_common/wrappers.py", line 307, in _fn                                                                                                                                                      
       result = fn(*args, is_out=(out is not None), **kwargs)  # type: ignore[arg-type]                                                                                                                                                         
     File "/home/bobren/local/a/pytorch/torch/_decomp/decompositions.py", line 4515, in matmul                                                                                                                                                  
       return torch.squeeze(torch.mm(torch.unsqueeze(tensor1, 0), tensor2), 0)                                                                                                                                                                  
     File "/home/bobren/local/a/pytorch/torch/_subclasses/functional_tensor.py", line 512, in __torch_dispatch__                                                                                                                                
       outs_unwrapped = func._op_dk(                                                                                                                                                                                                            
     File "/home/bobren/local/a/pytorch/torch/utils/_stats.py", line 29, in wrapper                                                                                                                                                             
       return fn(*args, **kwargs)                                                                                                                                                                                                               
     File "/home/bobren/local/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1540, in __torch_dispatch__                                                                                                                                
       return proxy_call(self, func, self.pre_dispatch, args, kwargs)                                                                                                                                                                           
     File "/home/bobren/local/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 899, in proxy_call                                                                                                                                         
       r = maybe_handle_decomp(proxy_mode, func, args, kwargs)                                                                                                                                                                                  
     File "/home/bobren/local/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 2505, in maybe_handle_decomp                                                                                                                               
       out = CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)                                                                                                                                                                                   
     File "/home/bobren/local/a/pytorch/torch/_decomp/decompositions.py", line 91, in inner                                                                                                                                                     
       r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))                                                                                                                                                                 
     File "/home/bobren/local/a/pytorch/torch/_inductor/decomposition.py", line 353, in mm                                                                                                                                                      
       return self * input2                                                                                                                                                                                                                     
     File "/home/bobren/local/a/pytorch/torch/utils/_stats.py", line 29, in wrapper                                                                                                                                                             
       return fn(*args, **kwargs)                                                                                                                                                                                                               
     File "/home/bobren/local/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1540, in __torch_dispatch__                                                                                                                                
       return proxy_call(self, func, self.pre_dispatch, args, kwargs)                                                                                                                                                                           
     File "/home/bobren/local/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 1000, in proxy_call                                                                                                                                        
       out = func(*args, **kwargs)                                                                                                                                                                                                              
     File "/home/bobren/local/a/pytorch/torch/_ops.py", line 841, in __call__                                                                                                                                                                   
       return self._op(*args, **kwargs)                                                                                                                                                                                                         
     File "/home/bobren/local/a/pytorch/torch/utils/_stats.py", line 29, in wrapper                                                                                                                                                             
       return fn(*args, **kwargs)                                                                                                                                                                                                               
     File "/home/bobren/local/a/pytorch/torch/_subclasses/fake_tensor.py", line 1384, in __torch_dispatch__                                                                                                                                     
       return self.dispatch(func, types, args, kwargs)                                                                                                                                                                                          
     File "/home/bobren/local/a/pytorch/torch/_subclasses/fake_tensor.py", line 2111, in dispatch                                                                                                                                               
       return self._cached_dispatch_impl(func, types, args, kwargs)                                                                                                                                                                             
     File "/home/bobren/local/a/pytorch/torch/_subclasses/fake_tensor.py", line 1526, in _cached_dispatch_impl                                                                                                                                  
       output = self._dispatch_impl(func, types, args, kwargs)                                                                                                                                                                                  
     File "/home/bobren/local/a/pytorch/torch/_subclasses/fake_tensor.py", line 2634, in _dispatch_impl                                                                                                                                         
       return maybe_propagate_real_tensors(fast_impl(self, *args, **kwargs))                                                                                                                                                                    
     File "/home/bobren/local/a/pytorch/torch/_subclasses/fake_impls.py", line 1309, in fast_binary_impl                                                                                                                                        
       final_shape = infer_size(final_shape, shape)                                                                                                                                                                                             
     File "/home/bobren/local/a/pytorch/torch/_subclasses/fake_impls.py", line 1269, in infer_size                                                                                                                                              
       torch._check(                                                                                                                                                                                                                            
     File "/home/bobren/local/a/pytorch/torch/__init__.py", line 1702, in _check                                                                                                                                                                
       _check_with(RuntimeError, cond, message)                                                                                                                                                                                                 
     File "/home/bobren/local/a/pytorch/torch/__init__.py", line 1684, in _check_with                                                                                                                                                           
       raise error_type(message_evaluated)                                                                                                                                                                                                      
   torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:                                                                                                                                                                          
   RuntimeError: The size of tensor a (u0) must match the size of tensor b (18) at non-singleton dimension 1)                                                                                                                                   
   While executing %var_node_0 : [num_users=1] = call_function[target=torch.matmul](args = (%to_3, %to_4), kwargs = {})                                                                                                                         
   Original traceback:                                                                                                                                                                                                                          
     File "/tmp/torchfuzz/fuzz_35dd026b.py", line 15, in fuzzed_program                                                                                                                                                                         
       var_node_0 = torch.matmul(var_node_1.to(torch.float64), var_node_5.to(torch.float64)) # size=(18,), stride=(1,), dtype=float64, device=cuda                                                                                              
   Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)                                                                                                                     
   Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"  

cc @chauhang @penguinwu @ezyang @zou3519 @bdhirsh

Metadata

Metadata

Assignees

Labels

module: aotdispatchumbrella label for AOTAutograd issuesmodule: dynamic shapesmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2topic: fuzzertriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions