Skip to content

Commit 1d5e3b2

Browse files
[mlir][spirv] Use ODS generated attribute names for op definitions (#81552)
Since ODS generates getters functions for SPIRV operations' attribute names, we replace instances of these hardcoded strings in the SPIR-V dialect's op parser/printer with function calls for consistency. Fixes #77627 --------- Co-authored-by: Lei Zhang <[email protected]>
1 parent bc6b5be commit 1d5e3b2

File tree

8 files changed

+223
-181
lines changed

8 files changed

+223
-181
lines changed

mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ StringRef stringifyTypeName<FloatType>() {
3333
}
3434

3535
// Verifies an atomic update op.
36-
template <typename ExpectedElementType>
36+
template <typename AtomicOpTy, typename ExpectedElementType>
3737
static LogicalResult verifyAtomicUpdateOp(Operation *op) {
3838
auto ptrType = llvm::cast<spirv::PointerType>(op->getOperand(0).getType());
3939
auto elementType = ptrType.getPointeeType();
@@ -42,8 +42,10 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
4242
<< stringifyTypeName<ExpectedElementType>()
4343
<< " value, found " << elementType;
4444

45+
StringAttr semanticsAttrName =
46+
AtomicOpTy::getSemanticsAttrName(op->getName());
4547
auto memorySemantics =
46-
op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
48+
op->getAttrOfType<spirv::MemorySemanticsAttr>(semanticsAttrName)
4749
.getValue();
4850
if (failed(verifyMemorySemantics(op, memorySemantics))) {
4951
return failure();
@@ -56,95 +58,95 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
5658
//===----------------------------------------------------------------------===//
5759

5860
LogicalResult AtomicAndOp::verify() {
59-
return verifyAtomicUpdateOp<IntegerType>(getOperation());
61+
return verifyAtomicUpdateOp<AtomicAndOp, IntegerType>(getOperation());
6062
}
6163

6264
//===----------------------------------------------------------------------===//
6365
// spirv.AtomicIAddOp
6466
//===----------------------------------------------------------------------===//
6567

6668
LogicalResult AtomicIAddOp::verify() {
67-
return verifyAtomicUpdateOp<IntegerType>(getOperation());
69+
return verifyAtomicUpdateOp<AtomicIAddOp, IntegerType>(getOperation());
6870
}
6971

7072
//===----------------------------------------------------------------------===//
7173
// spirv.EXT.AtomicFAddOp
7274
//===----------------------------------------------------------------------===//
7375

7476
LogicalResult EXTAtomicFAddOp::verify() {
75-
return verifyAtomicUpdateOp<FloatType>(getOperation());
77+
return verifyAtomicUpdateOp<EXTAtomicFAddOp, FloatType>(getOperation());
7678
}
7779

7880
//===----------------------------------------------------------------------===//
7981
// spirv.AtomicIDecrementOp
8082
//===----------------------------------------------------------------------===//
8183

8284
LogicalResult AtomicIDecrementOp::verify() {
83-
return verifyAtomicUpdateOp<IntegerType>(getOperation());
85+
return verifyAtomicUpdateOp<AtomicIDecrementOp, IntegerType>(getOperation());
8486
}
8587

8688
//===----------------------------------------------------------------------===//
8789
// spirv.AtomicIIncrementOp
8890
//===----------------------------------------------------------------------===//
8991

9092
LogicalResult AtomicIIncrementOp::verify() {
91-
return verifyAtomicUpdateOp<IntegerType>(getOperation());
93+
return verifyAtomicUpdateOp<AtomicIIncrementOp, IntegerType>(getOperation());
9294
}
9395

9496
//===----------------------------------------------------------------------===//
9597
// spirv.AtomicISubOp
9698
//===----------------------------------------------------------------------===//
9799

98100
LogicalResult AtomicISubOp::verify() {
99-
return verifyAtomicUpdateOp<IntegerType>(getOperation());
101+
return verifyAtomicUpdateOp<AtomicISubOp, IntegerType>(getOperation());
100102
}
101103

102104
//===----------------------------------------------------------------------===//
103105
// spirv.AtomicOrOp
104106
//===----------------------------------------------------------------------===//
105107

106108
LogicalResult AtomicOrOp::verify() {
107-
return verifyAtomicUpdateOp<IntegerType>(getOperation());
109+
return verifyAtomicUpdateOp<AtomicOrOp, IntegerType>(getOperation());
108110
}
109111

110112
//===----------------------------------------------------------------------===//
111113
// spirv.AtomicSMaxOp
112114
//===----------------------------------------------------------------------===//
113115

114116
LogicalResult AtomicSMaxOp::verify() {
115-
return verifyAtomicUpdateOp<IntegerType>(getOperation());
117+
return verifyAtomicUpdateOp<AtomicSMaxOp, IntegerType>(getOperation());
116118
}
117119

118120
//===----------------------------------------------------------------------===//
119121
// spirv.AtomicSMinOp
120122
//===----------------------------------------------------------------------===//
121123

122124
LogicalResult AtomicSMinOp::verify() {
123-
return verifyAtomicUpdateOp<IntegerType>(getOperation());
125+
return verifyAtomicUpdateOp<AtomicSMinOp, IntegerType>(getOperation());
124126
}
125127

126128
//===----------------------------------------------------------------------===//
127129
// spirv.AtomicUMaxOp
128130
//===----------------------------------------------------------------------===//
129131

130132
LogicalResult AtomicUMaxOp::verify() {
131-
return verifyAtomicUpdateOp<IntegerType>(getOperation());
133+
return verifyAtomicUpdateOp<AtomicUMaxOp, IntegerType>(getOperation());
132134
}
133135

134136
//===----------------------------------------------------------------------===//
135137
// spirv.AtomicUMinOp
136138
//===----------------------------------------------------------------------===//
137139

138140
LogicalResult AtomicUMinOp::verify() {
139-
return verifyAtomicUpdateOp<IntegerType>(getOperation());
141+
return verifyAtomicUpdateOp<AtomicUMinOp, IntegerType>(getOperation());
140142
}
141143

142144
//===----------------------------------------------------------------------===//
143145
// spirv.AtomicXorOp
144146
//===----------------------------------------------------------------------===//
145147

146148
LogicalResult AtomicXorOp::verify() {
147-
return verifyAtomicUpdateOp<IntegerType>(getOperation());
149+
return verifyAtomicUpdateOp<AtomicXorOp, IntegerType>(getOperation());
148150
}
149151

150152
} // namespace mlir::spirv

mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
8787
parser.parseRSquare())
8888
return failure();
8989

90-
result.addAttribute(kBranchWeightAttrName,
90+
StringAttr branchWeightsAttrName =
91+
BranchConditionalOp::getBranchWeightsAttrName(result.name);
92+
result.addAttribute(branchWeightsAttrName,
9193
builder.getArrayAttr({trueWeight, falseWeight}));
9294
}
9395

@@ -199,11 +201,11 @@ LogicalResult FunctionCallOp::verify() {
199201
}
200202

201203
CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
202-
return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
204+
return (*this)->getAttrOfType<SymbolRefAttr>(getCalleeAttrName());
203205
}
204206

205207
void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
206-
(*this)->setAttr(kCallee, callee.get<SymbolRefAttr>());
208+
(*this)->setAttr(getCalleeAttrName(), callee.get<SymbolRefAttr>());
207209
}
208210

209211
Operation::operand_range FunctionCallOp::getArgOperands() {

0 commit comments

Comments
 (0)