Skip to content

Conversation

mtthss
Copy link
Contributor

@mtthss mtthss commented Nov 1, 2019

Introducing optix: a composable gradient processing and optimization library.

It's objective is to support the composition of arbitrary sets of gradient transformations,
including sequential transformations (e.g. clip then rescale by rms), and parallel transformation (where multiple distinct optimizers share a subset of the variable to optimize).

Many popular optimizers can be implemented as one-liners, and, for convenience,
we additionally provide aliases for the most common ones.

@googlebot
Copy link
Collaborator

Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

📝 Please visit https://cla.developers.google.com/ to sign.

Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.


What to do if you already signed the CLA

Individual signers
Corporate signers

ℹ️ Googlers: Go here for more info.

@mtthss
Copy link
Contributor Author

mtthss commented Nov 1, 2019

@googlebot I signed it
(I am a google employee so I registered as such)

@googlebot
Copy link
Collaborator

CLAs look good, thanks!

ℹ️ Googlers: Go here for more info.

@googlebot googlebot added cla: yes and removed cla: no labels Nov 1, 2019
### Utilities for building and using custom optimizers. ###


def chainer(*args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this is maybe better as chain?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


def update_fn(updates, state):
f = lambda g, t: g + decay * t
update_trace = tree_multimap(f, updates, state.trace)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost every optimizer update_fn (all of them except clip_by_global_norm) performs tree_multimap over an inner per-parameter update. Is that a pattern that you can extract out like @optimizer in optimizers.py or is there a reason why that doesn't work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I purposefully diverged from optimizers.py in this aspects for two reasons:

  1. there are several gradient transformations of interest that require to consider the entire gradient and cannot be computer variable by variable. The clip_by_global_norm transformation included here is one example, but others could be PopArt, KFac, etc...
  2. it really is just a handful of characters that are saved, and at the cost of introducing an additional level of indirection and a reduction in flexibility.

return tree_multimap(lambda p, u: p + u, params, updates)


### Aliases for popular optimizers. ###
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there one or two (maybe less popular) optimizers that you can add as examples of the benefits of composability? For instance, perhaps nadam from https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ or LARS can be expressed using chainer and the existing primitives?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an example I added a noisy_sgd as from the paper: https://arxiv.org/abs/1511.06807.

Crucially though, the composable nature of optix would allow the user to build in a one liner also a noisy_adam or noisy_rmsprop, or more generally combine the idea from the paper with any optimizer of his/her choice.

As a comparison to TF or jax/experimental/optimizers.py, there the user could only add noise to the gradient before applying the adam/rmsprop rescaling, because the adam/rmsprop would immediately apply the updates and the user could not insert itself before the update without rewriting the entire optimizer.

Instead in optix the user may experiment with adding the noise before or after the adam/rmsprop rescaling (and I suspect the latter, very complicated without optix, will actually be better).

@jekbradbury jekbradbury requested a review from mattjj November 1, 2019 21:39
@jekbradbury jekbradbury merged commit e4d4e4e into jax-ml:master Nov 7, 2019
@mattjj mattjj mentioned this pull request Mar 10, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants