Skip to content

Commit 36c9c9e

Browse files
committed
Add target description query and verifier pass
1 parent de0376f commit 36c9c9e

File tree

14 files changed

+423
-26
lines changed

14 files changed

+423
-26
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ endif()
106106
set(GC_LIB_LINKED_LIBS
107107
GCJitWrapper
108108
GCCpuRuntime
109+
GCPasses
110+
GCAnalysis
109111
)
110112
add_mlir_library(graph_compiler SHARED ${GC_LIB_SOURCES})
111113
target_include_directories(graph_compiler PUBLIC ${GC_LIB_INCLUDES})
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
//===-- TargetDescriptionAnalysis.h - target description class --*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_ANALYSIS_TARGETDESCRIPTIONANALYSIS_H
10+
#define MLIR_ANALYSIS_TARGETDESCRIPTIONANALYSIS_H
11+
12+
#include "gc/Dialect/Linalgx/LinalgxOps.h"
13+
#include "mlir/Dialect/DLTI/DLTI.h"
14+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
15+
#include "mlir/Interfaces/DataLayoutInterfaces.h"
16+
#include "llvm/ADT/StringRef.h"
17+
18+
namespace mlir {
19+
namespace gc {
20+
21+
using namespace mlir;
22+
23+
class TargetDescriptionAnalysisBase {
24+
public:
25+
TargetDescriptionAnalysisBase(Operation *op, std::string device)
26+
: ctx(op->getContext()), device(device),
27+
layout(isa<ModuleOp>(op) ? dyn_cast<ModuleOp>(op)
28+
: op->getParentOfType<ModuleOp>()),
29+
loc(op->getLoc()) {}
30+
// get the device ID
31+
std::string getDevice() { return device; }
32+
33+
// get the MLIR context
34+
MLIRContext *getContext() { return ctx; }
35+
36+
// get the data layout
37+
DataLayout getLayout() { return layout; }
38+
39+
// get the property value by key
40+
std::optional<Attribute> getPropertyValue(StringRef key);
41+
42+
// get the location
43+
Location getLocation() { return loc; }
44+
45+
// check if the property exists
46+
bool hasProperty(StringRef key) { return getPropertyValue(key).has_value(); }
47+
48+
private:
49+
MLIRContext *ctx;
50+
std::string device;
51+
DataLayout layout;
52+
Location loc;
53+
};
54+
55+
class CPUTargetDescriptionAnalysis : public TargetDescriptionAnalysisBase {
56+
public:
57+
static constexpr StringLiteral kL1CacheSize = "L1_cache_size_in_bytes";
58+
static constexpr StringLiteral kL2CacheSize = "L2_cache_size_in_bytes";
59+
static constexpr StringLiteral kL3CacheSize = "L3_cache_size_in_bytes";
60+
static constexpr StringLiteral kMaxVectorWidth = "max_vector_width";
61+
static constexpr StringLiteral kNumThreads = "num_threads";
62+
63+
// get runtime OMP_NUM_THREADS
64+
size_t getNumThreads();
65+
66+
// get cache size by cacheLevel
67+
size_t getCacheSize(uint8_t cacheLevel);
68+
69+
// get the maximum vector length in bits
70+
size_t getMaxVectorWidth();
71+
72+
CPUTargetDescriptionAnalysis(Operation *op)
73+
: TargetDescriptionAnalysisBase(op, "CPU") {}
74+
};
75+
76+
} // namespace gc
77+
} // namespace mlir
78+
79+
#endif

include/gc/Transforms/Passes.td

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,45 +18,46 @@ def TileLinalgNamed : Pass<"tile-named-linalg", "func::FuncOp"> {
1818
}
1919

2020
def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> {
21-
let summary = "Lower the operations from the oneDNN Graph dialect into Linalg";
22-
let description = [{
23-
Lowers the `onednn_graph` ops to `linalg` ops.
24-
}];
21+
let summary =
22+
"Lower the operations from the oneDNN Graph dialect into Linalg";
23+
let description = [{Lowers the `onednn_graph` ops to `linalg` ops.}];
2524
let dependentDialects = [
26-
"func::FuncDialect",
27-
"math::MathDialect",
28-
"arith::ArithDialect",
29-
"tensor::TensorDialect",
30-
"linalg::LinalgDialect",
31-
"linalgx::LinalgxDialect"
25+
"func::FuncDialect", "math::MathDialect", "arith::ArithDialect",
26+
"tensor::TensorDialect", "linalg::LinalgDialect", "linalgx::LinalgxDialect"
3227
];
3328
}
3429

3530
#ifdef GC_USE_GPU
3631
def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
3732
let summary = "Convert linalg dialect to XeGPU dialect.";
38-
let description = [{
39-
Lower linalg ops to XeGPU dialect.
40-
}];
41-
let dependentDialects = ["linalg::LinalgDialect",
42-
"gpu::GPUDialect",
43-
"xegpu::XeGPUDialect",
44-
"scf::SCFDialect",
45-
"memref::MemRefDialect",
46-
"arith::ArithDialect",
47-
"math::MathDialect",
48-
"vector::VectorDialect"];
33+
let description = [{Lower linalg ops to XeGPU dialect.}];
34+
let dependentDialects = [
35+
"linalg::LinalgDialect", "gpu::GPUDialect", "xegpu::XeGPUDialect",
36+
"scf::SCFDialect", "memref::MemRefDialect", "arith::ArithDialect",
37+
"math::MathDialect", "vector::VectorDialect"
38+
];
4939
let options = [
5040
Option<"kTile", "k-tile", "int64_t",
51-
/*default=*/"32",
52-
"GEMM tile size for reduction dimension.">,
41+
/*default=*/"32", "GEMM tile size for reduction dimension.">,
5342
Option<"stages", "stages", "int64_t",
54-
/*default=*/"1",
55-
"Number of cooperative prefetch stages.">,
43+
/*default=*/"1", "Number of cooperative prefetch stages.">,
5644
ListOption<"dpasTile", "dpas-tile", "int64_t",
5745
"DPAS register block sizes MxNxK">,
5846
];
5947
}
6048
#endif
6149

50+
def VerifyTargetDescription : Pass<"verify-target-description", "ModuleOp"> {
51+
let summary = "Verify the target description from ModuleOp DLTI attribute.";
52+
let description = [{
53+
Verify the target description from ModuleOp DLTI attribute. Raise error for unexpected input(such as a negative number of num_threads), and raise warn for missing fields, and provide a default value(such as 32K for L1_cache_size).
54+
}];
55+
let dependentDialects = ["DLTIDialect"];
56+
let options = [
57+
Option<"device", "device", "std::string",
58+
/*default=*/"\"CPU\"",
59+
"The device to verify. Supported device: CPU, ">,
60+
];
61+
}
62+
6263
#endif // GC_DIALECT_GC_PASSES

