Skip to content

[Proposal] Consistent argnums and argnames parameters for transformations #10614

@JeppeKlitgaard

Description

@JeppeKlitgaard

Hey JAX team,

I have been trying to wrap my head around 'argument annotation` in JAX for a bit in the hopes of finding a more intuitive/consistent implementation, which has lead me to the big block of text below. I would be super keen to hear your thoughts as I try to dive deeper into the inner workings of JAX.

Lately there have been a number of issues requesting improvements to *_argnums and *_argnames parameters used in transformations in addition to other ergonomics improvements related to declaring which function arguments should be annotated with a given property. I figured it might be helpful to make an over-arching issue with the end goal of having a consistent, ergonomic way of specifying these parameters. Managing argument 'annotations' in transformations has definitely been one of the more frustrating experiences of learning JAX (which is otherwise entirely amazing, of course)

Related issues:

jax.jit correctly implements static_argnames even for cases with keyword-only arguments, which would suggest that it should be possible to add argnames equivalents to any function that currently only implements argnums.

An easier but less robust fix could be to map argnames to argnums using inspect (see discussion: #1159). This would likely not work for keyword-only arguments (though it might for things like donate_arg...?)

Current shortcomings

Currently even the most robust implementation of the 'argument annotation' mechanism behaves in a somewhat counter-intuitive way (although this is suggested in the fine print of the docstring, if one reads it with sufficient care):

def f(a, /, b, *, c):
    print(a, b, c)

jf = jit(f, static_argnames=("a", "b", "c"))
jf(1, 2, c=3)
> Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> 2 3
# Expected: 1 2 3

jf2 = jit(f, static_argnames=("b", "c"), static_argnums=(0,))
jf2(1, 2, c=3)
> 1 Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> 3
# Expected: 1 2 3

jf2(1, b=2, c=3)
> 1 2 3
# As expected

jf3 = jit(f, static_argnums=(0, 1, 2))
jf3(1, 2, c=3)
> 1 2 Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
# Expected: 1 2 3

jf3(1, b=2, c=3)
> 1 2 Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
# Expected: 1 2 3

The fact that we have one instance where we are able to get the expected result gives hope that a solution should be possible by inspecting the function and arguments and modifying static_argnums and static_argnames accordingly – or perhaps a better solution exists? Ideally we would want to avoid inspecting the arguments at call-time.

I have started toying with validation of static_argnums and static_argnames in #10603

Goals

My suggestion would be that a solution that fixes the inconsistencies above (or in the worst case documents them thoroughly) is found for jax.jit.

Once that is done, it would be great to see *_argnames and keyword-arg support added to other functions:

  • jax.experiment.pjit
  • jax.pmap
  • jax.value_and_grad
  • jax.custom_vjp
  • jax.custom_jvp
  • jax.hessian
  • jax.jacrev
  • jax.jacfwd
  • jax.grad

Additionally #10476 can be explored (could live in jax.experimental.annotations, if there is any interest for this feature at all)

Progress

  • Get feedback and decide on: (this issue)
    • Interface (potential changes in function signatures for argument annotations)
    • Behaviour
  • Document interface and behaviour (initial PR: [WIP] Document argument annotations #10677)
  • Make tests and ensure consistency for functions

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions