Skip to content

Commit f65fc29

Browse files
committed
[mlir][MemRef] Add more ops to narrow type support, strided metadata
- Add support fef memory_space_cast to strided metadata expansion and narrow type emulation - Add support for expand_shape to narrow type emulation (like collapse_shape, it's a noop after linearization) and to expand-strided-metadata (mirroring the collapse_shape pattern) - Add support for memref.dealloc to narrow type emulation (it is a trivial rewrite) and for memref.copy (which is unsupported when it is used for a layout change but a trivial rewrite otherwise)
1 parent 10d7805 commit f65fc29

File tree

4 files changed

+278
-3
lines changed

4 files changed

+278
-3
lines changed

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,46 @@ struct ConvertMemRefAssumeAlignment final
235235
}
236236
};
237237

238+
//===----------------------------------------------------------------------===//
239+
// ConvertMemRefCopy
240+
//===----------------------------------------------------------------------===//
241+
242+
struct ConvertMemRefCopy final : OpConversionPattern<memref::CopyOp> {
243+
using OpConversionPattern::OpConversionPattern;
244+
245+
LogicalResult
246+
matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
247+
ConversionPatternRewriter &rewriter) const override {
248+
auto maybeRankedSource = dyn_cast<MemRefType>(op.getSource().getType());
249+
auto maybeRankedDest = dyn_cast<MemRefType>(op.getTarget().getType());
250+
if (maybeRankedSource && maybeRankedDest &&
251+
maybeRankedSource.getLayout() != maybeRankedDest.getLayout())
252+
return rewriter.notifyMatchFailure(
253+
op, llvm::formatv("memref.copy emulation with distinct layouts ({0} "
254+
"and {1}) is currently unimplemented",
255+
maybeRankedSource.getLayout(),
256+
maybeRankedDest.getLayout()));
257+
rewriter.replaceOpWithNewOp<memref::CopyOp>(op, adaptor.getSource(),
258+
adaptor.getTarget());
259+
return success();
260+
}
261+
};
262+
263+
//===----------------------------------------------------------------------===//
264+
// ConvertMemRefDealloc
265+
//===----------------------------------------------------------------------===//
266+
267+
struct ConvertMemRefDealloc final : OpConversionPattern<memref::DeallocOp> {
268+
using OpConversionPattern::OpConversionPattern;
269+
270+
LogicalResult
271+
matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
272+
ConversionPatternRewriter &rewriter) const override {
273+
rewriter.replaceOpWithNewOp<memref::DeallocOp>(op, adaptor.getMemref());
274+
return success();
275+
}
276+
};
277+
238278
//===----------------------------------------------------------------------===//
239279
// ConvertMemRefLoad
240280
//===----------------------------------------------------------------------===//
@@ -301,6 +341,30 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
301341
}
302342
};
303343

344+
//===----------------------------------------------------------------------===//
345+
// ConvertMemRefMemorySpaceCast
346+
//===----------------------------------------------------------------------===//
347+
348+
struct ConvertMemRefMemorySpaceCast final
349+
: OpConversionPattern<memref::MemorySpaceCastOp> {
350+
using OpConversionPattern::OpConversionPattern;
351+
352+
LogicalResult
353+
matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
354+
ConversionPatternRewriter &rewriter) const override {
355+
Type newTy = getTypeConverter()->convertType(op.getDest().getType());
356+
if (!newTy) {
357+
return rewriter.notifyMatchFailure(
358+
op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
359+
op.getDest().getType()));
360+
}
361+
362+
rewriter.replaceOpWithNewOp<memref::MemorySpaceCastOp>(op, newTy,
363+
adaptor.getSource());
364+
return success();
365+
}
366+
};
367+
304368
//===----------------------------------------------------------------------===//
305369
// ConvertMemRefReinterpretCast
306370
//===----------------------------------------------------------------------===//
@@ -492,6 +556,28 @@ struct ConvertMemRefCollapseShape final
492556
}
493557
};
494558

559+
/// Emulating a `memref.expand_shape` becomes a no-op after emulation given
560+
/// that we flatten memrefs to a single dimension as part of the emulation and
561+
/// the expansion would just have been undone.
562+
struct ConvertMemRefExpandShape final
563+
: OpConversionPattern<memref::ExpandShapeOp> {
564+
using OpConversionPattern::OpConversionPattern;
565+
566+
LogicalResult
567+
matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor,
568+
ConversionPatternRewriter &rewriter) const override {
569+
Value srcVal = adaptor.getSrc();
570+
auto newTy = dyn_cast<MemRefType>(srcVal.getType());
571+
if (!newTy)
572+
return failure();
573+
574+
if (newTy.getRank() != 1)
575+
return failure();
576+
577+
rewriter.replaceOp(expandShapeOp, srcVal);
578+
return success();
579+
}
580+
};
495581
} // end anonymous namespace
496582

