Skip to content

Commit 1aadeb3

Browse files
author
Simon Camphausen
committed
[WIP][mlir][EmitC] Model lvalues as a type in EmitC
1 parent de18f5e commit 1aadeb3

File tree

17 files changed

+390
-161
lines changed

17 files changed

+390
-161
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,22 @@ def EmitC_LogicalOrOp : EmitC_BinaryOp<"logical_or", [CExpression]> {
835835
let assemblyFormat = "operands attr-dict `:` type(operands)";
836836
}
837837

838+
def EmitC_LValueToRValueOp : EmitC_Op<"lvalue_to_rvalue", [
839+
TypesMatchWith<"result type matches value type of 'operand'",
840+
"operand", "result",
841+
"::llvm::cast<LValueType>($_self).getValue()">
842+
]> {
843+
let summary = "lvalue to rvalue conversion operation";
844+
let description = [{}];
845+
846+
let arguments = (ins EmitC_LValueType:$operand);
847+
let results = (outs AnyType:$result);
848+
849+
let assemblyFormat = "$operand attr-dict `:` type($operand)";
850+
851+
let hasVerifier = 1;
852+
}
853+
838854
def EmitC_MulOp : EmitC_BinaryOp<"mul", [CExpression]> {
839855
let summary = "Multiplication operation";
840856
let description = [{
@@ -1009,7 +1025,7 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
10091025
}];
10101026

10111027
let arguments = (ins EmitC_OpaqueOrTypedAttr:$value);
1012-
let results = (outs EmitCType);
1028+
let results = (outs EmitC_LValueType);
10131029

10141030
let hasVerifier = 1;
10151031
}
@@ -1137,7 +1153,7 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
11371153
```
11381154
}];
11391155

1140-
let arguments = (ins EmitCType:$var, EmitCType:$value);
1156+
let arguments = (ins EmitC_LValueType:$var, EmitCType:$value);
11411157
let results = (outs);
11421158

11431159
let hasVerifier = 1;
@@ -1243,7 +1259,7 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
12431259
EmitC_PointerType]>,
12441260
"the value to subscript">:$value,
12451261
Variadic<EmitCType>:$indices);
1246-
let results = (outs EmitCType:$result);
1262+
let results = (outs EmitC_LValueType:$result);
12471263

12481264
let builders = [
12491265
OpBuilder<(ins "TypedValue<ArrayType>":$array, "ValueRange":$indices), [{

mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,23 @@ def EmitC_ArrayType : EmitC_Type<"Array", "array", [ShapedTypeInterface]> {
8383
let hasCustomAssemblyFormat = 1;
8484
}
8585

86+
def EmitC_LValueType : EmitC_Type<"LValue", "lvalue"> {
87+
let summary = "EmitC lvalue type";
88+
89+
let description = [{
90+
Values of this type can be assigned to and their address can be taken.
91+
}];
92+
93+
let parameters = (ins "Type":$value);
94+
let builders = [
95+
TypeBuilderWithInferredContext<(ins "Type":$value), [{
96+
return $_get(value.getContext(), value);
97+
}]>
98+
];
99+
let assemblyFormat = "`<` qualified($value) `>`";
100+
let genVerifyDecl = 1;
101+
}
102+
86103
def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> {
87104
let summary = "EmitC opaque type";
88105

mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp

Lines changed: 75 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,10 @@ static SmallVector<Value> createVariablesForResults(T op,
6363

6464
for (OpResult result : op.getResults()) {
6565
Type resultType = result.getType();
66+
Type varType = emitc::LValueType::get(resultType);
6667
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
6768
emitc::VariableOp var =
68-
rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
69+
rewriter.create<emitc::VariableOp>(loc, varType, noInit);
6970
resultVariables.push_back(var);
7071
}
7172

@@ -76,57 +77,98 @@ static SmallVector<Value> createVariablesForResults(T op,
7677
// the current insertion point of given rewriter.
7778
static void assignValues(ValueRange values, SmallVector<Value> &variables,
7879
PatternRewriter &rewriter, Location loc) {
79-
for (auto [value, var] : llvm::zip(values, variables))
80-
rewriter.create<emitc::AssignOp>(loc, var, value);
80+
for (auto [value, var] : llvm::zip(values, variables)) {
81+
assert(isa<emitc::LValueType>(var.getType()) &&
82+
"expected var to be an lvalue type");
83+
assert(!isa<emitc::LValueType>(value.getType()) &&
84+
"expected value to not be an lvalue type");
85+
auto assign = rewriter.create<emitc::AssignOp>(loc, var, value);
86+
87+
// TODO: Make sure this is safe, as this moves operations with memory
88+
// effects.
89+
if (auto op = dyn_cast_if_present<emitc::LValueToRValueOp>(
90+
value.getDefiningOp())) {
91+
rewriter.moveOpBefore(op, assign);
92+
}
93+
}
8194
}
8295

83-
static void lowerYield(SmallVector<Value> &resultVariables,
84-
PatternRewriter &rewriter, scf::YieldOp yield) {
96+
static void lowerYield(SmallVector<Value> &variables, PatternRewriter &rewriter,
97+
scf::YieldOp yield) {
8598
Location loc = yield.getLoc();
8699
ValueRange operands = yield.getOperands();
87100

88101
OpBuilder::InsertionGuard guard(rewriter);
89102
rewriter.setInsertionPoint(yield);
90103

91-
assignValues(operands, resultVariables, rewriter, loc);
104+
assignValues(operands, variables, rewriter, loc);
92105

93106
rewriter.create<emitc::YieldOp>(loc);
94107
rewriter.eraseOp(yield);
95108
}
96109

110+
static void replaceUsers(PatternRewriter &rewriter,
111+
SmallVector<Value> fromValues,
112+
SmallVector<Value> toValues) {
113+
OpBuilder::InsertionGuard guard(rewriter);
114+
for (auto [from, to] : llvm::zip(fromValues, toValues)) {
115+
assert(from.getType() == cast<emitc::LValueType>(to.getType()).getValue() &&
116+
"expected types to match");
117+
118+
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
119+
Operation *op = operand.getOwner();
120+
// Skip yield ops, as these get rewritten anyways.
121+
if (isa<scf::YieldOp>(op)) {
122+
continue;
123+
}
124+
Location loc = op->getLoc();
125+
126+
rewriter.setInsertionPoint(op);
127+
Value rValue =
128+
rewriter.create<emitc::LValueToRValueOp>(loc, from.getType(), to);
129+
operand.set(rValue);
130+
}
131+
}
132+
}
133+
97134
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
98135
PatternRewriter &rewriter) const {
99136
Location loc = forOp.getLoc();
100137

101-
// Create an emitc::variable op for each result. These variables will be
102-
// assigned to by emitc::assign ops within the loop body.
103-
SmallVector<Value> resultVariables =
104-
createVariablesForResults(forOp, rewriter);
105-
SmallVector<Value> iterArgsVariables =
106-
createVariablesForResults(forOp, rewriter);
138+
// Create an emitc::variable op for each result. These variables will be used
139+
// for the results of the operations as well as the iter_args. They are
140+
// assigned to by emitc::assign ops before the loop and at the end of the loop
141+
// body.
142+
SmallVector<Value> variables = createVariablesForResults(forOp, rewriter);
107143

108-
assignValues(forOp.getInits(), iterArgsVariables, rewriter, loc);
144+
// Assign initial values to the iter arg variables.
145+
assignValues(forOp.getInits(), variables, rewriter, loc);
109146

110-
emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
111-
loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
147+
// Replace users of the iter args with variables.
148+
SmallVector<Value> iterArgs;
149+
for (BlockArgument arg : forOp.getRegionIterArgs()) {
150+
iterArgs.push_back(arg);
151+
}
112152

113-
Block *loweredBody = loweredFor.getBody();
153+
replaceUsers(rewriter, iterArgs, variables);
114154

115-
// Erase the auto-generated terminator for the lowered for op.
116-
rewriter.eraseOp(loweredBody->getTerminator());
155+
emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
156+
loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
157+
rewriter.eraseBlock(loweredFor.getBody());
117158

118-
SmallVector<Value> replacingValues;
119-
replacingValues.push_back(loweredFor.getInductionVar());
120-
replacingValues.append(iterArgsVariables.begin(), iterArgsVariables.end());
159+
rewriter.inlineRegionBefore(forOp.getRegion(), loweredFor.getRegion(),
160+
loweredFor.getRegion().end());
161+
Operation *terminator = loweredFor.getRegion().back().getTerminator();
162+
lowerYield(variables, rewriter, cast<scf::YieldOp>(terminator));
121163

122-
rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
123-
lowerYield(iterArgsVariables, rewriter,
124-
cast<scf::YieldOp>(loweredBody->getTerminator()));
164+
// Erase block arguments for iter_args.
165+
loweredFor.getRegion().back().eraseArguments(1, variables.size());
125166

126-
// Copy iterArgs into results after the for loop.
127-
assignValues(iterArgsVariables, resultVariables, rewriter, loc);
167+
// Replace all users of the results with lazily created lvalue-to-rvalue
168+
// ops.
169+
replaceUsers(rewriter, forOp.getResults(), variables);
128170

129-
rewriter.replaceOp(forOp, resultVariables);
171+
rewriter.eraseOp(forOp);
130172
return success();
131173
}
132174

@@ -167,7 +209,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
167209

168210
bool hasElseBlock = !elseRegion.empty();
169211

170-
auto loweredIf =
212+
emitc::IfOp loweredIf =
171213
rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false);
172214

173215
Region &loweredThenRegion = loweredIf.getThenRegion();
@@ -178,7 +220,11 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
178220
lowerRegion(elseRegion, loweredElseRegion);
179221
}
180222

181-
rewriter.replaceOp(ifOp, resultVariables);
223+
// Replace all users of the results with lazily created lvalue-to-rvalue
224+
// ops.
225+
replaceUsers(rewriter, ifOp.getResults(), resultVariables);
226+
227+
rewriter.eraseOp(ifOp);
182228
return success();
183229
}
184230

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ static LogicalResult verifyInitializationAttribute(Operation *op,
140140
<< "string attributes are not supported, use #emitc.opaque instead";
141141

142142
Type resultType = op->getResult(0).getType();
143+
if (auto lType = dyn_cast<LValueType>(resultType))
144+
resultType = lType.getValue();
143145
Type attrType = cast<TypedAttr>(value).getType();
144146

145147
if (resultType != attrType)
@@ -203,18 +205,21 @@ LogicalResult ApplyOp::verify() {
203205
/// assigned-to variable type.
204206
LogicalResult emitc::AssignOp::verify() {
205207
Value variable = getVar();
206-
Operation *variableDef = variable.getDefiningOp();
207-
if (!variableDef ||
208-
!llvm::isa<emitc::VariableOp, emitc::SubscriptOp>(variableDef))
208+
209+
if (!variable.getDefiningOp())
210+
return emitOpError() << "cannot assign to block argument";
211+
if (!llvm::isa<emitc::LValueType>(variable.getType()))
209212
return emitOpError() << "requires first operand (" << variable
210-
<< ") to be a Variable or subscript";
211-
212-
Value value = getValue();
213-
if (variable.getType() != value.getType())
214-
return emitOpError() << "requires value's type (" << value.getType()
215-
<< ") to match variable's type (" << variable.getType()
216-
<< ")";
217-
if (isa<ArrayType>(variable.getType()))
213+
<< ") to be an lvalue";
214+
215+
Type valueType = getValue().getType();
216+
Type variableType = variable.getType().cast<emitc::LValueType>().getValue();
217+
if (variableType != valueType)
218+
return emitOpError() << "requires value's type (" << valueType
219+
<< ") to match variable's type (" << variableType
220+
<< ")\n variable: " << variable
221+
<< "\n value: " << getValue() << "\n";
222+
if (isa<ArrayType>(variableType))
218223
return emitOpError() << "cannot assign to array type";
219224
return success();
220225
}
@@ -769,6 +774,47 @@ LogicalResult emitc::LiteralOp::verify() {
769774
return emitOpError() << "value must not be empty";
770775
return success();
771776
}
777+
778+
//===----------------------------------------------------------------------===//
779+
// LValueToRValueOp
780+
//===----------------------------------------------------------------------===//
781+
782+
LogicalResult emitc::LValueToRValueOp::verify() {
783+
Type operandType = getOperand().getType();
784+
Type resultType = getResult().getType();
785+
if (!llvm::isa<emitc::LValueType>(operandType))
786+
return emitOpError("operand must be a lvalue");
787+
if (llvm::cast<emitc::LValueType>(operandType).getValue() != resultType)
788+
return emitOpError("types must match");
789+
790+
Value result = getResult();
791+
if (!result.hasOneUse()) {
792+
int numUses = std::distance(result.use_begin(), result.use_end());
793+
return emitOpError("must have exactly one use, but got ") << numUses;
794+
}
795+
796+
Block *block = result.getParentBlock();
797+
798+
Operation *user = *result.getUsers().begin();
799+
Block *userBlock = user->getBlock();
800+
801+
if (block != userBlock) {
802+
return emitOpError("user must be in the same block");
803+
}
804+
805+
// for (auto it = block.begin(), e = std::prev(block.end()); it != e; it++) {
806+
// if (*it == this)
807+
// }
808+
809+
// TODO: To model this op correctly as a memory read of the lvalue, we
810+
// should additionally ensure that the single use of the op follows immediatly
811+
// on this definition. Alternativly we could alter emitc ops to implicitly
812+
// support lvalues. This would make it harder to do partial conversions and
813+
// mix dialects though.
814+
815+
return success();
816+
}
817+
772818
//===----------------------------------------------------------------------===//
773819
// SubOp
774820
//===----------------------------------------------------------------------===//
@@ -964,6 +1010,20 @@ emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
9641010
return emitc::ArrayType::get(*shape, elementType);
9651011
}
9661012

1013+
//===----------------------------------------------------------------------===//
1014+
// LValueType
1015+
//===----------------------------------------------------------------------===//
1016+
1017+
LogicalResult mlir::emitc::LValueType::verify(
1018+
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1019+
mlir::Type value) {
1020+
if (llvm::isa<emitc::LValueType>(value)) {
1021+
return emitError()
1022+
<< "!emitc.lvalue type cannot be nested inside another type";
1023+
}
1024+
return success();
1025+
}
1026+
9671027
//===----------------------------------------------------------------------===//
9681028
// OpaqueType
9691029
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)