Skip to content

Commit 3cccb20

Browse files
[MLIR][Tensor] Enhance bufferization of tensor.expand_shape op (#128871)
Instead of inferring the output shape argument of memref.expand_shape op, use output_shape argument of tensor.expand_shape op by adding dynamic dimension support for bufferization of tensor.expand_shape when there are more than one dynamic dim within a reassociation set.
1 parent 9f28621 commit 3cccb20

File tree

2 files changed

+26
-19
lines changed

2 files changed

+26
-19
lines changed

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -337,14 +337,12 @@ struct ExpandShapeOpInterface
337337
if (failed(buffer))
338338
return failure();
339339

340-
// Memref result type is inferred by the builder based on reassociation
341-
// indices and result shape.
342-
// TODO: Instead of inferring the output shape argument of
343-
// memref.expand_shape op, use output_shape argument of tensor.expand_shape
344-
// op.
345-
replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
346-
rewriter, op, tensorResultType.getShape(), *buffer,
347-
expandShapeOp.getReassociationIndices());
340+
auto memrefExpandShape = rewriter.create<memref::ExpandShapeOp>(
341+
op->getLoc(), tensorResultType.getShape(), *buffer,
342+
expandShapeOp.getReassociationIndices(),
343+
expandShapeOp.getMixedOutputShape());
344+
replaceOpWithBufferizedValues(rewriter, op,
345+
memrefExpandShape->getResults());
348346
return success();
349347
}
350348
};

mlir/test/Dialect/Tensor/bufferize.mlir

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -366,14 +366,10 @@ func.func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5x
366366
// -----
367367

368368
// CHECK-LABEL: func @tensor.expand_shape(
369-
// CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32>
369+
// CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32>, %[[sz0:.*]]: index
370370
func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?x10xf32> {
371371
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]]
372-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
373-
// CHECK: %[[DIM:.*]] = memref.dim %[[m1]], %[[C0]] : memref<?x10xf32>
374-
// CHECK: %[[C2:.*]] = arith.constant 2 : index
375-
// CHECK: %[[VAL_1:.*]] = arith.divsi %[[DIM]], %[[C2]] : index
376-
// CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[VAL_1]], 10] : memref<?x10xf32> into memref<2x?x10xf32>
372+
// CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[sz0]], 10] : memref<?x10xf32> into memref<2x?x10xf32>
377373
%0 = tensor.expand_shape %t1 [[0, 1], [2]] output_shape [2, %sz0, 10]
378374
: tensor<?x10xf32> into tensor<2x?x10xf32>
379375

@@ -385,23 +381,20 @@ func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?
385381
// -----
386382

387383
// CHECK-LABEL: func @tensor.expand_shape_of_slice(
388-
// CHECK-SAME: %[[t1:.*]]: tensor<?x20xf32>
384+
// CHECK-SAME: %[[t1:.*]]: tensor<?x20xf32>, %{{.*}}: index, %{{.*}}: index, %[[sz0:.*]]: index
389385
func.func @tensor.expand_shape_of_slice(
390386
%t1: tensor<?x20xf32>, %o1: index, %s1: index, %sz0: index) -> tensor<?x7x2x5xf32> {
391387
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] :
392388
// CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref<?x20xf32> to memref<?x10xf32, strided<[20, 1], offset: ?>>
393389
%0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] :
394390
tensor<?x20xf32> to tensor<?x10xf32>
395-
// CHECK: %[[C7:.*]] = arith.constant 7 : index
396-
// CHECK: %[[VAL_1:.*]] = arith.divsi %{{.*}}, %[[C7]] : index
397-
// CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
391+
// CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[sz0]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
398392
%1 = tensor.expand_shape %0 [[0, 1], [2, 3]] output_shape [%sz0, 7, 2, 5] :
399393
tensor<?x10xf32> into tensor<?x7x2x5xf32>
400394
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
401395
// CHECK: return %[[r]]
402396
return %1 : tensor<?x7x2x5xf32>
403397
}
404-
405398
// -----
406399

407400
// CHECK-LABEL: func @tensor.expand_shape_of_scalar_slice(
@@ -417,7 +410,20 @@ func.func @tensor.expand_shape_of_scalar_slice(
417410
// CHECK: return %[[r]]
418411
return %1 : tensor<1xf32>
419412
}
413+
// -----
420414

415+
// CHECK-LABEL: func @tensor.expand_shape_multiple_dynamic_indices(
416+
// CHECK-SAME: %[[t1:.*]]: tensor<?x256xf32>, %[[sz0:.*]]: index, %[[sz1:.*]]: index, %[[sz2:.*]]: index
417+
func.func @tensor.expand_shape_multiple_dynamic_indices(%t1: tensor<?x256xf32>, %sz0: index, %sz1: index, %sz2: index) -> tensor<?x?x?x256xf32> {
418+
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]]
419+
// CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[sz0]], %[[sz1]], %[[sz2]], 256] : memref<?x256xf32> into memref<?x?x?x256xf32>
420+
%0 = tensor.expand_shape %t1 [[0, 1, 2], [3]] output_shape [%sz0, %sz1, %sz2, 256]
421+
: tensor<?x256xf32> into tensor<?x?x?x256xf32>
422+
423+
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
424+
// CHECK: return %[[r]]
425+
return %0 : tensor<?x?x?x256xf32>
426+
}
421427
// -----
422428

423429
// CHECK-LABEL: func @tensor.collapse_shape(
@@ -646,3 +652,6 @@ func.func @parallel_insert_slice_copy_before_write(%in: tensor<4xf32>, %out: ten
646652
// CHECK: }
647653
return
648654
}
655+
656+
// -----
657+

0 commit comments

Comments
 (0)