Skip to content

Commit bb4696c

Browse files
authored
[mlir][linalg] Fix for bias handling for Winograd (#110331)
PR makes winograd.output_transform op a destination style op and fixes handing of a pre-existing data in its output argument (i.e. possibly pre-initialized with bias, which was discarded before). --------- Signed-off-by: Dmitriy Smirnov <[email protected]>
1 parent 8bb12ca commit bb4696c

File tree

5 files changed

+106
-105
lines changed

5 files changed

+106
-105
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
313313
}
314314

315315
def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
316-
[AllElementTypesMatch<["value", "output"]>,
316+
[AllElementTypesMatch<["value", "output"]>, DestinationStyleOpInterface,
317317
DeclareOpInterfaceMethods<TilingInterface,
318318
["getIterationDomain",
319319
"getLoopIteratorTypes",
@@ -396,6 +396,7 @@ def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
396396
int64_t getOutputFDim() {
397397
return 3;
398398
}
399+
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
399400
}];
400401
let hasVerifier = 1;
401402
}

mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp

Lines changed: 55 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
729729

730730
auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
731731
ValueRange args) -> scf::ValueVector {
732+
auto context = builder.getContext();
732733
Value tileHIter = ivs[0];
733734
Value tileWIter = ivs[1];
734735
Value NIter = ivs[2];
@@ -740,29 +741,41 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
740741
FIter, 2, 3, /*loopNorFIdx=*/4,
741742
/*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1);
742743

743-
TransformMapKeyTy key = {m, r};
744-
int64_t retRows = 1;
745-
int64_t retCols = 1;
746-
int64_t leftScalarFactor = 1;
747-
int64_t rightScalarFactor = 1;
744+
const TransformMapKeyTy key = {m, r};
745+
const TransformMatrix &AMatrix = AMatrices.at(key);
746+
const TransformMatrix &ATMatrix = ATMatrices.at(key);
747+
int64_t scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) *
748+
(leftTransform ? ATMatrix.scalarFactor : 1);
749+
int64_t retCols = rightTransform ? AMatrix.cols : 1;
750+
int64_t retRows = leftTransform ? ATMatrix.rows : 1;
751+
748752
Value matmulRetValue = extractValue;
749753
Value zero = builder.create<arith::ConstantOp>(
750754
loc, rewriter.getZeroAttr(elementType));
751-
if (leftTransform) {
752-
// Get constant transform matrix AT.
753-
auto it = ATMatrices.find(key);
754-
if (it == ATMatrices.end())
755-
return {};
756-
const TransformMatrix &ATMatrix = it->second;
757755

758-
leftScalarFactor = ATMatrix.scalarFactor;
759-
retRows = ATMatrix.rows;
756+
auto affineMap =
757+
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
758+
Value heightOffset =
759+
builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
760+
Value widthOffset =
761+
builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
762+
763+
Value outInitVal =
764+
extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
765+
widthOffset, retRows, retCols,
766+
/*loopNorFIdx=*/0,
767+
/*loopCorFIdx=*/3, /*heightIdx=*/1,
768+
/*widthIdx=*/2);
769+
if (leftTransform) {
760770
auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
761-
auto empty =
762-
builder
763-
.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
764-
.getResult();
765-
auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
771+
Value init = outInitVal;
772+
if (rightTransform || scalarFactor != 1) {
773+
auto empty = builder
774+
.create<tensor::EmptyOp>(loc, matmulType.getShape(),
775+
elementType)
776+
.getResult();
777+
init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
778+
}
766779

767780
Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
768781
// Multiply AT x m.
@@ -772,21 +785,16 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
772785
}
773786

774787
if (rightTransform) {
775-
// Get constant transform matrix T.
776-
auto it = AMatrices.find(key);
777-
if (it == AMatrices.end())
778-
return {};
779-
const TransformMatrix &AMatrix = it->second;
780-
781-
rightScalarFactor = AMatrix.scalarFactor;
782788
auto matmulType =
783789
RankedTensorType::get({retRows, AMatrix.cols}, elementType);
784-
retCols = AMatrix.cols;
785-
auto empty =
786-
builder
787-
.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
788-
.getResult();
789-
auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
790+
Value init = outInitVal;
791+
if (scalarFactor != 1) {
792+
auto empty = builder
793+
.create<tensor::EmptyOp>(loc, matmulType.getShape(),
794+
elementType)
795+
.getResult();
796+
init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
797+
}
790798

791799
Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
792800
// Multiply y = (AT x m) x A.
@@ -795,48 +803,36 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
795803
matmulRetValue = matmulOp.getResult(0);
796804
}
797805

