Skip to content

xla in pmap fails (i.e. jit-of-pmap or lax.scan with collectives) #804

@jheek

Description

@jheek

The parallel xla interpreter currently doesn't properly support nested jit compilation.
A practical example of this issue is when trying to use psum from within scan:

pmap( partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.), axis_name="i" )(np.ones((8, 4)))

Scan need to compiles the body of the loop using xla which fails because psum is only defined in the context of the pxla interpreter.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions