Skip to content

[mlir][spirv] Support vector.step in vector to spirv conversion #100651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 26, 2024

Conversation

angelz913
Copy link
Contributor

@angelz913 angelz913 commented Jul 25, 2024

Added a conversion pattern and LIT tests for lowering vector.step to SPIR-V.
Fixes: #100602

@llvmbot
Copy link
Member

llvmbot commented Jul 25, 2024

@llvm/pr-subscribers-mlir

Author: Angel Zhang (angelz913)

Changes

Added a conversion pattern and LIT tests for lowering vector.step to SPIR-V. Related issue: #100602


Full diff: https://github.com/llvm/llvm-project/pull/100651.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+39-2)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+28)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 527fbe5cf628a..8b5789f9e8497 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -906,6 +906,42 @@ struct VectorReductionToFPDotProd final
   }
 };
 
+struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    Type dstType = typeConverter.convertType(stepOp.getType());
+    if (!dstType)
+      return failure();
+
+    Location loc = stepOp.getLoc();
+    int64_t numElements = stepOp.getType().getNumElements();
+    auto intType =
+        rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
+
+    // Input vectors of size 1 are converted to scalars by the type converter.
+    // We just create a constant in this case.
+    if (numElements == 1) {
+      Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
+      rewriter.replaceOp(stepOp, zero);
+      return success();
+    }
+
+    SmallVector<Value> source;
+    for (int64_t i = 0; i < numElements; ++i) {
+      Attribute intAttr = rewriter.getIntegerAttr(intType, i);
+      Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
+      source.push_back(constOp);
+    }
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
+                                                             source);
+    return success();
+  }
+};
+
 } // namespace
 #define CL_INT_MAX_MIN_OPS                                                     \
   spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -929,8 +965,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
       VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
       VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
-      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
-      typeConverter, patterns.getContext(), PatternBenefit(1));
+      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
+      VectorStepOpConvert>(typeConverter, patterns.getContext(),
+                           PatternBenefit(1));
 
   // Make sure that the more specialized dot product pattern has higher benefit
   // than the generic one that extracts all elements.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index edad208749930..016c9e141a712 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -794,6 +794,34 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
 
 // -----
 
+// CHECK-LABEL: @step
+//  CHECK-SAME: ()
+//       CHECK:   %[[CST0:.*]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST1:.*]] = spirv.Constant 1 : i32
+//       CHECK:   %[[CST2:.*]] = spirv.Constant 2 : i32
+//       CHECK:   %[[CST3:.*]] = spirv.Constant 3 : i32
+//       CHECK:   %[[CONSTRUCT:.*]] = spirv.CompositeConstruct %[[CST0]], %[[CST1]], %[[CST2]], %[[CST3]] : (i32, i32, i32, i32) -> vector<4xi32>
+//       CHECK:   %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CONSTRUCT]] : vector<4xi32> to vector<4xindex>
+//       CHECK:   return %[[CAST]] : vector<4xindex>
+func.func @step() -> vector<4xindex> {
+  %0 = vector.step : vector<4xindex>
+  return %0 : vector<4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @step_size1
+//  CHECK-SAME: ()
+//       CHECK:   %[[CST0:.*]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CST0]] : i32 to vector<1xindex>
+//       CHECK:   return %[[CAST]] : vector<1xindex>
+func.func @step_size1() -> vector<1xindex> {
+  %0 = vector.step : vector<1xindex>
+  return %0 : vector<1xindex>
+}
+
+// -----
+
 module attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>

@llvmbot
Copy link
Member

llvmbot commented Jul 25, 2024

@llvm/pr-subscribers-mlir-spirv

Author: Angel Zhang (angelz913)

Changes

Added a conversion pattern and LIT tests for lowering vector.step to SPIR-V. Related issue: #100602


Full diff: https://github.com/llvm/llvm-project/pull/100651.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+39-2)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+28)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 527fbe5cf628a..8b5789f9e8497 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -906,6 +906,42 @@ struct VectorReductionToFPDotProd final
   }
 };
 
+struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    Type dstType = typeConverter.convertType(stepOp.getType());
+    if (!dstType)
+      return failure();
+
+    Location loc = stepOp.getLoc();
+    int64_t numElements = stepOp.getType().getNumElements();
+    auto intType =
+        rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
+
+    // Input vectors of size 1 are converted to scalars by the type converter.
+    // We just create a constant in this case.
+    if (numElements == 1) {
+      Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
+      rewriter.replaceOp(stepOp, zero);
+      return success();
+    }
+
+    SmallVector<Value> source;
+    for (int64_t i = 0; i < numElements; ++i) {
+      Attribute intAttr = rewriter.getIntegerAttr(intType, i);
+      Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
+      source.push_back(constOp);
+    }
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
+                                                             source);
+    return success();
+  }
+};
+
 } // namespace
 #define CL_INT_MAX_MIN_OPS                                                     \
   spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -929,8 +965,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
       VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
       VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
-      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
-      typeConverter, patterns.getContext(), PatternBenefit(1));
+      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
+      VectorStepOpConvert>(typeConverter, patterns.getContext(),
+                           PatternBenefit(1));
 
   // Make sure that the more specialized dot product pattern has higher benefit
   // than the generic one that extracts all elements.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index edad208749930..016c9e141a712 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -794,6 +794,34 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
 
 // -----
 
+// CHECK-LABEL: @step
+//  CHECK-SAME: ()
+//       CHECK:   %[[CST0:.*]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST1:.*]] = spirv.Constant 1 : i32
+//       CHECK:   %[[CST2:.*]] = spirv.Constant 2 : i32
+//       CHECK:   %[[CST3:.*]] = spirv.Constant 3 : i32
+//       CHECK:   %[[CONSTRUCT:.*]] = spirv.CompositeConstruct %[[CST0]], %[[CST1]], %[[CST2]], %[[CST3]] : (i32, i32, i32, i32) -> vector<4xi32>
+//       CHECK:   %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CONSTRUCT]] : vector<4xi32> to vector<4xindex>
+//       CHECK:   return %[[CAST]] : vector<4xindex>
+func.func @step() -> vector<4xindex> {
+  %0 = vector.step : vector<4xindex>
+  return %0 : vector<4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @step_size1
+//  CHECK-SAME: ()
+//       CHECK:   %[[CST0:.*]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CST0]] : i32 to vector<1xindex>
+//       CHECK:   return %[[CAST]] : vector<1xindex>
+func.func @step_size1() -> vector<1xindex> {
+  %0 = vector.step : vector<1xindex>
+  return %0 : vector<1xindex>
+}
+
+// -----
+
 module attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % nits

@kuhar kuhar merged commit 599a91a into llvm:main Jul 26, 2024
7 checks passed
@angelz913 angelz913 deleted the vector-step-to-spirv branch August 8, 2024 20:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir][spirv] Support vector.step in vector to spirv conversion
3 participants