Skip to content

Commit e865860

Browse files
committed
[mlir] [linalg] Fix bufferize error in tensor.parallel_insert_slice op
tensor.parallel op has implicit inplace behavior. In the "copy-before-write" bufferize mode, the resolveConflict function will generate bufferize.copy, making the result incorrect. This patch fixes this issue.
1 parent 17316a5 commit e865860

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -997,6 +997,11 @@ struct ParallelInsertSliceOpInterface
997997
rewriter.eraseOp(op);
998998
return success();
999999
}
1000+
1001+
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
1002+
const AnalysisState &state) const {
1003+
return success();
1004+
}
10001005
};
10011006

10021007
/// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled

mlir/test/Dialect/Tensor/bufferize.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,3 +626,23 @@ func.func @tensor.splat_dynamic(%f: f32, %m: index, %n: index) -> tensor<?x3x?xf
626626
return %0 : tensor<?x3x?xf32>
627627
}
628628

629+
// -----
630+
631+
// CHECK-LABEL: func.func @parallel_insert_slice_copy_before_write
632+
func.func @parallel_insert_slice_copy_before_write(%in: tensor<4xf32>, %out: tensor<4xf32>) {
633+
%c1 = arith.constant 1 : index
634+
%num_threads = arith.constant 4 : index
635+
636+
// CHECK: scf.forall {{.*}} {
637+
%result = scf.forall (%thread_idx) in (%num_threads) shared_outs (%o = %out) -> tensor<4xf32> {
638+
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<4xf32> to tensor<1xf32>
639+
scf.forall.in_parallel {
640+
// CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<4xf32> to memref<1xf32, strided<[1], offset: ?>>
641+
// CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<4xf32> to memref<1xf32, strided<[1], offset: ?>>
642+
tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] :
643+
tensor<1xf32> into tensor<4xf32>
644+
}
645+
}
646+
// CHECK: }
647+
return
648+
}

0 commit comments

Comments
 (0)