-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] add tensor.concat, bitcast, expand_shape, collapse_shape vectorization support #97297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…pe vectorization support
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write If you have received no comments on your PR for a week, you can request a review If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir-linalg Author: xiaohui1.xu (BRUCE11111) ChangesPatch is 26.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97297.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4eb334f8bbbfa..e0fd5f1b14070 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3388,8 +3388,9 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
// TODO: Check that the correct number of vectorSizes was provided.
for (Operation *target : targets) {
- if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
- target)) {
+ if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp,
+ tensor::BitcastOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp,
+ tensor::ConcatOp>(target)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Unsupported Op, cannot vectorize";
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 3a75d2ac08157..7a4db82749fd1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1718,6 +1718,209 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
return success();
}
+/// Vectorize a `tensor::expandshape` to these 3 Ops:
+/// Vector::TransferReadOp - Reads a vector from the source tensor
+/// ShapeCastOp - Reshape the data based on the target.
+/// vector::TransferWriteOp. - Write the result vector back to the destination
+/// tensor
+static LogicalResult lowerTensorReshape(RewriterBase &rewriter,
+ Operation *inputOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(inputOp);
+ auto src = inputOp->getOperand(0);
+ auto srcType = mlir::dyn_cast<ShapedType>(src.getType());
+ auto result = inputOp->getResults()[0];
+ auto resultType = mlir::dyn_cast<ShapedType>(result.getType());
+ ArrayRef<int64_t> resultShape = resultType.getShape();
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ Location loc = inputOp->getLoc();
+
+ llvm::SmallVector<int64_t> srcVectorizedShape;
+ llvm::SmallDenseMap<int64_t, int64_t> shapeScales;
+
+ auto getVectorizeShape = [&](ArrayRef<int64_t> &retShape,
+ ArrayRef<int64_t> &inputShape) {
+ bool isResultShapeBigger = srcType.getRank() < resultType.getRank();
+
+ int64_t cur = 1, resultIdx = 0;
+ for (auto [srcIdx, ss] : llvm::enumerate(inputShape)) {
+ cur *= ss;
+ if (!isResultShapeBigger) {
+ // collapse
+ srcVectorizedShape.emplace_back(ss);
+ if (cur == retShape[resultIdx]) {
+ if (shapeScales.count(resultIdx)) {
+ srcVectorizedShape.back() *= shapeScales[resultIdx];
+ }
+ cur = 1;
+ resultIdx++;
+ }
+ } else {
+ // expand
+ if (cur == retShape[resultIdx]) {
+ srcVectorizedShape.emplace_back(cur);
+ if (shapeScales.count(srcIdx)) {
+ srcVectorizedShape.back() *= shapeScales[srcIdx];
+ }
+ cur = 1;
+ resultIdx++;
+ }
+ }
+ }
+ };
+ if (!inputVectorSizes.empty()) {
+ for (auto [idx, vs] : llvm::enumerate(inputVectorSizes)) {
+ if (vs != resultShape[idx])
+ shapeScales[idx] = vs / resultShape[idx];
+ }
+
+ bool isResultShapeBigger = srcType.getRank() < resultType.getRank();
+ if (!isResultShapeBigger) {
+ getVectorizeShape(resultShape, srcShape);
+ } else {
+ getVectorizeShape(srcShape, resultShape);
+ }
+ } else {
+ srcVectorizedShape.assign(srcShape.begin(), srcShape.end());
+ }
+ // read
+ auto padValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(srcType.getElementType()));
+ Value readResult = vector::createReadOrMaskedRead(
+ rewriter, loc, src,
+ inputVectorSizes.empty() ? srcType.getShape() : srcVectorizedShape,
+ padValue, false);
+
+ auto shapeCastType =
+ VectorType::get(inputVectorSizes.empty() ? resultShape : inputVectorSizes,
+ resultType.getElementType());
+ vector::ShapeCastOp shapeCastOp =
+ rewriter.create<vector::ShapeCastOp>(loc, shapeCastType, readResult);
+
+ // write
+ SmallVector<OpFoldResult> destSizes;
+ for (auto size : resultShape) {
+ destSizes.emplace_back(rewriter.getIndexAttr(size));
+ }
+ Operation *write = createWriteOrMaskedWrite(
+ rewriter, loc, shapeCastOp->getResults()[0], destSizes,
+ inputVectorSizes.empty() ? resultShape : inputVectorSizes, false);
+ newResults.push_back(write->getResult(0));
+ return success();
+}
+
+/// Vectorize a `tensor::bitcast` to these 3 Ops:
+/// vector::TransferReadOp - Reads a vector from the source tensor
+/// vector.Bitcast - Bitcast the data based on the target.
+/// vector::TransferWriteOp. - Write the result vector back to the destination
+/// tensor
+static LogicalResult lowerTensorBitcastOp(RewriterBase &rewriter,
+ tensor::BitcastOp bitCastOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(bitCastOp);
+
+ auto sourceType = bitCastOp.getSource().getType();
+ auto resultType = bitCastOp.getResult().getType();
+ auto resultShape = resultType.getShape();
+ if (inputVectorSizes.empty()) {
+ inputVectorSizes = resultShape;
+ }
+ Location loc = bitCastOp->getLoc();
+
+ // read
+ auto padValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(sourceType.getElementType()));
+ Value readResult = vector::createReadOrMaskedRead(
+ rewriter, loc, bitCastOp.getSource(), inputVectorSizes, padValue, false);
+
+ // bitcast
+ auto resultVectorType =
+ VectorType::get(inputVectorSizes, resultType.getElementType());
+ vector::BitCastOp vectorbitCastOp =
+ rewriter.create<vector::BitCastOp>(loc, resultVectorType, readResult);
+
+ // write
+ llvm::SmallVector<OpFoldResult> destSizes;
+ for (auto size : resultShape)
+ destSizes.emplace_back(rewriter.getIndexAttr(size));
+ auto write =
+ createWriteOrMaskedWrite(rewriter, loc, vectorbitCastOp->getResult(0),
+ destSizes, inputVectorSizes, false);
+ newResults.push_back(write->getResults()[0]);
+ return success();
+}
+
+/// Vectorize a `tensor::concat` to these 3 Ops:
+/// Tensor::EmptyOp - The result tensor.
+/// Vector::TransferWriteOp - Write the result vector back to the destination
+/// tensor.
+/// Vector::TransferWriteOp - Write the result vector back to the destination
+/// tensor.
+static LogicalResult lowerTensorConcatOp(RewriterBase &rewriter,
+ tensor::ConcatOp concatOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(concatOp);
+
+ Location loc = concatOp.getLoc();
+ FailureOr<Value> dest =
+ tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
+ if (failed(dest))
+ return failure();
+
+ auto empty = dest->getDefiningOp<tensor::EmptyOp>();
+ if (!empty)
+ return failure();
+
+ // Compute the partial sums for the slice offsets.
+ auto dim = concatOp.getDim();
+ Value dimValue =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dim));
+
+ int64_t rank = concatOp.getResultType().getRank();
+ auto srcType =
+ mlir::dyn_cast<RankedTensorType>(concatOp->getResultTypes()[0]);
+ auto padValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(srcType.getElementType()));
+
+ // Construct the chain of insert_slice ops into the destination.
+ Value result = *dest;
+ Value previous_offset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ for (auto [idx, input] : llvm::enumerate(concatOp.getInputs())) {
+
+ SmallVector<OpFoldResult> sizes =
+ tensor::getMixedSizes(rewriter, loc, input);
+ SmallVector<int64_t> readMaskShape;
+ auto inputType = mlir::dyn_cast<RankedTensorType>(input.getType());
+ auto sourceShape = inputType.getShape();
+
+ readMaskShape.append(sourceShape.begin(), sourceShape.end());
+ Value readResult = vector::createReadOrMaskedRead(
+ rewriter, loc, input, sourceShape, padValue, false);
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<Value> indices(rank, zero);
+ indices[dim] = previous_offset;
+ result = rewriter
+ .create<vector::TransferWriteOp>(
+ loc, readResult, result, indices,
+ rewriter.getMultiDimIdentityMap(rank))
+ ->getResults()[0];
+ if (idx != concatOp.getNumOperands() - 1) {
+ auto dimOp = rewriter.create<tensor::DimOp>(loc, input, dimValue);
+ previous_offset =
+ rewriter.create<arith::AddIOp>(loc, dimOp, previous_offset);
+ }
+ }
+
+ newResults.push_back(result);
+ return success();
+}
+
// TODO: probably need some extra checks for reduction followed by consumer
// ops that may not commute (e.g. linear reduction + non-linear instructions).
static LogicalResult reductionPreconditions(LinalgOp op) {
@@ -1931,6 +2134,108 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
return success();
}
+static LogicalResult
+lowerExpandOpPrecondition(tensor::ExpandShapeOp expandOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+ auto resultType = expandOp->getResultTypes()[0];
+ auto resultShape = mlir::dyn_cast<ShapedType>(resultType);
+ // check reassociation
+ llvm::SmallVector<int64_t> associateIndices;
+ for (auto &attr : expandOp.getReassociation()) {
+ for (auto &indice : mlir::dyn_cast<ArrayAttr>(attr)) {
+ associateIndices.push_back(mlir::dyn_cast<IntegerAttr>(indice).getInt());
+ }
+ }
+
+ if (llvm::any_of(associateIndices,
+ [](int64_t x) { return x == ShapedType::kDynamic; })) {
+ LDBG("Reassociation must be static: " << expandOp << "\n");
+ return failure();
+ }
+ // check input and output shape
+ if (!resultShape.hasStaticShape() ||
+ !expandOp.getSrcType().hasStaticShape()) {
+ LDBG("Input and output shape must be static: " << expandOp << "\n");
+ return failure();
+ }
+ if (!inputVectorSizes.empty() &&
+ failed(vector::isValidMaskedInputVector(resultShape.getShape(),
+ inputVectorSizes)))
+ return failure();
+
+ return success();
+}
+
+static LogicalResult
+lowerBitcastOpPrecondition(tensor::BitcastOp bitCastOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+ auto resultType = bitCastOp->getResultTypes()[0];
+ auto resultShapeType = mlir::dyn_cast<ShapedType>(resultType);
+ auto srcType = bitCastOp.getSource().getType();
+ auto srcShapeType = mlir::dyn_cast<ShapedType>(srcType);
+
+ bool isStaticInputOutput =
+ resultShapeType.hasStaticShape() && srcShapeType.hasStaticShape();
+ if (!isStaticInputOutput) {
+ LDBG("Input and output shape must be static: " << bitCastOp << "\n");
+ return failure();
+ }
+
+ if (!inputVectorSizes.empty() &&
+ failed(vector::isValidMaskedInputVector(resultShapeType.getShape(),
+ inputVectorSizes)))
+ return failure();
+ return success();
+}
+
+static LogicalResult
+lowerCollapseShapeOpPrecondition(tensor::CollapseShapeOp collapseOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+ auto resultType = collapseOp->getResultTypes()[0];
+ auto resultShapeType = mlir::dyn_cast<ShapedType>(resultType);
+ auto srcShapeType = collapseOp.getSrcType();
+
+ bool isStaticInputOutput =
+ resultShapeType.hasStaticShape() && srcShapeType.hasStaticShape();
+ if (!isStaticInputOutput) {
+ LDBG("Input and output shape must be static: " << collapseOp << "\n");
+ return failure();
+ }
+
+ if (!inputVectorSizes.empty() &&
+ failed(vector::isValidMaskedInputVector(resultShapeType.getShape(),
+ inputVectorSizes)))
+ return failure();
+ return success();
+}
+
+static LogicalResult
+lowerConcatOpPrecondition(tensor::ConcatOp concatOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+ if (!inputVectorSizes.empty()) {
+ LDBG("Concat operation do not support specify inputVectorSizes: "
+ << concatOp << "\n");
+ }
+ for (auto x : concatOp->getOperands()) {
+ auto type = mlir::dyn_cast<ShapedType>(x.getType());
+ if (!type) {
+ LDBG("Operation type error: " << concatOp << "\n");
+ return failure();
+ }
+ if (!type.hasStaticShape()) {
+ LDBG("Type must be static: " << concatOp << "\n");
+ return failure();
+ }
+ }
+ auto dim = concatOp.getDim();
+ if (dim >= (uint64_t)concatOp.getResultType().getRank()) {
+ LDBG("Invalid dim: " << concatOp << "\n");
+ return failure();
+ }
+
+ return success();
+}
+
/// Preconditions for scalable vectors.
static LogicalResult
vectorizeScalableVectorPrecondition(Operation *op,
@@ -1976,6 +2281,19 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
.Case<tensor::UnPackOp>([&](auto unpackOp) {
return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
})
+ .Case<tensor::ExpandShapeOp>([&](auto expandShapeOp) {
+ return lowerExpandOpPrecondition(expandShapeOp, inputVectorSizes);
+ })
+ .Case<tensor::CollapseShapeOp>([&](auto collapseShapeOp) {
+ return lowerCollapseShapeOpPrecondition(collapseShapeOp,
+ inputVectorSizes);
+ })
+ .Case<tensor::BitcastOp>([&](auto bitCastOp) {
+ return lowerBitcastOpPrecondition(bitCastOp, inputVectorSizes);
+ })
+ .Case<tensor::ConcatOp>([&](auto concatOp) {
+ return lowerConcatOpPrecondition(concatOp, inputVectorSizes);
+ })
.Default([](auto) { return failure(); });
}
@@ -2075,6 +2393,22 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
inputVectorSizes, results);
})
+ .Case<tensor::ExpandShapeOp>([&](auto expandShapeOp) {
+ return lowerTensorReshape(rewriter, expandShapeOp, inputVectorSizes,
+ results);
+ })
+ .Case<tensor::CollapseShapeOp>([&](auto collapseShapeOp) {
+ return lowerTensorReshape(rewriter, collapseShapeOp,
+ inputVectorSizes, results);
+ })
+ .Case<tensor::BitcastOp>([&](auto bitCastOp) {
+ return lowerTensorBitcastOp(rewriter, bitCastOp, inputVectorSizes,
+ results);
+ })
+ .Case<tensor::ConcatOp>([&](auto concatOp) {
+ return lowerTensorConcatOp(rewriter, concatOp, inputVectorSizes,
+ results);
+ })
.Default([](auto) { return failure(); });
if (failed(vectorizeResult)) {
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index bbeccc7fecd68..114815b4e3de8 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1055,3 +1055,195 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf
transform.yield
}
}
+
+ // -----
+
+ // CHECK-LABEL: func @test_vectorize_collapseshape
+func.func @test_vectorize_collapseshape(%source: tensor<8x8x32x16xf32>, %dest: tensor<64x512xf32>) -> tensor<64x512xf32> {
+ // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[C0:.*]]= arith.constant 0 : index
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK: %[[C80:.*]] = arith.constant 8 : index
+ // CHECK: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK: %[[C16:.*]] = arith.constant 16 : index
+ // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<8x8x32x32xi1>
+ // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<8x8x32x32xi1> -> vector<8x8x32x32xf32>
+ // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<8x8x32x32xf32> to vector<64x1024xf32>
+ // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x512xf32>
+ // CHECK: %[[C01:.*]] = arith.constant 0 : index
+ // CHECK: %[[C64:.*]] = arith.constant 64 : index
+ // CHECK: %[[C512:.*]] = arith.constant 512 : index
+ // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C64]], %[[C512]] : vector<64x1024xi1>
+ // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] {{.*}} : vector<64x1024xi1> -> tensor<64x512xf32>
+ // CHECK: return %[[WRIT]] : tensor<64x512xf32>
+ %collapsed = tensor.collapse_shape %source [[0, 1], [2, 3]] : tensor<8x8x32x16xf32> into tensor<64x512xf32>
+ return %collapsed : tensor<64x512xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.collapse_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [64, 1024] : !transform.any_op
+ transform.yield
+ }
+}
+
+ // -----
+
+ // CHECK-LABEL: func @test_vectorize_collapseshape_no_vector_size
+func.func @test_vectorize_collapseshape_no_vector_size(%source: tensor<8x8x32x16xf32>, %dest: tensor<64x512xf32>) -> tensor<64x512xf32> {
+ // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[C0:.*]]= arith.constant 0 : index
+ // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, %[[CST]] {in_bounds = [true, true, true, true]} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+ // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<8x8x32x16xf32> to vector<64x512xf32>
+ // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x512xf32>
+ // CHECK: %[[C01:.*]] = arith.constant 0 : index
+ // CHECK: %[[WRIT:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true, true]} : vector<64x512xf32>, tensor<64x512xf32>
+ // CHECK: return %[[WRIT]] : tensor<64x512xf32>
+ %collapsed = tensor.collapse_shape %source [[0, 1], [2, 3]] : tensor<8x8x32x16xf32> into tensor<64x512xf32>
+ return %collapsed : tensor<64x512xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.collapse_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 : !transform.any_op
+ transform.yield
+ }
+}
+
+ // -----
+
+ // CHECK-LABEL: func @test_vectorize_expandshape
+func.func @test_vectorize_expandshape(%source: tensor<64x512xf32>, %dest: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
+ // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[C0:.*]]= arith.constant 0 : index
+ // CHECK: %[[C64:.*]] = arith.constant 64 : index
+ // CHECK: %[[C512:.*]] = arith.constant 512 : index
+ // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C64]], %[[C512]] : vector<64x1024xi1>
+ // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<64x1024xi1> -> vector<64x1024xf32>
+ // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<64x1024xf32> to vector<8x8x32x32xf32>
+ // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<8x8x32x16xf32>
+ // CHECK: %[[C01:.*]]= arith.constant 0 : index
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK: %[[C80:.*]] = arith.constant 8 : index
+ // CHECK: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK: %[[C16:.*]] = arith.constant 16 : index
+ // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<8x8x3...
[truncated]
|
@llvm/pr-subscribers-mlir Author: xiaohui1.xu (BRUCE11111) ChangesPatch is 26.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97297.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4eb334f8bbbfa..e0fd5f1b14070 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3388,8 +3388,9 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
// TODO: Check that the correct number of vectorSizes was provided.
for (Operation *target : targets) {
- if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
- target)) {
+ if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp,
+ tensor::BitcastOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp,
+ tensor::ConcatOp>(target)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Unsupported Op, cannot vectorize";
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 3a75d2ac08157..7a4db82749fd1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1718,6 +1718,209 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
return success();
}
+/// Vectorize a `tensor::expandshape` to these 3 Ops:
+/// Vector::TransferReadOp - Reads a vector from the source tensor
+/// ShapeCastOp - Reshape the data based on the target.
+/// vector::TransferWriteOp. - Write the result vector back to the destination
+/// tensor
+static LogicalResult lowerTensorReshape(RewriterBase &rewriter,
+ Operation *inputOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(inputOp);
+ auto src = inputOp->getOperand(0);
+ auto srcType = mlir::dyn_cast<ShapedType>(src.getType());
+ auto result = inputOp->getResults()[0];
+ auto resultType = mlir::dyn_cast<ShapedType>(result.getType());
+ ArrayRef<int64_t> resultShape = resultType.getShape();
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ Location loc = inputOp->getLoc();
+
+ llvm::SmallVector<int64_t> srcVectorizedShape;
+ llvm::SmallDenseMap<int64_t, int64_t> shapeScales;
+
+ auto getVectorizeShape = [&](ArrayRef<int64_t> &retShape,
+ ArrayRef<int64_t> &inputShape) {
+ bool isResultShapeBigger = srcType.getRank() < resultType.getRank();
+
+ int64_t cur = 1, resultIdx = 0;
+ for (auto [srcIdx, ss] : llvm::enumerate(inputShape)) {
+ cur *= ss;
+ if (!isResultShapeBigger) {
+ // collapse
+ srcVectorizedShape.emplace_back(ss);
+ if (cur == retShape[resultIdx]) {
+ if (shapeScales.count(resultIdx)) {
+ srcVectorizedShape.back() *= shapeScales[resultIdx];
+ }
+ cur = 1;
+ resultIdx++;
+ }
+ } else {
+ // expand
+ if (cur == retShape[resultIdx]) {
+ srcVectorizedShape.emplace_back(cur);
+ if (shapeScales.count(srcIdx)) {
+ srcVectorizedShape.back() *= shapeScales[srcIdx];
+ }
+ cur = 1;
+ resultIdx++;
+ }
+ }
+ }
+ };
+ if (!inputVectorSizes.empty()) {
+ for (auto [idx, vs] : llvm::enumerate(inputVectorSizes)) {
+ if (vs != resultShape[idx])
+ shapeScales[idx] = vs / resultShape[idx];
+ }
+
+ bool isResultShapeBigger = srcType.getRank() < resultType.getRank();
+ if (!isResultShapeBigger) {
+ getVectorizeShape(resultShape, srcShape);
+ } else {
+ getVectorizeShape(srcShape, resultShape);
+ }
+ } else {
+ srcVectorizedShape.assign(srcShape.begin(), srcShape.end());
+ }
+ // read
+ auto padValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(srcType.getElementType()));
+ Value readResult = vector::createReadOrMaskedRead(
+ rewriter, loc, src,
+ inputVectorSizes.empty() ? srcType.getShape() : srcVectorizedShape,
+ padValue, false);
+
+ auto shapeCastType =
+ VectorType::get(inputVectorSizes.empty() ? resultShape : inputVectorSizes,
+ resultType.getElementType());
+ vector::ShapeCastOp shapeCastOp =
+ rewriter.create<vector::ShapeCastOp>(loc, shapeCastType, readResult);
+
+ // write
+ SmallVector<OpFoldResult> destSizes;
+ for (auto size : resultShape) {
+ destSizes.emplace_back(rewriter.getIndexAttr(size));
+ }
+ Operation *write = createWriteOrMaskedWrite(
+ rewriter, loc, shapeCastOp->getResults()[0], destSizes,
+ inputVectorSizes.empty() ? resultShape : inputVectorSizes, false);
+ newResults.push_back(write->getResult(0));
+ return success();
+}
+
+/// Vectorize a `tensor::bitcast` to these 3 Ops:
+/// vector::TransferReadOp - Reads a vector from the source tensor
+/// vector.Bitcast - Bitcast the data based on the target.
+/// vector::TransferWriteOp. - Write the result vector back to the destination
+/// tensor
+static LogicalResult lowerTensorBitcastOp(RewriterBase &rewriter,
+ tensor::BitcastOp bitCastOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(bitCastOp);
+
+ auto sourceType = bitCastOp.getSource().getType();
+ auto resultType = bitCastOp.getResult().getType();
+ auto resultShape = resultType.getShape();
+ if (inputVectorSizes.empty()) {
+ inputVectorSizes = resultShape;
+ }
+ Location loc = bitCastOp->getLoc();
+
+ // read
+ auto padValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(sourceType.getElementType()));
+ Value readResult = vector::createReadOrMaskedRead(
+ rewriter, loc, bitCastOp.getSource(), inputVectorSizes, padValue, false);
+
+ // bitcast
+ auto resultVectorType =
+ VectorType::get(inputVectorSizes, resultType.getElementType());
+ vector::BitCastOp vectorbitCastOp =
+ rewriter.create<vector::BitCastOp>(loc, resultVectorType, readResult);
+
+ // write
+ llvm::SmallVector<OpFoldResult> destSizes;
+ for (auto size : resultShape)
+ destSizes.emplace_back(rewriter.getIndexAttr(size));
+ auto write =
+ createWriteOrMaskedWrite(rewriter, loc, vectorbitCastOp->getResult(0),
+ destSizes, inputVectorSizes, false);
+ newResults.push_back(write->getResults()[0]);
+ return success();
+}
+
+/// Vectorize a `tensor::concat` to these 3 Ops:
+/// Tensor::EmptyOp - The result tensor.
+/// Vector::TransferWriteOp - Write the result vector back to the destination
+/// tensor.
+/// Vector::TransferWriteOp - Write the result vector back to the destination
+/// tensor.
+static LogicalResult lowerTensorConcatOp(RewriterBase &rewriter,
+ tensor::ConcatOp concatOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(concatOp);
+
+ Location loc = concatOp.getLoc();
+ FailureOr<Value> dest =
+ tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
+ if (failed(dest))
+ return failure();
+
+ auto empty = dest->getDefiningOp<tensor::EmptyOp>();
+ if (!empty)
+ return failure();
+
+ // Compute the partial sums for the slice offsets.
+ auto dim = concatOp.getDim();
+ Value dimValue =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dim));
+
+ int64_t rank = concatOp.getResultType().getRank();
+ auto srcType =
+ mlir::dyn_cast<RankedTensorType>(concatOp->getResultTypes()[0]);
+ auto padValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(srcType.getElementType()));
+
+ // Construct the chain of insert_slice ops into the destination.
+ Value result = *dest;
+ Value previous_offset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ for (auto [idx, input] : llvm::enumerate(concatOp.getInputs())) {
+
+ SmallVector<OpFoldResult> sizes =
+ tensor::getMixedSizes(rewriter, loc, input);
+ SmallVector<int64_t> readMaskShape;
+ auto inputType = mlir::dyn_cast<RankedTensorType>(input.getType());
+ auto sourceShape = inputType.getShape();
+
+ readMaskShape.append(sourceShape.begin(), sourceShape.end());
+ Value readResult = vector::createReadOrMaskedRead(
+ rewriter, loc, input, sourceShape, padValue, false);
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<Value> indices(rank, zero);
+ indices[dim] = previous_offset;
+ result = rewriter
+ .create<vector::TransferWriteOp>(
+ loc, readResult, result, indices,
+ rewriter.getMultiDimIdentityMap(rank))
+ ->getResults()[0];
+ if (idx != concatOp.getNumOperands() - 1) {
+ auto dimOp = rewriter.create<tensor::DimOp>(loc, input, dimValue);
+ previous_offset =
+ rewriter.create<arith::AddIOp>(loc, dimOp, previous_offset);
+ }
+ }
+
+ newResults.push_back(result);
+ return success();
+}
+
// TODO: probably need some extra checks for reduction followed by consumer
// ops that may not commute (e.g. linear reduction + non-linear instructions).
static LogicalResult reductionPreconditions(LinalgOp op) {
@@ -1931,6 +2134,108 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
return success();
}
+static LogicalResult
+lowerExpandOpPrecondition(tensor::ExpandShapeOp expandOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+ auto resultType = expandOp->getResultTypes()[0];
+ auto resultShape = mlir::dyn_cast<ShapedType>(resultType);
+ // check reassociation
+ llvm::SmallVector<int64_t> associateIndices;
+ for (auto &attr : expandOp.getReassociation()) {
+ for (auto &indice : mlir::dyn_cast<ArrayAttr>(attr)) {
+ associateIndices.push_back(mlir::dyn_cast<IntegerAttr>(indice).getInt());
+ }
+ }
+
+ if (llvm::any_of(associateIndices,
+ [](int64_t x) { return x == ShapedType::kDynamic; })) {
+ LDBG("Reassociation must be static: " << expandOp << "\n");
+ return failure();
+ }
+ // check input and output shape
+ if (!resultShape.hasStaticShape() ||
+ !expandOp.getSrcType().hasStaticShape()) {
+ LDBG("Input and output shape must be static: " << expandOp << "\n");
+ return failure();
+ }
+ if (!inputVectorSizes.empty() &&
+ failed(vector::isValidMaskedInputVector(resultShape.getShape(),
+ inputVectorSizes)))
+ return failure();
+
+ return success();
+}
+
+static LogicalResult
+lowerBitcastOpPrecondition(tensor::BitcastOp bitCastOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+ auto resultType = bitCastOp->getResultTypes()[0];
+ auto resultShapeType = mlir::dyn_cast<ShapedType>(resultType);
+ auto srcType = bitCastOp.getSource().getType();
+ auto srcShapeType = mlir::dyn_cast<ShapedType>(srcType);
+
+ bool isStaticInputOutput =
+ resultShapeType.hasStaticShape() && srcShapeType.hasStaticShape();
+ if (!isStaticInputOutput) {
+ LDBG("Input and output shape must be static: " << bitCastOp << "\n");
+ return failure();
+ }
+
+ if (!inputVectorSizes.empty() &&
+ failed(vector::isValidMaskedInputVector(resultShapeType.getShape(),
+ inputVectorSizes)))
+ return failure();
+ return success();
+}
+
+static LogicalResult
+lowerCollapseShapeOpPrecondition(tensor::CollapseShapeOp collapseOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+ auto resultType = collapseOp->getResultTypes()[0];
+ auto resultShapeType = mlir::dyn_cast<ShapedType>(resultType);
+ auto srcShapeType = collapseOp.getSrcType();
+
+ bool isStaticInputOutput =
+ resultShapeType.hasStaticShape() && srcShapeType.hasStaticShape();
+ if (!isStaticInputOutput) {
+ LDBG("Input and output shape must be static: " << collapseOp << "\n");
+ return failure();
+ }
+
+ if (!inputVectorSizes.empty() &&
+ failed(vector::isValidMaskedInputVector(resultShapeType.getShape(),
+ inputVectorSizes)))
+ return failure();
+ return success();
+}
+
+static LogicalResult
+lowerConcatOpPrecondition(tensor::ConcatOp concatOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+ if (!inputVectorSizes.empty()) {
+ LDBG("Concat operation do not support specify inputVectorSizes: "
+ << concatOp << "\n");
+ }
+ for (auto x : concatOp->getOperands()) {
+ auto type = mlir::dyn_cast<ShapedType>(x.getType());
+ if (!type) {
+ LDBG("Operation type error: " << concatOp << "\n");
+ return failure();
+ }
+ if (!type.hasStaticShape()) {
+ LDBG("Type must be static: " << concatOp << "\n");
+ return failure();
+ }
+ }
+ auto dim = concatOp.getDim();
+ if (dim >= (uint64_t)concatOp.getResultType().getRank()) {
+ LDBG("Invalid dim: " << concatOp << "\n");
+ return failure();
+ }
+
+ return success();
+}
+
/// Preconditions for scalable vectors.
static LogicalResult
vectorizeScalableVectorPrecondition(Operation *op,
@@ -1976,6 +2281,19 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
.Case<tensor::UnPackOp>([&](auto unpackOp) {
return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
})
+ .Case<tensor::ExpandShapeOp>([&](auto expandShapeOp) {
+ return lowerExpandOpPrecondition(expandShapeOp, inputVectorSizes);
+ })
+ .Case<tensor::CollapseShapeOp>([&](auto collapseShapeOp) {
+ return lowerCollapseShapeOpPrecondition(collapseShapeOp,
+ inputVectorSizes);
+ })
+ .Case<tensor::BitcastOp>([&](auto bitCastOp) {
+ return lowerBitcastOpPrecondition(bitCastOp, inputVectorSizes);
+ })
+ .Case<tensor::ConcatOp>([&](auto concatOp) {
+ return lowerConcatOpPrecondition(concatOp, inputVectorSizes);
+ })
.Default([](auto) { return failure(); });
}
@@ -2075,6 +2393,22 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
inputVectorSizes, results);
})
+ .Case<tensor::ExpandShapeOp>([&](auto expandShapeOp) {
+ return lowerTensorReshape(rewriter, expandShapeOp, inputVectorSizes,
+ results);
+ })
+ .Case<tensor::CollapseShapeOp>([&](auto collapseShapeOp) {
+ return lowerTensorReshape(rewriter, collapseShapeOp,
+ inputVectorSizes, results);
+ })
+ .Case<tensor::BitcastOp>([&](auto bitCastOp) {
+ return lowerTensorBitcastOp(rewriter, bitCastOp, inputVectorSizes,
+ results);
+ })
+ .Case<tensor::ConcatOp>([&](auto concatOp) {
+ return lowerTensorConcatOp(rewriter, concatOp, inputVectorSizes,
+ results);
+ })
.Default([](auto) { return failure(); });
if (failed(vectorizeResult)) {
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index bbeccc7fecd68..114815b4e3de8 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1055,3 +1055,195 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf
transform.yield
}
}
+
+ // -----
+
+ // CHECK-LABEL: func @test_vectorize_collapseshape
+func.func @test_vectorize_collapseshape(%source: tensor<8x8x32x16xf32>, %dest: tensor<64x512xf32>) -> tensor<64x512xf32> {
+ // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[C0:.*]]= arith.constant 0 : index
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK: %[[C80:.*]] = arith.constant 8 : index
+ // CHECK: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK: %[[C16:.*]] = arith.constant 16 : index
+ // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<8x8x32x32xi1>
+ // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<8x8x32x32xi1> -> vector<8x8x32x32xf32>
+ // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<8x8x32x32xf32> to vector<64x1024xf32>
+ // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x512xf32>
+ // CHECK: %[[C01:.*]] = arith.constant 0 : index
+ // CHECK: %[[C64:.*]] = arith.constant 64 : index
+ // CHECK: %[[C512:.*]] = arith.constant 512 : index
+ // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C64]], %[[C512]] : vector<64x1024xi1>
+ // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] {{.*}} : vector<64x1024xi1> -> tensor<64x512xf32>
+ // CHECK: return %[[WRIT]] : tensor<64x512xf32>
+ %collapsed = tensor.collapse_shape %source [[0, 1], [2, 3]] : tensor<8x8x32x16xf32> into tensor<64x512xf32>
+ return %collapsed : tensor<64x512xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.collapse_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [64, 1024] : !transform.any_op
+ transform.yield
+ }
+}
+
+ // -----
+
+ // CHECK-LABEL: func @test_vectorize_collapseshape_no_vector_size
+func.func @test_vectorize_collapseshape_no_vector_size(%source: tensor<8x8x32x16xf32>, %dest: tensor<64x512xf32>) -> tensor<64x512xf32> {
+ // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[C0:.*]]= arith.constant 0 : index
+ // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, %[[CST]] {in_bounds = [true, true, true, true]} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+ // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<8x8x32x16xf32> to vector<64x512xf32>
+ // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x512xf32>
+ // CHECK: %[[C01:.*]] = arith.constant 0 : index
+ // CHECK: %[[WRIT:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true, true]} : vector<64x512xf32>, tensor<64x512xf32>
+ // CHECK: return %[[WRIT]] : tensor<64x512xf32>
+ %collapsed = tensor.collapse_shape %source [[0, 1], [2, 3]] : tensor<8x8x32x16xf32> into tensor<64x512xf32>
+ return %collapsed : tensor<64x512xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.collapse_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 : !transform.any_op
+ transform.yield
+ }
+}
+
+ // -----
+
+ // CHECK-LABEL: func @test_vectorize_expandshape
+func.func @test_vectorize_expandshape(%source: tensor<64x512xf32>, %dest: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
+ // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[C0:.*]]= arith.constant 0 : index
+ // CHECK: %[[C64:.*]] = arith.constant 64 : index
+ // CHECK: %[[C512:.*]] = arith.constant 512 : index
+ // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C64]], %[[C512]] : vector<64x1024xi1>
+ // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<64x1024xi1> -> vector<64x1024xf32>
+ // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<64x1024xf32> to vector<8x8x32x32xf32>
+ // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<8x8x32x16xf32>
+ // CHECK: %[[C01:.*]]= arith.constant 0 : index
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK: %[[C80:.*]] = arith.constant 8 : index
+ // CHECK: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK: %[[C16:.*]] = arith.constant 16 : index
+ // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<8x8x3...
[truncated]
|
SmallVectorImpl<Value> &newResults) { | ||
OpBuilder::InsertionGuard g(rewriter); | ||
rewriter.setInsertionPoint(inputOp); | ||
auto src = inputOp->getOperand(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please expand auto
unless the type is obvious from line-level context.
rewriter.setInsertionPoint(inputOp); | ||
auto src = inputOp->getOperand(0); | ||
auto srcType = mlir::dyn_cast<ShapedType>(src.getType()); | ||
auto result = inputOp->getResults()[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In MLIR, we generally dislike "raw" indexed accessors. Can this function be turned into a function template and use named accessors instead?
ArrayRef<int64_t> srcShape = srcType.getShape(); | ||
Location loc = inputOp->getLoc(); | ||
|
||
llvm::SmallVector<int64_t> srcVectorizedShape; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to prefix SmallVector
with llvm::
.
llvm::SmallVector<int64_t> srcVectorizedShape; | ||
llvm::SmallDenseMap<int64_t, int64_t> shapeScales; | ||
|
||
auto getVectorizeShape = [&](ArrayRef<int64_t> &retShape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ArrayRef
is already a reference as the name indicates, there is no need to pass it by reference.
for (auto [srcIdx, ss] : llvm::enumerate(inputShape)) { | ||
cur *= ss; | ||
if (!isResultShapeBigger) { | ||
// collapse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use full sentences in comments, including capitalization and trailing full stops.
// check reassociation | ||
llvm::SmallVector<int64_t> associateIndices; | ||
for (auto &attr : expandOp.getReassociation()) { | ||
for (auto &indice : mlir::dyn_cast<ArrayAttr>(attr)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: the singular of "indices" is "index".
if (!type) { | ||
LDBG("Operation type error: " << concatOp << "\n"); | ||
return failure(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this ever happen?
lowerConcatOpPrecondition(tensor::ConcatOp concatOp, | ||
ArrayRef<int64_t> inputVectorSizes) { | ||
if (!inputVectorSizes.empty()) { | ||
LDBG("Concat operation do not support specify inputVectorSizes: " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LDBG("Concat operation do not support specify inputVectorSizes: " | |
LDBG("Concat operation does not support specify inputVectorSizes: " |
@@ -1976,6 +2281,19 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition( | |||
.Case<tensor::UnPackOp>([&](auto unpackOp) { | |||
return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes); | |||
}) | |||
.Case<tensor::ExpandShapeOp>([&](auto expandShapeOp) { | |||
return lowerExpandOpPrecondition(expandShapeOp, inputVectorSizes); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a strange choice to call these lowerFoo
when everything above is called vectorizeFoo
.
// CHECK: %[[C32:.*]] = arith.constant 32 : index | ||
// CHECK: %[[C16:.*]] = arith.constant 16 : index | ||
// CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<8x8x32x32xi1> | ||
// CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<8x8x32x32xi1> -> vector<8x8x32x32xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't check that we actually read inside the mask. Same below.
I'm not sure if it is a good idea to vectorize expand_shape, collapse_shape into vector.shape_cast. Because the tensor version does not have data-movement behavior but the vector.shape_cast does. (The doc of vector.shape_cast is outdated.) Is there a RFC or discussion about why we are adding vectorization patterns for these ops? In our use cases, we handle all these things at graph level, so we don't really need to think about these vectorization. Also, can you add some high level overview to PR description? I'm sorry that I can't look at implementation details because of my own priorities, so I can't advice much now. However, having an overview of a big change (+529-2) is generally good for reviewers and others. On the other hand, concat and bitcast is new to me. Looking at the doc, tensor.bitcast is not equivalent to vector.bitcast. E.g., |
Sorry, I've been travelling and didn't have the time to reply earlier. Basically, I agree with all the points from @hanhanW, thanks! I think that we should understand the motivation a bit better before proceeding with this.
I believe it is this one: We haven't really agreed on any specific path forward. |
Yes, the current work is about this RFC. Motivation: The current community discussion has reached a certain degree of consensus on the RFC work, such as vector-to-loops, etc. This work is necessary. But some detailed designs are still under discussion. For the current PR, we only need to know, before converting large virtual vector to
In the current vector dialect, we only have shape_cast that can align semantically with tensor.expand_shape, so I currently use shape_cast.
Thanks for pointing out that, I noticed that differentce between tensor.bitcast and vector.bitcast, I will convert tensor.bitcast → vector.tranfer_read + arith.bitcast + vector.transfer_write. In addition, I have a mask-related question for @hanhanW , which is written in another PR#97248. Sincere Thanks~ |
I don't believe that that's the case.
Why? |
No description provided.