Skip to content

[mlir][vector] Add vector.from_elements op #95938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -720,10 +720,9 @@ def Vector_ExtractOp :
return getStaticPosition().size();
}

/// Return "true" if the op has at least one dynamic position.
bool hasDynamicPosition() {
auto dynPos = getDynamicPosition();
return std::any_of(dynPos.begin(), dynPos.end(),
[](Value operand) { return operand != nullptr; });
return !getDynamicPosition().empty();
}
}];

Expand Down Expand Up @@ -769,6 +768,41 @@ def Vector_FMAOp :
}];
}

def Vector_FromElementsOp : Vector_Op<"from_elements", [
Pure,
TypesMatchWith<"operand types match result element type",
"result", "elements", "SmallVector<Type>("
"::llvm::cast<VectorType>($_self).getNumElements(), "
"::llvm::cast<VectorType>($_self).getElementType())">]> {
let summary = "operation that defines a vector from scalar elements";
let description = [{
This operation defines a vector from one or multiple scalar elements. The
number of elements must match the number of elements in the result type.
All elements must have the same type, which must match the element type of
the result vector type.

`elements` are a flattened version of the result vector in row-major order.

Example:

```mlir
// %f1
%0 = vector.from_elements %f1 : vector<f32>
// [%f1, %f2]
%1 = vector.from_elements %f1, %f2 : vector<2xf32>
// [[%f1, %f2, %f3], [%f4, %f5, %f6]]
%2 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<2x3xf32>
// [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]]
%3 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<3x1x2xf32>
```
}];

let arguments = (ins Variadic<AnyType>:$elements);
let results = (outs AnyVectorOfAnyRank:$result);
let assemblyFormat = "$elements attr-dict `:` type($result)";
let hasCanonicalizer = 1;
}

def Vector_InsertElementOp :
Vector_Op<"insertelement", [Pure,
TypesMatchWith<"source operand type matches element type of result",
Expand Down
27 changes: 26 additions & 1 deletion mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1836,6 +1836,30 @@ struct VectorDeinterleaveOpLowering
}
};

/// Conversion pattern for a `vector.from_elements`.
struct VectorFromElementsLowering
: public ConvertOpToLLVMPattern<vector::FromElementsOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = fromElementsOp.getLoc();
VectorType vectorType = fromElementsOp.getType();
// TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
// Such ops should be handled in the same way as vector.insert.
if (vectorType.getRank() > 1)
return rewriter.notifyMatchFailure(fromElementsOp,
"rank > 1 vectors are not supported");
Type llvmType = typeConverter->convertType(vectorType);
Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
result = rewriter.create<vector::InsertOp>(loc, val, result, idx);
rewriter.replaceOp(fromElementsOp, result);
return success();
}
};

} // namespace

/// Populate the given list with patterns that convert from Vector to LLVM.
Expand All @@ -1861,7 +1885,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorSplatOpLowering, VectorSplatNdOpLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
MaskedReductionOpConversion, VectorInterleaveOpLowering,
VectorDeinterleaveOpLowering>(converter);
VectorDeinterleaveOpLowering, VectorFromElementsLowering>(
converter);
// Transfer ops with rank > 1 are handled by VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
}
Expand Down
111 changes: 111 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1877,6 +1877,45 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
return Value();
}

/// Try to fold the extraction of a scalar from a vector defined by
/// vector.from_elements. E.g.:
///
/// %0 = vector.from_elements %a, %b : vector<2xf32>
/// %1 = vector.extract %0[0] : f32 from vector<2xf32>
/// ==> fold to %a
static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
// Dynamic extractions cannot be folded.
if (extractOp.hasDynamicPosition())
return {};

// Look for extract(from_elements).
auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
if (!fromElementsOp)
return {};

// Scalable vectors are not supported.
auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
if (vecType.isScalable())
return {};

// Only extractions of scalars are supported.
int64_t rank = vecType.getRank();
ArrayRef<int64_t> indices = extractOp.getStaticPosition();
if (extractOp.getType() != vecType.getElementType())
return {};
assert(static_cast<int64_t>(indices.size()) == rank &&
"unexpected number of indices");

// Compute flattened/linearized index and fold to operand.
int flatIndex = 0;
int stride = 1;
for (int i = rank - 1; i >= 0; --i) {
flatIndex += indices[i] * stride;
stride *= vecType.getDimSize(i);
}
return fromElementsOp.getElements()[flatIndex];
}

OpFoldResult ExtractOp::fold(FoldAdaptor) {
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
Expand All @@ -1895,6 +1934,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
return val;
if (auto val = foldExtractStridedOpFromInsertChain(*this))
return val;
if (auto val = foldScalarExtractFromFromElements(*this))
return val;
return OpFoldResult();
}

Expand Down Expand Up @@ -2099,13 +2140,60 @@ LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
return success();
}

/// Try to canonicalize the extraction of a subvector from a vector defined by
/// vector.from_elements. E.g.:
///
/// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
/// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
/// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
PatternRewriter &rewriter) {
// Dynamic positions are not supported.
if (extractOp.hasDynamicPosition())
return failure();

// Scalar extracts are handled by the folder.
auto resultType = dyn_cast<VectorType>(extractOp.getType());
if (!resultType)
return failure();

// Look for extracts from a from_elements op.
auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
if (!fromElementsOp)
return failure();
VectorType inputType = fromElementsOp.getType();

// Scalable vectors are not supported.
if (resultType.isScalable() || inputType.isScalable())
return failure();

// Compute the position of first extracted element and flatten/linearize the
// position.
SmallVector<int64_t> firstElementPos =
llvm::to_vector(extractOp.getStaticPosition());
firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
int flatIndex = 0;
int stride = 1;
for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
flatIndex += firstElementPos[i] * stride;
stride *= inputType.getDimSize(i);
}

