Skip to content

[mlir][vector] add tensor.concat, bitcast, expand_shape, collapse_shape vectorization support #97297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

BRUCE11111
Copy link

No description provided.

Copy link

github-actions bot commented Jul 1, 2024

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Jul 1, 2024

@llvm/pr-subscribers-mlir-linalg

Author: xiaohui1.xu (BRUCE11111)

Changes

Patch is 26.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97297.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+334)
  • (modified) mlir/test/Dialect/Linalg/vectorization.mlir (+192)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4eb334f8bbbfa..e0fd5f1b14070 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3388,8 +3388,9 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
 
   // TODO: Check that the correct number of vectorSizes was provided.
   for (Operation *target : targets) {
-    if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
-            target)) {
+    if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp,
+             tensor::BitcastOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp,
+             tensor::ConcatOp>(target)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Unsupported Op, cannot vectorize";
     }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 3a75d2ac08157..7a4db82749fd1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1718,6 +1718,209 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   return success();
 }
 
+/// Vectorize a `tensor::expandshape` to these 3 Ops:
+///   Vector::TransferReadOp - Reads a vector from the source tensor
+///   ShapeCastOp - Reshape the data based on the target.
+///   vector::TransferWriteOp. - Write the result vector back to the destination
+///   tensor
+static LogicalResult lowerTensorReshape(RewriterBase &rewriter,
+                                        Operation *inputOp,
+                                        ArrayRef<int64_t> inputVectorSizes,
+                                        SmallVectorImpl<Value> &newResults) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(inputOp);
+  auto src = inputOp->getOperand(0);
+  auto srcType = mlir::dyn_cast<ShapedType>(src.getType());
+  auto result = inputOp->getResults()[0];
+  auto resultType = mlir::dyn_cast<ShapedType>(result.getType());
+  ArrayRef<int64_t> resultShape = resultType.getShape();
+  ArrayRef<int64_t> srcShape = srcType.getShape();
+  Location loc = inputOp->getLoc();
+
+  llvm::SmallVector<int64_t> srcVectorizedShape;
+  llvm::SmallDenseMap<int64_t, int64_t> shapeScales;
+
+  auto getVectorizeShape = [&](ArrayRef<int64_t> &retShape,
+                               ArrayRef<int64_t> &inputShape) {
+    bool isResultShapeBigger = srcType.getRank() < resultType.getRank();
+
+    int64_t cur = 1, resultIdx = 0;
+    for (auto [srcIdx, ss] : llvm::enumerate(inputShape)) {
+      cur *= ss;
+      if (!isResultShapeBigger) {
+        // collapse
+        srcVectorizedShape.emplace_back(ss);
+        if (cur == retShape[resultIdx]) {
+          if (shapeScales.count(resultIdx)) {
+            srcVectorizedShape.back() *= shapeScales[resultIdx];
+          }
+          cur = 1;
+          resultIdx++;
+        }
+      } else {
+        // expand
+        if (cur == retShape[resultIdx]) {
+          srcVectorizedShape.emplace_back(cur);
+          if (shapeScales.count(srcIdx)) {
+            srcVectorizedShape.back() *= shapeScales[srcIdx];
+          }
+          cur = 1;
+          resultIdx++;
+        }
+      }
+    }
+  };
+  if (!inputVectorSizes.empty()) {
+    for (auto [idx, vs] : llvm::enumerate(inputVectorSizes)) {
+      if (vs != resultShape[idx])
+        shapeScales[idx] = vs / resultShape[idx];
+    }
+
+    bool isResultShapeBigger = srcType.getRank() < resultType.getRank();
+    if (!isResultShapeBigger) {
+      getVectorizeShape(resultShape, srcShape);
+    } else {
+      getVectorizeShape(srcShape, resultShape);
+    }
+  } else {
+    srcVectorizedShape.assign(srcShape.begin(), srcShape.end());
+  }
+  // read
+  auto padValue = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getZeroAttr(srcType.getElementType()));
+  Value readResult = vector::createReadOrMaskedRead(
+      rewriter, loc, src,
+      inputVectorSizes.empty() ? srcType.getShape() : srcVectorizedShape,
+      padValue, false);
+
+  auto shapeCastType =
+      VectorType::get(inputVectorSizes.empty() ? resultShape : inputVectorSizes,
+                      resultType.getElementType());
+  vector::ShapeCastOp shapeCastOp =
+      rewriter.create<vector::ShapeCastOp>(loc, shapeCastType, readResult);
+
+  // write
+  SmallVector<OpFoldResult> destSizes;
+  for (auto size : resultShape) {
+    destSizes.emplace_back(rewriter.getIndexAttr(size));
+  }
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, shapeCastOp->getResults()[0], destSizes,
+      inputVectorSizes.empty() ? resultShape : inputVectorSizes, false);
+  newResults.push_back(write->getResult(0));
+  return success();
+}
+
+/// Vectorize a `tensor::bitcast` to these 3 Ops:
+///   vector::TransferReadOp - Reads a vector from the source tensor
+///   vector.Bitcast - Bitcast the data based on the target.
+///   vector::TransferWriteOp. - Write the result vector back to the destination
+///   tensor
+static LogicalResult lowerTensorBitcastOp(RewriterBase &rewriter,
+                                          tensor::BitcastOp bitCastOp,
+                                          ArrayRef<int64_t> inputVectorSizes,
+                                          SmallVectorImpl<Value> &newResults) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(bitCastOp);
+
+  auto sourceType = bitCastOp.getSource().getType();
+  auto resultType = bitCastOp.getResult().getType();
+  auto resultShape = resultType.getShape();
+  if (inputVectorSizes.empty()) {
+    inputVectorSizes = resultShape;
+  }
+  Location loc = bitCastOp->getLoc();
+
+  // read
+  auto padValue = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getZeroAttr(sourceType.getElementType()));
+  Value readResult = vector::createReadOrMaskedRead(
+      rewriter, loc, bitCastOp.getSource(), inputVectorSizes, padValue, false);
+
+  // bitcast
+  auto resultVectorType =
+      VectorType::get(inputVectorSizes, resultType.getElementType());
+  vector::BitCastOp vectorbitCastOp =
+      rewriter.create<vector::BitCastOp>(loc, resultVectorType, readResult);
+
+  // write
+  llvm::SmallVector<OpFoldResult> destSizes;
+  for (auto size : resultShape)
+    destSizes.emplace_back(rewriter.getIndexAttr(size));
+  auto write =
+      createWriteOrMaskedWrite(rewriter, loc, vectorbitCastOp->getResult(0),
+                               destSizes, inputVectorSizes, false);
+  newResults.push_back(write->getResults()[0]);
+  return success();
+}
+
+/// Vectorize a `tensor::concat` to these 3 Ops:
+///   Tensor::EmptyOp - The result tensor.
+///   Vector::TransferWriteOp - Write the result vector back to the destination
+///   tensor.
+///   Vector::TransferWriteOp - Write the result vector back to the destination
+///   tensor.
+static LogicalResult lowerTensorConcatOp(RewriterBase &rewriter,
+                                         tensor::ConcatOp concatOp,
+                                         ArrayRef<int64_t> inputVectorSizes,
+                                         SmallVectorImpl<Value> &newResults) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(concatOp);
+
+  Location loc = concatOp.getLoc();
+  FailureOr<Value> dest =
+      tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
+  if (failed(dest))
+    return failure();
+
+  auto empty = dest->getDefiningOp<tensor::EmptyOp>();
+  if (!empty)
+    return failure();
+
+  // Compute the partial sums for the slice offsets.
+  auto dim = concatOp.getDim();
+  Value dimValue =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dim));
+
+  int64_t rank = concatOp.getResultType().getRank();
+  auto srcType =
+      mlir::dyn_cast<RankedTensorType>(concatOp->getResultTypes()[0]);
+  auto padValue = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getZeroAttr(srcType.getElementType()));
+
+  // Construct the chain of insert_slice ops into the destination.
+  Value result = *dest;
+  Value previous_offset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  for (auto [idx, input] : llvm::enumerate(concatOp.getInputs())) {
+
+    SmallVector<OpFoldResult> sizes =
+        tensor::getMixedSizes(rewriter, loc, input);
+    SmallVector<int64_t> readMaskShape;
+    auto inputType = mlir::dyn_cast<RankedTensorType>(input.getType());
+    auto sourceShape = inputType.getShape();
+
+    readMaskShape.append(sourceShape.begin(), sourceShape.end());
+    Value readResult = vector::createReadOrMaskedRead(
+        rewriter, loc, input, sourceShape, padValue, false);
+    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    SmallVector<Value> indices(rank, zero);
+    indices[dim] = previous_offset;
+    result = rewriter
+                 .create<vector::TransferWriteOp>(
+                     loc, readResult, result, indices,
+                     rewriter.getMultiDimIdentityMap(rank))
+                 ->getResults()[0];
+    if (idx != concatOp.getNumOperands() - 1) {
+      auto dimOp = rewriter.create<tensor::DimOp>(loc, input, dimValue);
+      previous_offset =
+          rewriter.create<arith::AddIOp>(loc, dimOp, previous_offset);
+    }
+  }
+
+  newResults.push_back(result);
+  return success();
+}
+
 // TODO: probably need some extra checks for reduction followed by consumer
 // ops that may not commute (e.g. linear reduction + non-linear instructions).
 static LogicalResult reductionPreconditions(LinalgOp op) {
@@ -1931,6 +2134,108 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
   return success();
 }
 
