Skip to content

Commit 770f96a

Browse files
ashjeongAnthony Tran
authored andcommitted
switch type and value ordering for arith Constant[XX]Op (llvm#144636)
This change standardizes the order of the parameters for `Constant[XXX] Ops` to match with all other `Op` `build()` constructors. In all instances of generated code for the MLIR dialects's Ops (that is the TableGen using the .td files to create the .h.inc/.cpp.inc files), the desired result type is always specified before the value. Examples: ``` // ArithOps.h.inc class ConstantOp : public ::mlir::Op<ConstantOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::Type>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::ZeroOperands, ::mlir::OpTrait::OpInvariants, ::mlir::BytecodeOpInterface::Trait, ::mlir::OpTrait::ConstantLike, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::OpAsmOpInterface::Trait, ::mlir::InferIntRangeInterface::Trait, ::mlir::InferTypeOpInterface::Trait> { public: .... static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::TypedAttr value); static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypedAttr value); static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::TypedAttr value); static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); ... ``` ``` // ArithOps.h.inc class SubIOp : public ::mlir::Op<SubIOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::Type>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::BytecodeOpInterface::Trait, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::InferIntRangeInterface::Trait, ::mlir::arith::ArithIntegerOverflowFlagsInterface::Trait, ::mlir::OpTrait::SameOperandsAndResultType, ::mlir::VectorUnrollOpInterface::Trait, ::mlir::OpTrait::Elementwise, ::mlir::OpTrait::Scalarizable, ::mlir::OpTrait::Vectorizable, ::mlir::OpTrait::Tensorizable, ::mlir::InferTypeOpInterface::Trait> { public: ... static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::Value lhs, ::mlir::Value rhs, ::mlir::arith::IntegerOverflowFlagsAttr overflowFlags); static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value lhs, ::mlir::Value rhs, ::mlir::arith::IntegerOverflowFlagsAttr overflowFlags); static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs, ::mlir::arith::IntegerOverflowFlagsAttr overflowFlags); static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::Value lhs, ::mlir::Value rhs, ::mlir::arith::IntegerOverflowFlags overflowFlags = ::mlir::arith::IntegerOverflowFlags::none); static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value lhs, ::mlir::Value rhs, ::mlir::arith::IntegerOverflowFlags overflowFlags = ::mlir::arith::IntegerOverflowFlags::none); static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs, ::mlir::arith::IntegerOverflowFlags overflowFlags = ::mlir::arith::IntegerOverflowFlags::none); static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); ... ``` In comparison, in the distinct case of `ConstantIntOp` and `ConstantFloatOp`, the ordering of the result type and the value is switched. Thus, this PR corrects the ordering of the aforementioned `Constant[XXX]Ops` to match with other constructors.
1 parent 9579038 commit 770f96a

File tree

11 files changed

+31
-33
lines changed

11 files changed

+31
-33
lines changed

mlir/include/mlir/Dialect/Arith/IR/Arith.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ class ConstantIntOp : public arith::ConstantOp {
6262

6363
/// Build a constant int op that produces an integer of the specified type,
6464
/// which must be an integer type.
65-
static void build(OpBuilder &builder, OperationState &result, int64_t value,
66-
Type type);
65+
static void build(OpBuilder &builder, OperationState &result, Type type,
66+
int64_t value);
6767

6868
inline int64_t value() {
6969
return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();
@@ -79,8 +79,8 @@ class ConstantFloatOp : public arith::ConstantOp {
7979
static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
8080

8181
/// Build a constant float op that produces a float of the specified type.
82-
static void build(OpBuilder &builder, OperationState &result,
83-
const APFloat &value, FloatType type);
82+
static void build(OpBuilder &builder, OperationState &result, FloatType type,
83+
const APFloat &value);
8484

8585
inline APFloat value() {
8686
return cast<FloatAttr>(arith::ConstantOp::getValue()).getValue();

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
244244

245245
// Clamp to the negation range.
246246
Value min = rewriter.create<arith::ConstantIntOp>(
247-
loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
248-
intermediateType);
247+
loc, intermediateType,
248+
APInt::getSignedMinValue(inputBitWidth).getSExtValue());
249249
Value max = rewriter.create<arith::ConstantIntOp>(
250-
loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
251-
intermediateType);
250+
loc, intermediateType,
251+
APInt::getSignedMaxValue(inputBitWidth).getSExtValue());
252252
auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
253253

254254
// Truncate to the final value.

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,11 +1073,11 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
10731073
int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
10741074

10751075
auto min = rewriter.create<arith::ConstantIntOp>(
1076-
loc, APInt::getSignedMinValue(outBitwidth).getSExtValue(),
1077-
accETy);
1076+
loc, accETy,
1077+
APInt::getSignedMinValue(outBitwidth).getSExtValue());
10781078
auto max = rewriter.create<arith::ConstantIntOp>(
1079-
loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
1080-
accETy);
1079+
loc, accETy,
1080+
APInt::getSignedMaxValue(outBitwidth).getSExtValue());
10811081
auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
10821082
/*isUnsigned=*/false);
10831083

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,7 @@ void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
257257
}
258258