// Replace the op with a smaller from_elements op.
rewriter.replaceOpWithNewOp<FromElementsOp>(
extractOp, resultType,
fromElementsOp.getElements().slice(flatIndex,
resultType.getNumElements()));
return success();
}
} // namespace

void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
}

static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
Expand All @@ -2122,6 +2210,29 @@ std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}

//===----------------------------------------------------------------------===//
// FromElementsOp
//===----------------------------------------------------------------------===//

/// Rewrite a vector.from_elements into a vector.splat if all elements are the
/// same SSA value. E.g.:
///
/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
/// ==> rewrite to vector.splat %a : vector<3xf32>
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
PatternRewriter &rewriter) {
if (!llvm::all_equal(fromElementsOp.getElements()))
return failure();
rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
fromElementsOp.getElements().front());
return success();
}

void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add(rewriteFromElementsAsSplat);
}

//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//
Expand Down
31 changes: 31 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2590,3 +2590,34 @@ func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
%0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64>
return %0 : vector<2x2xi64>
}

// -----

// CHECK-LABEL: func.func @vector_from_elements_1d(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
// CHECK: %[[undef:.*]] = llvm.mlir.undef : vector<3xf32>
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[insert0:.*]] = llvm.insertelement %[[a]], %[[undef]][%[[c0]] : i64] : vector<3xf32>
// CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[insert1:.*]] = llvm.insertelement %[[b]], %[[insert0]][%[[c1]] : i64] : vector<3xf32>
// CHECK: %[[c2:.*]] = llvm.mlir.constant(2 : i64) : i64
// CHECK: %[[insert2:.*]] = llvm.insertelement %[[a]], %[[insert1]][%[[c2]] : i64] : vector<3xf32>
// CHECK: return %[[insert2]]
func.func @vector_from_elements_1d(%a: f32, %b: f32) -> vector<3xf32> {
%0 = vector.from_elements %a, %b, %a : vector<3xf32>
return %0 : vector<3xf32>
}

// -----

// CHECK-LABEL: func.func @vector_from_elements_0d(
// CHECK-SAME: %[[a:.*]]: f32)
// CHECK: %[[undef:.*]] = llvm.mlir.undef : vector<1xf32>
// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[insert0:.*]] = llvm.insertelement %[[a]], %[[undef]][%[[c0]] : i64] : vector<1xf32>
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[insert0]] : vector<1xf32> to vector<f32>
// CHECK: return %[[cast]]
func.func @vector_from_elements_0d(%a: f32) -> vector<f32> {
%0 = vector.from_elements %a : vector<f32>
return %0 : vector<f32>
}
69 changes: 69 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2642,3 +2642,72 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
// CHECK: return %[[a]], %[[a]], %[[extract1]], %[[a]], %[[a]], %[[extract2]], %[[extract3]]
return %1, %3, %5, %7, %9, %10, %11 : f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>
}

// -----

// CHECK-LABEL: func @extract_scalar_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
// Extract from 0D.
%0 = vector.from_elements %a : vector<f32>
%1 = vector.extract %0[] : f32 from vector<f32>

// Extract from 1D.
%2 = vector.from_elements %a : vector<1xf32>
%3 = vector.extract %2[0] : f32 from vector<1xf32>
%4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
%5 = vector.extract %4[4] : f32 from vector<5xf32>

// Extract from 2D.
%6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
%7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
%8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
%9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
%10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>

// CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
}

// -----

// CHECK-LABEL: func @extract_1d_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
// CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
%1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
%2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
// CHECK: return %[[splat1]], %[[splat2]]
return %1, %2 : vector<3xf32>, vector<3xf32>
}

// -----

// CHECK-LABEL: func @extract_2d_from_from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
// CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
%1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
%2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: return %[[splat1]], %[[splat2]]
return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
}

// -----

// CHECK-LABEL: func @from_elements_to_splat(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
// CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
%0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
// CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
%1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
// CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
%2 = vector.from_elements %a : vector<f32>
// CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
}
17 changes: 17 additions & 0 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1854,3 +1854,20 @@ func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {
%0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<[2]xf32>)
return
}

// -----

func.func @invalid_from_elements(%a: f32) {
// expected-error @+1 {{'vector.from_elements' 1 operands present, but expected 2}}
vector.from_elements %a : vector<2xf32>
return
}

// -----

// expected-note @+1 {{prior use here}}
func.func @invalid_from_elements(%a: f32, %b: i32) {
// expected-error @+1 {{use of value '%b' expects different type than prior uses: 'f32' vs 'i32'}}
vector.from_elements %a, %b : vector<2xf32>
return
}
14 changes: 14 additions & 0 deletions mlir/test/Dialect/Vector/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1158,3 +1158,17 @@ func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4
%0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32>
return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
}

// CHECK-LABEL: func @from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @from_elements(%a: f32, %b: f32) -> (vector<f32>, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>) {
// CHECK: vector.from_elements %[[a]] : vector<f32>
%0 = vector.from_elements %a : vector<f32>
// CHECK: vector.from_elements %[[a]] : vector<1xf32>
%1 = vector.from_elements %a : vector<1xf32>
// CHECK: vector.from_elements %[[a]], %[[b]] : vector<1x2xf32>
%2 = vector.from_elements %a, %b : vector<1x2xf32>
// CHECK: vector.from_elements %[[b]], %[[b]], %[[a]], %[[a]] : vector<2x2xf32>
%3 = vector.from_elements %b, %b, %a, %a : vector<2x2xf32>
return %0, %1, %2, %3 : vector<f32>, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>
}
Loading