+static LogicalResult
+lowerExpandOpPrecondition(tensor::ExpandShapeOp expandOp,
+                          ArrayRef<int64_t> inputVectorSizes) {
+  auto resultType = expandOp->getResultTypes()[0];
+  auto resultShape = mlir::dyn_cast<ShapedType>(resultType);
+  // check reassociation
+  llvm::SmallVector<int64_t> associateIndices;
+  for (auto &attr : expandOp.getReassociation()) {
+    for (auto &indice : mlir::dyn_cast<ArrayAttr>(attr)) {
+      associateIndices.push_back(mlir::dyn_cast<IntegerAttr>(indice).getInt());
+    }
+  }
+
+  if (llvm::any_of(associateIndices,
+                   [](int64_t x) { return x == ShapedType::kDynamic; })) {
+    LDBG("Reassociation must be static: " << expandOp << "\n");
+    return failure();
+  }
+  // check input and output shape
+  if (!resultShape.hasStaticShape() ||
+      !expandOp.getSrcType().hasStaticShape()) {
+    LDBG("Input and output shape must be static: " << expandOp << "\n");
+    return failure();
+  }
+  if (!inputVectorSizes.empty() &&
+      failed(vector::isValidMaskedInputVector(resultShape.getShape(),
+                                              inputVectorSizes)))
+    return failure();
+
+  return success();
+}
+
+static LogicalResult
+lowerBitcastOpPrecondition(tensor::BitcastOp bitCastOp,
+                           ArrayRef<int64_t> inputVectorSizes) {
+  auto resultType = bitCastOp->getResultTypes()[0];
+  auto resultShapeType = mlir::dyn_cast<ShapedType>(resultType);
+  auto srcType = bitCastOp.getSource().getType();
+  auto srcShapeType = mlir::dyn_cast<ShapedType>(srcType);
+
+  bool isStaticInputOutput =
+      resultShapeType.hasStaticShape() && srcShapeType.hasStaticShape();
+  if (!isStaticInputOutput) {
+    LDBG("Input and output shape must be static: " << bitCastOp << "\n");
+    return failure();
+  }
+
+  if (!inputVectorSizes.empty() &&
+      failed(vector::isValidMaskedInputVector(resultShapeType.getShape(),
+                                              inputVectorSizes)))
+    return failure();
+  return success();
+}
+
+static LogicalResult
+lowerCollapseShapeOpPrecondition(tensor::CollapseShapeOp collapseOp,
+                                 ArrayRef<int64_t> inputVectorSizes) {
+  auto resultType = collapseOp->getResultTypes()[0];
+  auto resultShapeType = mlir::dyn_cast<ShapedType>(resultType);
+  auto srcShapeType = collapseOp.getSrcType();
+
+  bool isStaticInputOutput =
+      resultShapeType.hasStaticShape() && srcShapeType.hasStaticShape();
+  if (!isStaticInputOutput) {
+    LDBG("Input and output shape must be static: " << collapseOp << "\n");
+    return failure();
+  }
+
+  if (!inputVectorSizes.empty() &&
+      failed(vector::isValidMaskedInputVector(resultShapeType.getShape(),
+                                              inputVectorSizes)))
+    return failure();
+  return success();
+}
+
+static LogicalResult
+lowerConcatOpPrecondition(tensor::ConcatOp concatOp,
+                          ArrayRef<int64_t> inputVectorSizes) {
+  if (!inputVectorSizes.empty()) {
+    LDBG("Concat operation do not support specify inputVectorSizes: "
+         << concatOp << "\n");
+  }
+  for (auto x : concatOp->getOperands()) {
+    auto type = mlir::dyn_cast<ShapedType>(x.getType());
+    if (!type) {
+      LDBG("Operation type error: " << concatOp << "\n");
+      return failure();
+    }
+    if (!type.hasStaticShape()) {
+      LDBG("Type must be static: " << concatOp << "\n");
+      return failure();
+    }
+  }
+  auto dim = concatOp.getDim();
+  if (dim >= (uint64_t)concatOp.getResultType().getRank()) {
+    LDBG("Invalid dim: " << concatOp << "\n");
+    return failure();
+  }
+
+  return success();
+}
+
 /// Preconditions for scalable vectors.
 static LogicalResult
 vectorizeScalableVectorPrecondition(Operation *op,
@@ -1976,6 +2281,19 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
       .Case<tensor::UnPackOp>([&](auto unpackOp) {
         return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
       })
+      .Case<tensor::ExpandShapeOp>([&](auto expandShapeOp) {
+        return lowerExpandOpPrecondition(expandShapeOp, inputVectorSizes);
+      })
+      .Case<tensor::CollapseShapeOp>([&](auto collapseShapeOp) {
+        return lowerCollapseShapeOpPrecondition(collapseShapeOp,
+                                                inputVectorSizes);
+      })
+      .Case<tensor::BitcastOp>([&](auto bitCastOp) {
+        return lowerBitcastOpPrecondition(bitCastOp, inputVectorSizes);
+      })
+      .Case<tensor::ConcatOp>([&](auto concatOp) {
+        return lowerConcatOpPrecondition(concatOp, inputVectorSizes);
+      })
       .Default([](auto) { return failure(); });
 }
 
