Skip to content

[mlir][vector] Propagate scalability to gather/scatter ptrs vector #97584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

c-rhodes
Copy link
Collaborator

@c-rhodes c-rhodes commented Jul 3, 2024

In convert-vector-to-llvm the first operand (vector of pointers holding
all memory addresses to read) to the masked.gather (and scatter)
intrinsic has a fixed vector type.

This may result in intrinsics where the scalable flag has been dropped:

  %0 = llvm.intr.masked.gather %1, %2, %3 {alignment = 4 : i32}
    : (!llvm.vec<4 x ptr>, vector<[4]xi1>, vector<[4]xi32>) -> vector<[4]xi32>

Fortunately the operand is overloaded on the result type so we end up
with the correct IR when lowering to LLVM, but this is still incorrect.
This patch fixes it by propagating scalability.

c-rhodes added 2 commits July 3, 2024 14:20
In convert-vector-to-llvm the first operand (vector of pointers holding
all memory addresses to read) to the masked.gather (and scatter)
intrinsic has a fixed vector type.

This may result in intrinsics where the scalable flag has been dropped:

  %0 = llvm.intr.masked.gather %1, %2, %3 {alignment = 4 : i32}
    : (!llvm.vec<4 x ptr>, vector<[4]xi1>, vector<[4]xi32>) -> vector<[4]xi32>

Fortunately the operand is overloaded on the result type so we end up
with the correct IR when lowering to LLVM, but this is still incorrect.
This patch fixes it by propagating scalability.
@llvmbot
Copy link
Member

llvmbot commented Jul 3, 2024

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Cullen Rhodes (c-rhodes)

Changes

In convert-vector-to-llvm the first operand (vector of pointers holding
all memory addresses to read) to the masked.gather (and scatter)
intrinsic has a fixed vector type.

This may result in intrinsics where the scalable flag has been dropped:

  %0 = llvm.intr.masked.gather %1, %2, %3 {alignment = 4 : i32}
    : (!llvm.vec&lt;4 x ptr&gt;, vector&lt;[4]xi1&gt;, vector&lt;[4]xi32&gt;) -&gt; vector&lt;[4]xi32&gt;

Fortunately the operand is overloaded on the result type so we end up
with the correct IR when lowering to LLVM, but this is still incorrect.
This patch fixes it by propagating scalability.


Full diff: https://github.com/llvm/llvm-project/pull/97584.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+11-10)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+25)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 0eac55255b133..77bdacbc46990 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -102,11 +102,13 @@ static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
                             const LLVMTypeConverter &typeConverter,
                             MemRefType memRefType, Value llvmMemref, Value base,
-                            Value index, uint64_t vLen) {
+                            Value index, VectorType vectorType) {
   assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
          "unsupported memref type");
   auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
-  auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
+  auto ptrsType =
+      LLVM::getVectorType(pType, vectorType.getDimSize(0),
+                          /*isScalable=*/vectorType.getScalableDims()[0]);
   return rewriter.create<LLVM::GEPOp>(
       loc, ptrsType, typeConverter.convertType(memRefType.getElementType()),
       base, index);
@@ -288,9 +290,9 @@ class VectorGatherOpConversion
     if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
       auto vType = gather.getVectorType();
       // Resolve address.
-      Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(),
-                                  memRefType, base, ptr, adaptor.getIndexVec(),
-                                  /*vLen=*/vType.getDimSize(0));
+      Value ptrs =
+          getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
+                         base, ptr, adaptor.getIndexVec(), vType);
       // Replace with the gather intrinsic.
       rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
           gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
@@ -305,8 +307,7 @@ class VectorGatherOpConversion
       // Resolve address.
       Value ptrs = getIndexedPtrs(
           rewriter, loc, typeConverter, memRefType, base, ptr,
-          /*index=*/vectorOperands[0],
-          LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
+          /*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
       // Create the gather intrinsic.
       return rewriter.create<LLVM::masked_gather>(
           loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
@@ -343,9 +344,9 @@ class VectorScatterOpConversion
     VectorType vType = scatter.getVectorType();
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
-    Value ptrs = getIndexedPtrs(
-        rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(),
-        ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
+    Value ptrs =
+        getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
+                       adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
 
     // Replace with the scatter intrinsic.
     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 09b79708a9ab2..6f8145b618b71 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2248,6 +2248,19 @@ func.func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3
 
 // -----
 
+func.func @gather_op_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2: vector<[3]xi1>, %arg3: vector<[3]xf32>) -> vector<[3]xf32> {
+  %0 = arith.constant 0: index
+  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<[3]xi32>, vector<[3]xi1>, vector<[3]xf32> into vector<[3]xf32>
+  return %1 : vector<[3]xf32>
+}
+
+// CHECK-LABEL: func @gather_op_scalable
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: return %[[G]] : vector<[3]xf32>
+
+// -----
+
 func.func @gather_op_global_memory(%arg0: memref<?xf32, 1>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
   %0 = arith.constant 0: index
   %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32, 1>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
@@ -2351,6 +2364,18 @@ func.func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<
 
 // -----
 
+func.func @scatter_op_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2: vector<[3]xi1>, %arg3: vector<[3]xf32>) {
+  %0 = arith.constant 0: index
+  vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<[3]xi32>, vector<[3]xi1>, vector<[3]xf32>
+  return
+}
+
+// CHECK-LABEL: func @scatter_op_scalable
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
+
+// -----
+
 func.func @scatter_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: vector<3xi1>, %arg3: vector<3xindex>) {
   %0 = arith.constant 0: index
   vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref<?xindex>, vector<3xindex>, vector<3xi1>, vector<3xindex>

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks for the fix! LGTM

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Is there a way to check this in the verifier?

@c-rhodes
Copy link
Collaborator Author

c-rhodes commented Jul 5, 2024

Thanks! Is there a way to check this in the verifier?

I've added verifiers for gather/scatter intrinsic to check this. Generally the intrinsics are quite loose with constraints, seems like a bit of a rabbit hole.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks!

@c-rhodes c-rhodes merged commit 1e7d6d3 into llvm:main Jul 9, 2024
7 checks passed
@c-rhodes c-rhodes deleted the mlir-vector-to-llvm-scalable-indexed-ptrs branch July 9, 2024 08:06
aaryanshukla pushed a commit to aaryanshukla/llvm-project that referenced this pull request Jul 14, 2024
…lvm#97584)

In convert-vector-to-llvm the first operand (vector of pointers holding
all memory addresses to read) to the masked.gather (and scatter)
intrinsic has a fixed vector type.

This may result in intrinsics where the scalable flag has been dropped:
```
  %0 = llvm.intr.masked.gather %1, %2, %3 {alignment = 4 : i32}
    : (!llvm.vec<4 x ptr>, vector<[4]xi1>, vector<[4]xi32>) -> vector<[4]xi32>
```
Fortunately the operand is overloaded on the result type so we end up
with the correct IR when lowering to LLVM, but this is still incorrect.
This patch fixes it by propagating scalability.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants