Skip to content

Commit e6821dd

Browse files
authored
Revert "[MLIR][Python] add ctype python binding support for bf16" (#93771)
Reverts #92489 This broke the bots.
1 parent 49ef21d commit e6821dd

File tree

3 files changed

+1
-61
lines changed

3 files changed

+1
-61
lines changed

mlir/python/mlir/runtime/np_to_memref.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,6 @@
77
import numpy as np
88
import ctypes
99

10-
try:
11-
import ml_dtypes
12-
except ModuleNotFoundError:
13-
# The third-party ml_dtypes provides some optional low precision data-types for NumPy.
14-
ml_dtypes = None
15-
1610

1711
class C128(ctypes.Structure):
1812
"""A ctype representation for MLIR's Double Complex."""
@@ -32,12 +26,6 @@ class F16(ctypes.Structure):
3226
_fields_ = [("f16", ctypes.c_int16)]
3327

3428

35-
class BF16(ctypes.Structure):
36-
"""A ctype representation for MLIR's BFloat16."""
37-
38-
_fields_ = [("bf16", ctypes.c_int16)]
39-
40-
4129
# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
4230
def as_ctype(dtp):
4331
"""Converts dtype to ctype."""
@@ -47,8 +35,6 @@ def as_ctype(dtp):
4735
return C64
4836
if dtp == np.dtype(np.float16):
4937
return F16
50-
if ml_dtypes is not None and dtp == ml_dtypes.bfloat16:
51-
return BF16
5238
return np.ctypeslib.as_ctypes_type(dtp)
5339

5440

@@ -60,11 +46,6 @@ def to_numpy(array):
6046
return array.view("complex64")
6147
if array.dtype == F16:
6248
return array.view("float16")
63-
assert not (
64-
array.dtype == BF16 and ml_dtypes is None
65-
), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
66-
if array.dtype == BF16:
67-
return array.view("bfloat16")
6849
return array
6950

7051

mlir/python/requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
numpy>=1.19.5, <=1.26
22
pybind11>=2.9.0, <=2.10.3
3-
PyYAML>=5.3.1, <=6.0.1
4-
ml_dtypes # provides several NumPy dtype extensions, including the bf16
3+
PyYAML>=5.3.1, <=6.0.1

mlir/test/python/execution_engine.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from mlir.passmanager import *
66
from mlir.execution_engine import *
77
from mlir.runtime import *
8-
from ml_dtypes import bfloat16
98

109

1110
# Log everything to stderr and flush so that we have a unified stream to match
@@ -522,45 +521,6 @@ def testComplexUnrankedMemrefAdd():
522521
run(testComplexUnrankedMemrefAdd)
523522

524523

525-
# Test bf16 memrefs
526-
# CHECK-LABEL: TEST: testBF16Memref
527-
def testBF16Memref():
528-
with Context():
529-
module = Module.parse(
530-
"""
531-
module {
532-
func.func @main(%arg0: memref<1xbf16>,
533-
%arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } {
534-
%0 = arith.constant 0 : index
535-
%1 = memref.load %arg0[%0] : memref<1xbf16>
536-
memref.store %1, %arg1[%0] : memref<1xbf16>
537-
return
538-
}
539-
} """
540-
)
541-
542-
arg1 = np.array([0.5]).astype(bfloat16)
543-
arg2 = np.array([0.0]).astype(bfloat16)
544-
545-
arg1_memref_ptr = ctypes.pointer(
546-
ctypes.pointer(get_ranked_memref_descriptor(arg1))
547-
)
548-
arg2_memref_ptr = ctypes.pointer(
549-
ctypes.pointer(get_ranked_memref_descriptor(arg2))
550-
)
551-
552-
execution_engine = ExecutionEngine(lowerToLLVM(module))
553-
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
554-
555-
# test to-numpy utility
556-
# CHECK: [0.5]
557-
npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
558-
log(npout)
559-
560-
561-
run(testBF16Memref)
562-
563-
564524
# Test addition of two 2d_memref
565525
# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
566526
def testDynamicMemrefAdd2D():

0 commit comments

Comments
 (0)