Skip to content

Commit 5cde5a5

Browse files
committed
[mlir] add interchange, pad and scalarize to structured transform dialect
Add ops to the structured transform extension of the transform dialect that perform interchange, padding and scalarization on structured ops. Along with tiling that is already defined, this provides a minimal set of transformations necessary to build vectorizable code for a single structured op. Define two helper traits: one that implements TransformOpInterface by applying a function to each payload op independently and another that provides a simple "functional-style" producer/consumer list of memory effects for the transform ops. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D126374
1 parent b4dbcba commit 5cde5a5

File tree

11 files changed

+762
-40
lines changed

11 files changed

+762
-40
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,16 @@
99
#ifndef MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H
1010
#define MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H
1111

12+
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
1213
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
1314
#include "mlir/IR/OpImplementation.h"
1415

16+
namespace mlir {
17+
namespace linalg {
18+
class LinalgOp;
19+
} // namespace linalg
20+
} // namespace mlir
21+
1522
//===----------------------------------------------------------------------===//
1623
// Linalg Transform Operations
1724
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,81 @@ include "mlir/Dialect/PDL/IR/PDLTypes.td"
1616
include "mlir/Interfaces/SideEffectInterfaces.td"
1717
include "mlir/IR/OpBase.td"
1818

19+
def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
20+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
21+
TransformOpInterface, TransformEachOpTrait]> {
22+
let description = [{
23+
Interchanges the iterators of the operations pointed to by the target handle
24+
using the iterator interchange attribute.
25+
}];
26+
27+
let arguments =
28+
(ins PDL_Operation:$target,
29+
DefaultValuedAttr<I64ArrayAttr, "{}">:$iterator_interchange);
30+
let results = (outs PDL_Operation:$transformed);
31+
32+
let assemblyFormat = "$target attr-dict";
33+
let hasVerifier = 1;
34+
35+
let extraClassDeclaration = [{
36+
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
37+
::mlir::linalg::LinalgOp target);
38+
}];
39+
}
40+
41+
def PadOp : Op<Transform_Dialect, "structured.pad",
42+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
43+
TransformOpInterface, TransformEachOpTrait]> {
44+
let description = [{
45+
Pads the operations pointed to by the target handle using the options
46+
provides as operation attributes.
47+
}];
48+
49+
let arguments =
50+
(ins PDL_Operation:$target,
51+
DefaultValuedAttr<ArrayAttr, "{}">:$padding_values,
52+
DefaultValuedAttr<I64ArrayAttr, "{}">:$padding_dimensions,
53+
DefaultValuedAttr<I64ArrayAttr, "{}">:$pack_paddings,
54+
DefaultValuedAttr<I64ArrayAttr, "{}">:$hoist_paddings,
55+
DefaultValuedAttr<
56+
TypedArrayAttrBase<I64ArrayAttr, "array of arrays of i64">,
57+
"{}">:$transpose_paddings);
58+
let results = (outs PDL_Operation:$transformed);
59+
60+
let assemblyFormat = "$target attr-dict";
61+
let hasVerifier = 1;
62+
63+
let extraClassDeclaration = [{
64+
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
65+
::mlir::linalg::LinalgOp target);
66+
}];
67+
}
68+
69+
def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
70+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
71+
TransformOpInterface, TransformEachOpTrait]> {
72+
let description = [{
73+
Indicates that ops of a specific kind in the given function should be
74+
scalarized (i.e. their dynamic dimensions tiled by 1).
75+
76+
This operation returns the tiled op but not the loops.
77+
78+
We make this design choice because it is hard to know ahead of time the
79+
number of loops that will be produced (it depends on the number of dynamic
80+
dimensions after multiple transformations have been applied).
81+
}];
82+
83+
let arguments = (ins PDL_Operation:$target);
84+
let results = (outs PDL_Operation:$result);
85+
86+
let assemblyFormat = "$target attr-dict";
87+
88+
let extraClassDeclaration = [{
89+
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
90+
::mlir::linalg::LinalgOp target);
91+
}];
92+
}
93+
1994
def TileOp : Op<Transform_Dialect, "structured.tile",
2095
[DeclareOpInterfaceMethods<TransformOpInterface>,
2196
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,31 @@ class PossibleTopLevelTransformOpTrait
397397
}
398398
};
399399

