Skip to content

Commit e26e836

Browse files
committed
Address comments
1 parent 8c42da9 commit e26e836

File tree

2 files changed

+14
-15
lines changed

2 files changed

+14
-15
lines changed

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
415415
}
416416
};
417417

418-
// Shuffles arith extend ops after vector.extract op.
418+
// Rewrites: vector.extract(arith.extend) -> arith.extend(vector.extract).
419419
//
420420
// This transforms IR like:
421421
// %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
@@ -460,13 +460,13 @@ struct SwapVectorExtractOfArithExtend
460460
Operation *newExtend =
461461
rewriter.create(loc, extendOpName, Value(newExtract), resultType);
462462

463-
rewriter.replaceOp(extractOp, newExtend->getResult(0));
463+
rewriter.replaceOp(extractOp, newExtend);
464464

465465
return success();
466466
}
467467
};
468468

469-
// Shuffles arith extend ops after vector.scalable.extract op.
469+
// Same as above, but for vector.scalable.extract.
470470
//
471471
// This transforms IR like:
472472
// %0 = arith.extsi %src : vector<[8]xi8> to vector<[8]xi32>
@@ -499,16 +499,15 @@ struct SwapVectorScalableExtractOfArithExtend
499499

500500
// Create new extract from source of extend.
501501
VectorType extractResultVectorType =
502-
VectorType::Builder(resultType)
503-
.setElementType(extendSourceVectorType.getElementType());
502+
resultType.clone(extendSourceVectorType.getElementType());
504503
Value newExtract = rewriter.create<vector::ScalableExtractOp>(
505504
loc, extractResultVectorType, extendSource, extractOp.getPos());
506505

507506
// Extend new extract to original result type.
508507
Operation *newExtend =
509508
rewriter.create(loc, extendOpName, Value(newExtract), resultType);
510509

511-
rewriter.replaceOp(extractOp, newExtend->getResult(0));
510+
rewriter.replaceOp(extractOp, newExtend);
512511

513512
return success();
514513
}

mlir/test/Dialect/ArmSME/vector-legalization.mlir

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -322,23 +322,23 @@ func.func @extract_from_arith_ext(%src: vector<4x[8]xi8>) -> vector<[8]xi32> {
322322
// CHECK-SAME: %[[SRC:[a-z0-9]+]]: vector<4x[8]xi8>,
323323
// CHECK-SAME: %[[DIM:[a-z0-9]+]]: index
324324
// CHECK: %[[EXTRACT:.*]] = vector.extract %[[SRC]][%[[DIM]]] : vector<[8]xi8> from vector<4x[8]xi8>
325-
// CHECK: %[[EXTEND:.*]] = arith.extsi %[[EXTRACT]] : vector<[8]xi8> to vector<[8]xi32>
325+
// CHECK: %[[EXTEND:.*]] = arith.extui %[[EXTRACT]] : vector<[8]xi8> to vector<[8]xi32>
326326
// CHECK: return %[[EXTEND]]
327327
func.func @non_constant_extract_from_arith_ext(%src: vector<4x[8]xi8>, %dim: index) -> vector<[8]xi32> {
328-
%0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
328+
%0 = arith.extui %src : vector<4x[8]xi8> to vector<4x[8]xi32>
329329
%1 = vector.extract %0[%dim] : vector<[8]xi32> from vector<4x[8]xi32>
330330
return %1 : vector<[8]xi32>
331331
}
332332

333333
// -----
334334

335335
// CHECK-LABEL: @scalable_extract_from_arith_ext(
336-
// CHECK-SAME: %[[SRC:.*]]: vector<[8]xi8>
337-
// CHECK: %[[EXTRACT:.*]] = vector.scalable.extract %[[SRC]][0] : vector<[4]xi8> from vector<[8]xi8>
338-
// CHECK: %[[EXTEND:.*]] = arith.extsi %[[EXTRACT]] : vector<[4]xi8> to vector<[4]xi32>
336+
// CHECK-SAME: %[[SRC:.*]]: vector<[8]xf16>
337+
// CHECK: %[[EXTRACT:.*]] = vector.scalable.extract %[[SRC]][0] : vector<[4]xf16> from vector<[8]xf16>
338+
// CHECK: %[[EXTEND:.*]] = arith.extf %[[EXTRACT]] : vector<[4]xf16> to vector<[4]xf32>
339339
// CHECK: return %[[EXTEND]]
340-
func.func @scalable_extract_from_arith_ext(%src: vector<[8]xi8>) -> vector<[4]xi32> {
341-
%0 = arith.extsi %src : vector<[8]xi8> to vector<[8]xi32>
342-
%1 = vector.scalable.extract %0[0] : vector<[4]xi32> from vector<[8]xi32>
343-
return %1 : vector<[4]xi32>
340+
func.func @scalable_extract_from_arith_ext(%src: vector<[8]xf16>) -> vector<[4]xf32> {
341+
%0 = arith.extf %src : vector<[8]xf16> to vector<[8]xf32>
342+
%1 = vector.scalable.extract %0[0] : vector<[4]xf32> from vector<[8]xf32>
343+
return %1 : vector<[4]xf32>
344344
}

0 commit comments

Comments
 (0)