798-
if (leftScalarFactor * rightScalarFactor != 1) {
799-
// Multiply scalar factor.
800-
Value scalarFactor = builder.create<arith::ConstantOp>(
801-
loc,
802-
FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor));
806+
if (scalarFactor != 1) {
807+
// Multiply by scalar factor and add outInitVal.
808+
Value scalarFactorValue = builder.create<arith::ConstantOp>(
809+
loc, FloatAttr::get(elementType, scalarFactor));
803810
auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
804-
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
805-
elementType);
806-
807811
auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
808812
SmallVector<AffineMap> affineMaps = {
809-
AffineMap::get(2, 0, init.getContext()), identityAffineMap};
810-
auto broadcastedScalar =
813+
AffineMap::get(2, 0, context), identityAffineMap, identityAffineMap};
814+
815+
matmulRetValue =
811816
rewriter
812817
.create<linalg::GenericOp>(
813-
loc, matmulType, ValueRange{scalarFactor}, ValueRange{init},
814-
affineMaps,
818+
loc, matmulType,
819+
ValueRange{scalarFactorValue, matmulRetValue},
820+
ValueRange{outInitVal}, affineMaps,
815821
llvm::ArrayRef<utils::IteratorType>{
816822
utils::IteratorType::parallel,
817823
utils::IteratorType::parallel},
818824
[&](OpBuilder &nestedBuilder, Location nestedLoc,
819825
ValueRange args) {
820-
nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
826+
auto mulf = nestedBuilder.create<arith::MulFOp>(
827+
nestedLoc, args[0], args[1]);
828+
auto addf = nestedBuilder.create<arith::AddFOp>(
829+
nestedLoc, mulf.getResult(), args[2]);
830+
nestedBuilder.create<linalg::YieldOp>(nestedLoc,
831+
addf.getResult());
821832
})
822833
.getResult(0);
823-
824-
matmulRetValue = builder
825-
.create<linalg::MulOp>(
826-
loc, matmulType,
827-
ValueRange{broadcastedScalar, matmulRetValue},
828-
ValueRange{init})
829-
.getResult(0);
830834
}
831835

832-
auto context = builder.getContext();
833-
auto affineMap =
834-
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
835-
Value heightOffset =
836-
builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
837-
Value widthOffset =
838-
builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
839-
840836
// Insert (H, W) to (N, H, W, F).
841837
Value combinedVal =
842838
insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter,

mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -85,31 +85,32 @@ module attributes {transform.with_named_sequence} {
8585
// CHECK: scf.yield %[[S9]]
8686
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
8787
// CHECK: %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
88+
// CHECK: %[[S7:.*]] = tensor.empty()
8889
// CHECK: %[[S6:.*]] = linalg.batch_matmul
8990
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2]
90-
// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
91-
// CHECK: %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S7]])
91+
// CHECK: %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
9292
// CHECK: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
9393
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
9494
// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
9595
// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
96-
// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
96+
// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG6]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
9797
// CHECK: %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
9898
// CHECK: %[[S15:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
9999
// CHECK: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
100+
// CHECK: %[[S25:.*]] = tensor.extract_slice %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
100101
// CHECK: %[[S16:.*]] = tensor.empty() : tensor<4x6xf32>
101102
// CHECK: %[[S17:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S16]] : tensor<4x6xf32>) -> tensor<4x6xf32>
102103
// CHECK: %[[S18:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_8]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S17]] : tensor<4x6xf32>) -> tensor<4x6xf32>
103104
// CHECK: %[[S19:.*]] = tensor.empty() : tensor<4x4xf32>
104105
// CHECK: %[[S20:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S19]] : tensor<4x4xf32>) -> tensor<4x4xf32>
105106
// CHECK: %[[S21:.*]] = linalg.matmul ins(%[[S18]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
106-
// CHECK: %[[S22:.*]] = tensor.empty() : tensor<4x4xf32>
107-
// CHECK: %[[S23:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S22]] : tensor<4x4xf32>) {
108-
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
109-
// CHECK: linalg.yield %[[IN]] : f32
107+
// CHECK: %[[S23:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S21]] : f32, tensor<4x4xf32>) outs(%[[S25]] : tensor<4x4xf32>) {
108+
// CHECK: ^bb0(%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32):
109+
// CHECK: %[[VAL_90:.*]] = arith.mulf %[[IN1]], %[[IN2]] : f32
110+
// CHECK: %[[VAL_91:.*]] = arith.addf %[[VAL_90]], %[[OUT]] : f32
111+
/// CHECK: linalg.yield %[[VAL_91]] : f32
110112
// CHECK: } -> tensor<4x4xf32>
111-
// CHECK: %[[S24:.*]] = linalg.mul ins(%[[S23]], %[[S21]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S22]] : tensor<4x4xf32>) -> tensor<4x4xf32>
112-
// CHECK: %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S24]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
113+
// CHECK: %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S23]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
113114
// CHECK: scf.yield %[[INSERTED_SLICE_9]]
114115
// CHECK: scf.yield %[[S15]]
115116
// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
@@ -218,32 +219,33 @@ module attributes {transform.with_named_sequence} {
218219
// CHECK: scf.yield %[[S9]]
219220
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
220221
// CHECK: %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
222+
// CHECK: %[[S7:.*]] = tensor.empty()
221223
// CHECK: %[[S6:.*]] = linalg.batch_matmul
222224
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2]
223225
// CHECK: %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0]
224-
// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
225-
// CHECK: %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]])
226+
// CHECK: %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[PADDED_8]])
226227
// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
227228
// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
228229
// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
229230
// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
230-
// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
231+
// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG7]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
231232
// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
232233
// CHECK: %[[S15:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
233234
// CHECK: %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
235+
// CHECK: %[[S26:.*]] = tensor.extract_slice %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
234236
// CHECK: %[[S17:.*]] = tensor.empty() : tensor<4x6xf32>
235237
// CHECK: %[[S18:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S17]] : tensor<4x6xf32>) -> tensor<4x6xf32>
236238
// CHECK: %[[S19:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_11]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S18]] : tensor<4x6xf32>) -> tensor<4x6xf32>
237239
// CHECK: %[[S20:.*]] = tensor.empty() : tensor<4x4xf32>
238240
// CHECK: %[[S21:.*]] = linalg.fill ins(%[[CST_6]] : f32) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32>
239241
// CHECK: %[[S22:.*]] = linalg.matmul ins(%[[S19]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S21]] : tensor<4x4xf32>) -> tensor<4x4xf32>
240-
// CHECK: %[[S23:.*]] = tensor.empty() : tensor<4x4xf32>
241-
// CHECK: %[[S24:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S23]] : tensor<4x4xf32>) {
242-
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
243-
// CHECK: linalg.yield %[[IN]] : f32
242+
// CHECK: %[[S24:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S22]] : f32, tensor<4x4xf32>) outs(%[[S26]] : tensor<4x4xf32>) {
243+
// CHECK: ^bb0(%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32):
244+
// CHECK: %[[VAL_104:.*]] = arith.mulf %[[IN1]], %[[IN2]] : f32
245+
// CHECK: %[[VAL_105:.*]] = arith.addf %[[VAL_104]], %[[OUT]] : f32
246+
/// CHECK: linalg.yield %[[VAL_105]] : f32
244247
// CHECK: } -> tensor<4x4xf32>
245-
// CHECK: %[[S25:.*]] = linalg.mul ins(%[[S24]], %[[S22]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S23]] : tensor<4x4xf32>) -> tensor<4x4xf32>
246-
// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S25]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
248+
// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S24]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
247249
// CHECK: scf.yield %[[INSERTED_SLICE_12]]
248250
// CHECK: scf.yield %[[S15]] : tensor<2x4x4x2xf32>
249251
// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
@@ -330,16 +332,17 @@ module attributes {transform.with_named_sequence} {
330332
// CHECK: %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
331333
// CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
332334
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
335+
// CHECK: %[[S15:.*]] = tensor.extract_slice %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
333336
// CHECK: %[[S9:.*]] = tensor.empty() : tensor<4x1xf32>
334337
// CHECK: %[[S10:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S9]] : tensor<4x1xf32>) -> tensor<4x1xf32>
335338
// CHECK: %[[S11:.*]] = linalg.matmul ins(%[[CST_0]], %[[EXTRACTED_SLICE]] : tensor<4x6xf32>, tensor<6x1xf32>) outs(%[[S10]] : tensor<4x1xf32>) -> tensor<4x1xf32>
336-
// CHECK: %[[S12:.*]] = tensor.empty() : tensor<4x1xf32>
337-
// CHECK: %[[S13:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S12]] : tensor<4x1xf32>) {
338-
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
339-
// CHECK: linalg.yield %[[IN]] : f32
339+
// CHECK: %[[S13:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S11]] : f32, tensor<4x1xf32>) outs(%[[S15]] : tensor<4x1xf32>) {
340+
// CHECK: ^bb0(%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32):
341+
// CHECK: %[[VAL_57:.*]] = arith.mulf %[[IN1]], %[[IN2]] : f32
342+
// CHECK: %[[VAL_58:.*]] = arith.addf %[[VAL_57]], %[[OUT]] : f32
343+
/// CHECK: linalg.yield %[[VAL_58]] : f32
340344
// CHECK: } -> tensor<4x1xf32>
341-
// CHECK: %[[S14:.*]] = linalg.mul ins(%[[S13]], %[[S11]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S12]] : tensor<4x1xf32>) -> tensor<4x1xf32>
342-
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S14]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
345+
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
343346
// CHECK: scf.yield %[[INSERTED_SLICE]]
344347
// CHECK: scf.yield %[[S7]]
345348
// CHECK: return %[[S6]]

0 commit comments

Comments
 (0)