Skip to content

Reapply "[MLIR][Python] add ctype python binding support for bf16" #101271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 31, 2024

Conversation

xurui1995
Copy link
Contributor

Reapply the PR which was reverted due to built-bots, and now the bots get updated.
https://discourse.llvm.org/t/need-a-help-with-the-built-bots/79437
original PR: #92489, reverted in #93771

@llvmbot
Copy link
Member

llvmbot commented Jul 31, 2024

@llvm/pr-subscribers-mlir

Author: Bimo (xurui1995)

Changes

Reapply the PR which was reverted due to built-bots, and now the bots get updated.
https://discourse.llvm.org/t/need-a-help-with-the-built-bots/79437
original PR: #92489, reverted in #93771


Full diff: https://github.com/llvm/llvm-project/pull/101271.diff

3 Files Affected:

  • (modified) mlir/python/mlir/runtime/np_to_memref.py (+19)
  • (modified) mlir/python/requirements.txt (+2-1)
  • (modified) mlir/test/python/execution_engine.py (+40)
diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index f6b706f9bc8ae..882b2751921bf 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -7,6 +7,12 @@
 import numpy as np
 import ctypes
 
+try:
+    import ml_dtypes
+except ModuleNotFoundError:
+    # The third-party ml_dtypes provides some optional low precision data-types for NumPy.
+    ml_dtypes = None
+
 
 class C128(ctypes.Structure):
     """A ctype representation for MLIR's Double Complex."""
@@ -26,6 +32,12 @@ class F16(ctypes.Structure):
     _fields_ = [("f16", ctypes.c_int16)]
 
 
+class BF16(ctypes.Structure):
+    """A ctype representation for MLIR's BFloat16."""
+
+    _fields_ = [("bf16", ctypes.c_int16)]
+
+
 # https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
 def as_ctype(dtp):
     """Converts dtype to ctype."""
@@ -35,6 +47,8 @@ def as_ctype(dtp):
         return C64
     if dtp == np.dtype(np.float16):
         return F16
+    if ml_dtypes is not None and dtp == ml_dtypes.bfloat16:
+        return BF16
     return np.ctypeslib.as_ctypes_type(dtp)
 
 
@@ -46,6 +60,11 @@ def to_numpy(array):
         return array.view("complex64")
     if array.dtype == F16:
         return array.view("float16")
+    assert not (
+        array.dtype == BF16 and ml_dtypes is None
+    ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
+    if array.dtype == BF16:
+        return array.view("bfloat16")
     return array
 
 
diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index acd6dbb25edaf..6ec63e43adf89 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -1,3 +1,4 @@
 numpy>=1.19.5, <=1.26
 pybind11>=2.9.0, <=2.10.3
-PyYAML>=5.3.1, <=6.0.1
\ No newline at end of file
+PyYAML>=5.3.1, <=6.0.1
+ml_dtypes   # provides several NumPy dtype extensions, including the bf16
\ No newline at end of file
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index e8b47007a8907..8125bf3fb8fc9 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -5,6 +5,7 @@
 from mlir.passmanager import *
 from mlir.execution_engine import *
 from mlir.runtime import *
+from ml_dtypes import bfloat16
 
 
 # Log everything to stderr and flush so that we have a unified stream to match
@@ -521,6 +522,45 @@ def testComplexUnrankedMemrefAdd():
 run(testComplexUnrankedMemrefAdd)
 
 
+# Test bf16 memrefs
+# CHECK-LABEL: TEST: testBF16Memref
+def testBF16Memref():
+    with Context():
+        module = Module.parse(
+            """
+    module  {
+      func.func @main(%arg0: memref<1xbf16>,
+                      %arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } {
+        %0 = arith.constant 0 : index
+        %1 = memref.load %arg0[%0] : memref<1xbf16>
+        memref.store %1, %arg1[%0] : memref<1xbf16>
+        return
+      }
+    } """
+        )
+
+        arg1 = np.array([0.5]).astype(bfloat16)
+        arg2 = np.array([0.0]).astype(bfloat16)
+
+        arg1_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg1))
+        )
+        arg2_memref_ptr = ctypes.pointer(
+            ctypes.pointer(get_ranked_memref_descriptor(arg2))
+        )
+
+        execution_engine = ExecutionEngine(lowerToLLVM(module))
+        execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
+
+        # test to-numpy utility
+        # CHECK: [0.5]
+        npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
+        log(npout)
+
+
+run(testBF16Memref)
+
+
 #  Test addition of two 2d_memref
 # CHECK-LABEL: TEST: testDynamicMemrefAdd2D
 def testDynamicMemrefAdd2D():

