Minimal repro: ``` from functools import partial from jax import jit @partial(jit, static_argnums=(3,)) def f(a, b, c, d): if d > 0: return a + b - c else: return d ``` `f(a=1, b=2, c=3, d=4)` would complain about `TypeError: Jitted function has static_argnums=(3,) but was called with only 0 positional arguments.` but `f(1, 2, 3, 4)` works.