Skip to content

[mlir][Vector] Lower vector.to_elements to LLVM #145766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 26, 2025

Conversation

dcaballe
Copy link
Contributor

Only elements with at least one use are lowered to llvm.extractelement op.

@llvmbot
Copy link
Member

llvmbot commented Jun 25, 2025

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

Only elements with at least one use are lowered to llvm.extractelement op.


Full diff: https://github.com/llvm/llvm-project/pull/145766.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+33-1)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (+35-1)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d53d11f87efe8..f1543200fb56f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1985,6 +1985,37 @@ struct VectorFromElementsLowering
   }
 };
 
+/// Conversion pattern for a `vector.to_elements`.
+struct VectorToElementsLowering
+    : public ConvertOpToLLVMPattern<vector::ToElementsOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = toElementsOp.getLoc();
+    auto idxType = typeConverter->convertType(rewriter.getIndexType());
+    Value source = adaptor.getSource();
+
+    SmallVector<Value> results(toElementsOp->getNumResults());
+    for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
+      // Create an extractelement operation only for results that are not dead.
+      if (!element.use_empty()) {
+        auto constIdx = rewriter.create<LLVM::ConstantOp>(
+            loc, idxType, rewriter.getIntegerAttr(idxType, idx));
+        auto llvmType = typeConverter->convertType(element.getType());
+
+        Value result = rewriter.create<LLVM::ExtractElementOp>(
+            loc, llvmType, source, constIdx);
+        results[idx] = result;
+      }
+    }
+
+    rewriter.replaceOp(toElementsOp, results);
+    return success();
+  }
+};
+
 /// Conversion pattern for vector.step.
 struct VectorScalableStepOpLowering
     : public ConvertOpToLLVMPattern<vector::StepOp> {
@@ -2035,7 +2066,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
                VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
                MaskedReductionOpConversion, VectorInterleaveOpLowering,
                VectorDeinterleaveOpLowering, VectorFromElementsLowering,
-               VectorScalableStepOpLowering>(converter);
+               VectorToElementsLowering, VectorScalableStepOpLowering>(
+      converter);
 }
 
 void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 3df14528bac39..8f73e79d7bfc2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -1875,7 +1875,7 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
 // CHECK: %[[CAST_MEMREF:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[CST:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
 // CHECK: %[[VAL:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<f32> to vector<1xf32>
-// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
 // CHECK: %[[MUL:.*]] = llvm.mul %[[I]], %[[C100]] : i64
 // CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %[[J]] : i64
@@ -2421,6 +2421,40 @@ func.func @from_elements_0d(%arg0: f32) -> vector<f32> {
 
 // -----
 
+// CHECK-LABEL: func.func @vector_to_elements_no_dead_elements
+ // CHECK-SAME:     %[[A:.*]]: vector<4xf32>)
+ //      CHECK:   %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ //      CHECK:   %[[ELEM0:.*]] = llvm.extractelement %[[A]][%[[C0]] : i64] : vector<4xf32>
+ //      CHECK:   %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+ //      CHECK:   %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
+ //      CHECK:   %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+ //      CHECK:   %[[ELEM2:.*]] = llvm.extractelement %[[A]][%[[C2]] : i64] : vector<4xf32>
+ //      CHECK:   %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
+ //      CHECK:   %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
+ //      CHECK:   return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
+func.func @vector_to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %a : vector<4xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_to_elements_dead_elements
+ // CHECK-SAME:     %[[A:.*]]: vector<4xf32>)
+ //  CHECK-NOT:   llvm.mlir.constant(0 : i64) : i64
+ //      CHECK:   %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+ //      CHECK:   %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
+ //  CHECK-NOT:   llvm.mlir.constant(2 : i64) : i64
+ //      CHECK:   %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
+ //      CHECK:   %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
+ //      CHECK:   return %[[ELEM1]], %[[ELEM3]] : f32, f32
+func.func @vector_to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
+  %0:4 = vector.to_elements %a : vector<4xf32>
+  return %0#1, %0#3 : f32, f32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.step
 //===----------------------------------------------------------------------===//

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG % minor nits, thanks!

Comment on lines 2424 to 2458
// CHECK-LABEL: func.func @vector_to_elements_no_dead_elements
// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[A]][%[[C0]] : i64] : vector<4xf32>
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
// CHECK: %[[ELEM2:.*]] = llvm.extractelement %[[A]][%[[C2]] : i64] : vector<4xf32>
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
// CHECK: return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
func.func @vector_to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
%0:4 = vector.to_elements %a : vector<4xf32>
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
}