@@ -2075,6 +2393,22 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
             return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
                                              inputVectorSizes, results);
           })
+          .Case<tensor::ExpandShapeOp>([&](auto expandShapeOp) {
+            return lowerTensorReshape(rewriter, expandShapeOp, inputVectorSizes,
+                                      results);
+          })
+          .Case<tensor::CollapseShapeOp>([&](auto collapseShapeOp) {
+            return lowerTensorReshape(rewriter, collapseShapeOp,
+                                      inputVectorSizes, results);
+          })
+          .Case<tensor::BitcastOp>([&](auto bitCastOp) {
+            return lowerTensorBitcastOp(rewriter, bitCastOp, inputVectorSizes,
+                                        results);
+          })
+          .Case<tensor::ConcatOp>([&](auto concatOp) {
+            return lowerTensorConcatOp(rewriter, concatOp, inputVectorSizes,
+                                       results);
+          })
           .Default([](auto) { return failure(); });
 
   if (failed(vectorizeResult)) {
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index bbeccc7fecd68..114815b4e3de8 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1055,3 +1055,195 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf
     transform.yield
   } 
  }
+
+  // -----
+
+ // CHECK-LABEL: func @test_vectorize_collapseshape
+func.func @test_vectorize_collapseshape(%source: tensor<8x8x32x16xf32>, %dest: tensor<64x512xf32>) -> tensor<64x512xf32> {
+    // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+    // CHECK: %[[C0:.*]]= arith.constant 0 : index
+    // CHECK: %[[C8:.*]] = arith.constant 8 : index
+    // CHECK: %[[C80:.*]] = arith.constant 8 : index
+    // CHECK: %[[C32:.*]] = arith.constant 32 : index
+    // CHECK: %[[C16:.*]] = arith.constant 16 : index
+    // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<8x8x32x32xi1>
+    // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<8x8x32x32xi1> -> vector<8x8x32x32xf32>
+    // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<8x8x32x32xf32> to vector<64x1024xf32>
+    // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x512xf32>
+    // CHECK: %[[C01:.*]] = arith.constant 0 : index
+    // CHECK: %[[C64:.*]] = arith.constant 64 : index
+    // CHECK: %[[C512:.*]] = arith.constant 512 : index
+    // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C64]], %[[C512]] : vector<64x1024xi1>
+    // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] {{.*}} : vector<64x1024xi1> -> tensor<64x512xf32>
+    // CHECK: return %[[WRIT]] : tensor<64x512xf32>
+  %collapsed = tensor.collapse_shape %source [[0, 1], [2, 3]] : tensor<8x8x32x16xf32> into tensor<64x512xf32>
+   return %collapsed : tensor<64x512xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.collapse_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 vector_sizes [64, 1024] : !transform.any_op
+    transform.yield
+  } 
+}
+
+  // -----
+
+ // CHECK-LABEL: func @test_vectorize_collapseshape_no_vector_size
+func.func @test_vectorize_collapseshape_no_vector_size(%source: tensor<8x8x32x16xf32>, %dest: tensor<64x512xf32>) -> tensor<64x512xf32> {
+    // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+    // CHECK: %[[C0:.*]]= arith.constant 0 : index
+    // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, %[[CST]] {in_bounds = [true, true, true, true]} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+    // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<8x8x32x16xf32> to vector<64x512xf32>
+    // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x512xf32>
+    // CHECK: %[[C01:.*]] = arith.constant 0 : index
+    // CHECK: %[[WRIT:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true, true]} : vector<64x512xf32>, tensor<64x512xf32>
+    // CHECK: return %[[WRIT]] : tensor<64x512xf32>
+  %collapsed = tensor.collapse_shape %source [[0, 1], [2, 3]] : tensor<8x8x32x16xf32> into tensor<64x512xf32>
+   return %collapsed : tensor<64x512xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.collapse_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 : !transform.any_op
+    transform.yield
+  } 
+}
+
+  // -----
+
+ // CHECK-LABEL: func @test_vectorize_expandshape
+func.func @test_vectorize_expandshape(%source: tensor<64x512xf32>, %dest: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
+    // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+    // CHECK: %[[C0:.*]]= arith.constant 0 : index
+    // CHECK: %[[C64:.*]] = arith.constant 64 : index
+    // CHECK: %[[C512:.*]] = arith.constant 512 : index
+    // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C64]], %[[C512]] : vector<64x1024xi1>
+    // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<64x1024xi1> -> vector<64x1024xf32>
+    // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<64x1024xf32> to vector<8x8x32x32xf32>
+    // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<8x8x32x16xf32>
+    // CHECK: %[[C01:.*]]= arith.constant 0 : index
+    // CHECK: %[[C8:.*]] = arith.constant 8 : index
+    // CHECK: %[[C80:.*]] = arith.constant 8 : index
+    // CHECK: %[[C32:.*]] = arith.constant 32 : index
+    // CHECK: %[[C16:.*]] = arith.constant 16 : index
+    // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<8x8x3...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 1, 2024

@llvm/pr-subscribers-mlir

Author: xiaohui1.xu (BRUCE11111)

Changes

Patch is 26.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97297.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+334)
  • (modified) mlir/test/Dialect/Linalg/vectorization.mlir (+192)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4eb334f8bbbfa..e0fd5f1b14070 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3388,8 +3388,9 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
 
   // TODO: Check that the correct number of vectorSizes was provided.
   for (Operation *target : targets) {
-    if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
-            target)) {
+    if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp,
+             tensor::BitcastOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp,
+             tensor::ConcatOp>(target)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Unsupported Op, cannot vectorize";
     }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 3a75d2ac08157..7a4db82749fd1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1718,6 +1718,209 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   return success();
 }
 
