Skip to content

MPS topk failure for 5D tensor or aboveΒ #154890

@manuelcandales

Description

@manuelcandales

πŸ•΅οΈβ€β™‚οΈ Detected with FACTO

πŸ› Describe the bug

Similar to #154881

import torch

x = torch.ones([5, 4, 3, 2, 1], device="mps")
torch.ops.aten.topk(x, k=5, dim=0)

Error message:

/AppleInternal/Library/BuildRoots/01adf19d-fba1-11ef-a947-f2a857e00a32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSNDArray/Kernels/MPSNDArraySort.mm:208: failed assertion `(null)" Axis = 4. This class only supports axis = 0, 1, 2, 3
'
zsh: abort

Versions

torch==2.7.0

cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

Metadata

Metadata

Assignees

Labels

module: crashProblem manifests as a hard crash, as opposed to a RuntimeErrormodule: mpsRelated to Apple Metal Performance Shaders frameworkmodule: third_partytriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions