-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Add optix to experimental #1620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
optix is a composable gradient processing and optimization library
Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). 📝 Please visit https://cla.developers.google.com/ to sign. Once you've signed (or fixed any issues), please reply here with What to do if you already signed the CLAIndividual signers
Corporate signers
ℹ️ Googlers: Go here for more info. |
@googlebot I signed it |
CLAs look good, thanks! ℹ️ Googlers: Go here for more info. |
jax/experimental/optix.py
Outdated
### Utilities for building and using custom optimizers. ### | ||
|
||
|
||
def chainer(*args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like this is maybe better as chain
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
def update_fn(updates, state): | ||
f = lambda g, t: g + decay * t | ||
update_trace = tree_multimap(f, updates, state.trace) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost every optimizer update_fn
(all of them except clip_by_global_norm
) performs tree_multimap
over an inner per-parameter update. Is that a pattern that you can extract out like @optimizer
in optimizers.py or is there a reason why that doesn't work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I purposefully diverged from optimizers.py
in this aspects for two reasons:
- there are several gradient transformations of interest that require to consider the entire gradient and cannot be computer variable by variable. The
clip_by_global_norm
transformation included here is one example, but others could be PopArt, KFac, etc... - it really is just a handful of characters that are saved, and at the cost of introducing an additional level of indirection and a reduction in flexibility.
return tree_multimap(lambda p, u: p + u, params, updates) | ||
|
||
|
||
### Aliases for popular optimizers. ### |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there one or two (maybe less popular) optimizers that you can add as examples of the benefits of composability? For instance, perhaps nadam
from https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ or LARS can be expressed using chainer
and the existing primitives?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As an example I added a noisy_sgd
as from the paper: https://arxiv.org/abs/1511.06807.
Crucially though, the composable nature of optix
would allow the user to build in a one liner also a noisy_adam
or noisy_rmsprop
, or more generally combine the idea from the paper with any optimizer of his/her choice.
As a comparison to TF
or jax/experimental/optimizers.py
, there the user could only add noise to the gradient before applying the adam/rmsprop rescaling, because the adam/rmsprop would immediately apply the updates and the user could not insert itself before the update without rewriting the entire optimizer.
Instead in optix the user may experiment with adding the noise before or after the adam/rmsprop rescaling (and I suspect the latter, very complicated without optix, will actually be better).
Introducing
optix
: a composable gradient processing and optimization library.It's objective is to support the composition of arbitrary sets of gradient transformations,
including sequential transformations (e.g. clip then rescale by rms), and parallel transformation (where multiple distinct optimizers share a subset of the variable to optimize).
Many popular optimizers can be implemented as one-liners, and, for convenience,
we additionally provide aliases for the most common ones.