Skip to content

[mlir][vector] Add vector.from_elements op #95938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

This commit adds a new operation to the vector dialect: vector.from_elements

The op constructs a new vector from a given list of scalar values. It is similar to tensor.from_elements.

%0 = vector.from_elements %a, %b, %c, %a, %a, %a : vector<2x3xf32>

Constructing a new vector from elements was tedious before this op existed: a typical way was to define an arith.constant ... : vector<...>, followed by a chain of vector.insert.

Folders/canonicalizations are added that can fold vector.extract ops and convert the vector.from_elements op into a vector.splat op.

The LLVM lowering generates an llvm.mlir.undef, followed by a sequence of scalar insertions in the form of llvm.insertelement. Only 0-D and 1-D vectors are currently supported in the LLVM lowering.

@llvmbot
Copy link
Member

llvmbot commented Jun 18, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a new operation to the vector dialect: vector.from_elements

The op constructs a new vector from a given list of scalar values. It is similar to tensor.from_elements.

%0 = vector.from_elements %a, %b, %c, %a, %a, %a : vector&lt;2x3xf32&gt;

Constructing a new vector from elements was tedious before this op existed: a typical way was to define an arith.constant ... : vector&lt;...&gt;, followed by a chain of vector.insert.

Folders/canonicalizations are added that can fold vector.extract ops and convert the vector.from_elements op into a vector.splat op.

The LLVM lowering generates an llvm.mlir.undef, followed by a sequence of scalar insertions in the form of llvm.insertelement. Only 0-D and 1-D vectors are currently supported in the LLVM lowering.


Full diff: https://github.com/llvm/llvm-project/pull/95938.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+31-3)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+30-1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+103)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+31)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+69)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+17)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+14)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 56d866ac5b40c..f58b550266cd0 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -720,10 +720,9 @@ def Vector_ExtractOp :
       return getStaticPosition().size();
     }
 
+    /// Return "true" if the op has at least one dynamic position.
     bool hasDynamicPosition() {
-      auto dynPos = getDynamicPosition();
-      return std::any_of(dynPos.begin(), dynPos.end(),
-                         [](Value operand) { return operand != nullptr; });
+      return !getDynamicPosition().empty();
     }
   }];
 
@@ -769,6 +768,35 @@ def Vector_FMAOp :
   }];
 }
 
+def Vector_FromElementsOp : Vector_Op<"from_elements", [
+    Pure,
+    TypesMatchWith<"operand types match result element type",
+                   "result", "elements", "SmallVector<Type>("
+                   "::llvm::cast<VectorType>($_self).getNumElements(), "
+                   "::llvm::cast<VectorType>($_self).getElementType())">]> {
+  let summary = "operation that defines a vector from scalar elements";
+  let description = [{
+    This operation defines a vector from one or multiple scalar elements. The
+    number of elements must match the number of elements in the result type.
+    All elements must have the same type, which must match the element type of
+    the result vector type.
+
+    Example:
+
+    ```mlir
+    // [%f1, %f2]
+    %0 = vector.from_elements %f1, %f2 : vector<2xf32>
+    // [[%f1, %f2, %f3], [%f4, %f5, %f6]]
+    %1 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<2x3xf32>
+    ```
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$elements);
+  let results = (outs AnyVectorOfAnyRank:$result);
+  let assemblyFormat = "$elements attr-dict `:` type($result)";
+  let hasCanonicalizer = 1;
+}
+
 def Vector_InsertElementOp :
   Vector_Op<"insertelement", [Pure,
      TypesMatchWith<"source operand type matches element type of result",
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 60f7e95ade689..e67c22fe3aa5d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1836,6 +1836,34 @@ struct VectorDeinterleaveOpLowering
   }
 };
 
