Skip to content

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented May 25, 2020

Import experimental bridge from JAX to Tensorflow.

.travis.yml Outdated
fi
- if [ "$JAX_TF_BRIDGE" = true ] ;then
# jax_to_tf needs some fixes that are not in tensorflow==2.2.0
pip install tf-nightly>=2.3.0.dev20200525 ;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would pin a fixed version to avoid having the CI fail when a broken version is pushed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -0,0 +1,80 @@
# JAX to TensorFlow converter

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd avoid the name tf_bridge because that has another meaning in the TF world (namely the TF->XLA bridge).

I think jax2tf or jax_to_tf is preferable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed to jax_to_tf


# These don't have public equivalents.
# pylint: disable=g-direct-tensorflow-import
from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe these won't work in TensorFlow as built in opensource, because they aren't part of TF's public API, even though the code is part of TensorFlow. Or has this changed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These do work for me in OSS. They work even with the public 2.2.0 release.

return tf.add_n([a] + b)


def _threefry2x32(key1, key2, x1, x2):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize this isn't new in this PR most likely, but I'm a bit unhappy that we have to repeat the definition of threefry here. Note that on CPU we don't have a direct definition of the primitive; we expand it using xla.lower_fun. Could we do something similar here rather than repeating its definition?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a TODO, will do separately.

return stax.serial(FanOut(2), stax.parallel(Main, Identity), FanInSum, Relu)


def ResNet50(num_classes): # pylint: disable=invalid-name
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit sad at having multiple copies of the Stax model around, especially given we aren't really maintaining Stax in favor of other libraries.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was originally using Flax, but I think that the Flax + jax_to_tf should go into the Flax repo. As a substitute, Tom wrote this. I think it is fine as a test, for now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that we're inside JAX core now we could just from jax.examples import resnet50 and replace this all with resnet50.ResNet50(..)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@gnecula gnecula marked this pull request as ready for review May 27, 2020 16:58
@gnecula gnecula force-pushed the jax_tf branch 2 times, most recently from 6cd80fe to 0ba6063 Compare May 29, 2020 04:38
gnecula added 2 commits May 29, 2020 09:51
Renamed jax2tf.convert to jax_to_tf.
Added Travis test support.
Added OSS build configuration.
@gnecula gnecula merged commit 8e0a012 into jax-ml:master May 29, 2020
@gnecula gnecula deleted the jax_tf branch May 29, 2020 06:56
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Jun 11, 2020
* Initial import of jax2tf into JAX core

Renamed jax2tf.convert to jax_to_tf.
Added Travis test support.
Added OSS build configuration.

* Added support for squeeze
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants