Skip to content

[User model][tracker] Improve compilation of PPO model (Stable-baselines3) #93697

@msaroufim

Description

@msaroufim

pip install stable-baselines3[extra]

Repro

from stable_baselines3 import PPO
import torchdynamo 



@torchdynamo.optimize("inductor")
def train():
    model = PPO("MlpPolicy", "CartPole-v1").learn(10_000)


import time 
tic = time.time()
train()
toc = time.time()
print(toc - tic)

Lots of assertion errors

Logs

Current state of things (feel free to edit as we fix things):
#93697 (comment)

cc @ezyang @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @mruberry @rgommers @wconstab @zou3519 @aakhundov @soumith @ngimel

Metadata

Metadata

Assignees

Labels

dynamo-must-fixThese bugs affect TorchDynamo reliability.dynamo-user-empathy-daymodule: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions