Skip to content

Commit 5ef087b

Browse files
authored
Reapply "[MLIR][Python] add ctype python binding support for bf16" (#101271)
Reapply the PR which was reverted due to built-bots, and now the bots get updated. https://discourse.llvm.org/t/need-a-help-with-the-built-bots/79437 original PR: #92489, reverted in #93771
1 parent 8300eaa commit 5ef087b

File tree

3 files changed

+61
-1
lines changed

3 files changed

+61
-1
lines changed

mlir/python/mlir/runtime/np_to_memref.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77
import numpy as np
88
import ctypes
99

10+
try:
11+
import ml_dtypes
12+
except ModuleNotFoundError:
13+
# The third-party ml_dtypes provides some optional low precision data-types for NumPy.
14+
ml_dtypes = None
15+
1016

1117
class C128(ctypes.Structure):
1218
"""A ctype representation for MLIR's Double Complex."""
@@ -26,6 +32,12 @@ class F16(ctypes.Structure):
2632
_fields_ = [("f16", ctypes.c_int16)]
2733

2834

35+
class BF16(ctypes.Structure):
36+
"""A ctype representation for MLIR's BFloat16."""
37+
38+
_fields_ = [("bf16", ctypes.c_int16)]
39+
40+
2941
# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
3042
def as_ctype(dtp):
3143
"""Converts dtype to ctype."""
@@ -35,6 +47,8 @@ def as_ctype(dtp):
3547
return C64
3648
if dtp == np.dtype(np.float16):
3749
return F16
50+
if ml_dtypes is not None and dtp == ml_dtypes.bfloat16:
51+
return BF16
3852
return np.ctypeslib.as_ctypes_type(dtp)
3953

4054

@@ -46,6 +60,11 @@ def to_numpy(array):
4660
return array.view("complex64")
4761
if array.dtype == F16:
4862
return array.view("float16")
63+
assert not (
64+
array.dtype == BF16 and ml_dtypes is None
65+
), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
66+
if array.dtype == BF16:
67+
return array.view("bfloat16")
4968
return array
5069

5170

mlir/python/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
numpy>=1.19.5, <=1.26
22
pybind11>=2.9.0, <=2.10.3
3-
PyYAML>=5.3.1, <=6.0.1
3+
PyYAML>=5.3.1, <=6.0.1
4+
ml_dtypes # provides several NumPy dtype extensions, including the bf16

mlir/test/python/execution_engine.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from mlir.passmanager import *
66
from mlir.execution_engine import *
77
from mlir.runtime import *
8+
from ml_dtypes import bfloat16
89

910

1011
# Log everything to stderr and flush so that we have a unified stream to match
@@ -521,6 +522,45 @@ def testComplexUnrankedMemrefAdd():
521522
run(testComplexUnrankedMemrefAdd)
522523

523524

525+
# Test bf16 memrefs
526+
# CHECK-LABEL: TEST: testBF16Memref
527+
def testBF16Memref():
528+
with Context():
529+
module = Module.parse(
530+
"""
531+
module {
532+
func.func @main(%arg0: memref<1xbf16>,
533+
%arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } {
534+
%0 = arith.constant 0 : index
535+
%1 = memref.load %arg0[%0] : memref<1xbf16>
536+
memref.store %1, %arg1[%0] : memref<1xbf16>
537+
return
538+
}
539+
} """
540+
)
541+
542+
arg1 = np.array([0.5]).astype(bfloat16)
543+
arg2 = np.array([0.0]).astype(bfloat16)
544+
545+
arg1_memref_ptr = ctypes.pointer(
546+
ctypes.pointer(get_ranked_memref_descriptor(arg1))
547+
)
548+
arg2_memref_ptr = ctypes.pointer(
549+
ctypes.pointer(get_ranked_memref_descriptor(arg2))
550+
)
551+
552+
execution_engine = ExecutionEngine(lowerToLLVM(module))
553+
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
554+
555+
# test to-numpy utility
556+
# CHECK: [0.5]
557+
npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
558+
log(npout)
559+
560+
561+
run(testBF16Memref)
562+
563+
524564
# Test addition of two 2d_memref
525565
# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
526566
def testDynamicMemrefAdd2D():

0 commit comments

Comments
 (0)