259259
void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
260-
int64_t value, Type type) {
261-
assert(type.isSignlessInteger() &&
262-
"ConstantIntOp can only have signless integer type values");
260+
Type type, int64_t value) {
263261
arith::ConstantOp::build(builder, result, type,
264262
builder.getIntegerAttr(type, value));
265263
}
@@ -271,7 +269,7 @@ bool arith::ConstantIntOp::classof(Operation *op) {
271269
}
272270

273271
void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
274-
const APFloat &value, FloatType type) {
272+
FloatType type, const APFloat &value) {
275273
arith::ConstantOp::build(builder, result, type,
276274
builder.getFloatAttr(type, value));
277275
}
@@ -2363,7 +2361,7 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
23632361
rewriter.create<arith::XOrIOp>(
23642362
op.getLoc(), op.getCondition(),
23652363
rewriter.create<arith::ConstantIntOp>(
2366-
op.getLoc(), 1, op.getCondition().getType())));
2364+
op.getLoc(), op.getCondition().getType(), 1)));
23672365
return success();
23682366
}
23692367

mlir/lib/Dialect/Arith/Utils/Utils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand,
216216
from = b.create<arith::TruncFOp>(toFpTy, from);
217217
}
218218
Value zero = b.create<mlir::arith::ConstantFloatOp>(
219-
mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
219+
toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
220220
return b.create<complex::CreateOp>(targetType, from, zero);
221221
}
222222

@@ -229,7 +229,7 @@ static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand,
229229
from = b.create<arith::SIToFPOp>(toFpTy, from);
230230
}
231231
Value zero = b.create<mlir::arith::ConstantFloatOp>(
232-
mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
232+
toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
233233
return b.create<complex::CreateOp>(targetType, from, zero);
234234
}
235235

mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -820,13 +820,13 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
820820
const float initialOvershardingFactor = 8.0f;
821821

822822
Value scalingFactor = b.create<arith::ConstantFloatOp>(
823-
llvm::APFloat(initialOvershardingFactor), b.getF32Type());
823+
b.getF32Type(), llvm::APFloat(initialOvershardingFactor));
824824
for (const std::pair<int, float> &p : overshardingBrackets) {
825825
Value bracketBegin = b.create<arith::ConstantIndexOp>(p.first);
826826
Value inBracket = b.create<arith::CmpIOp>(
827827
arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin);
828828
Value bracketScalingFactor = b.create<arith::ConstantFloatOp>(
829-
llvm::APFloat(p.second), b.getF32Type());
829+
b.getF32Type(), llvm::APFloat(p.second));
830830
scalingFactor = b.create<arith::SelectOp>(inBracket, bracketScalingFactor,
831831
scalingFactor);
832832
}

mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,11 @@ struct GpuAllReduceRewriter {
8383

8484
// Compute lane id (invocation id withing the subgroup).
8585
Value subgroupMask =
86-
create<arith::ConstantIntOp>(kSubgroupSize - 1, int32Type);
86+
create<arith::ConstantIntOp>(int32Type, kSubgroupSize - 1);
8787
Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask);
8888
Value isFirstLane =
8989
create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId,
90-
create<arith::ConstantIntOp>(0, int32Type));
90+
create<arith::ConstantIntOp>(int32Type, 0));
9191

9292
Value numThreadsWithSmallerSubgroupId =
9393
create<arith::SubIOp>(invocationIdx, laneId);
@@ -282,7 +282,7 @@ struct GpuAllReduceRewriter {
282282
/// The first lane returns the result, all others return values are undefined.
283283
Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
284284
AccumulatorFactory &accumFactory) {
285-
Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
285+
Value subgroupSize = create<arith::ConstantIntOp>(int32Type, kSubgroupSize);
286286
Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
287287
activeWidth, subgroupSize);
288288
std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
@@ -296,7 +296,7 @@ struct GpuAllReduceRewriter {
296296
// lane is within the active range. The accumulated value is available
297297
// in the first lane.
298298
for (int i = 1; i < kSubgroupSize; i <<= 1) {
299-
Value offset = create<arith::ConstantIntOp>(i, int32Type);
299+
Value offset = create<arith::ConstantIntOp>(int32Type, i);
300300
auto shuffleOp = create<gpu::ShuffleOp>(
301301
shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR);
302302
// Skip the accumulation if the shuffle op read from a lane outside
@@ -318,7 +318,7 @@ struct GpuAllReduceRewriter {
318318
[&] {
319319
Value value = operand;
320320
for (int i = 1; i < kSubgroupSize; i <<= 1) {
321-
Value offset = create<arith::ConstantIntOp>(i, int32Type);
321+
Value offset = create<arith::ConstantIntOp>(int32Type, i);
322322
auto shuffleOp =
323323
create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize,
324324
gpu::ShuffleMode::XOR);
@@ -331,7 +331,7 @@ struct GpuAllReduceRewriter {
331331

332332
/// Returns value divided by the subgroup size (i.e. 32).
333333
Value getDivideBySubgroupSize(Value value) {
334-
Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
334+
Value subgroupSize = create<arith::ConstantIntOp>(int32Type, kSubgroupSize);
335335
return create<arith::DivSIOp>(int32Type, value, subgroupSize);
336336
}
337337

mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,13 @@ static Value getZero(OpBuilder &b, Location loc, Type elementType) {
133133
assert(elementType.isIntOrIndexOrFloat() &&
134134
"expected scalar type while computing zero value");
135135
if (isa<IntegerType>(elementType))
136-
return b.create<arith::ConstantIntOp>(loc, 0, elementType);
136+
return b.create<arith::ConstantIntOp>(loc, elementType, 0);
137137
if (elementType.isIndex())
138138
return b.create<arith::ConstantIndexOp>(loc, 0);
139139
// Assume float.
140140
auto floatType = cast<FloatType>(elementType);
141141
return b.create<arith::ConstantFloatOp>(
142-
loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
142+
loc, floatType, APFloat::getZero(floatType.getFloatSemantics()));
143143
}
144144

145145
GenericOp

mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,9 @@ Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
315315
auto inputType = input.getType();
316316
auto storageType = quantizedType.getStorageType();
317317
auto storageMinScalar = builder.create<arith::ConstantIntOp>(
318-
loc, quantizedType.getStorageTypeMin(), storageType);
318+
loc, storageType, quantizedType.getStorageTypeMin());
319319
auto storageMaxScalar = builder.create<arith::ConstantIntOp>(
320-
loc, quantizedType.getStorageTypeMax(), storageType);
320+
loc, storageType, quantizedType.getStorageTypeMax());
321321
auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
322322
inputType, inputShape);
323323
auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,

mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes,
141141
b.setInsertionPointToStart(innerLoop.getBody());
142142
// Insert in-bound check
143143
Value inbound =
144-
b.create<arith::ConstantIntOp>(op.getLoc(), 1, b.getIntegerType(1));
144+
b.create<arith::ConstantIntOp>(op.getLoc(), b.getIntegerType(1), 1);
145145
for (auto [outerUpperBound, outerIV, innerIV, innerStep] :
146146
llvm::zip(outerLoop.getUpperBound(), outerLoop.getInductionVars(),
147147
innerLoop.getInductionVars(), innerLoop.getStep())) {

mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
240240
if (isa<IndexType>(step.getType())) {
241241
one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
242242
} else {
243-
one = rewriter.create<arith::ConstantIntOp>(loc, 1, step.getType());
243+
one = rewriter.create<arith::ConstantIntOp>(loc, step.getType(), 1);
244244
}
245245

246246
Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);

0 commit comments

Comments
 (0)