Skip to content

[mlir][VectorToSPIRV] Add conversion for vector.extract with dynamic indices #114137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 6, 2024

Conversation

Groverkss
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Oct 29, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Kunwar Grover (Groverkss)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/114137.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+26-22)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+42)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 6184225cb6285d..ee8dccf025a0c6 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -40,22 +41,9 @@ using namespace mlir;
 /// Returns the integer value from the first valid input element, assuming Value
 /// inputs are defined by a constant index ops and Attribute inputs are integer
 /// attributes.
-static uint64_t getFirstIntValue(ValueRange values) {
-  return values[0].getDefiningOp<arith::ConstantIndexOp>().value();
-}
-static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
-  return cast<IntegerAttr>(attr[0]).getInt();
-}
 static uint64_t getFirstIntValue(ArrayAttr attr) {
   return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
 }
-static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
-  auto attr = foldResults[0].dyn_cast<Attribute>();
-  if (attr)
-    return getFirstIntValue(attr);
-
-  return getFirstIntValue(ValueRange{foldResults[0].get<Value>()});
-}
 
 /// Returns the number of bits for the given scalar/vector type.
 static int getNumBits(Type type) {
@@ -157,9 +145,6 @@ struct VectorExtractOpConvert final
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (extractOp.hasDynamicPosition())
-      return failure();
-
     Type dstType = getTypeConverter()->convertType(extractOp.getType());
     if (!dstType)
       return failure();
@@ -169,9 +154,17 @@ struct VectorExtractOpConvert final
       return success();
     }
 
-    int32_t id = getFirstIntValue(extractOp.getMixedPosition());
-    rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
-        extractOp, adaptor.getVector(), id);
+    std::optional<int64_t> id =
+        getConstantIntValue(extractOp.getMixedPosition()[0]);
+
+    if (id.has_value())
+      rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+          extractOp, dstType, adaptor.getVector(),
+          rewriter.getI32ArrayAttr(id.value()));
+    else
+      rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
+          extractOp, dstType, adaptor.getVector(),
+          adaptor.getDynamicPosition()[0]);
     return success();
   }
 };
@@ -249,9 +242,20 @@ struct VectorInsertOpConvert final
       return success();
     }
 
-    int32_t id = getFirstIntValue(insertOp.getMixedPosition());
-    rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
-        insertOp, adaptor.getSource(), adaptor.getDest(), id);
+    std::optional<int64_t> id =
+        getConstantIntValue(insertOp.getMixedPosition()[0]);
+
+    //    rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+    //        insertOp, adaptor.getSource(), adaptor.getDest(), id);
+    //    return success();
+
+    if (id.has_value())
+      rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+          insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+    else
+      rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
+          insertOp, insertOp.getDest(), adaptor.getSource(),
+          adaptor.getDynamicPosition()[0]);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 25ec5d0159bd5d..62210108aa73cf 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -186,6 +186,26 @@ func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: @extract_dynamic
+//  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index
+//       CHECK:   %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
+//       CHECK:   spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
+func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 {
+  %0 = vector.extract %arg0[%id] : f32 from vector<4xf32>
+  return %0: f32
+}
+
+// CHECK-LABEL: @extract_dynamic_cst
+//  CHECK-SAME: %[[V:.*]]: vector<4xf32>
+//       CHECK:   spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
+func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
+  %idx = arith.constant 1 : index
+  %0 = vector.extract %arg0[%idx] : f32 from vector<4xf32>
+  return %0: f32
+}
+
+// -----
+
 // CHECK-LABEL: @insert
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32
 //       CHECK:   spirv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>
@@ -216,6 +236,28 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3
 
 // -----
 
+// CHECK-LABEL: @insert_dynamic
+//  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index
+//       CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
+//       CHECK:   spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
+func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
+  %0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_dynamic_cst
+//  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>
+//       CHECK:   spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32>
+func.func @insert_dynamic_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> {
+  %idx = arith.constant 2 : index
+  %0 = vector.insert %val, %arg0[%idx] : f32 into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @extract_element
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
 //       CHECK:   spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32

@Groverkss Groverkss requested a review from kuhar November 5, 2024 09:02
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@Groverkss Groverkss force-pushed the vector-insert-extract-dynamic-spirv branch from 9d41a73 to 46795c8 Compare November 5, 2024 23:09
@Groverkss Groverkss merged commit c96a85a into llvm:main Nov 6, 2024
6 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants