-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Initial import of jax2tf into JAX core #3202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
.travis.yml
Outdated
fi | ||
- if [ "$JAX_TF_BRIDGE" = true ] ;then | ||
# jax_to_tf needs some fixes that are not in tensorflow==2.2.0 | ||
pip install tf-nightly>=2.3.0.dev20200525 ; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would pin a fixed version to avoid having the CI fail when a broken version is pushed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -0,0 +1,80 @@ | |||
# JAX to TensorFlow converter | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd avoid the name tf_bridge
because that has another meaning in the TF world (namely the TF->XLA bridge).
I think jax2tf
or jax_to_tf
is preferable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have changed to jax_to_tf
|
||
# These don't have public equivalents. | ||
# pylint: disable=g-direct-tensorflow-import | ||
from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe these won't work in TensorFlow as built in opensource, because they aren't part of TF's public API, even though the code is part of TensorFlow. Or has this changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These do work for me in OSS. They work even with the public 2.2.0 release.
return tf.add_n([a] + b) | ||
|
||
|
||
def _threefry2x32(key1, key2, x1, x2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realize this isn't new in this PR most likely, but I'm a bit unhappy that we have to repeat the definition of threefry here. Note that on CPU we don't have a direct definition of the primitive; we expand it using xla.lower_fun
. Could we do something similar here rather than repeating its definition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a TODO, will do separately.
return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu) | ||
|
||
|
||
def ResNet50(num_classes): # pylint: disable=invalid-name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit sad at having multiple copies of the Stax model around, especially given we aren't really maintaining Stax in favor of other libraries.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was originally using Flax, but I think that the Flax + jax_to_tf should go into the Flax repo. As a substitute, Tom wrote this. I think it is fine as a test, for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that we're inside JAX core now we could just from jax.examples import resnet50
and replace this all with resnet50.ResNet50(..)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
6cd80fe
to
0ba6063
Compare
Renamed jax2tf.convert to jax_to_tf. Added Travis test support. Added OSS build configuration.
* Initial import of jax2tf into JAX core Renamed jax2tf.convert to jax_to_tf. Added Travis test support. Added OSS build configuration. * Added support for squeeze
Import experimental bridge from JAX to Tensorflow.