Skip to content

[mlir][vector] add support for linearizing vector.bitcast in VectorLinearize #123110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 27, 2025

Conversation

chencha3
Copy link
Contributor

This PR adds support for converting vector.bitcast working on ND (N > 1) vectors into the same op working on linearized (1D) vectors.

@llvmbot
Copy link
Member

llvmbot commented Jan 15, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Chao Chen (chencha3)

Changes

This PR adds support for converting vector.bitcast working on ND (N > 1) vectors into the same op working on linearized (1D) vectors.


Full diff: https://github.com/llvm/llvm-project/pull/123110.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+34-3)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+20-1)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 68535ae5a7a5c6..b450ea91fef651 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -72,13 +72,14 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
     auto resType =
         getTypeConverter()->convertType<VectorType>(constOp.getType());
 
+    if (!resType)
+      return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
     if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
       return rewriter.notifyMatchFailure(
           loc,
           "Cannot linearize a constant scalable vector that's not a splat");
 
-    if (!resType)
-      return rewriter.notifyMatchFailure(loc, "can't convert return type");
     if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           loc, "Can't flatten since targetBitWidth <= OpSize");
@@ -459,6 +460,35 @@ struct LinearizeVectorInsert final
 private:
   unsigned targetVectorBitWidth;
 };
+
+struct LinearizeVectorBitCast final
+    : public OpConversionPattern<vector::BitCastOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorBitCast(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
+  LogicalResult
+  matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = castOp.getLoc();
+    auto resType = getTypeConverter()->convertType(castOp.getType());
+    if (!resType)
+      return rewriter.notifyMatchFailure(loc, "can't convert return type.");
+
+    if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth))
+      return rewriter.notifyMatchFailure(
+          loc, "Can't flatten since targetBitWidth <= OpSize");
+
+    rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType, adaptor.getSource());
+    return mlir::success();
+  }
+private:
+  unsigned targetVectorBitWidth;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
@@ -486,6 +516,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   target.markUnknownOpDynamicallyLegal(
       [=](Operation *op) -> std::optional<bool> {
         if ((isa<arith::ConstantOp>(op) ||
+             isa<vector::BitCastOp>(op) ||
              op->hasTrait<OpTrait::Vectorizable>())) {
           return (isLessThanTargetBitWidth(op, targetBitWidth)
                       ? typeConverter.isLegal(op)
@@ -494,7 +525,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
         return std::nullopt;
       });
 
-  patterns.add<LinearizeConstant, LinearizeVectorizable>(
+  patterns.add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
       typeConverter, patterns.getContext(), targetBitWidth);
 }
 
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 543e76b5b26e0c..0358c2637f72b2 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -179,7 +179,7 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
 
 // ALL-LABEL:   func.func @test_extract_strided_slice_1_scalable(
 // ALL-SAME:    %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
-func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {  
+func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
   // ALL-NOT: vector.shuffle
   // ALL-NOT: vector.shape_cast
   // ALL: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [1, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]xf32> to vector<2x[8]xf32>
@@ -318,3 +318,22 @@ func.func @test_vector_extract_scalar() {
   %0 = vector.extract %cst[0] : i32 from vector<4xi32>
   return
 }
+
+// -----
+
+// ALL-LABEL: test_vector_bitcast
+// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x1xf32>)
+func.func @test_vector_bitcast(%arg0: vector<4x1xf32>) -> vector<4x2xf16> {
+
+  // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4xf32>
+  // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<4xf32> to vector<8xf16>
+  // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<8xf16> to vector<4x2xf16>
+
+  // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4xf32>
+  // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<4xf32> to vector<8xf16>
+  // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<8xf16> to vector<4x2xf16>
+
+  // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4x2xf16>
+  %1 = vector.bitcast %arg0 : vector<4x1xf32> to vector<4x2xf16>
+  return %1 : vector<4x2xf16>
+}

Copy link

github-actions bot commented Jan 15, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!


// ALL-LABEL: test_vector_bitcast
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x1xf32>)
func.func @test_vector_bitcast(%arg0: vector<4x1xf32>) -> vector<4x2xf16> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind changing the test so that it doesn't use a unit dim?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I updated it. Thanks for your feedback.

@@ -459,6 +460,37 @@ struct LinearizeVectorInsert final
private:
unsigned targetVectorBitWidth;
};

struct LinearizeVectorBitCast final
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add documentation and tests for scalable vectors (bailing out for scalable vectors would be fine, though this should just work).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I added comments and tests for scalable vectors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@banach-space Do you have any more concerns?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the delay—this slipped off my radar.

Everything looks good, but I have a small request regarding the tests.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

I've made some small requests re test formatting, but these are nits.

@@ -459,6 +460,37 @@ struct LinearizeVectorInsert final
private:
unsigned targetVectorBitWidth;
};

struct LinearizeVectorBitCast final
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the delay—this slipped off my radar.

Everything looks good, but I have a small request regarding the tests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind using more descriptive LIT variables ? (e.g. RES or DOWNCAST/UPCAST for vector.shape_cast) This would be more consistent with the existing convention within the file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, fixed them and thanks.

// -----

// ALL-LABEL: test_vector_bitcast
// ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x4xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Since the name of the argument is arg0.

Suggested change
// ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x4xf32>
// ALL-SAME: %[[ARG_0:.*]]: vector<4x4xf32>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, fixed and thanks.

@chencha3 chencha3 merged commit bd5d361 into main Jan 27, 2025
8 checks passed
@chencha3 chencha3 deleted the users/chencha3/VectorLinearize/bitcast branch January 27, 2025 20:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants