Skip to content

Commit 2a603de

Browse files
ubfxftynse
authored andcommitted
[mlir][Python] Fix conversion of non-zero offset memrefs to np.arrays
Memref descriptors contain an `offset` field that denotes the start of the content of the memref relative to the `alignedPtr`. This offset is not considered when converting a memref descriptor to a np.array in the Python runtime library, essentially treating all memrefs as if they had an offset of zero. This patch introduces the necessary pointer arithmetic to find the actual beginning of the memref contents to the memref->numpy conversion functions. There is an ongoing discussion about whether the `offset` field is needed at all in the memref descriptor. Until that is decided, the Python runtime and CRunnerUtils should still correctly implement the offset handling. Related: https://reviews.llvm.org/D157008 Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D158494
1 parent 06918a9 commit 2a603de

File tree

2 files changed

+92
-2
lines changed

2 files changed

+92
-2
lines changed

mlir/python/mlir/runtime/np_to_memref.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,21 @@ def get_unranked_memref_descriptor(nparray):
114114
d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
115115
return d
116116

117+
def move_aligned_ptr_by_offset(aligned_ptr, offset):
118+
"""Moves the supplied ctypes pointer ahead by `offset` elements."""
119+
aligned_addr = ctypes.addressof(aligned_ptr.contents)
120+
elem_size = ctypes.sizeof(aligned_ptr.contents)
121+
shift = offset * elem_size
122+
content_ptr = ctypes.cast(aligned_addr + shift, type(aligned_ptr))
123+
return content_ptr
117124

118125
def unranked_memref_to_numpy(unranked_memref, np_dtype):
119126
"""Converts unranked memrefs to numpy arrays."""
120127
ctp = as_ctype(np_dtype)
121128
descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp)
122129
val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
123-
np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape)
130+
content_ptr = move_aligned_ptr_by_offset(val[0].aligned, val[0].offset)
131+
np_arr = np.ctypeslib.as_array(content_ptr, shape=val[0].shape)
124132
strided_arr = np.lib.stride_tricks.as_strided(
125133
np_arr,
126134
np.ctypeslib.as_array(val[0].shape),
@@ -131,8 +139,9 @@ def unranked_memref_to_numpy(unranked_memref, np_dtype):
131139

132140
def ranked_memref_to_numpy(ranked_memref):
133141
"""Converts ranked memrefs to numpy arrays."""
142+
content_ptr = move_aligned_ptr_by_offset(ranked_memref[0].aligned, ranked_memref[0].offset)
134143
np_arr = np.ctypeslib.as_array(
135-
ranked_memref[0].aligned, shape=ranked_memref[0].shape
144+
content_ptr, shape=ranked_memref[0].shape
136145
)
137146
strided_arr = np.lib.stride_tricks.as_strided(
138147
np_arr,

mlir/test/python/execution_engine.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,87 @@ def callback(a):
245245
run(testRankedMemRefCallback)
246246

247247

248+
# Test callback with a ranked memref with non-zero offset.
249+
# CHECK-LABEL: TEST: testRankedMemRefWithOffsetCallback
250+
def testRankedMemRefWithOffsetCallback():
251+
# Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
252+
@ctypes.CFUNCTYPE(
253+
None,
254+
ctypes.POINTER(
255+
make_nd_memref_descriptor(1, np.ctypeslib.as_ctypes_type(np.float32))
256+
),
257+
)
258+
def callback(a):
259+
arr = ranked_memref_to_numpy(a)
260+
log("Inside Callback: ")
261+
log(arr)
262+
263+
with Context():
264+
# The module takes a subview of the argument memref and calls the callback with it
265+
module = Module.parse(
266+
r"""
267+
func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} {
268+
%base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index
269+
%reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>>
270+
%cast = memref.cast %reinterpret_cast : memref<2xf32, strided<[1], offset: 3>> to memref<?xf32, strided<[?], offset: ?>>
271+
call @some_callback_into_python(%cast) : (memref<?xf32, strided<[?], offset: ?>>) -> ()
272+
return
273+
}
274+
func.func private @some_callback_into_python(memref<?xf32, strided<[?], offset: ?>>) attributes {llvm.emit_c_interface}
275+
"""
276+
)
277+
execution_engine = ExecutionEngine(lowerToLLVM(module))
278+
execution_engine.register_runtime("some_callback_into_python", callback)
279+
inp_arr = np.array([0, 0, 0, 1, 2], np.float32)
280+
# CHECK: Inside Callback:
281+
# CHECK{LITERAL}: [1. 2.]
282+
execution_engine.invoke(
283+
"callback_memref",
284+
ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))),
285+
)
286+
287+
288+
run(testRankedMemRefWithOffsetCallback)
289+
290+
291+
# Test callback with an unranked memref with non-zero offset
292+
# CHECK-LABEL: TEST: testUnrankedMemRefWithOffsetCallback
293+
def testUnrankedMemRefWithOffsetCallback():
294+
# Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
295+
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
296+
def callback(a):
297+
arr = unranked_memref_to_numpy(a, np.float32)
298+
log("Inside callback: ")
299+
log(arr)
300+
301+
with Context():
302+
# The module takes a subview of the argument memref, casts it to an unranked memref and
303+
# calls the callback with it.
304+
module = Module.parse(
305+
r"""
306+
func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} {
307+
%base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index
308+
%reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>>
309+
%cast = memref.cast %reinterpret_cast : memref<2xf32, strided<[1], offset: 3>> to memref<*xf32>
310+
call @some_callback_into_python(%cast) : (memref<*xf32>) -> ()
311+
return
312+
}
313+
func.func private @some_callback_into_python(memref<*xf32>) attributes {llvm.emit_c_interface}
314+
"""
315+
)
316+
execution_engine = ExecutionEngine(lowerToLLVM(module))
317+
execution_engine.register_runtime("some_callback_into_python", callback)
318+
inp_arr = np.array([1, 2, 3, 4, 5], np.float32)
319+
# CHECK: Inside callback:
320+
# CHECK{LITERAL}: [4. 5.]
321+
execution_engine.invoke(
322+
"callback_memref",
323+
ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))),
324+
)
325+
326+
run(testUnrankedMemRefWithOffsetCallback)
327+
328+
248329
# Test addition of two memrefs.
249330
# CHECK-LABEL: TEST: testMemrefAdd
250331
def testMemrefAdd():

0 commit comments

Comments
 (0)