Skip to content

Commit d6a2014

Browse files
committed
[mlir][Linalg]: Add memory space to linalg transform::PromoteOp
This patch allows to supply an optional memory space of the promoted buffer. Differential Revision: https://reviews.llvm.org/D159074
1 parent c39b504 commit d6a2014

File tree

5 files changed

+177
-44
lines changed

5 files changed

+177
-44
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,9 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
165165
//===----------------------------------------------------------------------===//
166166

167167
def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
168-
[FunctionalStyleTransformOpTrait,
168+
[FunctionalStyleTransformOpTrait,
169169
MemoryEffectsOpInterface,
170-
TransformOpInterface,
170+
TransformOpInterface,
171171
TransformEachOpTrait,
172172
ReportTrackingListenerFailuresOpTrait]> {
173173
let description = [{
@@ -414,8 +414,8 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
414414
[DenseArrayNonNegative<DenseI64ArrayAttr>]>:$iterator_interchange);
415415
let results = (outs TransformHandleTypeInterface:$transformed);
416416

417-
let assemblyFormat = [{
418-
$target
417+
let assemblyFormat = [{
418+
$target
419419
(`iterator_interchange` `=` $iterator_interchange^)? attr-dict
420420
`:` custom<SemiFunctionType>(type($target), type($transformed))
421421
}];
@@ -479,7 +479,7 @@ def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
479479
TransformOpInterface,
480480
ReportTrackingListenerFailuresOpTrait]> {
481481
let description = [{
482-
Lower a tensor.unpack into empty + linalg.transpose + tensor.collapse_shape +
482+
Lower a tensor.unpack into empty + linalg.transpose + tensor.collapse_shape +
483483
tensor.extract_slice.
484484

485485
#### Return modes
@@ -497,7 +497,7 @@ def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
497497
Transform_ConcreteOpType<"linalg.transpose">:$transpose_op,
498498
Transform_ConcreteOpType<"tensor.collapse_shape">:$collapse_shape_op,
499499
Transform_ConcreteOpType<"tensor.extract_slice">:$extract_slice_op);
500-
let assemblyFormat = [{
500+
let assemblyFormat = [{
501501
$target attr-dict `:` functional-type(operands, results)
502502
}];
503503

@@ -665,7 +665,7 @@ def PackOp : Op<Transform_Dialect, "structured.pack", [
665665
let description = [{
666666
Pack a LinalgOp by applying a data tiling transformation on the op and
667667
packing the operands according to the `packed_sizes` specification.
668-
668+
669669
Iterator dimensions are tiled in their canonical order in the op spec.
670670
Operands are packed according to the same canonical order of the op iterator
671671
dimensions.
@@ -700,7 +700,7 @@ def PackOp : Op<Transform_Dialect, "structured.pack", [
700700
// affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
701701
// M N m n
702702
// affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
703-
%0 = linalg.generic_representing_some_higher_d_matmul
703+
%0 = linalg.generic_representing_some_higher_d_matmul
704704
ins(%A, %B: tensor<?x?x2x4xf32>, tensor<?x?x4x3xf32>)
705705
outs( %C: tensor<?x?x2x3xf32>)
706706
```
@@ -727,7 +727,7 @@ def PackOp : Op<Transform_Dialect, "structured.pack", [
727727
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_packed_sizes);
728728
let results = (outs TransformHandleTypeInterface:$packed_op);
729729
let assemblyFormat = [{
730-
$target
730+
$target
731731
`packed_sizes` `=` custom<DynamicIndexList>($packed_sizes,
732732
$static_packed_sizes,
733733
type($packed_sizes))
@@ -756,27 +756,27 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
756756
Target a Linalg op and rewrite it into packed LinalgOp form by trying to
757757
infer whether a known suboperation is embedded
758758

759-
Different packing strategies are applied in order, when one applies
759+
Different packing strategies are applied in order, when one applies
760760
successfully, the transform returns:
761761
1. Matmul packing: Try to infer a matmul operation embedded in the target op.
762762
Specifically, this looks for 2 parallel dimensions that participate in
763763
an outer-product and 1 reduction dimension.
764764
These dimensions are referred as (m, n, k) to match canonical matmul
765765
terminology.
766-
766+
767767
The packed sizes for (m, n, k) are specified by `matmul_packed_sizes`
768768
and the optional `matmul_padded_sizes_next_multiple_of`.
769-
When an entry `matmul_packed_sizes[i]` is non-0, the corresponding
769+
When an entry `matmul_packed_sizes[i]` is non-0, the corresponding
770770
dimension is packed by `matmul_packed_sizes[i]`.
771771
Otherwise, the dimension is merely padded to the next multiple of
772772
`matmul_padded_sizes_next_multiple_of[i]`.
773773

774774
`matmul_padded_sizes_next_multiple_of` is optional and is expected to
775775
either be empty or of size `3`, matching the size of `matmul_packed_sizes`.
776-
For each individual element of `matmul_packed_sizes` and
776+
For each individual element of `matmul_packed_sizes` and
777777
`matmul_padded_sizes_next_multiple_of`, only one of them is allowed to
778778
be non-zero.
779-
779+
780780
The ordering of the packed dimensions (mm, nn, kk) is specified by the
781781
`matmul_inner_dims_order` attribute.
782782

@@ -787,7 +787,7 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
787787
the most minor indexing dimensions of the linalg.generic. The most minor
788788
dimensions are themselves ordered according to `inner_dims_order`.
789789
4. An elementwise traversal of `matmul_packed_sizes` and
790-
`matmul_padded_sizes_next_multiple_of` is performed and for each
790+
`matmul_padded_sizes_next_multiple_of` is performed and for each
791791
dimension `d`, either pack to `matmul_packed_sizes[d]` or pad to the
792792
`matmul_padded_sizes_next_multiple_of[d]`.
793793
5. Packing/padding is performed by the amounts determined in step 4. and
@@ -815,7 +815,7 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
815815
[DenseArrayCount<3>]>:$static_matmul_packed_sizes,
816816
ConfinedAttr<DefaultValuedAttr<DenseI64ArrayAttr, "{}">,
817817
[Attr<
818-
Or<[DenseArrayCount<0>.predicate,
818+
Or<[DenseArrayCount<0>.predicate,
819819
DenseArrayCount<3>.predicate]>,
820820
"with 0 or 3 elements"
821821
>]>
@@ -837,7 +837,7 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
837837
`matmul_packed_sizes` `=` custom<DynamicIndexList>($matmul_packed_sizes,
838838
$static_matmul_packed_sizes,
839839
type($matmul_packed_sizes))
840-
(`matmul_padded_sizes_next_multiple_of` `=`
840+
(`matmul_padded_sizes_next_multiple_of` `=`
841841
$matmul_padded_sizes_next_multiple_of^)?
842842
`matmul_inner_dims_order` `=` $matmul_inner_dims_order
843843
)
@@ -862,7 +862,7 @@ def PackTransposeOp : Op<Transform_Dialect, "structured.pack_transpose", [
862862
DeclareOpInterfaceMethods<TransformOpInterface>,
863863
ReportTrackingListenerFailuresOpTrait]> {
864864
let description = [{
865-
Apply a transposition to a single `tensor.pack` (resp. `tensor.unpack`) and
865+
Apply a transposition to a single `tensor.pack` (resp. `tensor.unpack`) and
866866
update the `linalg.generic` op that consumes (resp. produces) the operation.
867867

868868
This transform allows composing a simple `structured.pack` with additional
@@ -874,7 +874,7 @@ def PackTransposeOp : Op<Transform_Dialect, "structured.pack_transpose", [
874874
the specified `tensor.pack` or `tensor.unpack` op.
875875

876876
If the `target` of this op is a `tensor.pack` then a new `tensor.empty` will
877-
be created along with transposed versions of the `tensor.pack` and the
877+
be created along with transposed versions of the `tensor.pack` and the
878878
consuming `linalg.generic`, which is expected to be the sole consumer.
879879

880880
If the `target` of this op is a `tensor.unpack` then the whole pack / compute
@@ -894,7 +894,7 @@ def PackTransposeOp : Op<Transform_Dialect, "structured.pack_transpose", [
894894

895895
This operation returns 3 handles, one to the transformed LinalgOp, one to
896896
the transformed `tensor.pack` and one to the transformed `tensor.unpack`.
897-
The last handle for `tensor.unpack` is empty if `target_pack_or_unpack_op`
897+
The last handle for `tensor.unpack` is empty if `target_pack_or_unpack_op`
898898
was not itself a `tensor.unpack`.
899899
}];
900900

@@ -971,7 +971,7 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
971971
let builders = [
972972
// Builder for a transform::PadOp with automatic inference of padding
973973
// value. Warning: this will set the value 0 for the inferred elemental
974-
// type without taking the op into account and thus only work for the
974+
// type without taking the op into account and thus only work for the
975975
// add/mul ring at the moment.
976976
// TODO: support other operations (e.g. min, max etc).
977977
OpBuilder<(ins "Value":$target,
@@ -1048,7 +1048,7 @@ def HoistPadOp : Op<Transform_Dialect, "structured.hoist_pad",
10481048
Hoist the tensor.pad target operation by at most the given number of loops.
10491049
Optionally apply the transpose attribute to the inner dimensions.
10501050

1051-
TODO: In the future, we should consider rewriting as a tensor.pack after
1051+
TODO: In the future, we should consider rewriting as a tensor.pack after
10521052
hoisting since this abstraction is now available.
10531053
TODO: Maybe also return the linalg.generic transpose created at some point.
10541054

@@ -1060,7 +1060,7 @@ def HoistPadOp : Op<Transform_Dialect, "structured.hoist_pad",
10601060
If all the operations referred to by the `target` handle padproperly, the
10611061
transform succeeds. Otherwise the transform silently fails.
10621062

1063-
The return handle points to only the subset of successfully hoisted
1063+
The return handle points to only the subset of successfully hoisted
10641064
tensor.pad operations, which can be empty.
10651065
}];
10661066

@@ -1073,9 +1073,9 @@ def HoistPadOp : Op<Transform_Dialect, "structured.hoist_pad",
10731073
let results = (outs TransformHandleTypeInterface:$transformed);
10741074

10751075
let assemblyFormat = [{
1076-
$target
1077-
`by` $num_loops `loops`
1078-
(`,` `transpose` `by` $transpose^)?
1076+
$target
1077+
`by` $num_loops `loops`
1078+
(`,` `transpose` `by` $transpose^)?
10791079
attr-dict
10801080
`:` functional-type(operands, results)
10811081
}];
@@ -1122,6 +1122,7 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
11221122
DefaultValuedAttr<BoolArrayAttr, "{}">:$use_full_tile_buffers,
11231123
UnitAttr:$use_full_tiles_by_default,
11241124
UnitAttr:$use_alloca,
1125+
OptionalAttr<AnyAttr>:$memory_space,
11251126
OptionalAttr<DeviceMappingArrayAttr>:$mapping,
11261127
OptionalAttr<I64Attr>:$alignment);
11271128
let results = (outs TransformHandleTypeInterface:$transformed);
@@ -1202,7 +1203,7 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
12021203
let arguments = (ins TransformHandleTypeInterface:$target);
12031204
let results = (outs TransformHandleTypeInterface:$result);
12041205

1205-
let assemblyFormat =
1206+
let assemblyFormat =
12061207
"$target attr-dict `:`"
12071208
"custom<SemiFunctionType>(type($target), type($result))";
12081209

@@ -1248,9 +1249,9 @@ def DecomposeInterfaceOp : Op<Transform_Dialect, "structured.decompose_interface
12481249

12491250
def RewriteInDestinationPassingStyleOp : Op<
12501251
Transform_Dialect, "structured.rewrite_in_destination_passing_style",
1251-
[FunctionalStyleTransformOpTrait,
1252+
[FunctionalStyleTransformOpTrait,
12521253
MemoryEffectsOpInterface,
1253-
TransformOpInterface,
1254+
TransformOpInterface,
12541255
TransformEachOpTrait,
12551256
ReportTrackingListenerFailuresOpTrait]> {
12561257
let description = [{
@@ -1260,7 +1261,7 @@ def RewriteInDestinationPassingStyleOp : Op<
12601261
- tensor.pad
12611262
- tensor.generate
12621263
- tensor.from_elements
1263-
This dichotomy hints at a future interface, for now the implementation just
1264+
This dichotomy hints at a future interface, for now the implementation just
12641265
switches between different implementation.
12651266

12661267
#### Return modes
@@ -1271,7 +1272,7 @@ def RewriteInDestinationPassingStyleOp : Op<
12711272
The return handle points to a subset of successfully produced operations:
12721273
- `tensor.pad` case, the returned handle points to the tensor.insert_slice.
12731274
- `tensor.generate` case, the returned handle points to the linalg.generic.
1274-
- `tensor.from_elements` case, the returned handle points to the last
1275+
- `tensor.from_elements` case, the returned handle points to the last
12751276
`tensor.insert`.
12761277
}];
12771278

@@ -1483,7 +1484,7 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
14831484
TransformHandleTypeInterface:$split_linalg_op,
14841485
TransformHandleTypeInterface:$combining_linalg_op);
14851486

1486-
let assemblyFormat =
1487+
let assemblyFormat =
14871488
"$target attr-dict `:`"
14881489
"functional-type(operands, results)";
14891490

@@ -1990,7 +1991,7 @@ def TileToScfForOp : Op<Transform_Dialect, "structured.tile_to_scf_for",
19901991
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange);
19911992
let results = (outs TransformHandleTypeInterface:$tiled_linalg_op,
19921993
Variadic<TransformHandleTypeInterface>:$loops);
1993-
1994+
19941995
let builders = [
19951996
OpBuilder<(ins "Value":$target,
19961997
"ArrayRef<OpFoldResult>":$mixedTileSizes,
@@ -2057,7 +2058,7 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
20572058
UnitAttr:$disable_transfer_permutation_map_lowering_patterns);
20582059
let results = (outs TransformHandleTypeInterface:$transformed);
20592060

2060-
let assemblyFormat =
2061+
let assemblyFormat =
20612062
"$target attr-dict `:`"
20622063
"functional-type(operands, results)";
20632064

@@ -2279,16 +2280,16 @@ def HoistRedundantTensorSubsetsOp :
22792280
TransformOpInterface,
22802281
ReportTrackingListenerFailuresOpTrait]> {
22812282
let description = [{
2282-
Hoists supported tensor subset extract/insert operation pairs out of
2283+
Hoists supported tensor subset extract/insert operation pairs out of
22832284
immediately enclosing loop iteratively, if the following conditions
22842285
are true:
22852286
1. The 2 ops access the same tensor subset.
22862287
2. All operands are invariant under the enclosing loop.
2287-
2288+
22882289
The supported subset extract/insert operation pairs currently comprise:
22892290
- tensor.extract_slice / tensor.insert_slice
22902291
- vector.transfer_read / vector.transfer_write on tensors
2291-
2292+
22922293
Only scf.for loops are currently supported.
22932294

22942295
When applied to:
@@ -2304,8 +2305,8 @@ def HoistRedundantTensorSubsetsOp :
23042305
let results = (outs);
23052306

23062307
let assemblyFormat = [{
2307-
$target
2308-
attr-dict
2308+
$target
2309+
attr-dict
23092310
`:` functional-type(operands, results)
23102311
}];
23112312

@@ -2328,15 +2329,15 @@ def InsertSliceToCopyOp :
23282329
TransformEachOpTrait, TransformOpInterface]> {
23292330
let description = [{
23302331
Targeted rewrite of an tensor.insert_slice to linalg.copy.
2331-
This is useful to materialize copies explicitly before bufferization and
2332+
This is useful to materialize copies explicitly before bufferization and
23322333
transform them, avoiding the need to rediscover them after bufferization.
23332334

23342335
If the insert_slice source is already a linalg.copy, only return the source
23352336
op (i.e. do not create an additional linalg.copy op).
23362337

23372338
#### Return modes:
23382339

2339-
The operation always succeeds and returns a handle to the relevant
2340+
The operation always succeeds and returns a handle to the relevant
23402341
linalg.copy op.
23412342
}];
23422343

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,13 @@ struct LinalgPromotionOptions {
362362
alignment = align;
363363
return *this;
364364
}
365+
/// Memory space of promoted buffer. If `std::nullopt` do not specify memory
366+
/// space.
367+
std::optional<Attribute> memorySpace;
368+
LinalgPromotionOptions &setMemorySpace(Attribute memorySpc) {
369+
memorySpace = memorySpc;
370+
return *this;
371+
}
365372
/// Use alloca with the default allocation scheme.
366373
bool useAlloca = false;
367374
LinalgPromotionOptions &setUseAlloca(bool use) {

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1883,6 +1883,8 @@ transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
18831883
llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
18841884
if (getAlignment().has_value())
18851885
promotionOptions = promotionOptions.setAlignment(*getAlignment());
1886+
if (getMemorySpace().has_value())
1887+
promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
18861888

18871889
if (getMapping().has_value()) {
18881890
// The mapping should only contain an element

mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,16 @@ static Value allocBuffer(ImplicitLocOpBuilder &b,
5454
if (alignment.has_value())
5555
alignmentAttr = b.getI64IntegerAttr(alignment.value());
5656

57+
Attribute memorySpaceAttr;
58+
if (options.memorySpace.has_value())
59+
memorySpaceAttr = *options.memorySpace;
60+
5761
// Static buffer.
5862
if (std::optional<int64_t> cst = getConstantIntValue(allocSize)) {
5963
auto staticBufferType =
6064
MemRefType::get(width * cst.value(), b.getIntegerType(8));
65+
staticBufferType =
66+
MemRefType::Builder(staticBufferType).setMemorySpace(memorySpaceAttr);
6167
if (options.useAlloca) {
6268
return b.create<memref::AllocaOp>(staticBufferType, ValueRange{},
6369
alignmentAttr);
@@ -69,6 +75,8 @@ static Value allocBuffer(ImplicitLocOpBuilder &b,
6975
// Fallback dynamic buffer.
7076
auto dynamicBufferType =
7177
MemRefType::get(ShapedType::kDynamic, b.getIntegerType(8));
78+
dynamicBufferType =
79+
MemRefType::Builder(dynamicBufferType).setMemorySpace(memorySpaceAttr);
7280
Value mul = b.createOrFold<arith::MulIOp>(
7381
b.create<arith::ConstantIndexOp>(width), allocSize);
7482
if (options.useAlloca)
@@ -89,16 +97,23 @@ static std::optional<Value> defaultAllocBufferCallBack(
8997
auto zero = b.create<arith::ConstantIndexOp>(0);
9098
auto one = b.create<arith::ConstantIndexOp>(1);
9199

100+
Attribute memorySpaceAttr;
101+
if (options.memorySpace.has_value())
102+
memorySpaceAttr = *options.memorySpace;
103+
92104
Value allocSize = one;
93105
for (const auto &size : llvm::enumerate(boundingSubViewSize))
94106
allocSize = b.createOrFold<arith::MulIOp>(allocSize, size.value());
95107
Value buffer = allocBuffer(b, options, viewType.getElementType(), allocSize,
96108
layout, alignment);
97109
SmallVector<int64_t, 4> dynSizes(boundingSubViewSize.size(),
98110
ShapedType::kDynamic);
99-
Value view = b.createOrFold<memref::ViewOp>(
100-
MemRefType::get(dynSizes, viewType.getElementType()), buffer, zero,
101-
boundingSubViewSize);
111+
112+
auto viewMemRefType = MemRefType::get(dynSizes, viewType.getElementType());
113+
viewMemRefType =
114+
MemRefType::Builder(viewMemRefType).setMemorySpace(memorySpaceAttr);
115+
Value view = b.createOrFold<memref::ViewOp>(viewMemRefType, buffer, zero,
116+
boundingSubViewSize);
102117
return view;
103118
}
104119

0 commit comments

Comments
 (0)