400+
/// Trait implementing the TransformOpInterface for operations applying a
401+
/// transformation to a single operation handle and producing a single operation
402+
/// handle. The op must implement a method with one of the following signatures:
403+
/// - FailureOr<convertible-to-Operation*> applyToOne(OpTy)
404+
/// - LogicalResult applyToOne(OpTy)
405+
/// to perform a transformation that is applied in turn to all payload IR
406+
/// operations that correspond to the handle of the transform IR operation.
407+
/// In the functions above, OpTy is either Operation * or a concrete payload IR
408+
/// Op class that the transformation is applied to (NOT the class of the
409+
/// transform IR op). The op is expected to have one operand and zero or one
410+
/// results.
411+
template <typename OpTy>
412+
class TransformEachOpTrait
413+
: public OpTrait::TraitBase<OpTy, TransformEachOpTrait> {
414+
public:
415+
/// Calls `applyToOne` for every payload operation associated with the operand
416+
/// of this transform IR op. If `applyToOne` returns ops, associates them with
417+
/// the result of this transform op.
418+
LogicalResult apply(TransformResults &transformResults,
419+
TransformState &state);
420+
421+
/// Checks that the op matches the expectations of this trait.
422+
static LogicalResult verifyTrait(Operation *op);
423+
};
424+
400425
/// Side effect resource corresponding to the mapping between Transform IR
401426
/// values and Payload IR operations. An Allocate effect from this resource
402427
/// means creating a new mapping entry, it is always accompanied by a Write
@@ -426,9 +451,150 @@ struct PayloadIRResource
426451
StringRef getName() override { return "transform.payload_ir"; }
427452
};
428453

454+
/// Trait implementing the MemoryEffectOpInterface for single-operand
455+
/// single-result operations that "consume" their operand and produce a new
456+
/// result.
457+
template <typename OpTy>
458+
class FunctionalStyleTransformOpTrait
459+
: public OpTrait::TraitBase<OpTy, FunctionalStyleTransformOpTrait> {
460+
public:
461+
/// This op "consumes" the operand by reading and freeing it, "produces" the
462+
/// result by allocating and writing it and reads/writes the payload IR in the
463+
/// process.
464+
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
465+
effects.emplace_back(MemoryEffects::Read::get(),
466+
this->getOperation()->getOperand(0),
467+
TransformMappingResource::get());
468+
effects.emplace_back(MemoryEffects::Free::get(),
469+
this->getOperation()->getOperand(0),
470+
TransformMappingResource::get());
471+
effects.emplace_back(MemoryEffects::Allocate::get(),
472+
this->getOperation()->getResult(0),
473+
TransformMappingResource::get());
474+
effects.emplace_back(MemoryEffects::Write::get(),
475+
this->getOperation()->getResult(0),
476+
TransformMappingResource::get());
477+
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
478+
effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
479+
}
480+
481+
/// Checks that the op matches the expectations of this trait.
482+
static LogicalResult verifyTrait(Operation *op) {
483+
static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
484+
"expected single-operand op");
485+
static_assert(OpTy::template hasTrait<OpTrait::OneResult>(),
486+
"expected single-result op");
487+
if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
488+
op->emitError()
489+
<< "FunctionalStyleTransformOpTrait should only be attached to ops "
490+
"that implement MemoryEffectOpInterface";
491+
}
492+
return success();
493+
}
494+
};
495+
429496
} // namespace transform
430497
} // namespace mlir
431498

432499
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h.inc"
433500

501+
namespace mlir {
502+
namespace transform {
503+
namespace detail {
504+
/// Appends `result` to the vector assuming it corresponds to the success state
505+
/// in `FailureOr<convertible-to-Operation*>`. If `result` is just a
506+
/// `LogicalResult`, does nothing.
507+
template <typename Ty>
508+
std::enable_if_t<std::is_same<Ty, LogicalResult>::value, LogicalResult>
509+
appendTransformResultToVector(Ty result,
510+
SmallVectorImpl<Operation *> &results) {
511+
return result;
512+
}
513+
template <typename Ty>
514+
std::enable_if_t<!std::is_same<Ty, LogicalResult>::value, LogicalResult>
515+
appendTransformResultToVector(Ty result,
516+
SmallVectorImpl<Operation *> &results) {
517+
static_assert(
518+
std::is_convertible<typename Ty::value_type, Operation *>::value,
519+
"expected transform function to return operations");
520+
if (failed(result))
521+
return failure();
522+
523+
results.push_back(*result);
524+
return success();
525+
}
526+
527+
/// Applies a one-to-one transform to each of the given targets. Puts the
528+
/// results of transforms, if any, in `results` in the same order. Fails if any
529+
/// of the application fails. Individual transforms must be callable with
530+
/// one of the following signatures:
531+
/// - FailureOr<convertible-to-Operation*>(OpTy)
532+
/// - LogicalResult(OpTy)
533+
/// where OpTy is either
534+
/// - Operation *, in which case the transform is always applied;
535+
/// - a concrete Op class, in which case a check is performed whether
536+
/// `targets` contains operations of the same class and a failure is reported
537+
/// if it does not.
538+
template <typename FnTy>
539+
LogicalResult applyTransformToEach(ArrayRef<Operation *> targets,
540+
SmallVectorImpl<Operation *> &results,
541+
FnTy transform) {
542+
using OpTy = typename llvm::function_traits<FnTy>::template arg_t<0>;
543+
static_assert(std::is_convertible<OpTy, Operation *>::value,
544+
"expected transform function to take an operation");
545+
using RetTy = typename llvm::function_traits<FnTy>::result_t;
546+
static_assert(std::is_convertible<RetTy, LogicalResult>::value,
547+
"expected transform function to return LogicalResult or "
548+
"FailureOr<convertible-to-Operation*>");
549+
for (Operation *target : targets) {
550+
auto specificOp = dyn_cast<OpTy>(target);
551+
if (!specificOp)
552+
return failure();
553+
554+
auto result = transform(specificOp);
555+
if (failed(appendTransformResultToVector(result, results)))
556+
return failure();
557+
}
558+
return success();
559+
}
560+
} // namespace detail
561+
} // namespace transform
562+
} // namespace mlir
563+
564+
template <typename OpTy>
565+
mlir::LogicalResult mlir::transform::TransformEachOpTrait<OpTy>::apply(
566+
TransformResults &transformResults, TransformState &state) {
567+
using TransformOpType = typename llvm::function_traits<
568+
decltype(&OpTy::applyToOne)>::template arg_t<0>;
569+
ArrayRef<Operation *> targets =
570+
state.getPayloadOps(this->getOperation()->getOperand(0));
571+
SmallVector<Operation *> results;
572+
if (failed(detail::applyTransformToEach(
573+
targets, results, [&](TransformOpType specificOp) {
574+
return static_cast<OpTy *>(this)->applyToOne(specificOp);
575+
})))
576+
return failure();
577+
if (OpTy::template hasTrait<OpTrait::OneResult>()) {
578+
transformResults.set(
579+
this->getOperation()->getResult(0).template cast<OpResult>(), results);
580+
}
581+
return success();
582+
}
583+
584+
template <typename OpTy>
585+
mlir::LogicalResult
586+
mlir::transform::TransformEachOpTrait<OpTy>::verifyTrait(Operation *op) {
587+
static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
588+
"expected single-operand op");
589+
static_assert(OpTy::template hasTrait<OpTrait::OneResult>() ||
590+
OpTy::template hasTrait<OpTrait::ZeroResults>(),
591+
"expected zero- or single-result op");
592+
if (!op->getName().getInterface<TransformOpInterface>()) {
593+
return op->emitError() << "TransformEachOpTrait should only be attached to "
594+
"ops that implement TransformOpInterface";
595+
}
596+
597+
return success();
598+
}
599+
434600
#endif // DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,13 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
4949
];
5050
}
5151

52+
def FunctionalStyleTransformOpTrait
53+
: NativeOpTrait<"FunctionalStyleTransformOpTrait"> {
54+
let cppNamespace = "::mlir::transform";
55+
}
56+
57+
def TransformEachOpTrait : NativeOpTrait<"TransformEachOpTrait"> {
58+
let cppNamespace = "::mlir::transform";
59+
}
60+
5261
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD

0 commit comments

Comments
 (0)