+/// Vectorize a `tensor::expandshape` to these 3 Ops:
+///   Vector::TransferReadOp - Reads a vector from the source tensor
+///   ShapeCastOp - Reshape the data based on the target.
+///   vector::TransferWriteOp. - Write the result vector back to the destination
+///   tensor
+static LogicalResult lowerTensorReshape(RewriterBase &rewriter,
+                                        Operation *inputOp,
+                                        ArrayRef<int64_t> inputVectorSizes,
+                                        SmallVectorImpl<Value> &newResults) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(inputOp);
+  auto src = inputOp->getOperand(0);
+  auto srcType = mlir::dyn_cast<ShapedType>(src.getType());
+  auto result = inputOp->getResults()[0];
+  auto resultType = mlir::dyn_cast<ShapedType>(result.getType());
+  ArrayRef<int64_t> resultShape = resultType.getShape();
+  ArrayRef<int64_t> srcShape = srcType.getShape();
+  Location loc = inputOp->getLoc();
+
+  llvm::SmallVector<int64_t> srcVectorizedShape;
+  llvm::SmallDenseMap<int64_t, int64_t> shapeScales;
+
+  auto getVectorizeShape = [&](ArrayRef<int64_t> &retShape,
+                               ArrayRef<int64_t> &inputShape) {
+    bool isResultShapeBigger = srcType.getRank() < resultType.getRank();
+
+    int64_t cur = 1, resultIdx = 0;
+    for (auto [srcIdx, ss] : llvm::enumerate(inputShape)) {
+      cur *= ss;
+      if (!isResultShapeBigger) {
+        // collapse
+        srcVectorizedShape.emplace_back(ss);
+        if (cur == retShape[resultIdx]) {
+          if (shapeScales.count(resultIdx)) {
+            srcVectorizedShape.back() *= shapeScales[resultIdx];
+          }
+          cur = 1;
+          resultIdx++;
+        }
+      } else {
+        // expand
+        if (cur == retShape[resultIdx]) {
+          srcVectorizedShape.emplace_back(cur);
+          if (shapeScales.count(srcIdx)) {
+            srcVectorizedShape.back() *= shapeScales[srcIdx];
+          }
+          cur = 1;
+          resultIdx++;
+        }
+      }
+    }
+  };
+  if (!inputVectorSizes.empty()) {
+    for (auto [idx, vs] : llvm::enumerate(inputVectorSizes)) {
+      if (vs != resultShape[idx])
+        shapeScales[idx] = vs / resultShape[idx];
+    }
+
+    bool isResultShapeBigger = srcType.getRank() < resultType.getRank();
+    if (!isResultShapeBigger) {
+      getVectorizeShape(resultShape, srcShape);
+    } else {
+      getVectorizeShape(srcShape, resultShape);
+    }
+  } else {
+    srcVectorizedShape.assign(srcShape.begin(), srcShape.end());
+  }
+  // read
+  auto padValue = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getZeroAttr(srcType.getElementType()));
+  Value readResult = vector::createReadOrMaskedRead(
+      rewriter, loc, src,
+      inputVectorSizes.empty() ? srcType.getShape() : srcVectorizedShape,
+      padValue, false);
+
+  auto shapeCastType =
+      VectorType::get(inputVectorSizes.empty() ? resultShape : inputVectorSizes,
+                      resultType.getElementType());
+  vector::ShapeCastOp shapeCastOp =
+      rewriter.create<vector::ShapeCastOp>(loc, shapeCastType, readResult);
+
+  // write
+  SmallVector<OpFoldResult> destSizes;
+  for (auto size : resultShape) {
+    destSizes.emplace_back(rewriter.getIndexAttr(size));
+  }
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, shapeCastOp->getResults()[0], destSizes,
+      inputVectorSizes.empty() ? resultShape : inputVectorSizes, false);
+  newResults.push_back(write->getResult(0));
+  return success();
+}
+
+/// Vectorize a `tensor::bitcast` to these 3 Ops:
+///   vector::TransferReadOp - Reads a vector from the source tensor
+///   vector.Bitcast - Bitcast the data based on the target.
+///   vector::TransferWriteOp. - Write the result vector back to the destination
+///   tensor
+static LogicalResult lowerTensorBitcastOp(RewriterBase &rewriter,
+                                          tensor::BitcastOp bitCastOp,
+                                          ArrayRef<int64_t> inputVectorSizes,
+                                          SmallVectorImpl<Value> &newResults) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(bitCastOp);
+
+  auto sourceType = bitCastOp.getSource().getType();
+  auto resultType = bitCastOp.getResult().getType();
+  auto resultShape = resultType.getShape();
+  if (inputVectorSizes.empty()) {
+    inputVectorSizes = resultShape;
+  }
+  Location loc = bitCastOp->getLoc();
+
+  // read
+  auto padValue = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getZeroAttr(sourceType.getElementType()));
+  Value readResult = vector::createReadOrMaskedRead(
+      rewriter, loc, bitCastOp.getSource(), inputVectorSizes, padValue, false);
+
+  // bitcast
+  auto resultVectorType =
+      VectorType::get(inputVectorSizes, resultType.getElementType());
+  vector::BitCastOp vectorbitCastOp =
+      rewriter.create<vector::BitCastOp>(loc, resultVectorType, readResult);
+
+  // write
+  llvm::SmallVector<OpFoldResult> destSizes;
+  for (auto size : resultShape)
+    destSizes.emplace_back(rewriter.getIndexAttr(size));
+  auto write =
+      createWriteOrMaskedWrite(rewriter, loc, vectorbitCastOp->getResult(0),
+                               destSizes, inputVectorSizes, false);
+  newResults.push_back(write->getResults()[0]);
+  return success();
+}
+
+/// Vectorize a `tensor::concat` to these 3 Ops:
+///   Tensor::EmptyOp - The result tensor.
+///   Vector::TransferWriteOp - Write the result vector back to the destination
+///   tensor.
+///   Vector::TransferWriteOp - Write the result vector back to the destination
+///   tensor.
+static LogicalResult lowerTensorConcatOp(RewriterBase &rewriter,
+                                         tensor::ConcatOp concatOp,
+                                         ArrayRef<int64_t> inputVectorSizes,
+                                         SmallVectorImpl<Value> &newResults) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(concatOp);
+
+  Location loc = concatOp.getLoc();
+  FailureOr<Value> dest =
+      tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
+  if (failed(dest))
+    return failure();
+
+  auto empty = dest->getDefiningOp<tensor::EmptyOp>();
+  if (!empty)
+    return failure();
+
+  // Compute the partial sums for the slice offsets.
+  auto dim = concatOp.getDim();
+  Value dimValue =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dim));
+
+  int64_t rank = concatOp.getResultType().getRank();
+  auto srcType =
+      mlir::dyn_cast<RankedTensorType>(concatOp->getResultTypes()[0]);
+  auto padValue = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getZeroAttr(srcType.getElementType()));
+
+  // Construct the chain of insert_slice ops into the destination.
+  Value result = *dest;
+  Value previous_offset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  for (auto [idx, input] : llvm::enumerate(concatOp.getInputs())) {
+
+    SmallVector<OpFoldResult> sizes =
+        tensor::getMixedSizes(rewriter, loc, input);
+    SmallVector<int64_t> readMaskShape;
+    auto inputType = mlir::dyn_cast<RankedTensorType>(input.getType());
+    auto sourceShape = inputType.getShape();
+
+    readMaskShape.append(sourceShape.begin(), sourceShape.end());
+    Value readResult = vector::createReadOrMaskedRead(
+        rewriter, loc, input, sourceShape, padValue, false);
+    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    SmallVector<Value> indices(rank, zero);
+    indices[dim] = previous_offset;
+    result = rewriter
+                 .create<vector::TransferWriteOp>(
+                     loc, readResult, result, indices,
+                     rewriter.getMultiDimIdentityMap(rank))
+                 ->getResults()[0];
+    if (idx != concatOp.getNumOperands() - 1) {
+      auto dimOp = rewriter.create<tensor::DimOp>(loc, input, dimValue);
+      previous_offset =
+          rewriter.create<arith::AddIOp>(loc, dimOp, previous_offset);
+    }
+  }
+
+  newResults.push_back(result);
+  return success();
+}
+
 // TODO: probably need some extra checks for reduction followed by consumer
 // ops that may not commute (e.g. linear reduction + non-linear instructions).
 static LogicalResult reductionPreconditions(LinalgOp op) {
@@ -1931,6 +2134,108 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
   return success();
 }
 
