-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Repro:
import jax
jax.lax.abs(jax.numpy.uint32(1))
Traceback (most recent call last):
File "/Users/vanderplas/github/google/jax/jax/_src/interpreters/mlir.py", line 814, in lower_jaxpr_to_module
if not ctx.module.operation.verify():
jaxlib.mlir._mlir_libs.MLIRError: Verification failed:
error: "jit(abs)/jit(main)/abs"("/Users/vanderplas/github/google/jax/tmp.py":2:0): 'stablehlo.abs' op operand #0 must be tensor of 4/8/16/32/64-bit signless integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<ui32>'
note: "jit(abs)/jit(main)/abs"("/Users/vanderplas/github/google/jax/tmp.py":2:0): see current operation: %0 = "stablehlo.abs"(%arg0) : (tensor<ui32>) -> tensor<ui32>
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/vanderplas/github/google/jax/tmp.py", line 2, in <module>
jax.lax.abs(jax.numpy.uint32(1))
File "/Users/vanderplas/github/google/jax/jax/_src/lax/lax.py", line 368, in abs
return abs_p.bind(x)
File "/Users/vanderplas/github/google/jax/jax/_src/core.py", line 386, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/Users/vanderplas/github/google/jax/jax/_src/core.py", line 389, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/Users/vanderplas/github/google/jax/jax/_src/core.py", line 869, in process_primitive
return primitive.impl(*tracers, **params)
File "/Users/vanderplas/github/google/jax/jax/_src/dispatch.py", line 128, in apply_primitive
compiled_fun = xla_primitive_callable(
File "/Users/vanderplas/github/google/jax/jax/_src/util.py", line 263, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/Users/vanderplas/github/google/jax/jax/_src/util.py", line 256, in cached
return f(*args, **kwargs)
File "/Users/vanderplas/github/google/jax/jax/_src/dispatch.py", line 157, in xla_primitive_callable
computation = sharded_lowering(
File "/Users/vanderplas/github/google/jax/jax/_src/dispatch.py", line 188, in sharded_lowering
return pxla.lower_sharding_computation(
File "/Users/vanderplas/github/google/jax/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/Users/vanderplas/github/google/jax/jax/_src/interpreters/pxla.py", line 2049, in lower_sharding_computation
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
File "/Users/vanderplas/github/google/jax/jax/_src/interpreters/pxla.py", line 1850, in _cached_lowering_to_hlo
lowering_result = mlir.lower_jaxpr_to_module(
File "/Users/vanderplas/github/google/jax/jax/_src/interpreters/mlir.py", line 829, in lower_jaxpr_to_module
raise ValueError("\n".join(msg_lines)) from e
ValueError: Cannot lower jaxpr with verifier errors:
'stablehlo.abs' op operand #0 must be tensor of 4/8/16/32/64-bit signless integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<ui32>'
at loc("jit(abs)/jit(main)/abs"("/Users/vanderplas/github/google/jax/tmp.py":2:0))
see current operation: %0 = "stablehlo.abs"(%arg0) : (tensor<ui32>) -> tensor<ui32>
at loc("jit(abs)/jit(main)/abs"("/Users/vanderplas/github/google/jax/tmp.py":2:0))
Module string:
#loc = loc(unknown)
"builtin.module"() <{sym_name = "jit_abs"}> ({
"func.func"() <{arg_attrs = [{mhlo.sharding = "{replicated}"}], function_type = (tensor<ui32>) -> tensor<ui32>, res_attrs = [{}], sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<ui32> loc(unknown)):
%0 = "stablehlo.abs"(%arg0) : (tensor<ui32>) -> tensor<ui32> loc(#loc2)
"func.return"(%0) : (tensor<ui32>) -> () loc(#loc)
}) : () -> () loc(#loc)
}) {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} : () -> () loc(#loc)
#loc1 = loc("/Users/vanderplas/github/google/jax/tmp.py":2:0)
#loc2 = loc("jit(abs)/jit(main)/abs"(#loc1))
I think the fix is to avoid binding the primitive for unsigned inputs.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working