-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Tensor] Enhance bufferization of tensor.expand_shape op #128871
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Tensor] Enhance bufferization of tensor.expand_shape op #128871
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: Arnab Dutta (arnab-polymage) ChangesInstead of inferring the output shape argument of Full diff: https://github.com/llvm/llvm-project/pull/128871.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 81404fa664cd4..efbe09f4d2419 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -337,14 +337,27 @@ struct ExpandShapeOpInterface
if (failed(buffer))
return failure();
- // Memref result type is inferred by the builder based on reassociation
- // indices and result shape.
- // TODO: Instead of inferring the output shape argument of
- // memref.expand_shape op, use output_shape argument of tensor.expand_shape
- // op.
- replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
- rewriter, op, tensorResultType.getShape(), *buffer,
- expandShapeOp.getReassociationIndices());
+ // Use output_shape argument of tensor.expand_shape op to get the result
+ // shapes of the memref.expand_shape op to be created.
+ SmallVector<OpFoldResult> outShape;
+ unsigned dynDimCount = 0;
+ for (unsigned i = 0, e = tensorResultType.getRank(); i < e; i++) {
+ if (tensorResultType.isDynamicDim(i))
+ outShape.push_back(expandShapeOp.getOutputShape()[dynDimCount++]);
+ }
+ auto memrefExpandShape = rewriter.create<memref::ExpandShapeOp>(
+ op->getLoc(), tensorResultType.getShape(), *buffer,
+ expandShapeOp.getReassociationIndices(), outShape);
+ SmallVector<int64_t> staticShape;
+ for (unsigned i = 0, e = tensorResultType.getRank(); i < e; i++) {
+ if (tensorResultType.isDynamicDim(i))
+ staticShape.push_back(ShapedType::kDynamic);
+ else
+ staticShape.push_back(tensorResultType.getDimSize(i));
+ }
+ memrefExpandShape.setStaticOutputShape(staticShape);
+ replaceOpWithBufferizedValues(rewriter, op,
+ memrefExpandShape->getResults());
return success();
}
};
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 9ea0a15f31185..c1beed95f2006 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -366,14 +366,10 @@ func.func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5x
// -----
// CHECK-LABEL: func @tensor.expand_shape(
-// CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32>
+// CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32>, %[[sz0:.*]]: index
func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?x10xf32> {
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]]
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[DIM:.*]] = memref.dim %[[m1]], %[[C0]] : memref<?x10xf32>
- // CHECK: %[[C2:.*]] = arith.constant 2 : index
- // CHECK: %[[VAL_1:.*]] = arith.divsi %[[DIM]], %[[C2]] : index
- // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[VAL_1]], 10] : memref<?x10xf32> into memref<2x?x10xf32>
+ // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[sz0]], 10] : memref<?x10xf32> into memref<2x?x10xf32>
%0 = tensor.expand_shape %t1 [[0, 1], [2]] output_shape [2, %sz0, 10]
: tensor<?x10xf32> into tensor<2x?x10xf32>
@@ -385,23 +381,20 @@ func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?
// -----
// CHECK-LABEL: func @tensor.expand_shape_of_slice(
-// CHECK-SAME: %[[t1:.*]]: tensor<?x20xf32>
+// CHECK-SAME: %[[t1:.*]]: tensor<?x20xf32>, %{{.*}}: index, %{{.*}}: index, %[[sz0:.*]]: index
func.func @tensor.expand_shape_of_slice(
%t1: tensor<?x20xf32>, %o1: index, %s1: index, %sz0: index) -> tensor<?x7x2x5xf32> {
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] :
// CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref<?x20xf32> to memref<?x10xf32, strided<[20, 1], offset: ?>>
%0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] :
tensor<?x20xf32> to tensor<?x10xf32>
- // CHECK: %[[C7:.*]] = arith.constant 7 : index
- // CHECK: %[[VAL_1:.*]] = arith.divsi %{{.*}}, %[[C7]] : index
- // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
+ // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[sz0]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
%1 = tensor.expand_shape %0 [[0, 1], [2, 3]] output_shape [%sz0, 7, 2, 5] :
tensor<?x10xf32> into tensor<?x7x2x5xf32>
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
// CHECK: return %[[r]]
return %1 : tensor<?x7x2x5xf32>
}
-
// -----
// CHECK-LABEL: func @tensor.expand_shape_of_scalar_slice(
@@ -417,7 +410,20 @@ func.func @tensor.expand_shape_of_scalar_slice(
// CHECK: return %[[r]]
return %1 : tensor<1xf32>
}
+// -----
+// CHECK-LABEL: func @tensor.expand_shape_multiple_dynamic_indices(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x256xf32>, %[[sz0:.*]]: index, %[[sz1:.*]]: index, %[[sz2:.*]]: index
+func.func @tensor.expand_shape_multiple_dynamic_indices(%t1: tensor<?x256xf32>, %sz0: index, %sz1: index, %sz2: index) -> tensor<?x?x?x256xf32> {
+ // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]]
+ // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[sz0]], %[[sz1]], %[[sz2]], 256] : memref<?x256xf32> into memref<?x?x?x256xf32>
+ %0 = tensor.expand_shape %t1 [[0, 1, 2], [3]] output_shape [%sz0, %sz1, %sz2, 256]
+ : tensor<?x256xf32> into tensor<?x?x?x256xf32>
+
+ // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
+ // CHECK: return %[[r]]
+ return %0 : tensor<?x?x?x256xf32>
+}
// -----
// CHECK-LABEL: func @tensor.collapse_shape(
@@ -646,3 +652,6 @@ func.func @parallel_insert_slice_copy_before_write(%in: tensor<4xf32>, %out: ten
// CHECK: }
return
}
+
+// -----
+
|
@ramiro050 @bondhugula please review |
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
What is dd support? |
I read it as dynamic dimensions :) |
Instead of inferring the output shape argument of memref.expand_shape op, use output_shape argument of tensor.expand_shape op by adding dynamic dimension support for bufferization of tensor.expand_shape when there are more than one dynamic dim within a reassociation set.
c710f83
to
d9b2521
Compare
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/52/builds/6432 Here is the relevant piece of the build log for the reference
|
…128871) Instead of inferring the output shape argument of memref.expand_shape op, use output_shape argument of tensor.expand_shape op by adding dynamic dimension support for bufferization of tensor.expand_shape when there are more than one dynamic dim within a reassociation set.
Instead of inferring the output shape argument of
memref.expand_shape op, use output_shape argument of tensor.expand_shape op by adding dynamic dimension support for bufferization of tensor.expand_shape when there are more than one dynamic dim within a reassociation set.