Skip to content

static_argnames of jax.jit does not correctly infer argnums  #10618

@JeppeKlitgaard

Description

@JeppeKlitgaard

This is most easily shown by accessing _infer_argnums_and_argnames directly:

def f(a, /, b, *, c):
	...

static_argnames = ("a", "b", "c")
_infer_argnums_and_argnames(f, None, static_argnames)
> ((1,), ("a", "b", "c"))
# Expected: ((0, 1, 2,), ("a", "b", "c")

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions