Skip to content

Tracker: decomposing scan (aka "Five Loop") #10982

@sharadmv

Description

@sharadmv

With jaxprs now supporting effect types, we can express side-effects like the State monad, where references can be read from and written to (i.e. mutation). We can use state to implement a simpler scan control flow primitive via a for primitive that supports reads/writes.

This issue will track the implementation progress:

  • get/swap/addupdate primitives
    • impl rules
    • abstract_eval rules
    • jvp rules
    • transpose rules
    • vmap rules
  • Discharging state
    • Basic implementation
    • Handling higher-order primitives
  • for primitive
    • impl rule
    • abstract_eval rule
    • MLIR lowering
    • jvp rule
    • partial_eval
      • basic implementation
      • optimizations
        • loop invariant
        • make loop index a ref and use read/writes to determine which values are loop-invariant
        • residual passthrough
        • rematerializing loop-dependent values
    • transpose
    • vmap rule
    • partial_eval_custom rule
    • Miscellaneous
      • Handling closed-over refs
      • Nested for loops
      • Unrolling
    • Reimplement scan in terms of for

The "raw" version of for can be found here. Next steps involve porting that code to JAX core and adding tests.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions