Skip to content

[mlir][vector] Add support for scalable vectors to VectorLinearize #86786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,16 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
PatternRewriter &rewriter) const = 0;
};

/// Returns true if the input Vector type can be linearized.
///
/// Linearization is meant in the sense of flattening vectors, e.g.:
/// * vector<NxMxKxi32> -> vector<N*M*Kxi32>
/// In this sense, Vectors that are either:
/// * already linearized, or
/// * contain more than 1 scalable dimensions,
/// are not linearizable.
bool isLinearizableVector(VectorType type);

} // namespace vector

/// Constructs a permutation map of invariant memref indices to vector
Expand Down
12 changes: 9 additions & 3 deletions mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
Location loc = constOp.getLoc();
auto resType =
getTypeConverter()->convertType<VectorType>(constOp.getType());

if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
return rewriter.notifyMatchFailure(
loc,
"Cannot linearize a constant scalable vector that's not a splat");

if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type");
if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
Expand Down Expand Up @@ -104,11 +110,11 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
ConversionTarget &target, unsigned targetBitWidth) {

typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
// Ignore scalable vectors for now.
if (type.getRank() <= 1 || type.isScalable())
if (!isLinearizableVector(type))
return type;

return VectorType::get(type.getNumElements(), type.getElementType());
return VectorType::get(type.getNumElements(), type.getElementType(),
type.isScalable());
});

auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,8 @@ SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics,
: memref::getMixedSizes(rewriter, loc, base);
return mixedSourceDims;
}

bool vector::isLinearizableVector(VectorType type) {
auto numScalableDims = llvm::count(type.getScalableDims(), true);
return (type.getRank() > 1) && (numScalableDims <= 1);
}
61 changes: 59 additions & 2 deletions mlir/test/Dialect/Vector/linearize.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s -split-input-file -test-vector-linearize | FileCheck %s --check-prefixes=ALL,DEFAULT
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 | FileCheck %s --check-prefixes=ALL,BW-128
// RUN: mlir-opt %s -split-input-file -test-vector-linearize -verify-diagnostics | FileCheck %s --check-prefixes=ALL,DEFAULT
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0

// ALL-LABEL: test_linearize
Expand Down Expand Up @@ -97,3 +97,60 @@ func.func @test_tensor_no_linearize(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf3

return %0, %arg0 : tensor<2x2xf32>, tensor<2x2xf32>
}

// -----

// ALL-LABEL: func.func @test_scalable_linearize(
// ALL-SAME: %[[ARG_0:.*]]: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
// DEFAULT: %[[SC:.*]] = vector.shape_cast %[[ARG_0]] : vector<2x[2]xf32> to vector<[4]xf32>
// DEFAULT: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<[4]xf32>
// BW-128: %[[SC:.*]] = vector.shape_cast %[[ARG_0]] : vector<2x[2]xf32> to vector<[4]xf32>
// BW-128: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<[4]xf32>
// BW-0: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<2x[2]xf32>
%0 = arith.constant dense<[[3., 3.], [3., 3.]]> : vector<2x[2]xf32>

// DEFAULT: %[[SIN:.*]] = math.sin %[[SC]] : vector<[4]xf32>
// BW-128: %[[SIN:.*]] = math.sin %[[SC]] : vector<[4]xf32>
// BW-0: %[[SIN:.*]] = math.sin %[[ARG_0]] : vector<2x[2]xf32>
%1 = math.sin %arg0 : vector<2x[2]xf32>

// DEFAULT: %[[ADDF:.*]] = arith.addf %[[SIN]], %[[CST]] : vector<[4]xf32>
// BW-128: %[[ADDF:.*]] = arith.addf %[[SIN]], %[[CST]] : vector<[4]xf32>
// BW-0: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<2x[2]xf32>
%2 = arith.addf %0, %1 : vector<2x[2]xf32>

// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
// BW-128: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
// ALL: return %[[RES]] : vector<2x[2]xf32>
return %2 : vector<2x[2]xf32>
}

// -----

// ALL-LABEL: func.func @test_scalable_no_linearize(
// ALL-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
// ALL: %[[CST:.*]] = arith.constant dense<2.000000e+00> : vector<[2]x[2]xf32>
%0 = arith.constant dense<[[2., 2.], [2., 2.]]> : vector<[2]x[2]xf32>

// ALL: %[[SIN:.*]] = math.sin %[[VAL_0]] : vector<[2]x[2]xf32>
%1 = math.sin %arg0 : vector<[2]x[2]xf32>

// ALL: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<[2]x[2]xf32>
%2 = arith.addf %0, %1 : vector<[2]x[2]xf32>

// ALL: return %[[RES]] : vector<[2]x[2]xf32>
return %2 : vector<[2]x[2]xf32>
}

// -----

func.func @test_scalable_no_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
// expected-error@+1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}}
%0 = arith.constant dense<[[1., 1.], [3., 3.]]> : vector<2x[2]xf32>
%1 = math.sin %arg0 : vector<2x[2]xf32>
%2 = arith.addf %0, %1 : vector<2x[2]xf32>

return %2 : vector<2x[2]xf32>
}
4 changes: 3 additions & 1 deletion mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,9 @@ struct TestFlattenVectorTransferPatterns
Option<unsigned> targetVectorBitwidth{
*this, "target-vector-bitwidth",
llvm::cl::desc(
"Minimum vector bitwidth to enable the flattening transformation"),
"Minimum vector bitwidth to enable the flattening transformation. "
"For scalable vectors this is the base size, i.e. the size "
"corresponding to vscale=1."),
llvm::cl::init(std::numeric_limits<unsigned>::max())};

void runOnOperation() override {
Expand Down