@xurui1995 xurui1995 requested a review from joker-eph July 31, 2024 01:38
@xurui1995
Copy link
Contributor Author

@joker-eph, hi, could you help approve this?
again, thank your effort on built-bots !

@joker-eph
Copy link
Collaborator

I'll merge it now and keep an eye on the bot in case.

@joker-eph joker-eph merged commit 5ef087b into llvm:main Jul 31, 2024
10 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented Jul 31, 2024

LLVM Buildbot has detected a new failure on builder mlir-rocm-mi200 running on mi200-buildbot while building mlir at step 6 "test-build-check-mlir-build-only-check-mlir".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/177/builds/2495

Here is the relevant piece of the build log for the reference:

Step 6 (test-build-check-mlir-build-only-check-mlir) failure: test (failure)
******************** TEST 'MLIR :: python/execution_engine.py' FAILED ********************
Exit Code: 1

Command Output (stdout):
--
# RUN: at line 1
/usr/bin/python3.8 /vol/worker/mi200-buildbot/mlir-rocm-mi200/llvm-project/mlir/test/python/execution_engine.py 2>&1 | /vol/worker/mi200-buildbot/mlir-rocm-mi200/build/bin/FileCheck /vol/worker/mi200-buildbot/mlir-rocm-mi200/llvm-project/mlir/test/python/execution_engine.py
# executed command: /usr/bin/python3.8 /vol/worker/mi200-buildbot/mlir-rocm-mi200/llvm-project/mlir/test/python/execution_engine.py
# note: command had no output on stdout or stderr
# error: command failed with exit status: 1
# executed command: /vol/worker/mi200-buildbot/mlir-rocm-mi200/build/bin/FileCheck /vol/worker/mi200-buildbot/mlir-rocm-mi200/llvm-project/mlir/test/python/execution_engine.py
# .---command stderr------------
# | /vol/worker/mi200-buildbot/mlir-rocm-mi200/llvm-project/mlir/test/python/execution_engine.py:26:16: error: CHECK-LABEL: expected string not found in input
# | # CHECK-LABEL: TEST: testCapsule
# |                ^
# | <stdin>:1:1: note: scanning from here
# | Traceback (most recent call last):
# | ^
# | <stdin>:2:62: note: possible intended match here
# |  File "/vol/worker/mi200-buildbot/mlir-rocm-mi200/llvm-project/mlir/test/python/execution_engine.py", line 8, in <module>
# |                                                              ^
# | 
# | Input file: <stdin>
# | Check file: /vol/worker/mi200-buildbot/mlir-rocm-mi200/llvm-project/mlir/test/python/execution_engine.py
# | 
# | -dump-input=help explains the following input dump.
# | 
# | Input was:
# | <<<<<<
# |             1: Traceback (most recent call last): 
# | label:26'0     X~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ error: no match found
# |             2:  File "/vol/worker/mi200-buildbot/mlir-rocm-mi200/llvm-project/mlir/test/python/execution_engine.py", line 8, in <module> 
# | label:26'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# | label:26'1                                                                  ?                                                             possible intended match
# |             3:  from ml_dtypes import bfloat16 
# | label:26'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# |             4: ModuleNotFoundError: No module named 'ml_dtypes' 
# | label:26'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# | >>>>>>
# `-----------------------------
# error: command failed with exit status: 1

--

********************


@xurui1995
Copy link
Contributor Author

I'll merge it now and keep an eye on the bot in case.

same problem in running mlir-rocm-mi200

@xurui1995
Copy link
Contributor Author

Previous

image

Now

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants