@@ -397,6 +397,31 @@ class PossibleTopLevelTransformOpTrait
397
397
}
398
398
};
399
399
400
+ // / Trait implementing the TransformOpInterface for operations applying a
401
+ // / transformation to a single operation handle and producing a single operation
402
+ // / handle. The op must implement a method with one of the following signatures:
403
+ // / - FailureOr<convertible-to-Operation*> applyToOne(OpTy)
404
+ // / - LogicalResult applyToOne(OpTy)
405
+ // / to perform a transformation that is applied in turn to all payload IR
406
+ // / operations that correspond to the handle of the transform IR operation.
407
+ // / In the functions above, OpTy is either Operation * or a concrete payload IR
408
+ // / Op class that the transformation is applied to (NOT the class of the
409
+ // / transform IR op). The op is expected to have one operand and zero or one
410
+ // / results.
411
+ template <typename OpTy>
412
+ class TransformEachOpTrait
413
+ : public OpTrait::TraitBase<OpTy, TransformEachOpTrait> {
414
+ public:
415
+ // / Calls `applyToOne` for every payload operation associated with the operand
416
+ // / of this transform IR op. If `applyToOne` returns ops, associates them with
417
+ // / the result of this transform op.
418
+ LogicalResult apply (TransformResults &transformResults,
419
+ TransformState &state);
420
+
421
+ // / Checks that the op matches the expectations of this trait.
422
+ static LogicalResult verifyTrait (Operation *op);
423
+ };
424
+
400
425
// / Side effect resource corresponding to the mapping between Transform IR
401
426
// / values and Payload IR operations. An Allocate effect from this resource
402
427
// / means creating a new mapping entry, it is always accompanied by a Write
@@ -426,9 +451,150 @@ struct PayloadIRResource
426
451
StringRef getName () override { return " transform.payload_ir" ; }
427
452
};
428
453
454
+ // / Trait implementing the MemoryEffectOpInterface for single-operand
455
+ // / single-result operations that "consume" their operand and produce a new
456
+ // / result.
457
+ template <typename OpTy>
458
+ class FunctionalStyleTransformOpTrait
459
+ : public OpTrait::TraitBase<OpTy, FunctionalStyleTransformOpTrait> {
460
+ public:
461
+ // / This op "consumes" the operand by reading and freeing it, "produces" the
462
+ // / result by allocating and writing it and reads/writes the payload IR in the
463
+ // / process.
464
+ void getEffects (SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
465
+ effects.emplace_back (MemoryEffects::Read::get (),
466
+ this ->getOperation ()->getOperand (0 ),
467
+ TransformMappingResource::get ());
468
+ effects.emplace_back (MemoryEffects::Free::get (),
469
+ this ->getOperation ()->getOperand (0 ),
470
+ TransformMappingResource::get ());
471
+ effects.emplace_back (MemoryEffects::Allocate::get (),
472
+ this ->getOperation ()->getResult (0 ),
473
+ TransformMappingResource::get ());
474
+ effects.emplace_back (MemoryEffects::Write::get (),
475
+ this ->getOperation ()->getResult (0 ),
476
+ TransformMappingResource::get ());
477
+ effects.emplace_back (MemoryEffects::Read::get (), PayloadIRResource::get ());
478
+ effects.emplace_back (MemoryEffects::Write::get (), PayloadIRResource::get ());
479
+ }
480
+
481
+ // / Checks that the op matches the expectations of this trait.
482
+ static LogicalResult verifyTrait (Operation *op) {
483
+ static_assert (OpTy::template hasTrait<OpTrait::OneOperand>(),
484
+ " expected single-operand op" );
485
+ static_assert (OpTy::template hasTrait<OpTrait::OneResult>(),
486
+ " expected single-result op" );
487
+ if (!op->getName ().getInterface <MemoryEffectOpInterface>()) {
488
+ op->emitError ()
489
+ << " FunctionalStyleTransformOpTrait should only be attached to ops "
490
+ " that implement MemoryEffectOpInterface" ;
491
+ }
492
+ return success ();
493
+ }
494
+ };
495
+
429
496
} // namespace transform
430
497
} // namespace mlir
431
498
432
499
#include " mlir/Dialect/Transform/IR/TransformInterfaces.h.inc"
433
500
501
+ namespace mlir {
502
+ namespace transform {
503
+ namespace detail {
504
+ // / Appends `result` to the vector assuming it corresponds to the success state
505
+ // / in `FailureOr<convertible-to-Operation*>`. If `result` is just a
506
+ // / `LogicalResult`, does nothing.
507
+ template <typename Ty>
508
+ std::enable_if_t <std::is_same<Ty, LogicalResult>::value, LogicalResult>
509
+ appendTransformResultToVector (Ty result,
510
+ SmallVectorImpl<Operation *> &results) {
511
+ return result;
512
+ }
513
+ template <typename Ty>
514
+ std::enable_if_t <!std::is_same<Ty, LogicalResult>::value, LogicalResult>
515
+ appendTransformResultToVector (Ty result,
516
+ SmallVectorImpl<Operation *> &results) {
517
+ static_assert (
518
+ std::is_convertible<typename Ty::value_type, Operation *>::value,
519
+ " expected transform function to return operations" );
520
+ if (failed (result))
521
+ return failure ();
522
+
523
+ results.push_back (*result);
524
+ return success ();
525
+ }
526
+
527
+ // / Applies a one-to-one transform to each of the given targets. Puts the
528
+ // / results of transforms, if any, in `results` in the same order. Fails if any
529
+ // / of the application fails. Individual transforms must be callable with
530
+ // / one of the following signatures:
531
+ // / - FailureOr<convertible-to-Operation*>(OpTy)
532
+ // / - LogicalResult(OpTy)
533
+ // / where OpTy is either
534
+ // / - Operation *, in which case the transform is always applied;
535
+ // / - a concrete Op class, in which case a check is performed whether
536
+ // / `targets` contains operations of the same class and a failure is reported
537
+ // / if it does not.
538
+ template <typename FnTy>
539
+ LogicalResult applyTransformToEach (ArrayRef<Operation *> targets,
540
+ SmallVectorImpl<Operation *> &results,
541
+ FnTy transform) {
542
+ using OpTy = typename llvm::function_traits<FnTy>::template arg_t <0 >;
543
+ static_assert (std::is_convertible<OpTy, Operation *>::value,
544
+ " expected transform function to take an operation" );
545
+ using RetTy = typename llvm::function_traits<FnTy>::result_t ;
546
+ static_assert (std::is_convertible<RetTy, LogicalResult>::value,
547
+ " expected transform function to return LogicalResult or "
548
+ " FailureOr<convertible-to-Operation*>" );
549
+ for (Operation *target : targets) {
550
+ auto specificOp = dyn_cast<OpTy>(target);
551
+ if (!specificOp)
552
+ return failure ();
553
+
554
+ auto result = transform (specificOp);
555
+ if (failed (appendTransformResultToVector (result, results)))
556
+ return failure ();
557
+ }
558
+ return success ();
559
+ }
560
+ } // namespace detail
561
+ } // namespace transform
562
+ } // namespace mlir
563
+
564
+ template <typename OpTy>
565
+ mlir::LogicalResult mlir::transform::TransformEachOpTrait<OpTy>::apply(
566
+ TransformResults &transformResults, TransformState &state) {
567
+ using TransformOpType = typename llvm::function_traits<
568
+ decltype (&OpTy::applyToOne)>::template arg_t <0 >;
569
+ ArrayRef<Operation *> targets =
570
+ state.getPayloadOps (this ->getOperation ()->getOperand (0 ));
571
+ SmallVector<Operation *> results;
572
+ if (failed (detail::applyTransformToEach (
573
+ targets, results, [&](TransformOpType specificOp) {
574
+ return static_cast <OpTy *>(this )->applyToOne (specificOp);
575
+ })))
576
+ return failure ();
577
+ if (OpTy::template hasTrait<OpTrait::OneResult>()) {
578
+ transformResults.set (
579
+ this ->getOperation ()->getResult (0 ).template cast <OpResult>(), results);
580
+ }
581
+ return success ();
582
+ }
583
+
584
+ template <typename OpTy>
585
+ mlir::LogicalResult
586
+ mlir::transform::TransformEachOpTrait<OpTy>::verifyTrait(Operation *op) {
587
+ static_assert (OpTy::template hasTrait<OpTrait::OneOperand>(),
588
+ " expected single-operand op" );
589
+ static_assert (OpTy::template hasTrait<OpTrait::OneResult>() ||
590
+ OpTy::template hasTrait<OpTrait::ZeroResults>(),
591
+ " expected zero- or single-result op" );
592
+ if (!op->getName ().getInterface <TransformOpInterface>()) {
593
+ return op->emitError () << " TransformEachOpTrait should only be attached to "
594
+ " ops that implement TransformOpInterface" ;
595
+ }
596
+
597
+ return success ();
598
+ }
599
+
434
600
#endif // DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H
0 commit comments