Skip to content

lax.abs crashes on unsigned ints #17958

@jakevdp

Description

@jakevdp

Repro:

import jax
jax.lax.abs(jax.numpy.uint32(1))
Traceback (most recent call last):
  File "/Users/vanderplas/github/google/jax/jax/_src/interpreters/mlir.py", line 814, in lower_jaxpr_to_module
    if not ctx.module.operation.verify():
jaxlib.mlir._mlir_libs.MLIRError: Verification failed:
error: "jit(abs)/jit(main)/abs"("/Users/vanderplas/github/google/jax/tmp.py":2:0): 'stablehlo.abs' op operand #0 must be tensor of 4/8/16/32/64-bit signless integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<ui32>'
 note: "jit(abs)/jit(main)/abs"("/Users/vanderplas/github/google/jax/tmp.py":2:0): see current operation: %0 = "stablehlo.abs"(%arg0) : (tensor<ui32>) -> tensor<ui32>

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/vanderplas/github/google/jax/tmp.py", line 2, in <module>
    jax.lax.abs(jax.numpy.uint32(1))
  File "/Users/vanderplas/github/google/jax/jax/_src/lax/lax.py", line 368, in abs
    return abs_p.bind(x)
  File "/Users/vanderplas/github/google/jax/jax/_src/core.py", line 386, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/Users/vanderplas/github/google/jax/jax/_src/core.py", line 389, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/Users/vanderplas/github/google/jax/jax/_src/core.py", line 869, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/Users/vanderplas/github/google/jax/jax/_src/dispatch.py", line 128, in apply_primitive
    compiled_fun = xla_primitive_callable(
  File "/Users/vanderplas/github/google/jax/jax/_src/util.py", line 263, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/Users/vanderplas/github/google/jax/jax/_src/util.py", line 256, in cached
    return f(*args, **kwargs)
  File "/Users/vanderplas/github/google/jax/jax/_src/dispatch.py", line 157, in xla_primitive_callable
    computation = sharded_lowering(
  File "/Users/vanderplas/github/google/jax/jax/_src/dispatch.py", line 188, in sharded_lowering
    return pxla.lower_sharding_computation(
  File "/Users/vanderplas/github/google/jax/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/Users/vanderplas/github/google/jax/jax/_src/interpreters/pxla.py", line 2049, in lower_sharding_computation
    nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
  File "/Users/vanderplas/github/google/jax/jax/_src/interpreters/pxla.py", line 1850, in _cached_lowering_to_hlo
    lowering_result = mlir.lower_jaxpr_to_module(
  File "/Users/vanderplas/github/google/jax/jax/_src/interpreters/mlir.py", line 829, in lower_jaxpr_to_module
    raise ValueError("\n".join(msg_lines)) from e
ValueError: Cannot lower jaxpr with verifier errors:
	'stablehlo.abs' op operand #0 must be tensor of 4/8/16/32/64-bit signless integer or f8E4M3B11FNUZ type or f8E4M3FN type or f8E4M3FNUZ type or f8E5M2 type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements or 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<ui32>'
		at loc("jit(abs)/jit(main)/abs"("/Users/vanderplas/github/google/jax/tmp.py":2:0))
	see current operation: %0 = "stablehlo.abs"(%arg0) : (tensor<ui32>) -> tensor<ui32>
		at loc("jit(abs)/jit(main)/abs"("/Users/vanderplas/github/google/jax/tmp.py":2:0))
Module string:
#loc = loc(unknown)
"builtin.module"() <{sym_name = "jit_abs"}> ({
  "func.func"() <{arg_attrs = [{mhlo.sharding = "{replicated}"}], function_type = (tensor<ui32>) -> tensor<ui32>, res_attrs = [{}], sym_name = "main", sym_visibility = "public"}> ({
  ^bb0(%arg0: tensor<ui32> loc(unknown)):
    %0 = "stablehlo.abs"(%arg0) : (tensor<ui32>) -> tensor<ui32> loc(#loc2)
    "func.return"(%0) : (tensor<ui32>) -> () loc(#loc)
  }) : () -> () loc(#loc)
}) {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} : () -> () loc(#loc)
#loc1 = loc("/Users/vanderplas/github/google/jax/tmp.py":2:0)
#loc2 = loc("jit(abs)/jit(main)/abs"(#loc1))

I think the fix is to avoid binding the primitive for unsigned inputs.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions