-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Open
Labels
enhancementNew feature or requestNew feature or request
Description
With jaxprs now supporting effect types, we can express side-effects like the State monad, where references can be read from and written to (i.e. mutation). We can use state to implement a simpler scan
control flow primitive via a for
primitive that supports reads/writes.
This issue will track the implementation progress:
-
get/swap/addupdate
primitives-
impl
rules -
abstract_eval
rules -
jvp
rules -
transpose
rules -
vmap
rules
-
- Discharging state
- Basic implementation
- Handling higher-order primitives
-
for
primitive-
impl
rule -
abstract_eval
rule - MLIR lowering
-
jvp
rule -
partial_eval
- basic implementation
- optimizations
- loop invariant
-
make loop index a ref anduse read/writes to determine which values are loop-invariant - residual passthrough
- rematerializing loop-dependent values
-
transpose
-
vmap
rule -
partial_eval_custom
rule - Miscellaneous
- Handling closed-over refs
- Nested for loops
- Unrolling
- Reimplement
scan
in terms offor
-
The "raw" version of for
can be found here. Next steps involve porting that code to JAX core and adding tests.
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request