+static LogicalResult
+lowerExpandOpPrecondition(tensor::ExpandShapeOp expandOp,
+                          ArrayRef<int64_t> inputVectorSizes) {
+  auto resultType = expandOp->getResultTypes()[0];
+  auto resultShape = mlir::dyn_cast<ShapedType>(resultType);
+  // check reassociation
+  llvm::SmallVector<int64_t> associateIndices;
+  for (auto &attr : expandOp.getReassociation()) {
+    for (auto &indice : mlir::dyn_cast<ArrayAttr>(attr)) {
+      associateIndices.push_back(mlir::dyn_cast<IntegerAttr>(indice).getInt());
+    }
+  }
+
+  if (llvm::any_of(associateIndices,
+                   [](int64_t x) { return x == ShapedType::kDynamic; })) {
+    LDBG("Reassociation must be static: " << expandOp << "\n");
+    return failure();
+  }
+  // check input and output shape
+  if (!resultShape.hasStaticShape() ||
+      !expandOp.getSrcType().hasStaticShape()) {
+    LDBG("Input and output shape must be static: " << expandOp << "\n");
+    return failure();
+  }
+  if (!inputVectorSizes.empty() &&
+      failed(vector::isValidMaskedInputVector(resultShape.getShape(),
+                                              inputVectorSizes)))
+    return failure();
+
+  return success();
+}
+
+static LogicalResult
+lowerBitcastOpPrecondition(tensor::BitcastOp bitCastOp,
+                           ArrayRef<int64_t> inputVectorSizes) {
+  auto resultType = bitCastOp->getResultTypes()[0];
+  auto resultShapeType = mlir::dyn_cast<ShapedType>(resultType);
+  auto srcType = bitCastOp.getSource().getType();
+  auto srcShapeType = mlir::dyn_cast<ShapedType>(srcType);
+
+  bool isStaticInputOutput =
+      resultShapeType.hasStaticShape() && srcShapeType.hasStaticShape();
+  if (!isStaticInputOutput) {
+    LDBG("Input and output shape must be static: " << bitCastOp << "\n");
+    return failure();
+  }
+
+  if (!inputVectorSizes.empty() &&
+      failed(vector::isValidMaskedInputVector(resultShapeType.getShape(),
+                                              inputVectorSizes)))
+    return failure();
+  return success();
+}
+
+static LogicalResult
+lowerCollapseShapeOpPrecondition(tensor::CollapseShapeOp collapseOp,
+                                 ArrayRef<int64_t> inputVectorSizes) {
+  auto resultType = collapseOp->getResultTypes()[0];
+  auto resultShapeType = mlir::dyn_cast<ShapedType>(resultType);
+  auto srcShapeType = collapseOp.getSrcType();
+
+  bool isStaticInputOutput =
+      resultShapeType.hasStaticShape() && srcShapeType.hasStaticShape();
+  if (!isStaticInputOutput) {
+    LDBG("Input and output shape must be static: " << collapseOp << "\n");
+    return failure();
+  }
+
+  if (!inputVectorSizes.empty() &&
+      failed(vector::isValidMaskedInputVector(resultShapeType.getShape(),
+                                              inputVectorSizes)))
+    return failure();
+  return success();
+}
+
+static LogicalResult
+lowerConcatOpPrecondition(tensor::ConcatOp concatOp,
+                          ArrayRef<int64_t> inputVectorSizes) {
+  if (!inputVectorSizes.empty()) {
+    LDBG("Concat operation do not support specify inputVectorSizes: "
+         << concatOp << "\n");
+  }
+  for (auto x : concatOp->getOperands()) {
+    auto type = mlir::dyn_cast<ShapedType>(x.getType());
+    if (!type) {
+      LDBG("Operation type error: " << concatOp << "\n");
+      return failure();
+    }
+    if (!type.hasStaticShape()) {
+      LDBG("Type must be static: " << concatOp << "\n");
+      return failure();
+    }
+  }
+  auto dim = concatOp.getDim();
+  if (dim >= (uint64_t)concatOp.getResultType().getRank()) {
+    LDBG("Invalid dim: " << concatOp << "\n");
+    return failure();
+  }
+
+  return success();
+}
+
 /// Preconditions for scalable vectors.
 static LogicalResult
 vectorizeScalableVectorPrecondition(Operation *op,
@@ -1976,6 +2281,19 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
       .Case<tensor::UnPackOp>([&](auto unpackOp) {
         return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
       })
+      .Case<tensor::ExpandShapeOp>([&](auto expandShapeOp) {
+        return lowerExpandOpPrecondition(expandShapeOp, inputVectorSizes);
+      })
+      .Case<tensor::CollapseShapeOp>([&](auto collapseShapeOp) {
+        return lowerCollapseShapeOpPrecondition(collapseShapeOp,
+                                                inputVectorSizes);
+      })
+      .Case<tensor::BitcastOp>([&](auto bitCastOp) {
+        return lowerBitcastOpPrecondition(bitCastOp, inputVectorSizes);
+      })
+      .Case<tensor::ConcatOp>([&](auto concatOp) {
+        return lowerConcatOpPrecondition(concatOp, inputVectorSizes);
+      })
       .Default([](auto) { return failure(); });
 }
 
@@ -2075,6 +2393,22 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
             return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
                                              inputVectorSizes, results);
           })
+          .Case<tensor::ExpandShapeOp>([&](auto expandShapeOp) {
+            return lowerTensorReshape(rewriter, expandShapeOp, inputVectorSizes,
+                                      results);
+          })
+          .Case<tensor::CollapseShapeOp>([&](auto collapseShapeOp) {
+            return lowerTensorReshape(rewriter, collapseShapeOp,
+                                      inputVectorSizes, results);
+          })
+          .Case<tensor::BitcastOp>([&](auto bitCastOp) {
+            return lowerTensorBitcastOp(rewriter, bitCastOp, inputVectorSizes,
+                                        results);
+          })
+          .Case<tensor::ConcatOp>([&](auto concatOp) {
+            return lowerTensorConcatOp(rewriter, concatOp, inputVectorSizes,
+                                       results);
+          })
           .Default([](auto) { return failure(); });
 
   if (failed(vectorizeResult)) {
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index bbeccc7fecd68..114815b4e3de8 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1055,3 +1055,195 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf
     transform.yield
   } 
  }
+
+  // -----
+
+ // CHECK-LABEL: func @test_vectorize_collapseshape
+func.func @test_vectorize_collapseshape(%source: tensor<8x8x32x16xf32>, %dest: tensor<64x512xf32>) -> tensor<64x512xf32> {
+    // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+    // CHECK: %[[C0:.*]]= arith.constant 0 : index
+    // CHECK: %[[C8:.*]] = arith.constant 8 : index
+    // CHECK: %[[C80:.*]] = arith.constant 8 : index
+    // CHECK: %[[C32:.*]] = arith.constant 32 : index
+    // CHECK: %[[C16:.*]] = arith.constant 16 : index
+    // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<8x8x32x32xi1>
+    // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<8x8x32x32xi1> -> vector<8x8x32x32xf32>
+    // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<8x8x32x32xf32> to vector<64x1024xf32>
+    // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x512xf32>
+    // CHECK: %[[C01:.*]] = arith.constant 0 : index
+    // CHECK: %[[C64:.*]] = arith.constant 64 : index
+    // CHECK: %[[C512:.*]] = arith.constant 512 : index
+    // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C64]], %[[C512]] : vector<64x1024xi1>
+    // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] {{.*}} : vector<64x1024xi1> -> tensor<64x512xf32>
+    // CHECK: return %[[WRIT]] : tensor<64x512xf32>
+  %collapsed = tensor.collapse_shape %source [[0, 1], [2, 3]] : tensor<8x8x32x16xf32> into tensor<64x512xf32>
+   return %collapsed : tensor<64x512xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.collapse_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 vector_sizes [64, 1024] : !transform.any_op
+    transform.yield
+  } 
+}
+
+  // -----
+
+ // CHECK-LABEL: func @test_vectorize_collapseshape_no_vector_size
+func.func @test_vectorize_collapseshape_no_vector_size(%source: tensor<8x8x32x16xf32>, %dest: tensor<64x512xf32>) -> tensor<64x512xf32> {
+    // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+    // CHECK: %[[C0:.*]]= arith.constant 0 : index
+    // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, %[[CST]] {in_bounds = [true, true, true, true]} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+    // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<8x8x32x16xf32> to vector<64x512xf32>
+    // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x512xf32>
+    // CHECK: %[[C01:.*]] = arith.constant 0 : index
+    // CHECK: %[[WRIT:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true, true]} : vector<64x512xf32>, tensor<64x512xf32>
+    // CHECK: return %[[WRIT]] : tensor<64x512xf32>
+  %collapsed = tensor.collapse_shape %source [[0, 1], [2, 3]] : tensor<8x8x32x16xf32> into tensor<64x512xf32>
+   return %collapsed : tensor<64x512xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.collapse_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 : !transform.any_op
+    transform.yield
+  } 
+}
+
+  // -----
+
+ // CHECK-LABEL: func @test_vectorize_expandshape
+func.func @test_vectorize_expandshape(%source: tensor<64x512xf32>, %dest: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
+    // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+    // CHECK: %[[C0:.*]]= arith.constant 0 : index
+    // CHECK: %[[C64:.*]] = arith.constant 64 : index
+    // CHECK: %[[C512:.*]] = arith.constant 512 : index
+    // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C64]], %[[C512]] : vector<64x1024xi1>
+    // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<64x1024xi1> -> vector<64x1024xf32>
+    // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<64x1024xf32> to vector<8x8x32x32xf32>
+    // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<8x8x32x16xf32>
+    // CHECK: %[[C01:.*]]= arith.constant 0 : index
+    // CHECK: %[[C8:.*]] = arith.constant 8 : index
+    // CHECK: %[[C80:.*]] = arith.constant 8 : index
+    // CHECK: %[[C32:.*]] = arith.constant 32 : index
+    // CHECK: %[[C16:.*]] = arith.constant 16 : index
+    // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<8x8x3...
[truncated]

@hanhanW hanhanW requested review from banach-space and dcaballe July 1, 2024 16:34
SmallVectorImpl<Value> &newResults) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(inputOp);
auto src = inputOp->getOperand(0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please expand auto unless the type is obvious from line-level context.

rewriter.setInsertionPoint(inputOp);
auto src = inputOp->getOperand(0);
auto srcType = mlir::dyn_cast<ShapedType>(src.getType());
auto result = inputOp->getResults()[0];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In MLIR, we generally dislike "raw" indexed accessors. Can this function be turned into a function template and use named accessors instead?

ArrayRef<int64_t> srcShape = srcType.getShape();
Location loc = inputOp->getLoc();

llvm::SmallVector<int64_t> srcVectorizedShape;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to prefix SmallVector with llvm::.

llvm::SmallVector<int64_t> srcVectorizedShape;
llvm::SmallDenseMap<int64_t, int64_t> shapeScales;

auto getVectorizeShape = [&](ArrayRef<int64_t> &retShape,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArrayRef is already a reference as the name indicates, there is no need to pass it by reference.

for (auto [srcIdx, ss] : llvm::enumerate(inputShape)) {
cur *= ss;
if (!isResultShapeBigger) {
// collapse
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use full sentences in comments, including capitalization and trailing full stops.

// check reassociation
llvm::SmallVector<int64_t> associateIndices;
for (auto &attr : expandOp.getReassociation()) {
for (auto &indice : mlir::dyn_cast<ArrayAttr>(attr)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: the singular of "indices" is "index".

Comment on lines +2221 to +2224
if (!type) {
LDBG("Operation type error: " << concatOp << "\n");
return failure();
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this ever happen?

lowerConcatOpPrecondition(tensor::ConcatOp concatOp,
ArrayRef<int64_t> inputVectorSizes) {
if (!inputVectorSizes.empty()) {
LDBG("Concat operation do not support specify inputVectorSizes: "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
LDBG("Concat operation do not support specify inputVectorSizes: "
LDBG("Concat operation does not support specify inputVectorSizes: "

@@ -1976,6 +2281,19 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
.Case<tensor::UnPackOp>([&](auto unpackOp) {
return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
})
.Case<tensor::ExpandShapeOp>([&](auto expandShapeOp) {
return lowerExpandOpPrecondition(expandShapeOp, inputVectorSizes);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a strange choice to call these lowerFoo when everything above is called vectorizeFoo.

// CHECK: %[[C32:.*]] = arith.constant 32 : index
// CHECK: %[[C16:.*]] = arith.constant 16 : index
// CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<8x8x32x32xi1>
// CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<8x8x32x32xi1> -> vector<8x8x32x32xf32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't check that we actually read inside the mask. Same below.

@hanhanW
Copy link
Contributor

hanhanW commented Jul 9, 2024

I'm not sure if it is a good idea to vectorize expand_shape, collapse_shape into vector.shape_cast. Because the tensor version does not have data-movement behavior but the vector.shape_cast does. (The doc of vector.shape_cast is outdated.) Is there a RFC or discussion about why we are adding vectorization patterns for these ops? In our use cases, we handle all these things at graph level, so we don't really need to think about these vectorization.

Also, can you add some high level overview to PR description? I'm sorry that I can't look at implementation details because of my own priorities, so I can't advice much now. However, having an overview of a big change (+529-2) is generally good for reviewers and others.

On the other hand, concat and bitcast is new to me. Looking at the doc, tensor.bitcast is not equivalent to vector.bitcast. E.g., %0 = tensor.bitcast %arg0 : tensor<2x4xi8> to tensor<2x2xi16> is not valid, but %0 = vector.bitcast %arg0 : vector<2x4xi8> to vector<2x2xi16> is valid. IMO, tensor/linalg ops have higher semantics versus vector ops. Lowering a tensor op to vector ops with adding more semantics looks bad to me.

@hanhanW hanhanW requested a review from MaheshRavishankar July 9, 2024 17:19
@banach-space
Copy link
Contributor

Sorry, I've been travelling and didn't have the time to reply earlier.

Basically, I agree with all the points from @hanhanW, thanks! I think that we should understand the motivation a bit better before proceeding with this.

Is there a RFC or discussion about why we are adding vectorization patterns for these ops?

I believe it is this one:

We haven't really agreed on any specific path forward.

@BRUCE11111
Copy link
Author

I believe it is this one:

We haven't really agreed on any specific path forward.

Yes, the current work is about this RFC.

Motivation:

The current community discussion has reached a certain degree of consensus on the RFC work, such as vector-to-loops, etc. This work is necessary. But some detailed designs are still under discussion.

For the current PR, we only need to know, before converting large virtual vector to scf.for + tiled vector, we first need to convert the operation to a vector operation. So I want to expand the existing functions of the community.


vectorize expand_shape, collapse_shape into vector.shape_cast.

In the current vector dialect, we only have shape_cast that can align semantically with tensor.expand_shape, so I currently use shape_cast.

tensor.bitcast is not equivalent to vector.bitcast.

Thanks for pointing out that, I noticed that differentce between tensor.bitcast and vector.bitcast, I will convert tensor.bitcast → vector.tranfer_read + arith.bitcast + vector.transfer_write.


In addition, I have a mask-related question for @hanhanW , which is written in another PR#97248.

Sincere Thanks~

@banach-space
Copy link
Contributor

The current community discussion has reached a certain degree of consensus on the RFC work,

I don't believe that that's the case.

This work is necessary.

Why?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants