Skip to content

Commit 8ba5c7a

Browse files
[mlir][Vector] Add initial support for inlining in the presence of vector ops (#70942)
1 parent c449a64 commit 8ba5c7a

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "mlir/IR/TypeUtilities.h"
3434
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
3535
#include "mlir/Support/LLVM.h"
36+
#include "mlir/Transforms/InliningUtils.h"
3637
#include "llvm/ADT/ArrayRef.h"
3738
#include "llvm/ADT/STLExtras.h"
3839
#include "llvm/ADT/SmallVector.h"
@@ -348,6 +349,19 @@ struct BitmaskEnumStorage : public AttributeStorage {
348349
// VectorDialect
349350
//===----------------------------------------------------------------------===//
350351

352+
namespace {
353+
/// This class defines the interface for handling inlining with vector dialect
354+
/// operations.
355+
struct VectorInlinerInterface : public DialectInlinerInterface {
356+
using DialectInlinerInterface::DialectInlinerInterface;
357+
358+
/// All vector dialect ops can be inlined.
359+
bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
360+
return true;
361+
}
362+
};
363+
} // namespace
364+
351365
void VectorDialect::initialize() {
352366
addAttributes<
353367
#define GET_ATTRDEF_LIST
@@ -358,6 +372,8 @@ void VectorDialect::initialize() {
358372
#define GET_OP_LIST
359373
#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
360374
>();
375+
376+
addInterfaces<VectorInlinerInterface>();
361377
}
362378

363379
/// Materialize a single constant operation from a given attribute value with
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: mlir-opt %s -inline | FileCheck %s
2+
3+
func.func @inner_func_inlinable(%v: f32) -> vector<4xf32> {
4+
%1 = vector.broadcast %v : f32 to vector<4xf32>
5+
return %1 : vector<4xf32>
6+
}
7+
8+
// CHECK-LABEL: func.func @test_inline(
9+
// CHECK-NOT: func.call
10+
// CHECK-NEXT: vector.broadcast
11+
func.func @test_inline(%v: f32) -> vector<4xf32> {
12+
%0 = call @inner_func_inlinable(%v) : (f32) -> vector<4xf32>
13+
return %0 : vector<4xf32>
14+
}

0 commit comments

Comments
 (0)