You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def jax_fn(x):
y = x.astype(jnp.bfloat16)
y *= 2.0
return y
tf_fn = jax2tf.convert(jax_fn)
tf_fn(tf.convert_to_tensor(1.0)) # TypeError: Cannot convert value dtype(bfloat16) to a TensorFlow DType.
Repro should hopefully be fairly self-explanatory: I get a TypeError when attempting to use bf16 typed jnp.arrays in jax2tf. The same function works fine if you remove the y *= 2.0 line and just return the bf16 casted y.