Skip to content

Commit ba32b9c

Browse files
Don't fold aten.clone if result isn't same type as input (#3347)
Similar to #2824, we were seeing some assertion failures after the addition checks around folders were tightened up in LLVM: llvm/llvm-project#75887 . This PR essentially moves the logic that used to be applied at the LLVM level into the folder, which seems to be the suggested fix.
1 parent 5928f68 commit ba32b9c

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2581,7 +2581,8 @@ void AtenMaskedFillTensorOp::getCanonicalizationPatterns(
25812581

25822582
OpFoldResult AtenCloneOp::fold(FoldAdaptor adaptor) {
25832583
// note: memory_format would be ignored
2584-
if (llvm::dyn_cast<ValueTensorType>(getSelf().getType())) {
2584+
if (getSelf().getType() == getResult().getType() &&
2585+
llvm::dyn_cast<ValueTensorType>(getSelf().getType())) {
25852586
// self should have value semantics
25862587
return getSelf();
25872588
}

test/Dialect/Torch/canonicalize.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3015,3 +3015,14 @@ func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor
30153015
%result0, %result1 = torch.aten.max_pool2d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[10,64,56,56],f32>, !torch.vtensor<[10,64,56,56],si64>
30163016
return %result0 : !torch.vtensor<[10,64,56,56],f32>
30173017
}
3018+
3019+
// -----
3020+
3021+
// CHECK-LABEL: @torch.aten.clone$no_fold(
3022+
func.func @torch.aten.clone$no_fold(%arg0: !torch.vtensor<[1,2,50,4],f32>) -> (!torch.tensor) {
3023+
// CHECK: %{{.*}} = torch.aten.clone %{{.*}}, %{{.*}} : !torch.vtensor<[1,2,50,4],f32>, !torch.none -> !torch.vtensor
3024+
%none = torch.constant.none
3025+
%0 = torch.aten.clone %arg0, %none : !torch.vtensor<[1,2,50,4],f32>, !torch.none -> !torch.vtensor
3026+
%1 = torch.copy.to_tensor %0 : !torch.tensor
3027+
return %1 : !torch.tensor
3028+
}

0 commit comments

Comments
 (0)