Skip to content

Commit d69e949

Browse files
authored
[mlir] [linalg] Fix bufferize error in tensor.parallel_insert_slice op (#98312)
tensor.parallel_insert_slice op has implicit inplace behavior. In the "copy-before-write" bufferize mode, the resolveConflict function will generate bufferize.copy, making the result incorrect. This patch fixes this issue.
1 parent 5c205b6 commit d69e949

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,8 +387,8 @@ struct ExtractSliceOpInterface
387387
if (failed(resultMemrefType))
388388
return failure();
389389
Value subView = rewriter.create<memref::SubViewOp>(
390-
loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets,
391-
mixedSizes, mixedStrides);
390+
loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
391+
mixedOffsets, mixedSizes, mixedStrides);
392392

393393
replaceOpWithBufferizedValues(rewriter, op, subView);
394394
return success();
@@ -407,8 +407,9 @@ struct ExtractSliceOpInterface
407407
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
408408
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
409409
return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
410-
extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
411-
mixedOffsets, mixedSizes, mixedStrides));
410+
extractSliceOp.getType().getShape(),
411+
llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
412+
mixedStrides));
412413
}
413414
};
414415

@@ -997,6 +998,13 @@ struct ParallelInsertSliceOpInterface
997998
rewriter.eraseOp(op);
998999
return success();
9991000
}
1001+
1002+
/// tensor.parallel_insert_slice op has implicit inplace behavior. We
1003+
/// shouldn't create copy to resolve conflict.
1004+
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
1005+
const AnalysisState &state) const {
1006+
return success();
1007+
}
10001008
};
10011009

10021010
/// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled

mlir/test/Dialect/Tensor/bufferize.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,3 +626,23 @@ func.func @tensor.splat_dynamic(%f: f32, %m: index, %n: index) -> tensor<?x3x?xf
626626
return %0 : tensor<?x3x?xf32>
627627
}
628628

629+
// -----
630+
631+
// CHECK-LABEL: func.func @parallel_insert_slice_copy_before_write
632+
func.func @parallel_insert_slice_copy_before_write(%in: tensor<4xf32>, %out: tensor<4xf32>) {
633+
%c1 = arith.constant 1 : index
634+
%num_threads = arith.constant 4 : index
635+
636+
// CHECK: scf.forall {{.*}} {
637+
%result = scf.forall (%thread_idx) in (%num_threads) shared_outs (%o = %out) -> tensor<4xf32> {
638+
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<4xf32> to tensor<1xf32>
639+
scf.forall.in_parallel {
640+
// CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<4xf32> to memref<1xf32, strided<[1], offset: ?>>
641+
// CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<4xf32> to memref<1xf32, strided<[1], offset: ?>>
642+
tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] :
643+
tensor<1xf32> into tensor<4xf32>
644+
}
645+
}
646+
// CHECK: }
647+
return
648+
}

0 commit comments

Comments
 (0)