// -----

// CHECK-LABEL: func.func @vector_to_elements_dead_elements
// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
// CHECK-NOT: llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
// CHECK-NOT: llvm.mlir.constant(2 : i64) : i64
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
// CHECK: return %[[ELEM1]], %[[ELEM3]] : f32, f32
func.func @vector_to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
%0:4 = vector.to_elements %a : vector<4xf32>
return %0#1, %0#3 : f32, f32
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Add block comment, remove vector from function names (for consistency)

Suggested change
// CHECK-LABEL: func.func @vector_to_elements_no_dead_elements
// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[A]][%[[C0]] : i64] : vector<4xf32>
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
// CHECK: %[[ELEM2:.*]] = llvm.extractelement %[[A]][%[[C2]] : i64] : vector<4xf32>
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
// CHECK: return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
func.func @vector_to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
%0:4 = vector.to_elements %a : vector<4xf32>
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
}
// -----
// CHECK-LABEL: func.func @vector_to_elements_dead_elements
// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
// CHECK-NOT: llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
// CHECK-NOT: llvm.mlir.constant(2 : i64) : i64
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
// CHECK: return %[[ELEM1]], %[[ELEM3]] : f32, f32
func.func @vector_to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
%0:4 = vector.to_elements %a : vector<4xf32>
return %0#1, %0#3 : f32, f32
}
//===----------------------------------------------------------------------===//
// vector.to_elements
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func.func @to_elements_no_dead_elements
// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[A]][%[[C0]] : i64] : vector<4xf32>
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
// CHECK: %[[ELEM2:.*]] = llvm.extractelement %[[A]][%[[C2]] : i64] : vector<4xf32>
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
// CHECK: return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
func.func @vector_to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
%0:4 = vector.to_elements %a : vector<4xf32>
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
}
// -----
// CHECK-LABEL: func.func @to_elements_dead_elements
// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
// CHECK-NOT: llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
// CHECK-NOT: llvm.mlir.constant(2 : i64) : i64
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
// CHECK: return %[[ELEM1]], %[[ELEM3]] : f32, f32
func.func @vector_to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
%0:4 = vector.to_elements %a : vector<4xf32>
return %0#1, %0#3 : f32, f32
}

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dcaballe do you plan to add SPIR-V lowering on your own? If not, could you open an issue and tag me so that I can assign it to someone?

SmallVector<Value> results(toElementsOp->getNumResults());
for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
// Create an extractelement operation only for results that are not dead.
if (!element.use_empty()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can invert this check and continue instead to reduce nesting

@dcaballe
Copy link
Contributor Author

@dcaballe do you plan to add SPIR-V lowering on your own? If not, could you open an issue and tag me so that I can assign it to someone?

#145929

dcaballe added 2 commits June 26, 2025 17:08
Only elements with at least one use are lowered to `llvm.extractelement` op.
@dcaballe dcaballe force-pushed the vector-to-elements-llvm-lowering branch from 2c40823 to cc50307 Compare June 26, 2025 17:25
@dcaballe dcaballe merged commit 7842e9e into llvm:main Jun 26, 2025
5 of 7 checks passed
anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
Only elements with at least one use are lowered to `llvm.extractelement`
op.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants