Skip to content

[mlir] Use StringRef::operator== instead of StringRef::equals (NFC) #91560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

kazutakahirata
Copy link
Contributor

I'm planning to remove StringRef::equals in favor of
StringRef::operator==.

  • StringRef::operator==/!= outnumber StringRef::equals by a factor of
    10 under mlir/ in terms of their usage.

  • The elimination of StringRef::equals brings StringRef closer to
    std::string_view, which has operator== but not equals.

  • S == "foo" is more readable than S.equals("foo"), especially for
    !Long.Expression.equals("str") vs Long.Expression != "str".

I'm planning to remove StringRef::equals in favor of
StringRef::operator==.

- StringRef::operator==/!= outnumber StringRef::equals by a factor of
  10 under mlir/ in terms of their usage.

- The elimination of StringRef::equals brings StringRef closer to
  std::string_view, which has operator== but not equals.

- S == "foo" is more readable than S.equals("foo"), especially for
  !Long.Expression.equals("str") vs Long.Expression != "str".
@llvmbot
Copy link
Member

llvmbot commented May 9, 2024

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-affine

Author: Kazu Hirata (kazutakahirata)

Changes

I'm planning to remove StringRef::equals in favor of
StringRef::operator==.

  • StringRef::operator==/!= outnumber StringRef::equals by a factor of
    10 under mlir/ in terms of their usage.

  • The elimination of StringRef::equals brings StringRef closer to
    std::string_view, which has operator== but not equals.

  • S == "foo" is more readable than S.equals("foo"), especially for
    !Long.Expression.equals("str") vs Long.Expression != "str".


Full diff: https://github.com/llvm/llvm-project/pull/91560.diff

14 Files Affected:

  • (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+1-1)
  • (modified) mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp (+8-8)
  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+6-8)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+5-8)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h (+1-2)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+6-8)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp (+3-3)
  • (modified) mlir/lib/IR/AttributeDetail.h (+1-1)
  • (modified) mlir/lib/TableGen/Builder.cpp (+1-1)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp (+2-2)
  • (modified) mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp (+1-1)
  • (modified) mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp (+1-1)
  • (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+2-2)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 3a4fc7d8063f4..82bfa9514a884 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -926,7 +926,7 @@ LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
 static bool isDefinedByCallTo(Value value, StringRef functionName) {
   assert(isa<LLVM::LLVMPointerType>(value.getType()));
   if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
-    return defOp.getCallee()->equals(functionName);
+    return *defOp.getCallee() == functionName;
   return false;
 }
 
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 775dd1e609037..b7fd454c60902 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -42,11 +42,11 @@ static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
 static constexpr StringRef kInvalidCaseStr = "Unsupported WMMA variant.";
 
 static NVVM::MMAFrag convertOperand(StringRef operandName) {
-  if (operandName.equals("AOp"))
+  if (operandName == "AOp")
     return NVVM::MMAFrag::a;
-  if (operandName.equals("BOp"))
+  if (operandName == "BOp")
     return NVVM::MMAFrag::b;
-  if (operandName.equals("COp"))
+  if (operandName == "COp")
     return NVVM::MMAFrag::c;
   llvm_unreachable("Unknown operand name");
 }
@@ -55,8 +55,8 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
   if (type.getElementType().isF16())
     return NVVM::MMATypes::f16;
   if (type.getElementType().isF32())
-    return type.getOperand().equals("COp") ? NVVM::MMATypes::f32
-                                           : NVVM::MMATypes::tf32;
+    return type.getOperand() == "COp" ? NVVM::MMATypes::f32
+                                      : NVVM::MMATypes::tf32;
 
   if (type.getElementType().isSignedInteger(8))
     return NVVM::MMATypes::s8;
@@ -99,15 +99,15 @@ struct WmmaLoadOpToNVVMLowering
     NVVM::MMATypes eltype = getElementType(retType);
     // NVVM intrinsics require to give mxnxk dimensions, infer the missing
     // dimension based on the valid intrinsics available.
-    if (retType.getOperand().equals("AOp")) {
+    if (retType.getOperand() == "AOp") {
       m = retTypeShape[0];
       k = retTypeShape[1];
       n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype);
-    } else if (retType.getOperand().equals("BOp")) {
+    } else if (retType.getOperand() == "BOp") {
       k = retTypeShape[0];
       n = retTypeShape[1];
       m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype);
-    } else if (retType.getOperand().equals("COp")) {
+    } else if (retType.getOperand() == "COp") {
       m = retTypeShape[0];
       n = retTypeShape[1];
       k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype);
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index f8485e02a2208..19f02297bfbb7 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -261,7 +261,7 @@ static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
 template <typename OpTy>
 static bool isTensorOp(OpTy xferOp) {
   if (isa<RankedTensorType>(xferOp.getShapedType())) {
-    if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
+    if (xferOp.getOperationName() == TransferWriteOp::getOperationName()) {
       // TransferWriteOps on tensors have a result.
       assert(xferOp->getNumResults() > 0);
     }
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index c9c0a7b4cc686..2e31487bd55a0 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -3585,20 +3585,18 @@ ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
       parser.resolveOperands(mapOperands, indexTy, result.operands))
     return failure();
 
-  if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
+  if (readOrWrite != "read" && readOrWrite != "write")
     return parser.emitError(parser.getNameLoc(),
                             "rw specifier has to be 'read' or 'write'");
-  result.addAttribute(
-      AffinePrefetchOp::getIsWriteAttrStrName(),
-      parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
+  result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
+                      parser.getBuilder().getBoolAttr(readOrWrite == "write"));
 
-  if (!cacheType.equals("data") && !cacheType.equals("instr"))
+  if (cacheType != "data" && cacheType != "instr")
     return parser.emitError(parser.getNameLoc(),
                             "cache type has to be 'data' or 'instr'");
 
-  result.addAttribute(
-      AffinePrefetchOp::getIsDataCacheAttrStrName(),
-      parser.getBuilder().getBoolAttr(cacheType.equals("data")));
+  result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
+                      parser.getBuilder().getBoolAttr(cacheType == "data"));
 
   return success();
 }
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index f1b9ca5c50020..0c2590d711301 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -152,8 +152,7 @@ LogicalResult
 MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
                       ArrayRef<int64_t> shape, Type elementType,
                       StringRef operand) {
-  if (!operand.equals("AOp") && !operand.equals("BOp") &&
-      !operand.equals("COp"))
+  if (operand != "AOp" && operand != "BOp" && operand != "COp")
     return emitError() << "operand expected to be one of AOp, BOp or COp";
 
   if (shape.size() != 2)
@@ -1941,8 +1940,7 @@ LogicalResult SubgroupMmaLoadMatrixOp::verify() {
     return emitError(
         "expected source memref most minor dim must have unit stride");
 
-  if (!operand.equals("AOp") && !operand.equals("BOp") &&
-      !operand.equals("COp"))
+  if (operand != "AOp" && operand != "BOp" && operand != "COp")
     return emitError("only AOp, BOp and COp can be loaded");
 
   return success();
@@ -1962,7 +1960,7 @@ LogicalResult SubgroupMmaStoreMatrixOp::verify() {
     return emitError(
         "expected destination memref most minor dim must have unit stride");
 
-  if (!srcMatrixType.getOperand().equals("COp"))
+  if (srcMatrixType.getOperand() != "COp")
     return emitError(
         "expected the operand matrix being stored to have 'COp' operand type");
 
@@ -1980,9 +1978,8 @@ LogicalResult SubgroupMmaComputeOp::verify() {
   opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
   opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
 
-  if (!opTypes[A].getOperand().equals("AOp") ||
-      !opTypes[B].getOperand().equals("BOp") ||
-      !opTypes[C].getOperand().equals("COp"))
+  if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
+      opTypes[C].getOperand() != "COp")
     return emitError("operands must be in the order AOp, BOp, COp");
 
   ArrayRef<int64_t> aShape, bShape, cShape;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
index 2040d0a06b2e3..8767b1c3ffc5b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
+++ b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
@@ -131,8 +131,7 @@ struct LLVMStructTypeStorage : public TypeStorage {
     /// Compares two keys.
     bool operator==(const Key &other) const {
       if (isIdentified())
-        return other.isIdentified() &&
-               other.getIdentifier().equals(getIdentifier());
+        return other.isIdentified() && other.getIdentifier() == getIdentifier();
 
       return !other.isIdentified() && other.isPacked() == isPacked() &&
              other.getTypeList() == getTypeList();
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index c9a85919ec799..199e7330a233c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1742,20 +1742,18 @@ ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
       parser.resolveOperands(indexInfo, indexTy, result.operands))
     return failure();
 
-  if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
+  if (readOrWrite != "read" && readOrWrite != "write")
     return parser.emitError(parser.getNameLoc(),
                             "rw specifier has to be 'read' or 'write'");
-  result.addAttribute(
-      PrefetchOp::getIsWriteAttrStrName(),
-      parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
+  result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
+                      parser.getBuilder().getBoolAttr(readOrWrite == "write"));
 
-  if (!cacheType.equals("data") && !cacheType.equals("instr"))
+  if (cacheType != "data" && cacheType != "instr")
     return parser.emitError(parser.getNameLoc(),
                             "cache type has to be 'data' or 'instr'");
 
-  result.addAttribute(
-      PrefetchOp::getIsDataCacheAttrStrName(),
-      parser.getBuilder().getBoolAttr(cacheType.equals("data")));
+  result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
+                      parser.getBuilder().getBoolAttr(cacheType == "data"));
 
   return success();
 }
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 92e5efaa81049..39f5cf1a75082 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -89,11 +89,11 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
   auto loc = parser.getCurrentLocation();
   ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
            "expected valid level property (e.g. nonordered, nonunique or high)")
-  if (strVal.equals(toPropString(LevelPropNonDefault::Nonunique))) {
+  if (strVal == toPropString(LevelPropNonDefault::Nonunique)) {
     *properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonunique);
-  } else if (strVal.equals(toPropString(LevelPropNonDefault::Nonordered))) {
+  } else if (strVal == toPropString(LevelPropNonDefault::Nonordered)) {
     *properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonordered);
-  } else if (strVal.equals(toPropString(LevelPropNonDefault::SoA))) {
+  } else if (strVal == toPropString(LevelPropNonDefault::SoA)) {
     *properties |= static_cast<uint64_t>(LevelPropNonDefault::SoA);
   } else {
     parser.emitError(loc, "unknown level property: ") << strVal;
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index dcd24af0107dd..26d40ac3a38f6 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -261,7 +261,7 @@ struct DenseStringElementsAttrStorage : public DenseElementsAttributeStorage {
     // Check to see if this storage represents a splat. If it doesn't then
     // combine the hash for the data starting with the first non splat element.
     for (size_t i = 1, e = data.size(); i != e; i++)
-      if (!firstElt.equals(data[i]))
+      if (firstElt != data[i])
         return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
 
     // Otherwise, this is a splat so just return the hash of the first element.
diff --git a/mlir/lib/TableGen/Builder.cpp b/mlir/lib/TableGen/Builder.cpp
index 47a2f6cc4456e..044765c726019 100644
--- a/mlir/lib/TableGen/Builder.cpp
+++ b/mlir/lib/TableGen/Builder.cpp
@@ -52,7 +52,7 @@ Builder::Builder(const llvm::Record *record, ArrayRef<SMLoc> loc)
   // Initialize the parameters of the builder.
   const llvm::DagInit *dag = def->getValueAsDag("dagParams");
   auto *defInit = dyn_cast<llvm::DefInit>(dag->getOperator());
-  if (!defInit || !defInit->getDef()->getName().equals("ins"))
+  if (!defInit || defInit->getDef()->getName() != "ins")
     PrintFatalError(def->getLoc(), "expected 'ins' in builders");
 
   bool seenDefaultValue = false;
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 40d8253d822f6..06673965245c0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -93,7 +93,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
     return failure();
 
   // Handle function entry count metadata.
-  if (name->getString().equals("function_entry_count")) {
+  if (name->getString() == "function_entry_count") {
 
     // TODO support function entry count metadata with GUID fields.
     if (node->getNumOperands() != 2)
@@ -111,7 +111,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
            << "expected function_entry_count to be attached to a function";
   }
 
-  if (!name->getString().equals("branch_weights"))
+  if (name->getString() != "branch_weights")
     return failure();
 
   // Handle branch weights metadata.
diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
index c376d6c73c645..ebaced57a24a4 100644
--- a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
@@ -413,7 +413,7 @@ void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
   // of inner-op), then we can print the entire region in a succinct way.
   // Here we assume that the prototype of "test.special.op" can be trivially
   // derived while parsing it back.
-  if (innerOp.getName().getStringRef().equals("test.special.op")) {
+  if (innerOp.getName().getStringRef() == "test.special.op") {
     p << " start test.special.op end";
   } else {
     p << " (";
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index b9a72119790e5..55bc0714c20ec 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -50,7 +50,7 @@ static void collectAllDefs(StringRef selectedDialect,
   } else {
     // Otherwise, generate the defs that belong to the selected dialect.
     auto dialectDefs = llvm::make_filter_range(defs, [&](const auto &def) {
-      return def.getDialect().getName().equals(selectedDialect);
+      return def.getDialect().getName() == selectedDialect;
     });
     resultDefs.assign(dialectDefs.begin(), dialectDefs.end());
   }
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 814008c254511..052020acdcb76 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -457,7 +457,7 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
     std::string sanitizedName = sanitizeName(namedAttr.name);
 
     // Unit attributes are handled specially.
-    if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
+    if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") {
       os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
                           namedAttr.name);
       os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
@@ -668,7 +668,7 @@ populateBuilderLinesAttr(const Operator &op,
       continue;
 
     // Unit attributes are handled specially.
-    if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
+    if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") {
       builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
                                            attribute->name, argNames[i]));
       continue;

@llvmbot
Copy link
Member

llvmbot commented May 9, 2024

@llvm/pr-subscribers-mlir

Author: Kazu Hirata (kazutakahirata)

Changes

I'm planning to remove StringRef::equals in favor of
StringRef::operator==.

  • StringRef::operator==/!= outnumber StringRef::equals by a factor of
    10 under mlir/ in terms of their usage.

  • The elimination of StringRef::equals brings StringRef closer to
    std::string_view, which has operator== but not equals.

  • S == "foo" is more readable than S.equals("foo"), especially for
    !Long.Expression.equals("str") vs Long.Expression != "str".


Full diff: https://github.com/llvm/llvm-project/pull/91560.diff

14 Files Affected:

  • (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+1-1)
  • (modified) mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp (+8-8)
  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+6-8)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+5-8)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h (+1-2)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+6-8)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp (+3-3)
  • (modified) mlir/lib/IR/AttributeDetail.h (+1-1)
  • (modified) mlir/lib/TableGen/Builder.cpp (+1-1)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp (+2-2)
  • (modified) mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp (+1-1)
  • (modified) mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp (+1-1)
  • (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+2-2)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 3a4fc7d8063f4..82bfa9514a884 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -926,7 +926,7 @@ LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
 static bool isDefinedByCallTo(Value value, StringRef functionName) {
   assert(isa<LLVM::LLVMPointerType>(value.getType()));
   if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
-    return defOp.getCallee()->equals(functionName);
+    return *defOp.getCallee() == functionName;
   return false;
 }
 
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 775dd1e609037..b7fd454c60902 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -42,11 +42,11 @@ static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
 static constexpr StringRef kInvalidCaseStr = "Unsupported WMMA variant.";
 
 static NVVM::MMAFrag convertOperand(StringRef operandName) {
-  if (operandName.equals("AOp"))
+  if (operandName == "AOp")
     return NVVM::MMAFrag::a;
-  if (operandName.equals("BOp"))
+  if (operandName == "BOp")
     return NVVM::MMAFrag::b;
-  if (operandName.equals("COp"))
+  if (operandName == "COp")
     return NVVM::MMAFrag::c;
   llvm_unreachable("Unknown operand name");
 }
@@ -55,8 +55,8 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
   if (type.getElementType().isF16())
     return NVVM::MMATypes::f16;
   if (type.getElementType().isF32())
-    return type.getOperand().equals("COp") ? NVVM::MMATypes::f32
-                                           : NVVM::MMATypes::tf32;
+    return type.getOperand() == "COp" ? NVVM::MMATypes::f32
+                                      : NVVM::MMATypes::tf32;
 
   if (type.getElementType().isSignedInteger(8))
     return NVVM::MMATypes::s8;
@@ -99,15 +99,15 @@ struct WmmaLoadOpToNVVMLowering
     NVVM::MMATypes eltype = getElementType(retType);
     // NVVM intrinsics require to give mxnxk dimensions, infer the missing
     // dimension based on the valid intrinsics available.
-    if (retType.getOperand().equals("AOp")) {
+    if (retType.getOperand() == "AOp") {
       m = retTypeShape[0];
       k = retTypeShape[1];
       n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype);
-    } else if (retType.getOperand().equals("BOp")) {
+    } else if (retType.getOperand() == "BOp") {
       k = retTypeShape[0];
       n = retTypeShape[1];
       m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype);
-    } else if (retType.getOperand().equals("COp")) {
+    } else if (retType.getOperand() == "COp") {
       m = retTypeShape[0];
       n = retTypeShape[1];
       k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype);
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index f8485e02a2208..19f02297bfbb7 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -261,7 +261,7 @@ static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
 template <typename OpTy>
 static bool isTensorOp(OpTy xferOp) {
   if (isa<RankedTensorType>(xferOp.getShapedType())) {
-    if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
+    if (xferOp.getOperationName() == TransferWriteOp::getOperationName()) {
       // TransferWriteOps on tensors have a result.
       assert(xferOp->getNumResults() > 0);
     }
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index c9c0a7b4cc686..2e31487bd55a0 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -3585,20 +3585,18 @@ ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
       parser.resolveOperands(mapOperands, indexTy, result.operands))
     return failure();
 
-  if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
+  if (readOrWrite != "read" && readOrWrite != "write")
     return parser.emitError(parser.getNameLoc(),
                             "rw specifier has to be 'read' or 'write'");
-  result.addAttribute(
-      AffinePrefetchOp::getIsWriteAttrStrName(),
-      parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
+  result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
+                      parser.getBuilder().getBoolAttr(readOrWrite == "write"));
 
-  if (!cacheType.equals("data") && !cacheType.equals("instr"))
+  if (cacheType != "data" && cacheType != "instr")
     return parser.emitError(parser.getNameLoc(),
                             "cache type has to be 'data' or 'instr'");
 
-  result.addAttribute(
-      AffinePrefetchOp::getIsDataCacheAttrStrName(),
-      parser.getBuilder().getBoolAttr(cacheType.equals("data")));
+  result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
+                      parser.getBuilder().getBoolAttr(cacheType == "data"));
 
   return success();
 }
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index f1b9ca5c50020..0c2590d711301 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -152,8 +152,7 @@ LogicalResult
 MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
                       ArrayRef<int64_t> shape, Type elementType,
                       StringRef operand) {
-  if (!operand.equals("AOp") && !operand.equals("BOp") &&
-      !operand.equals("COp"))
+  if (operand != "AOp" && operand != "BOp" && operand != "COp")
     return emitError() << "operand expected to be one of AOp, BOp or COp";
 
   if (shape.size() != 2)
@@ -1941,8 +1940,7 @@ LogicalResult SubgroupMmaLoadMatrixOp::verify() {
     return emitError(
         "expected source memref most minor dim must have unit stride");
 
-  if (!operand.equals("AOp") && !operand.equals("BOp") &&
-      !operand.equals("COp"))
+  if (operand != "AOp" && operand != "BOp" && operand != "COp")
     return emitError("only AOp, BOp and COp can be loaded");
 
   return success();
@@ -1962,7 +1960,7 @@ LogicalResult SubgroupMmaStoreMatrixOp::verify() {
     return emitError(
         "expected destination memref most minor dim must have unit stride");
 
-  if (!srcMatrixType.getOperand().equals("COp"))
+  if (srcMatrixType.getOperand() != "COp")
     return emitError(
         "expected the operand matrix being stored to have 'COp' operand type");
 
@@ -1980,9 +1978,8 @@ LogicalResult SubgroupMmaComputeOp::verify() {
   opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
   opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
 
-  if (!opTypes[A].getOperand().equals("AOp") ||
-      !opTypes[B].getOperand().equals("BOp") ||
-      !opTypes[C].getOperand().equals("COp"))
+  if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
+      opTypes[C].getOperand() != "COp")
     return emitError("operands must be in the order AOp, BOp, COp");
 
   ArrayRef<int64_t> aShape, bShape, cShape;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
index 2040d0a06b2e3..8767b1c3ffc5b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
+++ b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
@@ -131,8 +131,7 @@ struct LLVMStructTypeStorage : public TypeStorage {
     /// Compares two keys.
     bool operator==(const Key &other) const {
       if (isIdentified())
-        return other.isIdentified() &&
-               other.getIdentifier().equals(getIdentifier());
+        return other.isIdentified() && other.getIdentifier() == getIdentifier();
 
       return !other.isIdentified() && other.isPacked() == isPacked() &&
              other.getTypeList() == getTypeList();
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index c9a85919ec799..199e7330a233c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1742,20 +1742,18 @@ ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
       parser.resolveOperands(indexInfo, indexTy, result.operands))
     return failure();
 
-  if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
+  if (readOrWrite != "read" && readOrWrite != "write")
     return parser.emitError(parser.getNameLoc(),
                             "rw specifier has to be 'read' or 'write'");
-  result.addAttribute(
-      PrefetchOp::getIsWriteAttrStrName(),
-      parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
+  result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
+                      parser.getBuilder().getBoolAttr(readOrWrite == "write"));
 
-  if (!cacheType.equals("data") && !cacheType.equals("instr"))
+  if (cacheType != "data" && cacheType != "instr")
     return parser.emitError(parser.getNameLoc(),
                             "cache type has to be 'data' or 'instr'");
 
-  result.addAttribute(
-      PrefetchOp::getIsDataCacheAttrStrName(),
-      parser.getBuilder().getBoolAttr(cacheType.equals("data")));
+  result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
+                      parser.getBuilder().getBoolAttr(cacheType == "data"));
 
   return success();
 }
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 92e5efaa81049..39f5cf1a75082 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -89,11 +89,11 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
   auto loc = parser.getCurrentLocation();
   ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
            "expected valid level property (e.g. nonordered, nonunique or high)")
-  if (strVal.equals(toPropString(LevelPropNonDefault::Nonunique))) {
+  if (strVal == toPropString(LevelPropNonDefault::Nonunique)) {
     *properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonunique);
-  } else if (strVal.equals(toPropString(LevelPropNonDefault::Nonordered))) {
+  } else if (strVal == toPropString(LevelPropNonDefault::Nonordered)) {
     *properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonordered);
-  } else if (strVal.equals(toPropString(LevelPropNonDefault::SoA))) {
+  } else if (strVal == toPropString(LevelPropNonDefault::SoA)) {
     *properties |= static_cast<uint64_t>(LevelPropNonDefault::SoA);
   } else {
     parser.emitError(loc, "unknown level property: ") << strVal;
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index dcd24af0107dd..26d40ac3a38f6 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -261,7 +261,7 @@ struct DenseStringElementsAttrStorage : public DenseElementsAttributeStorage {
     // Check to see if this storage represents a splat. If it doesn't then
     // combine the hash for the data starting with the first non splat element.
     for (size_t i = 1, e = data.size(); i != e; i++)
-      if (!firstElt.equals(data[i]))
+      if (firstElt != data[i])
         return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
 
     // Otherwise, this is a splat so just return the hash of the first element.
diff --git a/mlir/lib/TableGen/Builder.cpp b/mlir/lib/TableGen/Builder.cpp
index 47a2f6cc4456e..044765c726019 100644
--- a/mlir/lib/TableGen/Builder.cpp
+++ b/mlir/lib/TableGen/Builder.cpp
@@ -52,7 +52,7 @@ Builder::Builder(const llvm::Record *record, ArrayRef<SMLoc> loc)
   // Initialize the parameters of the builder.
   const llvm::DagInit *dag = def->getValueAsDag("dagParams");
   auto *defInit = dyn_cast<llvm::DefInit>(dag->getOperator());
-  if (!defInit || !defInit->getDef()->getName().equals("ins"))
+  if (!defInit || defInit->getDef()->getName() != "ins")
     PrintFatalError(def->getLoc(), "expected 'ins' in builders");
 
   bool seenDefaultValue = false;
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 40d8253d822f6..06673965245c0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -93,7 +93,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
     return failure();
 
   // Handle function entry count metadata.
-  if (name->getString().equals("function_entry_count")) {
+  if (name->getString() == "function_entry_count") {
 
     // TODO support function entry count metadata with GUID fields.
     if (node->getNumOperands() != 2)
@@ -111,7 +111,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
            << "expected function_entry_count to be attached to a function";
   }
 
-  if (!name->getString().equals("branch_weights"))
+  if (name->getString() != "branch_weights")
     return failure();
 
   // Handle branch weights metadata.
diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
index c376d6c73c645..ebaced57a24a4 100644
--- a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
@@ -413,7 +413,7 @@ void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
   // of inner-op), then we can print the entire region in a succinct way.
   // Here we assume that the prototype of "test.special.op" can be trivially
   // derived while parsing it back.
-  if (innerOp.getName().getStringRef().equals("test.special.op")) {
+  if (innerOp.getName().getStringRef() == "test.special.op") {
     p << " start test.special.op end";
   } else {
     p << " (";
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index b9a72119790e5..55bc0714c20ec 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -50,7 +50,7 @@ static void collectAllDefs(StringRef selectedDialect,
   } else {
     // Otherwise, generate the defs that belong to the selected dialect.
     auto dialectDefs = llvm::make_filter_range(defs, [&](const auto &def) {
-      return def.getDialect().getName().equals(selectedDialect);
+      return def.getDialect().getName() == selectedDialect;
     });
     resultDefs.assign(dialectDefs.begin(), dialectDefs.end());
   }
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 814008c254511..052020acdcb76 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -457,7 +457,7 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
     std::string sanitizedName = sanitizeName(namedAttr.name);
 
     // Unit attributes are handled specially.
-    if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
+    if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") {
       os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
                           namedAttr.name);
       os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
@@ -668,7 +668,7 @@ populateBuilderLinesAttr(const Operator &op,
       continue;
 
     // Unit attributes are handled specially.
-    if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
+    if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") {
       builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
                                            attribute->name, argNames[i]));
       continue;

@llvmbot
Copy link
Member

llvmbot commented May 9, 2024

@llvm/pr-subscribers-mlir-sparse

Author: Kazu Hirata (kazutakahirata)

Changes

I'm planning to remove StringRef::equals in favor of
StringRef::operator==.

  • StringRef::operator==/!= outnumber StringRef::equals by a factor of
    10 under mlir/ in terms of their usage.

  • The elimination of StringRef::equals brings StringRef closer to
    std::string_view, which has operator== but not equals.

  • S == "foo" is more readable than S.equals("foo"), especially for
    !Long.Expression.equals("str") vs Long.Expression != "str".


Full diff: https://github.com/llvm/llvm-project/pull/91560.diff

14 Files Affected:

  • (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+1-1)
  • (modified) mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp (+8-8)
  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+6-8)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+5-8)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h (+1-2)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+6-8)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp (+3-3)
  • (modified) mlir/lib/IR/AttributeDetail.h (+1-1)
  • (modified) mlir/lib/TableGen/Builder.cpp (+1-1)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp (+2-2)
  • (modified) mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp (+1-1)
  • (modified) mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp (+1-1)
  • (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+2-2)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 3a4fc7d8063f4..82bfa9514a884 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -926,7 +926,7 @@ LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
 static bool isDefinedByCallTo(Value value, StringRef functionName) {
   assert(isa<LLVM::LLVMPointerType>(value.getType()));
   if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
-    return defOp.getCallee()->equals(functionName);
+    return *defOp.getCallee() == functionName;
   return false;
 }
 
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 775dd1e609037..b7fd454c60902 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -42,11 +42,11 @@ static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
 static constexpr StringRef kInvalidCaseStr = "Unsupported WMMA variant.";
 
 static NVVM::MMAFrag convertOperand(StringRef operandName) {
-  if (operandName.equals("AOp"))
+  if (operandName == "AOp")
     return NVVM::MMAFrag::a;
-  if (operandName.equals("BOp"))
+  if (operandName == "BOp")
     return NVVM::MMAFrag::b;
-  if (operandName.equals("COp"))
+  if (operandName == "COp")
     return NVVM::MMAFrag::c;
   llvm_unreachable("Unknown operand name");
 }
@@ -55,8 +55,8 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
   if (type.getElementType().isF16())
     return NVVM::MMATypes::f16;
   if (type.getElementType().isF32())
-    return type.getOperand().equals("COp") ? NVVM::MMATypes::f32
-                                           : NVVM::MMATypes::tf32;
+    return type.getOperand() == "COp" ? NVVM::MMATypes::f32
+                                      : NVVM::MMATypes::tf32;
 
   if (type.getElementType().isSignedInteger(8))
     return NVVM::MMATypes::s8;
@@ -99,15 +99,15 @@ struct WmmaLoadOpToNVVMLowering
     NVVM::MMATypes eltype = getElementType(retType);
     // NVVM intrinsics require to give mxnxk dimensions, infer the missing
     // dimension based on the valid intrinsics available.
-    if (retType.getOperand().equals("AOp")) {
+    if (retType.getOperand() == "AOp") {
       m = retTypeShape[0];
       k = retTypeShape[1];
       n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype);
-    } else if (retType.getOperand().equals("BOp")) {
+    } else if (retType.getOperand() == "BOp") {
       k = retTypeShape[0];
       n = retTypeShape[1];
       m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype);
-    } else if (retType.getOperand().equals("COp")) {
+    } else if (retType.getOperand() == "COp") {
       m = retTypeShape[0];
       n = retTypeShape[1];
       k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype);
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index f8485e02a2208..19f02297bfbb7 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -261,7 +261,7 @@ static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
 template <typename OpTy>
 static bool isTensorOp(OpTy xferOp) {
   if (isa<RankedTensorType>(xferOp.getShapedType())) {
-    if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
+    if (xferOp.getOperationName() == TransferWriteOp::getOperationName()) {
       // TransferWriteOps on tensors have a result.
       assert(xferOp->getNumResults() > 0);
     }
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index c9c0a7b4cc686..2e31487bd55a0 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -3585,20 +3585,18 @@ ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
       parser.resolveOperands(mapOperands, indexTy, result.operands))
     return failure();
 
-  if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
+  if (readOrWrite != "read" && readOrWrite != "write")
     return parser.emitError(parser.getNameLoc(),
                             "rw specifier has to be 'read' or 'write'");
-  result.addAttribute(
-      AffinePrefetchOp::getIsWriteAttrStrName(),
-      parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
+  result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
+                      parser.getBuilder().getBoolAttr(readOrWrite == "write"));
 
-  if (!cacheType.equals("data") && !cacheType.equals("instr"))
+  if (cacheType != "data" && cacheType != "instr")
     return parser.emitError(parser.getNameLoc(),
                             "cache type has to be 'data' or 'instr'");
 
-  result.addAttribute(
-      AffinePrefetchOp::getIsDataCacheAttrStrName(),
-      parser.getBuilder().getBoolAttr(cacheType.equals("data")));
+  result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
+                      parser.getBuilder().getBoolAttr(cacheType == "data"));
 
   return success();
 }
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index f1b9ca5c50020..0c2590d711301 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -152,8 +152,7 @@ LogicalResult
 MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
                       ArrayRef<int64_t> shape, Type elementType,
                       StringRef operand) {
-  if (!operand.equals("AOp") && !operand.equals("BOp") &&
-      !operand.equals("COp"))
+  if (operand != "AOp" && operand != "BOp" && operand != "COp")
     return emitError() << "operand expected to be one of AOp, BOp or COp";
 
   if (shape.size() != 2)
@@ -1941,8 +1940,7 @@ LogicalResult SubgroupMmaLoadMatrixOp::verify() {
     return emitError(
         "expected source memref most minor dim must have unit stride");
 
-  if (!operand.equals("AOp") && !operand.equals("BOp") &&
-      !operand.equals("COp"))
+  if (operand != "AOp" && operand != "BOp" && operand != "COp")
     return emitError("only AOp, BOp and COp can be loaded");
 
   return success();
@@ -1962,7 +1960,7 @@ LogicalResult SubgroupMmaStoreMatrixOp::verify() {
     return emitError(
         "expected destination memref most minor dim must have unit stride");
 
-  if (!srcMatrixType.getOperand().equals("COp"))
+  if (srcMatrixType.getOperand() != "COp")
     return emitError(
         "expected the operand matrix being stored to have 'COp' operand type");
 
@@ -1980,9 +1978,8 @@ LogicalResult SubgroupMmaComputeOp::verify() {
   opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
   opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
 
-  if (!opTypes[A].getOperand().equals("AOp") ||
-      !opTypes[B].getOperand().equals("BOp") ||
-      !opTypes[C].getOperand().equals("COp"))
+  if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
+      opTypes[C].getOperand() != "COp")
     return emitError("operands must be in the order AOp, BOp, COp");
 
   ArrayRef<int64_t> aShape, bShape, cShape;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
index 2040d0a06b2e3..8767b1c3ffc5b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
+++ b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
@@ -131,8 +131,7 @@ struct LLVMStructTypeStorage : public TypeStorage {
     /// Compares two keys.
     bool operator==(const Key &other) const {
       if (isIdentified())
-        return other.isIdentified() &&
-               other.getIdentifier().equals(getIdentifier());
+        return other.isIdentified() && other.getIdentifier() == getIdentifier();
 
       return !other.isIdentified() && other.isPacked() == isPacked() &&
              other.getTypeList() == getTypeList();
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index c9a85919ec799..199e7330a233c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1742,20 +1742,18 @@ ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
       parser.resolveOperands(indexInfo, indexTy, result.operands))
     return failure();
 
-  if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
+  if (readOrWrite != "read" && readOrWrite != "write")
     return parser.emitError(parser.getNameLoc(),
                             "rw specifier has to be 'read' or 'write'");
-  result.addAttribute(
-      PrefetchOp::getIsWriteAttrStrName(),
-      parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
+  result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
+                      parser.getBuilder().getBoolAttr(readOrWrite == "write"));
 
-  if (!cacheType.equals("data") && !cacheType.equals("instr"))
+  if (cacheType != "data" && cacheType != "instr")
     return parser.emitError(parser.getNameLoc(),
                             "cache type has to be 'data' or 'instr'");
 
-  result.addAttribute(
-      PrefetchOp::getIsDataCacheAttrStrName(),
-      parser.getBuilder().getBoolAttr(cacheType.equals("data")));
+  result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
+                      parser.getBuilder().getBoolAttr(cacheType == "data"));
 
   return success();
 }
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 92e5efaa81049..39f5cf1a75082 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -89,11 +89,11 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
   auto loc = parser.getCurrentLocation();
   ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
            "expected valid level property (e.g. nonordered, nonunique or high)")
-  if (strVal.equals(toPropString(LevelPropNonDefault::Nonunique))) {
+  if (strVal == toPropString(LevelPropNonDefault::Nonunique)) {
     *properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonunique);
-  } else if (strVal.equals(toPropString(LevelPropNonDefault::Nonordered))) {
+  } else if (strVal == toPropString(LevelPropNonDefault::Nonordered)) {
     *properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonordered);
-  } else if (strVal.equals(toPropString(LevelPropNonDefault::SoA))) {
+  } else if (strVal == toPropString(LevelPropNonDefault::SoA)) {
     *properties |= static_cast<uint64_t>(LevelPropNonDefault::SoA);
   } else {
     parser.emitError(loc, "unknown level property: ") << strVal;
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index dcd24af0107dd..26d40ac3a38f6 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -261,7 +261,7 @@ struct DenseStringElementsAttrStorage : public DenseElementsAttributeStorage {
     // Check to see if this storage represents a splat. If it doesn't then
     // combine the hash for the data starting with the first non splat element.
     for (size_t i = 1, e = data.size(); i != e; i++)
-      if (!firstElt.equals(data[i]))
+      if (firstElt != data[i])
         return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
 
     // Otherwise, this is a splat so just return the hash of the first element.
diff --git a/mlir/lib/TableGen/Builder.cpp b/mlir/lib/TableGen/Builder.cpp
index 47a2f6cc4456e..044765c726019 100644
--- a/mlir/lib/TableGen/Builder.cpp
+++ b/mlir/lib/TableGen/Builder.cpp
@@ -52,7 +52,7 @@ Builder::Builder(const llvm::Record *record, ArrayRef<SMLoc> loc)
   // Initialize the parameters of the builder.
   const llvm::DagInit *dag = def->getValueAsDag("dagParams");
   auto *defInit = dyn_cast<llvm::DefInit>(dag->getOperator());
-  if (!defInit || !defInit->getDef()->getName().equals("ins"))
+  if (!defInit || defInit->getDef()->getName() != "ins")
     PrintFatalError(def->getLoc(), "expected 'ins' in builders");
 
   bool seenDefaultValue = false;
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 40d8253d822f6..06673965245c0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -93,7 +93,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
     return failure();
 
   // Handle function entry count metadata.
-  if (name->getString().equals("function_entry_count")) {
+  if (name->getString() == "function_entry_count") {
 
     // TODO support function entry count metadata with GUID fields.
     if (node->getNumOperands() != 2)
@@ -111,7 +111,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
            << "expected function_entry_count to be attached to a function";
   }
 
-  if (!name->getString().equals("branch_weights"))
+  if (name->getString() != "branch_weights")
     return failure();
 
   // Handle branch weights metadata.
diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
index c376d6c73c645..ebaced57a24a4 100644
--- a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
@@ -413,7 +413,7 @@ void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
   // of inner-op), then we can print the entire region in a succinct way.
   // Here we assume that the prototype of "test.special.op" can be trivially
   // derived while parsing it back.
-  if (innerOp.getName().getStringRef().equals("test.special.op")) {
+  if (innerOp.getName().getStringRef() == "test.special.op") {
     p << " start test.special.op end";
   } else {
     p << " (";
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index b9a72119790e5..55bc0714c20ec 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -50,7 +50,7 @@ static void collectAllDefs(StringRef selectedDialect,
   } else {
     // Otherwise, generate the defs that belong to the selected dialect.
     auto dialectDefs = llvm::make_filter_range(defs, [&](const auto &def) {
-      return def.getDialect().getName().equals(selectedDialect);
+      return def.getDialect().getName() == selectedDialect;
     });
     resultDefs.assign(dialectDefs.begin(), dialectDefs.end());
   }
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 814008c254511..052020acdcb76 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -457,7 +457,7 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
     std::string sanitizedName = sanitizeName(namedAttr.name);
 
     // Unit attributes are handled specially.
-    if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
+    if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") {
       os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
                           namedAttr.name);
       os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
@@ -668,7 +668,7 @@ populateBuilderLinesAttr(const Operator &op,
       continue;
 
     // Unit attributes are handled specially.
-    if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
+    if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") {
       builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
                                            attribute->name, argNames[i]));
       continue;

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

@kazutakahirata kazutakahirata merged commit dec8055 into llvm:main May 9, 2024
@kazutakahirata kazutakahirata deleted the cleanup_StringRef_equals_mlir branch May 9, 2024 06:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants