Skip to content

[mlir][spirv] Use ODS generated attribute names for op definitions #81552

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ StringRef stringifyTypeName<FloatType>() {
}

// Verifies an atomic update op.
template <typename ExpectedElementType>
template <typename AtomicOpTy, typename ExpectedElementType>
static LogicalResult verifyAtomicUpdateOp(Operation *op) {
auto ptrType = llvm::cast<spirv::PointerType>(op->getOperand(0).getType());
auto elementType = ptrType.getPointeeType();
Expand All @@ -42,8 +42,10 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
<< stringifyTypeName<ExpectedElementType>()
<< " value, found " << elementType;

StringAttr semanticsAttrName =
AtomicOpTy::getSemanticsAttrName(op->getName());
auto memorySemantics =
op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
op->getAttrOfType<spirv::MemorySemanticsAttr>(semanticsAttrName)
.getValue();
if (failed(verifyMemorySemantics(op, memorySemantics))) {
return failure();
Expand All @@ -56,95 +58,95 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
//===----------------------------------------------------------------------===//

LogicalResult AtomicAndOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
return verifyAtomicUpdateOp<AtomicAndOp, IntegerType>(getOperation());
}

//===----------------------------------------------------------------------===//
// spirv.AtomicIAddOp
//===----------------------------------------------------------------------===//

LogicalResult AtomicIAddOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
return verifyAtomicUpdateOp<AtomicIAddOp, IntegerType>(getOperation());
}

//===----------------------------------------------------------------------===//
// spirv.EXT.AtomicFAddOp
//===----------------------------------------------------------------------===//

LogicalResult EXTAtomicFAddOp::verify() {
return verifyAtomicUpdateOp<FloatType>(getOperation());
return verifyAtomicUpdateOp<EXTAtomicFAddOp, FloatType>(getOperation());
}

//===----------------------------------------------------------------------===//
// spirv.AtomicIDecrementOp
//===----------------------------------------------------------------------===//

LogicalResult AtomicIDecrementOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
return verifyAtomicUpdateOp<AtomicIDecrementOp, IntegerType>(getOperation());
}

//===----------------------------------------------------------------------===//
// spirv.AtomicIIncrementOp
//===----------------------------------------------------------------------===//

LogicalResult AtomicIIncrementOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
return verifyAtomicUpdateOp<AtomicIIncrementOp, IntegerType>(getOperation());
}

//===----------------------------------------------------------------------===//
// spirv.AtomicISubOp
//===----------------------------------------------------------------------===//

LogicalResult AtomicISubOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
return verifyAtomicUpdateOp<AtomicISubOp, IntegerType>(getOperation());
}

//===----------------------------------------------------------------------===//
// spirv.AtomicOrOp
//===----------------------------------------------------------------------===//

LogicalResult AtomicOrOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
return verifyAtomicUpdateOp<AtomicOrOp, IntegerType>(getOperation());
}

//===----------------------------------------------------------------------===//
// spirv.AtomicSMaxOp
//===----------------------------------------------------------------------===//

LogicalResult AtomicSMaxOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
return verifyAtomicUpdateOp<AtomicSMaxOp, IntegerType>(getOperation());
}

//===----------------------------------------------------------------------===//
// spirv.AtomicSMinOp
//===----------------------------------------------------------------------===//

LogicalResult AtomicSMinOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
return verifyAtomicUpdateOp<AtomicSMinOp, IntegerType>(getOperation());
}

//===----------------------------------------------------------------------===//
// spirv.AtomicUMaxOp
//===----------------------------------------------------------------------===//

LogicalResult AtomicUMaxOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
return verifyAtomicUpdateOp<AtomicUMaxOp, IntegerType>(getOperation());
}

//===----------------------------------------------------------------------===//
// spirv.AtomicUMinOp
//===----------------------------------------------------------------------===//

LogicalResult AtomicUMinOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
return verifyAtomicUpdateOp<AtomicUMinOp, IntegerType>(getOperation());
}

//===----------------------------------------------------------------------===//
// spirv.AtomicXorOp
//===----------------------------------------------------------------------===//

LogicalResult AtomicXorOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
return verifyAtomicUpdateOp<AtomicXorOp, IntegerType>(getOperation());
}

} // namespace mlir::spirv
8 changes: 5 additions & 3 deletions mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
parser.parseRSquare())
return failure();

result.addAttribute(kBranchWeightAttrName,
StringAttr branchWeightsAttrName =
BranchConditionalOp::getBranchWeightsAttrName(result.name);
result.addAttribute(branchWeightsAttrName,
builder.getArrayAttr({trueWeight, falseWeight}));
}

Expand Down Expand Up @@ -199,11 +201,11 @@ LogicalResult FunctionCallOp::verify() {
}

CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
return (*this)->getAttrOfType<SymbolRefAttr>(getCalleeAttrName());
}

void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
(*this)->setAttr(kCallee, callee.get<SymbolRefAttr>());
(*this)->setAttr(getCalleeAttrName(), callee.get<SymbolRefAttr>());
}

Operation::operand_range FunctionCallOp::getArgOperands() {
Expand Down
Loading