Skip to content

kwargs sometimes cannot work well with jit #1159

@zhongwen

Description

@zhongwen

Minimal repro:

from functools import partial
from jax import jit

@partial(jit, static_argnums=(3,))
def f(a, b, c, d):
  if d > 0:
    return a + b - c
  else:
    return d

f(a=1, b=2, c=3, d=4) would complain about TypeError: Jitted function has static_argnums=(3,) but was called with only 0 positional arguments.

but f(1, 2, 3, 4) works.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions