Skip to content

Commit dbab92d

Browse files
[mlir][vector] Add vector.from_elements op
This commit adds a new operation to the vector dialect: `vector.from_elements` The op constructs a new vector from a given list of scalar values. It is similar to `tensor.from_elements`. Constructing a new vector from elements was tedious before this op existed: a typical way was to define an `arith.constant ... : vector<...>`, followed by a chain of `vector.insert`. Folders/canonicalizations are added that can fold `vector.extract` ops and convert the `vector.from_elements` op into a `vector.splat` op. The LLVM lowering generates an `llvm.mlir.undef`, followed by a sequence of scalar insertions in the form of `llvm.insertelement`. Only 0-D and 1-D vectors are currently supported in the LLVM lowering.
1 parent e64ed1d commit dbab92d

File tree

7 files changed

+305
-4
lines changed

7 files changed

+305
-4
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -720,10 +720,9 @@ def Vector_ExtractOp :
720720
return getStaticPosition().size();
721721
}
722722

723+
/// Return "true" if the op has at least one dynamic position.
723724
bool hasDynamicPosition() {
724-
auto dynPos = getDynamicPosition();
725-
return std::any_of(dynPos.begin(), dynPos.end(),
726-
[](Value operand) { return operand != nullptr; });
725+
return !getDynamicPosition().empty();
727726
}
728727
}];
729728

@@ -769,6 +768,41 @@ def Vector_FMAOp :
769768
}];
770769
}
771770

771+
def Vector_FromElementsOp : Vector_Op<"from_elements", [
772+
Pure,
773+
TypesMatchWith<"operand types match result element type",
774+
"result", "elements", "SmallVector<Type>("
775+
"::llvm::cast<VectorType>($_self).getNumElements(), "
776+
"::llvm::cast<VectorType>($_self).getElementType())">]> {
777+
let summary = "operation that defines a vector from scalar elements";
778+
let description = [{
779+
This operation defines a vector from one or multiple scalar elements. The
780+
number of elements must match the number of elements in the result type.
781+
All elements must have the same type, which must match the element type of
782+
the result vector type.
783+
784+
`elements` are a flattened version of the result vector in row-major order.
785+
786+
Example:
787+
788+
```mlir
789+
// %f1
790+
%0 = vector.from_elements %f1 : vector<f32>
791+
// [%f1, %f2]
792+
%1 = vector.from_elements %f1, %f2 : vector<2xf32>
793+
// [[%f1, %f2, %f3], [%f4, %f5, %f6]]
794+
%2 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<2x3xf32>
795+
// [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]]
796+
%3 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<3x1x2xf32>
797+
```
798+
}];
799+
800+
let arguments = (ins Variadic<AnyType>:$elements);
801+
let results = (outs AnyVectorOfAnyRank:$result);
802+
let assemblyFormat = "$elements attr-dict `:` type($result)";
803+
let hasCanonicalizer = 1;
804+
}
805+
772806
def Vector_InsertElementOp :
773807
Vector_Op<"insertelement", [Pure,
774808
TypesMatchWith<"source operand type matches element type of result",

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1836,6 +1836,30 @@ struct VectorDeinterleaveOpLowering
18361836
}
18371837
};
18381838

1839+
/// Conversion pattern for a `vector.from_elements`.
1840+
struct VectorFromElementsLowering
1841+
: public ConvertOpToLLVMPattern<vector::FromElementsOp> {
1842+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1843+
1844+
LogicalResult
1845+
matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1846+
ConversionPatternRewriter &rewriter) const override {
1847+
Location loc = fromElementsOp.getLoc();
1848+
VectorType vectorType = fromElementsOp.getType();
1849+
// TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
1850+
// Such ops should be handled in the same way as vector.insert.
1851+
if (vectorType.getRank() > 1)
1852+
return rewriter.notifyMatchFailure(fromElementsOp,
1853+
"rank > 1 vectors are not supported");
1854+
Type llvmType = typeConverter->convertType(vectorType);
1855+
Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
1856+
for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
1857+
result = rewriter.create<vector::InsertOp>(loc, val, result, idx);
1858+
rewriter.replaceOp(fromElementsOp, result);
1859+
return success();
1860+
}
1861+
};
1862+
18391863
} // namespace
18401864

18411865
/// Populate the given list with patterns that convert from Vector to LLVM.
@@ -1861,7 +1885,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
18611885
VectorSplatOpLowering, VectorSplatNdOpLowering,
18621886
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
18631887
MaskedReductionOpConversion, VectorInterleaveOpLowering,
1864-
VectorDeinterleaveOpLowering>(converter);
1888+
VectorDeinterleaveOpLowering, VectorFromElementsLowering>(
1889+
converter);
18651890
// Transfer ops with rank > 1 are handled by VectorToSCF.
18661891
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
18671892
}

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1877,6 +1877,45 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
18771877
return Value();
18781878
}
18791879

1880+
/// Try to fold the extraction of a scalar from a vector defined by
1881+
/// vector.from_elements. E.g.:
1882+
///
1883+
/// %0 = vector.from_elements %a, %b : vector<2xf32>
1884+
/// %1 = vector.extract %0[0] : f32 from vector<2xf32>
1885+
/// ==> fold to %a
1886+
static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
1887+
// Dynamic extractions cannot be folded.
1888+
if (extractOp.hasDynamicPosition())
1889+
return {};
1890+
1891+
// Look for extract(from_elements).
1892+
auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
1893+
if (!fromElementsOp)
1894+
return {};
1895+
1896+
// Scalable vectors are not supported.
1897+
auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
1898+
if (vecType.isScalable())
1899+
return {};
1900+
1901+
// Only extractions of scalars are supported.
1902+
int64_t rank = vecType.getRank();
1903+
ArrayRef<int64_t> indices = extractOp.getStaticPosition();
1904+
if (extractOp.getType() != vecType.getElementType())
1905+
return {};
1906+
assert(static_cast<int64_t>(indices.size()) == rank &&
1907+
"unexpected number of indices");
1908+
1909+
// Compute flattened/linearized index and fold to operand.
1910+
int flatIndex = 0;
1911+
int stride = 1;
1912+
for (int i = rank - 1; i >= 0; --i) {
1913+
flatIndex += indices[i] * stride;
1914+
stride *= vecType.getDimSize(i);
1915+
}
1916+
return fromElementsOp.getElements()[flatIndex];
1917+
}
1918+
18801919
OpFoldResult ExtractOp::fold(FoldAdaptor) {
18811920
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
18821921
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -1895,6 +1934,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
18951934
return val;
18961935
if (auto val = foldExtractStridedOpFromInsertChain(*this))
18971936
return val;
1937+
if (auto val = foldScalarExtractFromFromElements(*this))
1938+
return val;
18981939
return OpFoldResult();
18991940
}
19001941

@@ -2099,13 +2140,60 @@ LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
20992140
return success();
21002141
}
21012142

2143+
/// Try to canonicalize the extraction of a subvector from a vector defined by
2144+
/// vector.from_elements. E.g.:
2145+
///
2146+
/// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2147+
/// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2148+
/// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2149+
LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2150+
PatternRewriter &rewriter) {
2151+
// Dynamic positions are not supported.
2152+
if (extractOp.hasDynamicPosition())
2153+
return failure();
2154+
2155+
// Scalar extracts are handled by the folder.
2156+
auto resultType = dyn_cast<VectorType>(extractOp.getType());
2157+
if (!resultType)
2158+
return failure();
2159+
2160+
// Look for extracts from a from_elements op.
2161+
auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
2162+
if (!fromElementsOp)
2163+
return failure();
2164+
VectorType inputType = fromElementsOp.getType();
2165+
2166+
// Scalable vectors are not supported.
2167+
if (resultType.isScalable() || inputType.isScalable())
2168+
return failure();
2169+
2170+
// Compute the position of first extracted element and flatten/linearize the
2171+
// position.
2172+
SmallVector<int64_t> firstElementPos =
2173+
llvm::to_vector(extractOp.getStaticPosition());
2174+
firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2175+
int flatIndex = 0;
2176+
int stride = 1;
2177+
for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2178+
flatIndex += firstElementPos[i] * stride;
2179+
stride *= inputType.getDimSize(i);
2180+
}
2181+
2182+
// Replace the op with a smaller from_elements op.
2183+
rewriter.replaceOpWithNewOp<FromElementsOp>(
2184+
extractOp, resultType,
2185+
fromElementsOp.getElements().slice(flatIndex,
2186+
resultType.getNumElements()));
2187+
return success();
2188+
}
21022189
} // namespace
21032190

