Skip to content

Commit ca371a8

Browse files
committed
[mlir][memref][spirv] Add conversion for memref.extract_aligned_pointer_as_index to SPIR-V
Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU. Index conversion is done based on 'use-64bit-index' option.
1 parent da6cc4a commit ca371a8

File tree

2 files changed

+70
-6
lines changed

2 files changed

+70
-6
lines changed

mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,17 @@ class CastPattern final : public OpConversionPattern<memref::CastOp> {
308308
}
309309
};
310310

311+
/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
312+
class ExtractAlignedPointerAsIndexOpPattern
313+
: public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
314+
public:
315+
using OpConversionPattern::OpConversionPattern;
316+
317+
LogicalResult
318+
matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
319+
OpAdaptor adaptor,
320+
ConversionPatternRewriter &rewriter) const override;
321+
};
311322
} // namespace
312323

313324
//===----------------------------------------------------------------------===//
@@ -922,17 +933,32 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
922933
return success();
923934
}
924935

936+
//===----------------------------------------------------------------------===//
937+
// ExtractAlignedPointerAsIndexOp
938+
//===----------------------------------------------------------------------===//
939+
940+
LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
941+
memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
942+
ConversionPatternRewriter &rewriter) const {
943+
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
944+
Type indexType = typeConverter.getIndexType();
945+
rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
946+
adaptor.getSource());
947+
return success();
948+
}
949+
925950
//===----------------------------------------------------------------------===//
926951
// Pattern population
927952
//===----------------------------------------------------------------------===//
928953

929954
namespace mlir {
930955
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
931956
RewritePatternSet &patterns) {
932-
patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
933-
DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
934-
LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
935-
ReinterpretCastPattern, CastPattern>(typeConverter,
936-
patterns.getContext());
957+
patterns
958+
.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
959+
DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
960+
MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
961+
CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
962+
typeConverter, patterns.getContext());
937963
}
938964
} // namespace mlir

mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s
1+
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8}, cse)" %s | FileCheck %s
2+
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s | FileCheck --check-prefix=CHECK64 %s
23

34
// Check that with proper compute and storage extensions, we don't need to
45
// perform special tricks.
@@ -414,6 +415,43 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
414415

415416
}
416417

418+
// -----
419+
420+
module attributes {
421+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel, Int64, Addresses], []>, #spirv.resource_limits<>>
422+
} {
423+
// CHECK-LABEL: func @extract_aligned_pointer_as_index_kernel
424+
func.func @extract_aligned_pointer_as_index_kernel(%m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> index {
425+
%0 = memref.extract_aligned_pointer_as_index %m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>> -> index
426+
// CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i32
427+
// CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i64
428+
// CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
429+
// CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
430+
431+
// CHECK: return %[[R:.*]] : index
432+
return %0: index
433+
}
434+
}
435+
436+
// -----
437+
438+
module attributes {
439+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader, Int64, Addresses], []>, #spirv.resource_limits<>>
440+
} {
441+
// CHECK-LABEL: func @extract_aligned_pointer_as_index_shader
442+
func.func @extract_aligned_pointer_as_index_shader(%m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> index {
443+
%0 = memref.extract_aligned_pointer_as_index %m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>> -> index
444+
// CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i32
445+
// CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i64
446+
// CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
447+
// CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
448+
449+
// CHECK: return %[[R:.*]] : index
450+
return %0: index
451+
}
452+
}
453+
454+
417455
// -----
418456

419457
// Check nontemporal attribute

0 commit comments

Comments
 (0)