497583
//===----------------------------------------------------------------------===//
@@ -504,9 +590,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
504590

505591
// Populate `memref.*` conversion patterns.
506592
patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
507-
ConvertMemRefAllocation<memref::AllocaOp>,
508-
ConvertMemRefCollapseShape, ConvertMemRefLoad,
509-
ConvertMemrefStore, ConvertMemRefAssumeAlignment,
593+
ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
594+
ConvertMemRefDealloc, ConvertMemRefCollapseShape,
595+
ConvertMemRefExpandShape, ConvertMemRefLoad, ConvertMemrefStore,
596+
ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
510597
ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
511598
typeConverter, patterns.getContext());
512599
memref::populateResolveExtractStridedMetadataPatterns(patterns);

mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,41 @@ struct ExtractStridedMetadataOpCollapseShapeFolder
726726
}
727727
};
728728

729+
/// Pattern to replace `extract_strided_metadata(expand_shape)`
730+
/// with the results of computing the sizes and strides on the expanded shape
731+
/// and dividing up dimensions into static and dynamic parts as needed.
732+
struct ExtractStridedMetadataOpExpandShapeFolder
733+
: OpRewritePattern<memref::ExtractStridedMetadataOp> {
734+
using OpRewritePattern::OpRewritePattern;
735+
736+
LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
737+
PatternRewriter &rewriter) const override {
738+
auto expandShapeOp = op.getSource().getDefiningOp<memref::ExpandShapeOp>();
739+
if (!expandShapeOp)
740+
return failure();
741+
742+
FailureOr<StridedMetadata> stridedMetadata =
743+
resolveReshapeStridedMetadata<memref::ExpandShapeOp>(
744+
rewriter, expandShapeOp, getExpandedSizes, getExpandedStrides);
745+
if (failed(stridedMetadata)) {
746+
return rewriter.notifyMatchFailure(
747+
op, "failed to resolve metadata in terms of source expand_shape op");
748+
}
749+
750+
Location loc = expandShapeOp.getLoc();
751+
SmallVector<Value> results;
752+
results.push_back(stridedMetadata->basePtr);
753+
results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
754+
stridedMetadata->offset));
755+
results.append(
756+
getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
757+
results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
758+
stridedMetadata->strides));
759+
rewriter.replaceOp(op, results);
760+
return success();
761+
}
762+
};
763+
729764
/// Replace `base, offset, sizes, strides =
730765
/// extract_strided_metadata(allocLikeOp)`
731766
///
@@ -1060,6 +1095,49 @@ class ExtractStridedMetadataOpCastFolder
10601095
}
10611096
};
10621097

1098+
/// Replace `base, offset, sizes, strides = extract_strided_metadata(
1099+
/// memory_space_cast(src) to dstTy)`
1100+
/// with
1101+
/// ```
1102+
/// oldBase, offset, sizes, strides = extract_strided_metadata(src)
1103+
/// destBaseTy = type(oldBase) with memory space from destTy
1104+
/// base = memory_space_cast(oldBase) to destBaseTy
1105+
/// ```
1106+
///
1107+
/// In other words, propagate metadata extraction accross memory space casts.
1108+
class ExtractStridedMetadataOpMemorySpaceCastFolder
1109+
: public OpRewritePattern<memref::ExtractStridedMetadataOp> {
1110+
using OpRewritePattern::OpRewritePattern;
1111+
1112+
LogicalResult
1113+
matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1114+
PatternRewriter &rewriter) const override {
1115+
Location loc = extractStridedMetadataOp.getLoc();
1116+
Value source = extractStridedMetadataOp.getSource();
1117+
auto memSpaceCastOp = source.getDefiningOp<memref::MemorySpaceCastOp>();
1118+
if (!memSpaceCastOp)
1119+
return failure();
1120+
auto newExtractStridedMetadata =
1121+
rewriter.create<memref::ExtractStridedMetadataOp>(
1122+
loc, memSpaceCastOp.getSource());
1123+
SmallVector<Value> results(newExtractStridedMetadata.getResults());
1124+
if (!extractStridedMetadataOp.getBaseBuffer().use_empty()) {
1125+
auto baseBuffer = results[0];
1126+
auto baseBufferType = cast<MemRefType>(baseBuffer.getType());
1127+
MemRefType::Builder newTypeBuilder(baseBufferType);
1128+
newTypeBuilder.setMemorySpace(
1129+
memSpaceCastOp.getResult().getType().getMemorySpace());
1130+
results[0] = rewriter.create<memref::MemorySpaceCastOp>(
1131+
loc, Type{newTypeBuilder}, baseBuffer);
1132+
} else {
1133+
// Don't create spurious casts for values that are going away.
1134+
results[0] = nullptr;
1135+
}
1136+
rewriter.replaceOp(extractStridedMetadataOp, results);
1137+
return success();
1138+
}
1139+
};
1140+
10631141
/// Replace `base, offset =
10641142
/// extract_strided_metadata(extract_strided_metadata(src)#0)`
10651143
/// With
@@ -1099,11 +1177,13 @@ void memref::populateExpandStridedMetadataPatterns(
10991177
ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
11001178
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
11011179
ExtractStridedMetadataOpCollapseShapeFolder,
1180+
ExtractStridedMetadataOpExpandShapeFolder,
11021181
ExtractStridedMetadataOpGetGlobalFolder,
11031182
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
11041183
ExtractStridedMetadataOpReinterpretCastFolder,
11051184
ExtractStridedMetadataOpSubviewFolder,
11061185
ExtractStridedMetadataOpCastFolder,
1186+
ExtractStridedMetadataOpMemorySpaceCastFolder,
11071187
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
11081188
patterns.getContext());
11091189
}
@@ -1113,11 +1193,13 @@ void memref::populateResolveExtractStridedMetadataPatterns(
11131193
patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
11141194
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
11151195
ExtractStridedMetadataOpCollapseShapeFolder,
1196+
ExtractStridedMetadataOpExpandShapeFolder,
11161197
ExtractStridedMetadataOpGetGlobalFolder,
11171198
ExtractStridedMetadataOpSubviewFolder,
11181199
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
11191200
ExtractStridedMetadataOpReinterpretCastFolder,
11201201
ExtractStridedMetadataOpCastFolder,
1202+
ExtractStridedMetadataOpMemorySpaceCastFolder,
11211203
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
11221204
patterns.getContext());
11231205
}

mlir/test/Dialect/MemRef/emulate-narrow-type.mlir

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ func.func @memref_i8() -> i8 {
66
%c3 = arith.constant 3 : index
77
%m = memref.alloc() : memref<4xi8, 1>
88
%v = memref.load %m[%c3] : memref<4xi8, 1>
9+
memref.dealloc %m : memref<4xi8, 1>
910
return %v : i8
1011
}
1112
// CHECK-LABEL: func @memref_i8()
1213
// CHECK: %[[M:.+]] = memref.alloc() : memref<4xi8, 1>
1314
// CHECK-NEXT: %[[V:.+]] = memref.load %[[M]][%{{.+}}] : memref<4xi8, 1>
15+
// CHECK-NEXT: memref.dealloc %[[M]]
1416
// CHECK-NEXT: return %[[V]]
1517

