Skip to content

Commit b08accd

Browse files
committed
correct the way to broadcast a scalar value
1 parent 549029b commit b08accd

File tree

2 files changed

+43
-15
lines changed

2 files changed

+43
-15
lines changed

mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -686,21 +686,43 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
686686
matmulRetValue = matmulOp.getResult(0);
687687
}
688688

689-
// Multiply scalar factor.
690-
Value scalarFactor = builder.create<arith::ConstantOp>(
691-
loc, FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor));
692-
auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
693-
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
694-
elementType);
695-
Value broadcastedScalar =
696-
builder.create<tensor::FromElementsOp>(loc, matmulType, scalarFactor);
697-
auto scaledMatmul = builder.create<linalg::MulOp>(
698-
loc, matmulType, ValueRange{broadcastedScalar, matmulRetValue},
699-
ValueRange{init});
689+
if (leftScalarFactor * rightScalarFactor != 1) {
690+
// Multiply scalar factor.
691+
Value scalarFactor = builder.create<arith::ConstantOp>(
692+
loc,
693+
FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor));
694+
auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
695+
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
696+
elementType);
697+
698+
auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
699+
SmallVector<AffineMap> affineMaps = {
700+
AffineMap::get(2, 0, init.getContext()), identityAffineMap};
701+
auto broadcastedScalar =
702+
rewriter
703+
.create<linalg::GenericOp>(
704+
loc, matmulType, ValueRange{scalarFactor}, ValueRange{init},
705+
affineMaps,
706+
llvm::ArrayRef<utils::IteratorType>{
707+
utils::IteratorType::parallel,
708+
utils::IteratorType::parallel},
709+
[&](OpBuilder &nestedBuilder, Location nestedLoc,
710+
ValueRange args) {
711+
nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
712+
})
713+
.getResult(0);
714+
715+
matmulRetValue = builder
716+
.create<linalg::MulOp>(
717+
loc, matmulType,
718+
ValueRange{broadcastedScalar, matmulRetValue},
719+
ValueRange{init})
720+
.getResult(0);
721+
}
700722

701723
// Insert (H, W) to (N, H, W, F).
702-
Value combinedVal = insert2DData(builder, loc, scaledMatmul.getResult(0),
703-
args[0], NIter, FIter, retRows, retCols,
724+
Value combinedVal = insert2DData(builder, loc, matmulRetValue, args[0],
725+
NIter, FIter, retRows, retCols,
704726
/*outLoopIdx=*/0,
705727
/*inLoopIdx=*/3, /*heightIdx=*/1,
706728
/*widthIdx=*/2, /*destSize=*/4);

mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>
1414
return %6 : tensor<2x4x4x2xf32>
1515
}
1616

17+
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> ()>
18+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
1719
// CHECK-LABEL: func.func @conv2d_4x4_3x3
1820
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
19-
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.024000e+03> : tensor<4x4xf32>
21+
// CHECK-DAG: %[[CST:.*]] = arith.constant 1.024000e+03 : f32
2022
// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
2123
// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
2224
// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
@@ -73,7 +75,11 @@ func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>
7375
// CHECK-NEXT: %[[S10:.*]] = tensor.empty() : tensor<4x4xf32>
7476
// CHECK-NEXT: %[[S11:.*]] = linalg.matmul ins(%[[S9]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S10]] : tensor<4x4xf32>) -> tensor<4x4xf32>
7577
// CHECK-NEXT: %[[S12:.*]] = tensor.empty() : tensor<4x4xf32>
76-
// CHECK-NEXT: %[[S13:.*]] = linalg.mul ins(%cst, %[[S11]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32>
78+
// CHECK-NEXT: %[[BROADCAST:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S12]] : tensor<4x4xf32>) {
79+
// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
80+
// CHECK-NEXT: linalg.yield %[[IN]] : f32
81+
// CHECK-NEXT: } -> tensor<4x4xf32>
82+
// CHECK-NEXT: %[[S13:.*]] = linalg.mul ins(%[[BROADCAST]], %[[S11]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32>
7783
// CHECK-NEXT: %[[S14:.*]] = tensor.empty() : tensor<1x4x4x1xf32>
7884
// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32>
7985
// CHECK-NEXT: %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32>

0 commit comments

Comments
 (0)