-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Closed
Labels
module: onnxRelated to torch.onnxRelated to torch.onnxonnx-needs-infoneeds information from the author / reporter before ONNX team can take actionneeds information from the author / reporter before ONNX team can take actiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Hi,
I was exporting https://github.com/yunjey/stargan to ONNX, they have Instance Normalization layers with track_running_stats=True. In this case the layer keeps 4 params: running mean/variance, weight and bias.
Here is a minimal illustaration.
norm = nn.InstanceNorm2d(64, affine=True, track_running_stats=True)
input = torch.randn(1, 64, 128, 128)
norm(input)
norm.eval()
torch.onnx._export(norm, # model being run
(torch.rand(1,64, 128, 128)),
"./norm.onnx") ;
with torch.no_grad():
torchout=norm(input)
But when I export it to onnx, it keeps only weight and bias params. Can this be fixed?
I have also tested the layer with the onnxruntime, and the absolute and relative error is large.
import onnxruntime
ort_session = onnxruntime.InferenceSession("./norm.onnx")
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input)}
ort_outs = ort_session.run(None,ort_inputs)
# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torchout), ort_outs[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")
And this is the result of the onnx runtime from the instance noramlization layer.
Versions
Versions of relevant libraries:
[pip3] numpy==1.19.5
[pip3] torch==1.8.1+cpu
[pip3] torchaudio==0.8.1
[pip3] torchsummary==1.5.1
[pip3] torchvision==0.9.1+cpu
[pip3] torchviz==0.0.1
[conda] Could not collect
Metadata
Metadata
Assignees
Labels
module: onnxRelated to torch.onnxRelated to torch.onnxonnx-needs-infoneeds information from the author / reporter before ONNX team can take actionneeds information from the author / reporter before ONNX team can take actiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module