Skip to content

Commit f799862

Browse files
authored
[FXML-4440] Introduce a pass that annotates the type of the argument as an attribute (#167)
Create a pass that annotates all func op inputs with their type as attributes.
1 parent 29e1ec6 commit f799862

File tree

4 files changed

+86
-0
lines changed

4 files changed

+86
-0
lines changed

mlir/include/mlir/Dialect/Func/Transforms/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,15 @@ def DuplicateFunctionEliminationPass : Pass<"duplicate-function-elimination",
5151
let constructor = "mlir::func::createDuplicateFunctionEliminationPass()";
5252
}
5353

54+
def AnnotateFunctionType: Pass<"annotate-function-type", "func::FuncOp"> {
55+
let summary = "Annotate the function type as type attributes";
56+
let description = [{
57+
Annotates all the inputs and outputs of func.func operators with a type
58+
attribute. The type attribute mirrors the actual type of the inputs/outputs.
59+
60+
This pass can be used to trace back the original types of func.func
61+
operators in case they need to be modified.
62+
}];
63+
}
64+
5465
#endif // MLIR_DIALECT_FUNC_TRANSFORMS_PASSES_TD
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//===- AnnotateInputTypes.cpp - Type attribute annotation for func ops ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass that creates type attributes for func parameters,
10+
// that mirror the actual type. This is useful when the func op input types
11+
// might change.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "mlir/Dialect/Func/Transforms/Passes.h"
16+
17+
#include "mlir/Dialect/Func/IR/FuncOps.h"
18+
#include "mlir/IR/BuiltinAttributes.h"
19+
#include "mlir/Pass/Pass.h"
20+
21+
using namespace mlir;
22+
23+
namespace mlir::func {
24+
#define GEN_PASS_DEF_ANNOTATEFUNCTIONTYPE
25+
#include "mlir/Dialect/Func/Transforms/Passes.h.inc"
26+
} // namespace mlir::func
27+
28+
namespace {
29+
struct AnnotateFunctionTypePass
30+
: public mlir::func::impl::AnnotateFunctionTypeBase<
31+
AnnotateFunctionTypePass> {
32+
33+
void runOnOperation() override {
34+
func::FuncOp func = getOperation();
35+
auto inputs = func.getArgumentTypes();
36+
auto results = func.getResultTypes();
37+
38+
for (const auto [argNum, type] : llvm::enumerate(inputs)) {
39+
func.setArgAttr(argNum, "func.orig_type", TypeAttr::get(type));
40+
}
41+
42+
for (const auto [resultNum, type] : llvm::enumerate(results)) {
43+
func.setResultAttr(resultNum, "func.orig_type", TypeAttr::get(type));
44+
}
45+
}
46+
};
47+
} // namespace

mlir/lib/Dialect/Func/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_dialect_library(MLIRFuncTransforms
2+
AnnotateFunctionType.cpp
23
DecomposeCallGraphTypes.cpp
34
DuplicateFunctionElimination.cpp
45
FuncBufferize.cpp
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: mlir-opt %s --split-input-file --annotate-function-type | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @one_arg(%arg0: tensor<f32> {func.orig_type = tensor<f32>}) -> (tensor<f32> {func.orig_type = tensor<f32>}) {
4+
func.func @one_arg(%arg0: tensor<f32>) -> tensor<f32> {
5+
return %arg0 : tensor<f32>
6+
}
7+
8+
// -----
9+
10+
// CHECK-LABEL: func.func @one_arg_int(%arg0: tensor<ui8> {func.orig_type = tensor<ui8>}) -> (tensor<ui8> {func.orig_type = tensor<ui8>}) {
11+
func.func @one_arg_int(%arg0: tensor<ui8>) -> tensor<ui8> {
12+
return %arg0 : tensor<ui8>
13+
}
14+
15+
// -----
16+
17+
// CHECK-LABEL: func.func @n_rank_tensor(%arg0: tensor<3x4x5xf32> {func.orig_type = tensor<3x4x5xf32>}) -> (tensor<3x4x5xf32> {func.orig_type = tensor<3x4x5xf32>}) {
18+
func.func @n_rank_tensor(%arg0: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> {
19+
return %arg0 : tensor<3x4x5xf32>
20+
}
21+
22+
// -----
23+
24+
// CHECK-LABEL: func.func @two_args(%arg0: f32 {func.orig_type = f32}, %arg1: f32 {func.orig_type = f32}) -> (f32 {func.orig_type = f32}) {
25+
func.func @two_args(%arg0: f32, %arg1: f32) -> f32 {
26+
return %arg0 : f32
27+
}

0 commit comments

Comments
 (0)