21042191
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
21052192
MLIRContext *context) {
21062193
results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
21072194
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
21082195
results.add(foldExtractFromShapeCastToShapeCast);
2196+
results.add(foldExtractFromFromElements);
21092197
}
21102198

21112199
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
@@ -2122,6 +2210,29 @@ std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
21222210
return llvm::to_vector<4>(getVectorType().getShape());
21232211
}
21242212

2213+
//===----------------------------------------------------------------------===//
2214+
// FromElementsOp
2215+
//===----------------------------------------------------------------------===//
2216+
2217+
/// Rewrite a vector.from_elements into a vector.splat if all elements are the
2218+
/// same SSA value. E.g.:
2219+
///
2220+
/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2221+
/// ==> rewrite to vector.splat %a : vector<3xf32>
2222+
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
2223+
PatternRewriter &rewriter) {
2224+
if (!llvm::all_equal(fromElementsOp.getElements()))
2225+
return failure();
2226+
rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
2227+
fromElementsOp.getElements().front());
2228+
return success();
2229+
}
2230+
2231+
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2232+
MLIRContext *context) {
2233+
results.add(rewriteFromElementsAsSplat);
2234+
}
2235+
21252236
//===----------------------------------------------------------------------===//
21262237
// BroadcastOp
21272238
//===----------------------------------------------------------------------===//

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2590,3 +2590,34 @@ func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
25902590
%0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64>
25912591
return %0 : vector<2x2xi64>
25922592
}
2593+
2594+
// -----
2595+
2596+
// CHECK-LABEL: func.func @vector_from_elements_1d(
2597+
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
2598+
// CHECK: %[[undef:.*]] = llvm.mlir.undef : vector<3xf32>
2599+
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i64) : i64
2600+
// CHECK: %[[insert0:.*]] = llvm.insertelement %[[a]], %[[undef]][%[[c0]] : i64] : vector<3xf32>
2601+
// CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : i64) : i64
2602+
// CHECK: %[[insert1:.*]] = llvm.insertelement %[[b]], %[[insert0]][%[[c1]] : i64] : vector<3xf32>
2603+
// CHECK: %[[c2:.*]] = llvm.mlir.constant(2 : i64) : i64
2604+
// CHECK: %[[insert2:.*]] = llvm.insertelement %[[a]], %[[insert1]][%[[c2]] : i64] : vector<3xf32>
2605+
// CHECK: return %[[insert2]]
2606+
func.func @vector_from_elements_1d(%a: f32, %b: f32) -> vector<3xf32> {
2607+
%0 = vector.from_elements %a, %b, %a : vector<3xf32>
2608+
return %0 : vector<3xf32>
2609+
}
2610+
2611+
// -----
2612+
2613+
// CHECK-LABEL: func.func @vector_from_elements_0d(
2614+
// CHECK-SAME: %[[a:.*]]: f32)
2615+
// CHECK: %[[undef:.*]] = llvm.mlir.undef : vector<1xf32>
2616+
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i64) : i64
2617+
// CHECK: %[[insert0:.*]] = llvm.insertelement %[[a]], %[[undef]][%[[c0]] : i64] : vector<1xf32>
2618+
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[insert0]] : vector<1xf32> to vector<f32>
2619+
// CHECK: return %[[cast]]
2620+
func.func @vector_from_elements_0d(%a: f32) -> vector<f32> {
2621+
%0 = vector.from_elements %a : vector<f32>
2622+
return %0 : vector<f32>
2623+
}

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2642,3 +2642,72 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
26422642
// CHECK: return %[[a]], %[[a]], %[[extract1]], %[[a]], %[[a]], %[[extract2]], %[[extract3]]
26432643
return %1, %3, %5, %7, %9, %10, %11 : f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>
26442644
}
2645+
2646+
// -----
2647+
2648+
// CHECK-LABEL: func @extract_scalar_from_from_elements(
2649+
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
2650+
func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
2651+
// Extract from 0D.
2652+
%0 = vector.from_elements %a : vector<f32>
2653+
%1 = vector.extract %0[] : f32 from vector<f32>
2654+
2655+
// Extract from 1D.
2656+
%2 = vector.from_elements %a : vector<1xf32>
2657+
%3 = vector.extract %2[0] : f32 from vector<1xf32>
2658+
%4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
2659+
%5 = vector.extract %4[4] : f32 from vector<5xf32>
2660+
2661+
// Extract from 2D.
2662+
%6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
2663+
%7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
2664+
%8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
2665+
%9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
2666+
%10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>
2667+
2668+
// CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
2669+
return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
2670+
}
2671+
2672+
// -----
2673+
2674+
// CHECK-LABEL: func @extract_1d_from_from_elements(
2675+
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
2676+
func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
2677+
%0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
2678+
// CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
2679+
%1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
2680+
// CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
2681+
%2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
2682+
// CHECK: return %[[splat1]], %[[splat2]]
2683+
return %1, %2 : vector<3xf32>, vector<3xf32>
2684+
}
2685+
2686+
// -----
2687+
2688+
// CHECK-LABEL: func @extract_2d_from_from_elements(
2689+
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
2690+
func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
2691+
%0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
2692+
// CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
2693+
%1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
2694+
// CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
2695+
%2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
2696+
// CHECK: return %[[splat1]], %[[splat2]]
2697+
return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
2698+
}
2699+
2700+
// -----
2701+
2702+
// CHECK-LABEL: func @from_elements_to_splat(
2703+
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
2704+
func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
2705+
// CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
2706+
%0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
2707+
// CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
2708+
%1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
2709+
// CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
2710+
%2 = vector.from_elements %a : vector<f32>
2711+
// CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
2712+
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
2713+
}

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1854,3 +1854,20 @@ func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {
18541854
%0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<[2]xf32>)
18551855
return
18561856
}
1857+
1858+
// -----
1859+
1860+
func.func @invalid_from_elements(%a: f32) {
1861+
// expected-error @+1 {{'vector.from_elements' 1 operands present, but expected 2}}
1862+
vector.from_elements %a : vector<2xf32>
1863+
return
1864+
}
1865+
1866+
// -----
1867+
1868+
// expected-note @+1 {{prior use here}}
1869+
func.func @invalid_from_elements(%a: f32, %b: i32) {
1870+
// expected-error @+1 {{use of value '%b' expects different type than prior uses: 'f32' vs 'i32'}}
1871+
vector.from_elements %a, %b : vector<2xf32>
1872+
return
1873+
}

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,3 +1158,17 @@ func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4
11581158
%0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32>
11591159
return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
11601160
}
1161+
1162+
// CHECK-LABEL: func @from_elements(
1163+
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
1164+
func.func @from_elements(%a: f32, %b: f32) -> (vector<f32>, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>) {
1165+
// CHECK: vector.from_elements %[[a]] : vector<f32>
1166+
%0 = vector.from_elements %a : vector<f32>
1167+
// CHECK: vector.from_elements %[[a]] : vector<1xf32>
1168+
%1 = vector.from_elements %a : vector<1xf32>
1169+
// CHECK: vector.from_elements %[[a]], %[[b]] : vector<1x2xf32>
1170+
%2 = vector.from_elements %a, %b : vector<1x2xf32>
1171+
// CHECK: vector.from_elements %[[b]], %[[b]], %[[a]], %[[a]] : vector<2x2xf32>
1172+
%3 = vector.from_elements %b, %b, %a, %a : vector<2x2xf32>
1173+
return %0, %1, %2, %3 : vector<f32>, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>
1174+
}

0 commit comments

Comments
 (0)