Skip to content

while_loop vmap error #489

@mattjj

Description

@mattjj

This one is reported by @fehiepsi in another thread. Quoting the repro from there:

# generate a random number in the interval [0, 0.5]
def f(key):
    def body_fn(uk):
        key = uk[1]
        u = random.uniform(key, ())
        key, _ = random.split(key)
        return u, key

    u, _ = lax.while_loop(lambda uk: uk[0] > 0.5, body_fn, (1., key))
    #u = random.uniform(key, ())  # this is fine
    return u

print(f(random.PRNGKey(0)))  # no error
print(vmap(f)(random.split(random.PRNGKey(0), 2)))  # TypeError: 'NoneType' object cannot be interpreted as an integer

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions