Skip to content

[jax2tf] - TypeError when attempting to perform operations on a bfloat16 typed array.  #3942

@russbates

Description

@russbates
def jax_fn(x):
  y = x.astype(jnp.bfloat16)
  y *= 2.0
  return y
tf_fn = jax2tf.convert(jax_fn)
tf_fn(tf.convert_to_tensor(1.0)) # TypeError: Cannot convert value dtype(bfloat16) to a TensorFlow DType.

Repro should hopefully be fairly self-explanatory: I get a TypeError when attempting to use bf16 typed jnp.arrays in jax2tf. The same function works fine if you remove the y *= 2.0 line and just return the bf16 casted y.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions