Skip to content

Commit 1246465

Browse files
committed
Support named op layout propagation and pack processing
1 parent 23dfa97 commit 1246465

File tree

13 files changed

+1611
-0
lines changed

13 files changed

+1611
-0
lines changed

include/gc/Analysis/GlobalAnalysis.h

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
//===- GlobalAnalysis.h - Graph Compiler analysis pass ----------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_ANALYSIS_GLOBALANALYSIS_H
10+
#define MLIR_ANALYSIS_GLOBALANALYSIS_H
11+
12+
#include <numeric>
13+
14+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
15+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
16+
#include "mlir/Pass/Pass.h"
17+
#include "mlir/Support/LLVM.h"
18+
#include "llvm/ADT/DenseMap.h"
19+
#include "llvm/Support/Debug.h"
20+
21+
namespace mlir {
22+
namespace gc {
23+
24+
using namespace mlir;
25+
26+
class TensorLayout {
27+
public:
28+
TensorLayout(ArrayRef<int64_t> outerAxis, ArrayRef<int64_t> innerAxis,
29+
ArrayRef<OpFoldResult> tileSizes)
30+
: outerAxis(outerAxis), innerAxis(innerAxis), tileSizes(tileSizes) {
31+
assert(innerAxis.size() == tileSizes.size());
32+
}
33+
34+
bool isPlainLayout() const {
35+
for (int64_t i = 0; i < static_cast<int64_t>(outerAxis.size()); ++i) {
36+
if (i != outerAxis[i])
37+
return false;
38+
}
39+
return tileSizes.empty() && innerAxis.empty();
40+
}
41+
42+
static TensorLayout createPlainLayout(int64_t rank) {
43+
SmallVector<int64_t> outerAxis(rank, 0);
44+
std::iota(outerAxis.begin(), outerAxis.end(), 0);
45+
return TensorLayout(outerAxis, SmallVector<int64_t>{},
46+
SmallVector<OpFoldResult>{});
47+
}
48+
49+
DenseMap<int64_t, SmallVector<int64_t>> getPlainToPackedAxisMapping() {
50+
DenseMap<int64_t, SmallVector<int64_t>> axisMapping;
51+
int64_t outerAxisSize = outerAxis.size();
52+
for (int64_t i = 0; i < outerAxisSize; ++i) {
53+
axisMapping[outerAxis[i]].push_back(i);
54+
}
55+
for (int64_t i = 0; i < static_cast<int64_t>(innerAxis.size()); ++i) {
56+
axisMapping[innerAxis[i]].push_back(outerAxisSize + i);
57+
}
58+
return axisMapping;
59+
}
60+
61+
FailureOr<int64_t> getPlainAxis(int64_t idx) {
62+
int64_t totalRank = outerAxis.size() + innerAxis.size();
63+
if (idx >= totalRank || idx < 0) {
64+
return failure();
65+
} else if (idx >= static_cast<int64_t>(outerAxis.size())) {
66+
return innerAxis[idx - outerAxis.size()];
67+
} else {
68+
return outerAxis[idx];
69+
}
70+
}
71+
72+
size_t getRank() const { return outerAxis.size(); }
73+
74+
SmallVector<int64_t> getOuterAxis() const { return outerAxis; }
75+
76+
SmallVector<int64_t> getInnerAxis() const { return innerAxis; }
77+
78+
SmallVector<OpFoldResult> getTileSizes() const { return tileSizes; }
79+
80+
friend llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
81+
const TensorLayout &layout);
82+
83+
bool operator==(const TensorLayout &layout);
84+
85+
private:
86+
SmallVector<int64_t> outerAxis;
87+
SmallVector<int64_t> innerAxis;
88+
SmallVector<OpFoldResult> tileSizes;
89+
};
90+
91+
class OperatorLayout {
92+
public:
93+
OperatorLayout() {}
94+
95+
OperatorLayout(SmallVector<TensorLayout> inputLayouts,
96+
SmallVector<TensorLayout> outputLayouts) {
97+
supportedInputLayouts = inputLayouts;
98+
supportedOutputLayouts = outputLayouts;
99+
}
100+
101+
SmallVector<TensorLayout> getSupportedInputLayouts() const {
102+
return supportedInputLayouts;
103+
}
104+
105+
SmallVector<TensorLayout> getSupportedOutputLayouts() const {
106+
return supportedOutputLayouts;
107+
}
108+
109+
TensorLayout getOutputLayout(int64_t idx) const {
110+
assert(idx < static_cast<int64_t>(supportedOutputLayouts.size()));
111+
return supportedOutputLayouts[idx];
112+
}
113+
114+
bool isPlain() const {
115+
for (const auto &layout : llvm::concat<const TensorLayout>(
116+
supportedInputLayouts, supportedOutputLayouts)) {
117+
if (!layout.isPlainLayout())
118+
return false;
119+
}
120+
return true;
121+
}
122+
123+
friend llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
124+
const OperatorLayout &opLayout);
125+
126+
private:
127+
SmallVector<TensorLayout> supportedInputLayouts;
128+
SmallVector<TensorLayout> supportedOutputLayouts;
129+
};
130+
131+
class GlobalAnalysis {
132+
public:
133+
explicit GlobalAnalysis(Operation *root);
134+
135+
FailureOr<OperatorLayout> getOpLayout(Operation *op) {
136+
if (layoutCache.find(op) != layoutCache.end())
137+
return layoutCache[op];
138+
else
139+
return failure("Current op does not have layout information.");
140+
}
141+
142+
private:
143+
DenseMap<Operation *, OperatorLayout> layoutCache;
144+
};
145+
146+
namespace utils {
147+
bool isPackableNamedOp(Operation *op);
148+
}
149+
} // namespace gc
150+
} // namespace mlir
151+
152+
#endif

include/gc/Transforms/Passes.td

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,38 @@ def MergeNestedForall : Pass<"merge-nested-forall"> {
7474
let dependentDialects = ["scf::SCFDialect"];
7575
}
7676

77+
def PropagateLayoutOnNamedOps : Pass<"propagate-layout-on-named-ops"> {
78+
let summary = "Insert and propagte tensor.pack to pack the computation of linalg named ops and tensor ops.";
79+
let description = [{
80+
Insert and propagte tensor.pack on linalg named ops and tensor ops.
81+
}];
82+
let dependentDialects = [
83+
"mlir::tensor::TensorDialect",
84+
"mlir::linalg::LinalgDialect",
85+
"mlir::linalgx::LinalgxDialect"
86+
];
87+
}
88+
89+
def PostProcessPackUnpack : Pass<"post-process-pack-unpack"> {
90+
let summary = "Fold and simplify pack and unpack ops.";
91+
let description = [{
92+
Fold and simplify pack and unpack ops.
93+
}];
94+
let dependentDialects = [
95+
"mlir::tensor::TensorDialect",
96+
"mlir::linalg::LinalgDialect"
97+
];
98+
}
99+
100+
def LowerPackUnpack : Pass<"lower-pack-unpack"> {
101+
let summary = "Lower pack and unpack ops.";
102+
let description = [{
103+
Lower pack and unpack into transpose and shape related ops.
104+
}];
105+
let dependentDialects = [
106+
"mlir::tensor::TensorDialect",
107+
"mlir::linalg::LinalgDialect"
108+
];
109+
}
110+
77111
#endif // GC_DIALECT_GC_PASSES

include/gc/Transforms/Transforms.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===- Transforms.h - transformation utilities ------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef GC_TRANSFORMS_TRANSFORMS_H
10+
#define GC_TRANSFORMS_TRANSFORMS_H
11+
12+
#include "gc/Analysis/GlobalAnalysis.h"
13+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
14+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15+
16+
namespace mlir {
17+
namespace gc {
18+
FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
19+
linalg::LinalgOp linalgOp,
20+
OperatorLayout opLayout);
21+
22+
LogicalResult namedOpLayoutPropagation(RewriterBase &rewriter,
23+
linalg::LinalgOp linalgOp,
24+
OperatorLayout opLayout);
25+
} // namespace gc
26+
} // namespace mlir
27+
28+
#endif // GC_TRANSFORMS_TRANSFORMS_H

lib/gc/Analysis/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS
44

55
add_mlir_library(GCAnalysis
66
MatmulConfigAnalysis.cpp
7+
GlobalAnalysis.cpp
78

89
ADDITIONAL_HEADER_DIRS
910
${PROJECT_SOURCE_DIR}/include

0 commit comments

Comments
 (0)