1618
// CHECK32-LABEL: func @memref_i8()
@@ -21,6 +23,7 @@ func.func @memref_i8() -> i8 {
2123
// CHECK32: %[[CAST:.+]] = arith.index_cast %[[C24]] : index to i32
2224
// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[V]], %[[CAST]]
2325
// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i8
26+
// CHECK32-NEXT: memref.dealloc %[[M]]
2427
// CHECK32-NEXT: return %[[TRUNC]]
2528

2629
// -----
@@ -485,3 +488,68 @@ func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 {
485488
// CHECK32-NOT: memref.collapse_shape
486489
// CHECK32: memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32>
487490

491+
// -----
492+
493+
func.func @memref_expand_shape_i4(%idx0 : index, %idx1 : index, %idx2 : index) -> i4 {
494+
%arr = memref.alloc() : memref<256x128xi4>
495+
%expand = memref.expand_shape %arr[[0, 1], [2]] output_shape [32, 8, 128] : memref<256x128xi4> into memref<32x8x128xi4>
496+
%1 = memref.load %expand[%idx0, %idx1, %idx2] : memref<32x8x128xi4>
497+
return %1 : i4
498+
}
499+
500+
// CHECK-LABEL: func.func @memref_expand_shape_i4(
501+
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<16384xi8>
502+
// CHECK-NOT: memref.expand_shape
503+
// CHECK: memref.load %[[ALLOC]][%{{.*}}] : memref<16384xi8>
504+
505+
// CHECK32-LABEL: func.func @memref_expand_shape_i4(
506+
// CHECK32: %[[ALLOC:.*]] = memref.alloc() : memref<4096xi32>
507+
// CHECK32-NOT: memref.expand_shape
508+
// CHECK32: memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32>
509+
510+
// -----
511+
512+
func.func @memref_memory_space_cast_i4(%arg0: memref<32x128xi4, 1>) -> memref<32x128xi4> {
513+
%cast = memref.memory_space_cast %arg0 : memref<32x128xi4, 1> to memref<32x128xi4>
514+
return %cast : memref<32x128xi4>
515+
}
516+
517+
// CHECK-LABEL: func.func @memref_memory_space_cast_i4(
518+
// CHECK-SAME: %[[ARG0:.*]]: memref<2048xi8, 1>
519+
// CHECK: %[[CAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<2048xi8, 1> to memref<2048xi8>
520+
// CHECK: return %[[CAST]]
521+
522+
// CHECK32-LABEL: func.func @memref_memory_space_cast_i4(
523+
// CHECK32-SAME: %[[ARG0:.*]]: memref<512xi32, 1>
524+
// CHECK32: %[[CAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<512xi32, 1> to memref<512xi32>
525+
// CHECK32: return %[[CAST]]
526+
527+
// -----
528+
529+
func.func @memref_copy_i4(%arg0: memref<32x128xi4, 1>, %arg1: memref<32x128xi4>) {
530+
memref.copy %arg0, %arg1 : memref<32x128xi4, 1> to memref<32x128xi4>
531+
return
532+
}
533+
534+
// CHECK-LABEL: func.func @memref_copy_i4(
535+
// CHECK-SAME: %[[ARG0:.*]]: memref<2048xi8, 1>, %[[ARG1:.*]]: memref<2048xi8>
536+
// CHECK: memref.copy %[[ARG0]], %[[ARG1]]
537+
// CHECK: return
538+
539+
// CHECK32-LABEL: func.func @memref_copy_i4(
540+
// CHECK32-SAME: %[[ARG0:.*]]: memref<512xi32, 1>, %[[ARG1:.*]]: memref<512xi32>
541+
// CHECK32: memref.copy %[[ARG0]], %[[ARG1]]
542+
// CHECK32: return
543+
544+
// -----
545+
546+
!colMajor = memref<8x8xi4, strided<[1, 8]>>
547+
func.func @copy_distinct_layouts(%idx : index) -> i4 {
548+
%c0 = arith.constant 0 : index
549+
%arr = memref.alloc() : memref<8x8xi4>
550+
%arr2 = memref.alloc() : !colMajor
551+
// expected-error @+1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}}
552+
memref.copy %arr, %arr2 : memref<8x8xi4> to !colMajor
553+
%ld = memref.load %arr2[%c0, %c0] : !colMajor
554+
return %ld : i4
555+
}

mlir/test/Dialect/MemRef/expand-strided-metadata.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,3 +1553,41 @@ func.func @extract_strided_metadata_of_collapse_shape(%base: memref<5x4xf32>)
15531553
// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index
15541554
// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata
15551555
// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref<f32>, index, index, index
1556+
1557+
// -----
1558+
1559+
func.func @extract_strided_metadata_of_memory_space_cast(%base: memref<20xf32>)
1560+
-> (memref<f32, 1>, index, index, index) {
1561+
1562+
%memory_space_cast = memref.memory_space_cast %base : memref<20xf32> to memref<20xf32, 1>
1563+
1564+
%base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %memory_space_cast :
1565+
memref<20xf32, 1> -> memref<f32, 1>, index, index, index
1566+
1567+
return %base_buffer, %offset, %size, %stride :
1568+
memref<f32, 1>, index, index, index
1569+
}
1570+
1571+
// CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast
1572+
// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index
1573+
// CHECK-DAG: %[[SIZE:.*]] = arith.constant 20 : index
1574+
// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index
1575+
// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata
1576+
// CHECK: %[[CAST:.*]] = memref.memory_space_cast %[[BASE]]
1577+
// CHECK: return %[[CAST]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref<f32, 1>, index, index, index
1578+
1579+
// -----
1580+
1581+
func.func @extract_strided_metadata_of_memory_space_cast_no_base(%base: memref<20xf32>)
1582+
-> (index, index, index) {
1583+
1584+
%memory_space_cast = memref.memory_space_cast %base : memref<20xf32> to memref<20xf32, 1>
1585+
1586+
%base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %memory_space_cast :
1587+
memref<20xf32, 1> -> memref<f32, 1>, index, index, index
1588+
1589+
return %offset, %size, %stride : index, index, index
1590+
}
1591+
1592+
// CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast_no_base
1593+
// CHECK-NOT: memref.memory_space_cast

0 commit comments

Comments
 (0)