@@ -330,91 +330,6 @@ static DiagnosedSilenceableFailure containsAll(ArrayRef<unsigned> reference,
330
330
return DiagnosedSilenceableFailure::success ();
331
331
}
332
332
333
- // / Populates `result` with the positional identifiers relative to `maxNumber`.
334
- // / If `isAll` is set, the result will contain all numbers from `0` to
335
- // / `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
336
- // / values from `rawList` are are interpreted as counting backwards from
337
- // / `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
338
- // / numbers remain as is. If `isInverted` is set, populates `result` with those
339
- // / values from the `0` to `maxNumber - 1` inclusive range that don't appear in
340
- // / `rawList`. If `rawList` contains values that are greater than or equal to
341
- // / `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
342
- // / given location. `maxNumber` must be positive. If `rawList` contains
343
- // / duplicate numbers or numbers that become duplicate after negative value
344
- // / remapping, emits a silenceable error.
345
- static DiagnosedSilenceableFailure
346
- expandTargetSpecification (Location loc, bool isAll, bool isInverted,
347
- ArrayRef<int64_t > rawList, int64_t maxNumber,
348
- SmallVectorImpl<int64_t > &result) {
349
- assert (maxNumber > 0 && " expected size to be positive" );
350
- assert (!(isAll && isInverted) && " cannot invert all" );
351
- if (isAll) {
352
- result = llvm::to_vector (llvm::seq<int64_t >(0 , maxNumber));
353
- return DiagnosedSilenceableFailure::success ();
354
- }
355
-
356
- SmallVector<int64_t > expanded;
357
- llvm::SmallDenseSet<int64_t > visited;
358
- expanded.reserve (rawList.size ());
359
- SmallVectorImpl<int64_t > &target = isInverted ? expanded : result;
360
- for (int64_t raw : rawList) {
361
- int64_t updated = raw < 0 ? maxNumber + raw : raw;
362
- if (updated >= maxNumber) {
363
- return emitSilenceableFailure (loc)
364
- << " position overflow " << updated << " (updated from " << raw
365
- << " ) for maximum " << maxNumber;
366
- }
367
- if (updated < 0 ) {
368
- return emitSilenceableFailure (loc) << " position underflow " << updated
369
- << " (updated from " << raw << " )" ;
370
- }
371
- if (!visited.insert (updated).second ) {
372
- return emitSilenceableFailure (loc) << " repeated position " << updated
373
- << " (updated from " << raw << " )" ;
374
- }
375
- target.push_back (updated);
376
- }
377
-
378
- if (!isInverted)
379
- return DiagnosedSilenceableFailure::success ();
380
-
381
- result.reserve (result.size () + (maxNumber - expanded.size ()));
382
- for (int64_t candidate : llvm::seq<int64_t >(0 , maxNumber)) {
383
- if (llvm::is_contained (expanded, candidate))
384
- continue ;
385
- result.push_back (candidate);
386
- }
387
-
388
- return DiagnosedSilenceableFailure::success ();
389
- }
390
-
391
- // / Checks if the positional specification defined is valid and reports errors
392
- // / otherwise.
393
- LogicalResult verifyStructuredTransformDimsOp (Operation *op,
394
- ArrayRef<int64_t > raw,
395
- bool inverted, bool all) {
396
- if (all) {
397
- if (inverted) {
398
- return op->emitOpError ()
399
- << " cannot request both 'all' and 'inverted' values in the list" ;
400
- }
401
- if (!raw.empty ()) {
402
- return op->emitOpError ()
403
- << " cannot both request 'all' and specific values in the list" ;
404
- }
405
- }
406
- if (!all && raw.empty ()) {
407
- return op->emitOpError () << " must request specific values in the list if "
408
- " 'all' is not specified" ;
409
- }
410
- SmallVector<int64_t > rawVector = llvm::to_vector (raw);
411
- auto *it = std::unique (rawVector.begin (), rawVector.end ());
412
- if (it != rawVector.end ())
413
- return op->emitOpError () << " expected the listed values to be unique" ;
414
-
415
- return success ();
416
- }
417
-
418
333
// ===----------------------------------------------------------------------===//
419
334
// MatchStructuredDimOp
420
335
// ===----------------------------------------------------------------------===//
@@ -475,8 +390,8 @@ LogicalResult transform::MatchStructuredDimOp::verify() {
475
390
return emitOpError () << " cannot request the same dimension to be both "
476
391
" parallel and reduction" ;
477
392
}
478
- return verifyStructuredTransformDimsOp (getOperation (), getRawDimList (),
479
- getIsInverted (), getIsAll ());
393
+ return verifyTransformMatchDimsOp (getOperation (), getRawDimList (),
394
+ getIsInverted (), getIsAll ());
480
395
}
481
396
482
397
// ===----------------------------------------------------------------------===//
@@ -592,8 +507,8 @@ LogicalResult verifyStructuredOperandOp(OpTy op) {
592
507
LogicalResult transform::MatchStructuredInputOp::verify () {
593
508
if (failed (verifyStructuredOperandOp (*this )))
594
509
return failure ();
595
- return verifyStructuredTransformDimsOp (getOperation (), getRawPositionList (),
596
- getIsInverted (), getIsAll ());
510
+ return verifyTransformMatchDimsOp (getOperation (), getRawPositionList (),
511
+ getIsInverted (), getIsAll ());
597
512
}
598
513
599
514
// ===----------------------------------------------------------------------===//
@@ -665,8 +580,8 @@ DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
665
580
LogicalResult transform::MatchStructuredInitOp::verify () {
666
581
if (failed (verifyStructuredOperandOp (*this )))
667
582
return failure ();
668
- return verifyStructuredTransformDimsOp (getOperation (), getRawPositionList (),
669
- getIsInverted (), getIsAll ());
583
+ return verifyTransformMatchDimsOp (getOperation (), getRawPositionList (),
584
+ getIsInverted (), getIsAll ());
670
585
}
671
586
672
587
// ===----------------------------------------------------------------------===//
@@ -793,78 +708,5 @@ void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
793
708
build (builder, state, ValueRange ());
794
709
}
795
710
796
- // ===----------------------------------------------------------------------===//
797
- // Printing and parsing for structured match ops.
798
- // ===----------------------------------------------------------------------===//
799
-
800
- // / Keyword syntax for positional specification inversion.
801
- constexpr const static llvm::StringLiteral kDimExceptKeyword = " except" ;
802
-
803
- // / Keyword syntax for full inclusion in positional specification.
804
- constexpr const static llvm::StringLiteral kDimAllKeyword = " all" ;
805
-
806
- // / Parses a positional specification for structured transform operations. The
807
- // / following forms are accepted:
808
- // /
809
- // / - `all`: sets `isAll` and returns;
810
- // / - comma-separated-integer-list: populates `rawDimList` with the values;
811
- // / - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
812
- // / with the values and sets `isInverted`.
813
- static ParseResult parseStructuredTransformDims (OpAsmParser &parser,
814
- DenseI64ArrayAttr &rawDimList,
815
- UnitAttr &isInverted,
816
- UnitAttr &isAll) {
817
- Builder &builder = parser.getBuilder ();
818
- if (parser.parseOptionalKeyword (kDimAllKeyword ).succeeded ()) {
819
- rawDimList = builder.getDenseI64ArrayAttr ({});
820
- isInverted = nullptr ;
821
- isAll = builder.getUnitAttr ();
822
- return success ();
823
- }
824
-
825
- isAll = nullptr ;
826
- isInverted = nullptr ;
827
- if (parser.parseOptionalKeyword (kDimExceptKeyword ).succeeded ()) {
828
- isInverted = builder.getUnitAttr ();
829
- }
830
-
831
- if (isInverted) {
832
- if (parser.parseLParen ().failed ())
833
- return failure ();
834
- }
835
-
836
- SmallVector<int64_t > values;
837
- ParseResult listResult = parser.parseCommaSeparatedList (
838
- [&]() { return parser.parseInteger (values.emplace_back ()); });
839
- if (listResult.failed ())
840
- return failure ();
841
-
842
- rawDimList = builder.getDenseI64ArrayAttr (values);
843
-
844
- if (isInverted) {
845
- if (parser.parseRParen ().failed ())
846
- return failure ();
847
- }
848
- return success ();
849
- }
850
-
851
- // / Prints a positional specification for structured transform operations.
852
- static void printStructuredTransformDims (OpAsmPrinter &printer, Operation *op,
853
- DenseI64ArrayAttr rawDimList,
854
- UnitAttr isInverted, UnitAttr isAll) {
855
- if (isAll) {
856
- printer << kDimAllKeyword ;
857
- return ;
858
- }
859
- if (isInverted) {
860
- printer << kDimExceptKeyword << " (" ;
861
- }
862
- llvm::interleaveComma (rawDimList.asArrayRef (), printer.getStream (),
863
- [&](int64_t value) { printer << value; });
864
- if (isInverted) {
865
- printer << " )" ;
866
- }
867
- }
868
-
869
711
#define GET_OP_CLASSES
870
712
#include " mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
0 commit comments