Skip to content

[mlir][vector] Add tests for populateSinkVectorBroadcastPatterns (1/n) #102286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -979,15 +979,18 @@ struct ReorderElementwiseOpsOnBroadcast final
if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
return failure();
if (!OpTrait::hasElementwiseMappableTraits(op))
return rewriter.notifyMatchFailure(
op, "Op doesn't have ElementwiseMappableTraits");
if (op->getNumOperands() == 0)
return failure();
if (op->getNumOperands() == 0 ||
op->getResults()[0].getType() != op->getOperand(0).getType()) {
return failure();
}
// Avoid operations that only accept vector types, since broadcast
// source might be scalar types.
if (op->getResults()[0].getType() != op->getOperand(0).getType())
return rewriter.notifyMatchFailure(op,
"result and operand type mismatch");
if (isa<vector::FMAOp>(op)) {
return failure();
return rewriter.notifyMatchFailure(
op,
"Op only accepts vector types - not supported as broadcast source "
"might be a scalar");
}

// Get the type of the lhs operand
Expand Down
117 changes: 100 additions & 17 deletions mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
Original file line number Diff line number Diff line change
@@ -1,18 +1,35 @@
// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s

//-----------------------------------------------------------------------------
// [Pattern: ReorderElementwiseOpsOnBroadcast]
//-----------------------------------------------------------------------------

// CHECK-LABEL: func.func @broadcast_scalar_with_bcast(
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> {
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
// CHECK: return %[[BCAST]] : vector<1x4xindex>

func.func @broadcast_scalar_with_bcast( %arg1: index, %arg2: index) -> vector<1x4xindex> {
func.func @broadcast_scalar_with_bcast(%arg1: index, %arg2: index) -> vector<1x4xindex> {
%0 = vector.broadcast %arg1 : index to vector<1x4xindex>
%1 = vector.broadcast %arg2 : index to vector<1x4xindex>
%2 = arith.addi %0, %1 : vector<1x4xindex>
return %2 : vector<1x4xindex>
}

// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_scalable(
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x[4]xindex> {
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x[4]xindex>
// CHECK: return %[[BCAST]] : vector<1x[4]xindex>

func.func @broadcast_scalar_with_bcast_scalable(%arg1: index, %arg2: index) -> vector<1x[4]xindex> {
%0 = vector.broadcast %arg1 : index to vector<1x[4]xindex>
%1 = vector.broadcast %arg2 : index to vector<1x[4]xindex>
%2 = arith.addi %0, %1 : vector<1x[4]xindex>
return %2 : vector<1x[4]xindex>
}

// -----

// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_and_splat(
Expand All @@ -21,13 +38,26 @@ func.func @broadcast_scalar_with_bcast( %arg1: index, %arg2: index) -> vector<1x
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
// CHECK: return %[[BCAST]] : vector<1x4xindex>
func.func @broadcast_scalar_with_bcast_and_splat( %arg1: index, %arg2: index) -> vector<1x4xindex> {
func.func @broadcast_scalar_with_bcast_and_splat(%arg1: index, %arg2: index) -> vector<1x4xindex> {
%0 = vector.splat %arg1 : vector<1x4xindex>
%1 = vector.broadcast %arg2 : index to vector<1x4xindex>
%2 = arith.addi %0, %1 : vector<1x4xindex>
return %2 : vector<1x4xindex>
}

// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_and_splat_scalable(
// CHECK-SAME: %[[ARG1:.*]]: index,
// CHECK-SAME: %[[ARG2:.*]]: index) -> vector<1x[4]xindex> {
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x[4]xindex>
// CHECK: return %[[BCAST]] : vector<1x[4]xindex>
func.func @broadcast_scalar_with_bcast_and_splat_scalable(%arg1: index, %arg2: index) -> vector<1x[4]xindex> {
%0 = vector.splat %arg1 : vector<1x[4]xindex>
%1 = vector.broadcast %arg2 : index to vector<1x[4]xindex>
%2 = arith.addi %0, %1 : vector<1x[4]xindex>
return %2 : vector<1x[4]xindex>
}

// -----

// CHECK-LABEL: func.func @broadcast_vector(
Expand All @@ -37,13 +67,27 @@ func.func @broadcast_scalar_with_bcast_and_splat( %arg1: index, %arg2: index) ->
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32>
// CHECK: return %[[BCAST]] : vector<3x4xf32>

func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
func.func @broadcast_vector(%arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
%arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32>
%arg2_bcast = vector.broadcast %arg2 : vector<4xf32> to vector<3x4xf32>
%2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32>
return %2 : vector<3x4xf32>
}

// CHECK-LABEL: func.func @broadcast_vector_scalable(
// CHECK-SAME: %[[ARG_0:.*]]: vector<[4]xf32>,
// CHECK-SAME: %[[ARG_1:.*]]: vector<[4]xf32>) -> vector<3x[4]xf32> {
// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<[4]xf32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<[4]xf32> to vector<3x[4]xf32>
// CHECK: return %[[BCAST]] : vector<3x[4]xf32>

func.func @broadcast_vector_scalable(%arg1: vector<[4]xf32>, %arg2: vector<[4]xf32>) -> vector<3x[4]xf32> {
%arg1_bcast = vector.broadcast %arg1 : vector<[4]xf32> to vector<3x[4]xf32>
%arg2_bcast = vector.broadcast %arg2 : vector<[4]xf32> to vector<3x[4]xf32>
%2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x[4]xf32>
return %2 : vector<3x[4]xf32>
}

// -----

// CHECK-LABEL: func.func @broadcast_scalar_and_vec(
Expand All @@ -53,13 +97,27 @@ func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vect
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<4xindex> to vector<1x4xindex>
// CHECK: %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x4xindex>
// CHECK: return %[[ADD]] : vector<1x4xindex>
func.func @broadcast_scalar_and_vec( %arg1: index, %arg2: vector<4xindex>) -> vector<1x4xindex> {
func.func @broadcast_scalar_and_vec(%arg1: index, %arg2: vector<4xindex>) -> vector<1x4xindex> {
%0 = vector.splat %arg1 : vector<1x4xindex>
%1 = vector.broadcast %arg2 : vector<4xindex> to vector<1x4xindex>
%2 = arith.addi %0, %1 : vector<1x4xindex>
return %2 : vector<1x4xindex>
}

// CHECK-LABEL: func.func @broadcast_scalar_and_vec_scalable(
// CHECK-SAME: %[[ARG1:.*]]: index,
// CHECK-SAME: %[[ARG2:.*]]: vector<[4]xindex>) -> vector<1x[4]xindex> {
// CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x[4]xindex>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<[4]xindex> to vector<1x[4]xindex>
// CHECK: %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x[4]xindex>
// CHECK: return %[[ADD]] : vector<1x[4]xindex>
func.func @broadcast_scalar_and_vec_scalable(%arg1: index, %arg2: vector<[4]xindex>) -> vector<1x[4]xindex> {
%0 = vector.splat %arg1 : vector<1x[4]xindex>
%1 = vector.broadcast %arg2 : vector<[4]xindex> to vector<1x[4]xindex>
%2 = arith.addi %0, %1 : vector<1x[4]xindex>
return %2 : vector<1x[4]xindex>
}

// -----

// CHECK-LABEL: func.func @broadcast_vector_and_scalar(
Expand All @@ -69,12 +127,25 @@ func.func @broadcast_scalar_and_vec( %arg1: index, %arg2: vector<4xindex>) -> ve
// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32>
// CHECK: return %[[ADD]] : vector<4xi32>

func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
func.func @broadcast_vector_and_scalar(%arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
%arg1_bcast = vector.broadcast %arg1 : i32 to vector<4xi32>
%2 = arith.addi %arg1_bcast, %arg2 : vector<4xi32>
return %2 : vector<4xi32>
}

// CHECK-LABEL: func.func @broadcast_vector_and_scalar_scalable(
// CHECK-SAME: %[[ARG_0:.*]]: i32,
// CHECK-SAME: %[[ARG_1:.*]]: vector<[4]xi32>) -> vector<[4]xi32> {
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<[4]xi32>
// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<[4]xi32>
// CHECK: return %[[ADD]] : vector<[4]xi32>

func.func @broadcast_vector_and_scalar_scalable(%arg1: i32, %arg2: vector<[4]xi32>) -> vector<[4]xi32> {
%arg1_bcast = vector.broadcast %arg1 : i32 to vector<[4]xi32>
%2 = arith.addi %arg1_bcast, %arg2 : vector<[4]xi32>
return %2 : vector<[4]xi32>
}

// -----

#matmat_accesses = [
Expand All @@ -87,40 +158,52 @@ func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vec
iterator_types = ["parallel", "parallel", "reduction"]
}

// CHECK-LABEL: func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
// CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<1.000000e+00> : vector<2x2xf32>
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<2.000000e+00> : vector<2x2xf32>
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<3.000000e+00> : vector<2x2xf32>
// CHECK: %[[VAL_3:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
// CHECK-LABEL: func.func @negative_not_elementwise
// CHECK-DAG: %[[F1:.*]] = arith.constant dense<1.000000e+00> : vector<2x2xf32>
// CHECK-DAG: %[[F2:.*]] = arith.constant dense<2.000000e+00> : vector<2x2xf32>
// CHECK-DAG: %[[F3:.*]] = arith.constant dense<3.000000e+00> : vector<2x2xf32>
// CHECK: %[[RES:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[F1]], %[[F2]], %[[F3]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
func.func @negative_not_elementwise() -> vector<2x2xf32> {
%f1 = arith.constant 1.0: f32
%f2 = arith.constant 2.0: f32
%f3 = arith.constant 3.0: f32

%A = vector.broadcast %f1 : f32 to vector<2x2xf32>
%B = vector.broadcast %f2 : f32 to vector<2x2xf32>
%C = vector.broadcast %f3 : f32 to vector<2x2xf32>
%mm1 = vector.contract #matmat_trait %A, %B, %C
%res = vector.contract #matmat_trait %A, %B, %C
: vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>

return %mm1 : vector<2x2xf32>
return %res : vector<2x2xf32>
}

// CHECK-LABEL: func.func @dont_sink_cmp(
// -----

// The source and the result for arith.cmp have different types - not supported

// CHECK-LABEL: func.func @negative_source_and_result_mismatch
// CHECK: %[[BROADCAST:.+]] = vector.broadcast
// CHECK: %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]]
// CHECK: return %[[RETURN]]
func.func @dont_sink_cmp(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> {
func.func @negative_source_and_result_mismatch(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> {
%0 = vector.broadcast %arg0 : f32 to vector<1xf32>
%1 = arith.cmpf uno, %0, %0 : vector<1xf32>
return %1 : vector<1xi1>
}

// CHECK-LABEL: func.func @dont_sink_fma(
// -----

// vector.fma only supports vectors - currently it's not possible to replace this with e.g.:
// %scalar_res = vector.fma %scalar_1, %scalar2
// %vec_res = vector.broadcast %scalar_res
//
// TODO: It should be possible to support this case

// CHECK-LABEL: func.func @negative_op_only_supports_vectors
// CHECK: %[[BROADCAST:.+]] = vector.broadcast
// CHECK: %[[RESULT:.+]] = vector.fma %[[BROADCAST]]
// CHECK: return %[[RESULT]]
func.func @dont_sink_fma(%arg0 : f32) -> vector<1xf32> {
func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> {
%0 = vector.broadcast %arg0 : f32 to vector<1xf32>
%1 = vector.fma %0, %0, %0 : vector<1xf32>
return %1 : vector<1xf32>
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,12 @@ func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1x


//===----------------------------------------------------------------------===//
// [Pattern: ReorderCastOpsOnBroadcast]
//
// Reorder casting ops and vector ops. The casting ops have almost identical
// pattern, so only arith.extsi op is tested.
//
// TODO: Potential duplication with sink-vector-broadcast.mlir
//===----------------------------------------------------------------------===//

// -----
Expand All @@ -272,6 +276,11 @@ func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {

// -----

//===----------------------------------------------------------------------===//
// [Pattern: ReorderElementwiseOpsOnTranspose]
//
// TODO: Potential duplication with sink-vector-broadcast.mlir
//===----------------------------------------------------------------------===//
func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32>
// CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32>
Expand All @@ -282,6 +291,7 @@ func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {

//===----------------------------------------------------------------------===//
// Reorder elementwise ops and vector ops.
// TODO: Potential duplication with sink-vector-broadcast.mlir
//===----------------------------------------------------------------------===//

// -----
Expand Down
Loading