Skip to content

[mlir][py] Add NVGPU's TensorMapDescriptorType in py bindings #88855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions mlir/include/mlir-c/Dialect/NVGPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,24 @@
#define MLIR_C_DIALECT_NVGPU_H

#include "mlir-c/IR.h"
#include "mlir-c/Support.h"

#ifdef __cplusplus
extern "C" {
#endif

MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(NVGPU, nvgpu);

//===---------------------------------------------------------------------===//
// TensorMapDescriptorType
//===---------------------------------------------------------------------===//

MLIR_CAPI_EXPORTED bool mlirTypeIsANVGPUTensorMapDescriptorType(MlirType type);

MLIR_CAPI_EXPORTED MlirType mlirNVGPUTensorMapDescriptorTypeGet(
MlirContext ctx, MlirType tensorMemrefType, int swizzle, int l2promo,
int oobFill, int interleave);

#ifdef __cplusplus
}
#endif
Expand Down
41 changes: 41 additions & 0 deletions mlir/lib/Bindings/Python/DialectNVGPU.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//===--- DialectNvgpu.cpp - Pybind module for Nvgpu dialect API support ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir-c/Dialect/NVGPU.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include <pybind11/pybind11.h>

namespace py = pybind11;
using namespace llvm;
using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::adaptors;

static void populateDialectNvgpuSubmodule(const pybind11::module &m) {
auto nvgpuTensorMapDescriptorType = mlir_type_subclass(
m, "TensorMapDescriptorType", mlirTypeIsANVGPUTensorMapDescriptorType);

nvgpuTensorMapDescriptorType.def_classmethod(
"get",
[](py::object cls, MlirType tensorMemrefType, int swizzle, int l2promo,
int oobFill, int interleave, MlirContext ctx) {
return cls(mlirNVGPUTensorMapDescriptorTypeGet(
ctx, tensorMemrefType, swizzle, l2promo, oobFill, interleave));
},
"Gets an instance of TensorMapDescriptorType in the same context",
py::arg("cls"), py::arg("tensor_type"), py::arg("swizzle"),
py::arg("l2promo"), py::arg("oob_fill"), py::arg("interleave"),
py::arg("ctx") = py::none());
}

PYBIND11_MODULE(_mlirDialectsNvgpu, m) {
m.doc() = "MLIR NVGPU dialect.";

populateDialectNvgpuSubmodule(m);
}
18 changes: 18 additions & 0 deletions mlir/lib/CAPI/Dialect/NVGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,23 @@
#include "mlir-c/Dialect/NVGPU.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/IR/BuiltinTypes.h"

using namespace mlir;
using namespace mlir::nvgpu;

MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(NVGPU, nvgpu, mlir::nvgpu::NVGPUDialect)

bool mlirTypeIsANVGPUTensorMapDescriptorType(MlirType type) {
return isa<nvgpu::TensorMapDescriptorType>(unwrap(type));
}

MlirType mlirNVGPUTensorMapDescriptorTypeGet(MlirContext ctx,
MlirType tensorMemrefType,
int swizzle, int l2promo,
int oobFill, int interleave) {
return wrap(nvgpu::TensorMapDescriptorType::get(
unwrap(ctx), cast<MemRefType>(unwrap(tensorMemrefType)),
TensorMapSwizzleKind(swizzle), TensorMapL2PromoKind(l2promo),
TensorMapOOBKind(oobFill), TensorMapInterleaveKind(interleave)));
}
13 changes: 13 additions & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
MLIRCAPIQuant
)

declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind
MODULE_NAME _mlirDialectsNvgpu
ADD_TO_PARENT MLIRPythonSources.Dialects.nvgpu
ROOT_DIR "${PYTHON_SOURCE_DIR}"
SOURCES
DialectNVGPU.cpp
PRIVATE_LINK_LIBS
LLVMSupport
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPINVGPU
)

declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind
MODULE_NAME _mlirDialectsPDL
ADD_TO_PARENT MLIRPythonSources.Dialects.pdl
Expand Down
1 change: 1 addition & 0 deletions mlir/python/mlir/dialects/nvgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@

from ._nvgpu_ops_gen import *
from ._nvgpu_enum_gen import *
from .._mlir_libs._mlirDialectsNvgpu import *
17 changes: 17 additions & 0 deletions mlir/test/python/dialects/nvgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,23 @@ def constructAndPrintInModule(f):
return f


# CHECK-LABEL: testTypes
@constructAndPrintInModule
def testTypes():
tensorMemrefType = MemRefType.get(
(128, 64), F16Type.get(), memory_space=Attribute.parse("3")
)
# CHECK: !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = l2promo_256b, oob = nan, interleave = none>
tma_desc = nvgpu.TensorMapDescriptorType.get(
tensorMemrefType,
nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
nvgpu.TensorMapL2PromoKind.L2PROMO_256B,
nvgpu.TensorMapOOBKind.OOB_NAN,
nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
)
print(tma_desc)


# CHECK-LABEL: testSmoke
@constructAndPrintInModule
def testSmoke():
Expand Down