lib/gc/Analysis/CMakeLists.txt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS
2+
MLIRIR
3+
MLIRSupport)
4+
5+
add_mlir_library(GCAnalysis
6+
TargetDescriptionAnalysis.cpp
7+
8+
ADDITIONAL_HEADER_DIRS
9+
${PROJECT_SOURCE_DIR}/include
10+
11+
DEPENDS
12+
GraphCompilerPassIncGen
13+
14+
LINK_LIBS PUBLIC
15+
${mlir_dialect_libs}
16+
${MLIR_LINK_COMPONENTS}
17+
)
18+
19+
set_property(GLOBAL APPEND PROPERTY GC_PASS_LIBS GCAnalysis)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
//===-- TargetDescriptionAnalysis.cpp - target description impl -*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "gc/Analysis/TargetDescriptionAnalysis.h"
10+
#include <limits>
11+
#include <llvm/Support/Debug.h>
12+
13+
namespace mlir {
14+
namespace gc {
15+
16+
#define DEBUG_TYPE "target-description-analysis"
17+
18+
// default values for properties
19+
static llvm::DenseMap<StringRef, int64_t> CPUTargetDeafultValueMap = {
20+
{CPUTargetDescriptionAnalysis::kNumThreads, 1},
21+
{CPUTargetDescriptionAnalysis::kL1CacheSize, 32 * 1024},
22+
{CPUTargetDescriptionAnalysis::kL2CacheSize, 32 * 32 * 1024},
23+
{CPUTargetDescriptionAnalysis::kL3CacheSize, 32 * 32 * 1024},
24+
{CPUTargetDescriptionAnalysis::kMaxVectorWidth, 512},
25+
};
26+
27+
static void emitNotFoundWarning(Location loc, StringRef key) {
28+
mlir::emitWarning(loc) << key << " not found, using default value "
29+
<< CPUTargetDeafultValueMap[key];
30+
}
31+
32+
static int64_t getIntFromAttribute(Attribute attr) {
33+
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
34+
if (intAttr.getType().isSignedInteger())
35+
return intAttr.getSInt();
36+
else if (intAttr.getType().isUnsignedInteger())
37+
return intAttr.getUInt();
38+
else
39+
return intAttr.getInt();
40+
}
41+
llvm_unreachable("Not an integer attribute");
42+
}
43+
44+
std::optional<Attribute>
45+
TargetDescriptionAnalysisBase::getPropertyValue(StringRef key) {
46+
return layout.getDevicePropertyValue(
47+
Builder(getContext()).getStringAttr(getDevice() /* device ID*/),
48+
Builder(getContext()).getStringAttr(key));
49+
}
50+
51+
size_t CPUTargetDescriptionAnalysis::getNumThreads() {
52+
std::optional<Attribute> numThreads = getPropertyValue(kNumThreads);
53+
54+
if (numThreads && isa<IntegerAttr>(*numThreads))
55+
return getIntFromAttribute(*numThreads);
56+
emitNotFoundWarning(getLocation(), kNumThreads);
57+
return CPUTargetDeafultValueMap[kNumThreads];
58+
}
59+
60+
size_t CPUTargetDescriptionAnalysis::getCacheSize(uint8_t cacheLevel) {
61+
assert(cacheLevel > 0 && cacheLevel < 4 && "Invalid cache level");
62+
StringLiteral key = "";
63+
if (cacheLevel == 1)
64+
key = kL1CacheSize;
65+
else if (cacheLevel == 2)
66+
key = kL2CacheSize;
67+
else if (cacheLevel == 3)
68+
key = kL3CacheSize;
69+
70+
std::optional<Attribute> cacheSize = getPropertyValue(key);
71+
if (cacheSize && isa<IntegerAttr>(*cacheSize))
72+
return getIntFromAttribute(*cacheSize);
73+
74+
emitNotFoundWarning(getLocation(), key);
75+
return CPUTargetDeafultValueMap[key];
76+
}
77+
78+
size_t CPUTargetDescriptionAnalysis::getMaxVectorWidth() {
79+
std::optional<Attribute> maxVectorWidth = getPropertyValue(kMaxVectorWidth);
80+
if (maxVectorWidth && isa<IntegerAttr>(*maxVectorWidth))
81+
return getIntFromAttribute(*maxVectorWidth);
82+
emitNotFoundWarning(getLocation(), kMaxVectorWidth);
83+
return CPUTargetDeafultValueMap[kMaxVectorWidth];
84+
}
85+
86+
} // namespace gc
87+
} // namespace mlir

lib/gc/CAPI/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ set(GC_ALL_LIBS
22
MLIROneDNNGraph
33
MLIRCPURuntimeDialect
44
GCPasses
5+
GCAnalysis
56
MLIRCPURuntimeTransforms)
67

78
if(GC_USE_GPU)

lib/gc/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ endif()
44

55
include(functions)
66

7+
add_subdirectory(Analysis)
78
add_subdirectory(CAPI)
89
add_subdirectory(Dialect)
910
add_subdirectory(Transforms)

lib/gc/ExecutionEngine/Driver/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,6 @@ add_mlir_library(GCJitWrapper
4242
${dialect_libs}
4343
${conversion_libs}
4444
${GC_PASSES}
45+
GCAnalysis
4546
)
4647

lib/gc/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_mlir_library(GCPasses
1313
OneDNNGraphToLinalg.cpp
1414
Pipeline.cpp
1515
TileNamed.cpp
16+
VerifyTargetDescription.cpp
1617

1718
ADDITIONAL_HEADER_DIRS
1819
${PROJECT_SOURCE_DIR}/include

0 commit comments

Comments
 (0)