+/// Conversion pattern for a `vector.from_elements`.
+struct VectorFromElementsLowering
+    : public ConvertOpToLLVMPattern<vector::FromElementsOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = fromElementsOp.getLoc();
+    VectorType vectorType = fromElementsOp.getType();
+    // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
+    // Such ops should be handled in the same way as vector.insert.
+    if (vectorType.getRank() > 1)
+      return rewriter.notifyMatchFailure(fromElementsOp,
+                                         "rank > 1 vectors are not supported");
+    Type llvmType = typeConverter->convertType(vectorType);
+    Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
+    for (auto it : llvm::enumerate(adaptor.getElements())) {
+      result = rewriter.create<LLVM::InsertElementOp>(
+          loc, result, it.value(),
+          rewriter.create<LLVM::ConstantOp>(
+              loc, rewriter.getI64IntegerAttr(it.index())));
+    }
+    rewriter.replaceOp(fromElementsOp, result);
+    return success();
+  }
+};
+
 } // namespace
 
 /// Populate the given list with patterns that convert from Vector to LLVM.
@@ -1861,7 +1889,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
                VectorSplatOpLowering, VectorSplatNdOpLowering,
                VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
                MaskedReductionOpConversion, VectorInterleaveOpLowering,
-               VectorDeinterleaveOpLowering>(converter);
+               VectorDeinterleaveOpLowering, VectorFromElementsLowering>(
+      converter);
   // Transfer ops with rank > 1 are handled by VectorToSCF.
   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
 }
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2bf4f16f96e6a..269352b3a8f7e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1877,6 +1877,39 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
   return Value();
 }
 
+static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
+  // Dynamic extractions cannot be folded.
+  if (extractOp.hasDynamicPosition())
+    return {};
+
+  // Look for extract(from_elements).
+  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
+  if (!fromElementsOp)
+    return {};
+
+  // Scalable vectors are not supported.
+  auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
+  if (vecType.isScalable())
+    return {};
+
+  // Only extractions of scalars are supported.
+  int64_t rank = vecType.getRank();
+  ArrayRef<int64_t> indices = extractOp.getStaticPosition();
+  if (extractOp.getType() != vecType.getElementType())
+    return {};
+  assert(static_cast<int64_t>(indices.size()) == rank &&
+         "unexpected number of indices");
+
+  // Compute flattened/linearized index and fold to operand.
+  int flatIndex = 0;
+  int stride = 1;
+  for (int i = rank - 1; i >= 0; --i) {
+    flatIndex += indices[i] * stride;
+    stride *= vecType.getDimSize(i);
+  }
+  return fromElementsOp.getElements()[flatIndex];
+}
+
 OpFoldResult ExtractOp::fold(FoldAdaptor) {
   // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
   // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -1895,6 +1928,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
     return val;
   if (auto val = foldExtractStridedOpFromInsertChain(*this))
     return val;
+  if (auto val = foldScalarExtractFromFromElements(*this))
+    return val;
   return OpFoldResult();
 }
 
@@ -2099,6 +2134,47 @@ LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
   return success();
 }
 
+/// Fold extract(from_elements) into a smaller from_elements op.
+LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
+                                          PatternRewriter &rewriter) {
+  // Dynamic positions are not supported.
+  if (extractOp.hasDynamicPosition())
+    return failure();
+
+  // Scalar extracts are handled by the folder.
+  auto resultType = dyn_cast<VectorType>(extractOp.getType());
+  if (!resultType)
+    return failure();
+
+  // Look for extracts from a from_elements op.
+  auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
+  if (!fromElementsOp)
+    return failure();
+  VectorType inputType = fromElementsOp.getType();
+
+  // Scalable vectors are not supported.
+  if (resultType.isScalable() || inputType.isScalable())
+    return failure();
+
+  // Compute the position of first extracted element and flatten/linearize the
+  // position.
+  SmallVector<int64_t> firstElementPos =
+      llvm::to_vector(extractOp.getStaticPosition());
+  firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
+  int flatIndex = 0;
+  int stride = 1;
+  for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
+    flatIndex += firstElementPos[i] * stride;
+    stride *= inputType.getDimSize(i);
+  }
+
+  // Replace the op with a smaller from_elements op.
+  rewriter.replaceOpWithNewOp<FromElementsOp>(
+      extractOp, resultType,
+      fromElementsOp.getElements().slice(flatIndex,
+                                         resultType.getNumElements()));
+  return success();
+}
 } // namespace
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -2106,6 +2182,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
               ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
+  results.add(foldExtractFromFromElements);
 }
 
 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
@@ -2122,6 +2199,32 @@ std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
 
+//===----------------------------------------------------------------------===//
+// FromElementsOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+class FromElementsToSplat final : public OpRewritePattern<FromElementsOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(FromElementsOp fromElementsOp,
+                                PatternRewriter &rewriter) const override {
+    if (!llvm::all_equal(fromElementsOp.getElements()))
+      return failure();
+    rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp,
+                                         fromElementsOp.getType(),
+                                         fromElementsOp.getElements().front());
+    return success();
+  }
+};
+} // namespace
+
+void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                 MLIRContext *context) {
+  results.add<FromElementsToSplat>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // BroadcastOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index bf4281ebcdec9..09b79708a9ab2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2590,3 +2590,34 @@ func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
   %0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64>
   return %0 : vector<2x2xi64>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @vector_from_elements_1d(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+//       CHECK:   %[[undef:.*]] = llvm.mlir.undef : vector<3xf32>
+//       CHECK:   %[[c0:.*]] = llvm.mlir.constant(0 : i64) : i64
+//       CHECK:   %[[insert0:.*]] = llvm.insertelement %[[a]], %[[undef]][%[[c0]] : i64] : vector<3xf32>
+//       CHECK:   %[[c1:.*]] = llvm.mlir.constant(1 : i64) : i64
+//       CHECK:   %[[insert1:.*]] = llvm.insertelement %[[b]], %[[insert0]][%[[c1]] : i64] : vector<3xf32>
+//       CHECK:   %[[c2:.*]] = llvm.mlir.constant(2 : i64) : i64
+//       CHECK:   %[[insert2:.*]] = llvm.insertelement %[[a]], %[[insert1]][%[[c2]] : i64] : vector<3xf32>
+//       CHECK:   return %[[insert2]]
+func.func @vector_from_elements_1d(%a: f32, %b: f32) -> vector<3xf32> {
+  %0 = vector.from_elements %a, %b, %a : vector<3xf32>
+  return %0 : vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_from_elements_0d(
+//  CHECK-SAME:     %[[a:.*]]: f32)
+//       CHECK:   %[[undef:.*]] = llvm.mlir.undef : vector<1xf32>
+//       CHECK:   %[[c0:.*]] = llvm.mlir.constant(0 : i64) : i64
+//       CHECK:   %[[insert0:.*]] = llvm.insertelement %[[a]], %[[undef]][%[[c0]] : i64] : vector<1xf32>
+//       CHECK:   %[[cast:.*]] = builtin.unrealized_conversion_cast %[[insert0]] : vector<1xf32> to vector<f32>
+//       CHECK:   return %[[cast]]
+func.func @vector_from_elements_0d(%a: f32) -> vector<f32> {
+  %0 = vector.from_elements %a : vector<f32>
+  return %0 : vector<f32>
+}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index caccd1f1c9c24..8181f1a8c5d13 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2642,3 +2642,72 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
   // CHECK:   return %[[a]], %[[a]], %[[extract1]], %[[a]], %[[a]], %[[extract2]], %[[extract3]]
   return %1, %3, %5, %7, %9, %10, %11 : f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @extract_scalar_from_from_elements(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
+  // Extract from 0D.
+  %0 = vector.from_elements %a : vector<f32>
+  %1 = vector.extract %0[] : f32 from vector<f32>
+
+  // Extract from 1D.
+  %2 = vector.from_elements %a : vector<1xf32>
+  %3 = vector.extract %2[0] : f32 from vector<1xf32>
+  %4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
+  %5 = vector.extract %4[4] : f32 from vector<5xf32>
+
+  // Extract from 2D.
+  %6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
+  %7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
+  %8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
+  %9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
+  %10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>
+
+  // CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
+  return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_1d_from_from_elements(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
+  %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
+  // CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
+  %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
+  // CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
+  %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
+  // CHECK: return %[[splat1]], %[[splat2]]
+  return %1, %2 : vector<3xf32>, vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_2d_from_from_elements(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
+  %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
+  // CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
+  %1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
+  // CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
+  %2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
+  // CHECK: return %[[splat1]], %[[splat2]]
+  return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @from_elements_to_splat(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
+  // CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
+  %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
+  // CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
+  %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
+  // CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
+  %2 = vector.from_elements %a : vector<f32>
+  // CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
+  return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
+}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 1516f51fe1458..d0eaed8f98cc5 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1854,3 +1854,20 @@ func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {
   %0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<[2]xf32>)
   return
 }
+
+// -----
+
+func.func @invalid_from_elements(%a: f32) {
+  // expected-error @+1 {{'vector.from_elements' 1 operands present, but expected 2}}
+  vector.from_elements %a : vector<2xf32>
+  return
+}
+
+// -----
+
+// expected-note @+1 {{prior use here}}
+func.func @invalid_from_elements(%a: f32, %b: i32) {
+  // expected-error @+1 {{use of value '%b' expects different type than prior uses: 'f32' vs 'i32'}}
+  vector.from_elements %a, %b : vector<2xf32>
+  return
+}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index c868c881d079a..4da09584db88b 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1158,3 +1158,17 @@ func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4
   %0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32>
   return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
 }
+
+// CHECK-LABEL: func @from_elements(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @from_elements(%a: f32, %b: f32) -> (vector<f32>, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>) {
+  // CHECK: vector.from_elements %[[a]] : vector<f32>
+  %0 = vector.from_elements %a : vector<f32>
+  // CHECK: vector.from_elements %[[a]] : vector<1xf32>
+  %1 = vector.from_elements %a : vector<1xf32>
+  // CHECK: vector.from_elements %[[a]], %[[b]] : vector<1x2xf32>
+  %2 = vector.from_elements %a, %b : vector<1x2xf32>
+  // CHECK: vector.from_elements %[[b]], %[[b]], %[[a]], %[[a]] : vector<2x2xf32>
+  %3 = vector.from_elements %b, %b, %a, %a : vector<2x2xf32>
+  return %0, %1, %2, %3 : vector<f32>, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>
+}
\ No newline at end of file

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great, thanks! Added some minor comments about doc but LGTM in general!

This commit adds a new operation to the vector dialect: `vector.from_elements`

The op constructs a new vector from a given list of scalar values. It is similar to `tensor.from_elements`.

Constructing a new vector from elements was tedious before this op existed: a typical way was to define an `arith.constant ... : vector<...>`, followed by a chain of `vector.insert`.

Folders/canonicalizations are added that can fold `vector.extract` ops and convert the `vector.from_elements` op into a `vector.splat` op.

The LLVM lowering generates an `llvm.mlir.undef`, followed by a sequence of scalar insertions in the form of `llvm.insertelement`. Only 0-D and 1-D vectors are currently supported in the LLVM lowering.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/vector_from_elements branch from 0f9835f to dbab92d Compare June 19, 2024 07:48
@matthias-springer matthias-springer merged commit c6ff244 into main Jun 19, 2024
5 of 6 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/vector_from_elements branch June 19, 2024 07:58
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
This commit adds a new operation to the vector dialect:
`vector.from_elements`

The op constructs a new vector from a given list of scalar values. It is
similar to `tensor.from_elements`.
```mlir
%0 = vector.from_elements %a, %b, %c, %a, %a, %a : vector<2x3xf32>
```

Constructing a new vector from elements was tedious before this op
existed: a typical way was to define an `arith.constant ... :
vector<...>`, followed by a chain of `vector.insert`.

Folders/canonicalizations are added that can fold `vector.extract` ops
and convert the `vector.from_elements` op into a `vector.splat` op.

The LLVM lowering generates an `llvm.mlir.undef`, followed by a sequence
of scalar insertions in the form of `llvm.insertelement`. Only 0-D and
1-D vectors are currently supported in the LLVM lowering.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants