Skip to content

Commit dc297cb

Browse files
authored
[mlir][memref][spirv] Add conversion for memref.extract_aligned_pointer_as_index to SPIR-V (#86750)
Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU. Index conversion is done based on 'use-64bit-index' option.
1 parent aab79c4 commit dc297cb

File tree

2 files changed

+71
-6
lines changed

2 files changed

+71
-6
lines changed

mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,17 @@ class CastPattern final : public OpConversionPattern<memref::CastOp> {
307307
}
308308
};
309309

310+
/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
311+
class ExtractAlignedPointerAsIndexOpPattern final
312+
: public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
313+
public:
314+
using OpConversionPattern::OpConversionPattern;
315+
316+
LogicalResult
317+
matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
318+
OpAdaptor adaptor,
319+
ConversionPatternRewriter &rewriter) const override;
320+
};
310321
} // namespace
311322

312323
//===----------------------------------------------------------------------===//
@@ -921,17 +932,32 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
921932
return success();
922933
}
923934

935+
//===----------------------------------------------------------------------===//
936+
// ExtractAlignedPointerAsIndexOp
937+
//===----------------------------------------------------------------------===//
938+
939+
LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
940+
memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
941+
ConversionPatternRewriter &rewriter) const {
942+
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
943+
Type indexType = typeConverter.getIndexType();
944+
rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
945+
adaptor.getSource());
946+
return success();
947+
}
948+
924949
//===----------------------------------------------------------------------===//
925950
// Pattern population
926951
//===----------------------------------------------------------------------===//
927952

928953
namespace mlir {
929954
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
930955
RewritePatternSet &patterns) {
931-
patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
932-
DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
933-
LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
934-
ReinterpretCastPattern, CastPattern>(typeConverter,
935-
patterns.getContext());
956+
patterns
957+
.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
958+
DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
959+
MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
960+
CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
961+
typeConverter, patterns.getContext());
936962
}
937963
} // namespace mlir

mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s
1+
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8}, cse)" %s | FileCheck %s
2+
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s \
3+
// RUN: | FileCheck --check-prefix=CHECK64 %s
24

35
// Check that with proper compute and storage extensions, we don't need to
46
// perform special tricks.
@@ -420,6 +422,43 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
420422

421423
}
422424

425+
// -----
426+
427+
module attributes {
428+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Kernel, Int64, Addresses, PhysicalStorageBufferAddresses], []>, #spirv.resource_limits<>>
429+
} {
430+
// CHECK-LABEL: func @extract_aligned_pointer_as_index_kernel
431+
func.func @extract_aligned_pointer_as_index_kernel(%m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> index {
432+
%0 = memref.extract_aligned_pointer_as_index %m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>> -> index
433+
// CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i32
434+
// CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
435+
// CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i64
436+
// CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
437+
438+
// CHECK: return %[[R:.*]] : index
439+
return %0: index
440+
}
441+
}
442+
443+
// -----
444+
445+
module attributes {
446+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Shader, Int64, Addresses, PhysicalStorageBufferAddresses], []>, #spirv.resource_limits<>>
447+
} {
448+
// CHECK-LABEL: func @extract_aligned_pointer_as_index_shader
449+
func.func @extract_aligned_pointer_as_index_shader(%m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> index {
450+
%0 = memref.extract_aligned_pointer_as_index %m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>> -> index
451+
// CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i32
452+
// CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
453+
// CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i64
454+
// CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
455+
456+
// CHECK: return %[[R:.*]] : index
457+
return %0: index
458+
}
459+
}
460+
461+
423462
// -----
424463

425464
// Check nontemporal attribute

0 commit comments

Comments
 (0)