Skip to content

Commit 6d085d7

Browse files
author
Simon Camphausen
committed
Add lvalue type
1 parent 54b20cb commit 6d085d7

File tree

24 files changed

+672
-325
lines changed

24 files changed

+672
-325
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@ def EmitC_ApplyOp : EmitC_Op<"apply", [CExpression]> {
9797
}];
9898
let arguments = (ins
9999
Arg<StrAttr, "the operator to apply">:$applicableOperator,
100-
EmitCType:$operand
100+
AnyTypeOf<[EmitCType, EmitC_LValueType]>:$operand
101101
);
102-
let results = (outs EmitCType:$result);
102+
let results = (outs AnyTypeOf<[EmitCType, EmitC_LValueType]>:$result);
103103
let assemblyFormat = [{
104104
$applicableOperator `(` $operand `)` attr-dict `:` functional-type($operand, results)
105105
}];
@@ -835,6 +835,20 @@ def EmitC_LogicalOrOp : EmitC_BinaryOp<"logical_or", [CExpression]> {
835835
let assemblyFormat = "operands attr-dict `:` type(operands)";
836836
}
837837

838+
def EmitC_LValueLoadOp : EmitC_Op<"lvalue_load", [
839+
TypesMatchWith<"result type matches value type of 'operand'",
840+
"operand", "result",
841+
"::llvm::cast<LValueType>($_self).getValue()">
842+
]> {
843+
let summary = "load an lvalue by assigning it to a local variable";
844+
let description = [{}];
845+
846+
let arguments = (ins EmitC_LValueType:$operand);
847+
let results = (outs AnyType:$result);
848+
849+
let assemblyFormat = "$operand attr-dict `:` type($operand)";
850+
}
851+
838852
def EmitC_MulOp : EmitC_BinaryOp<"mul", [CExpression]> {
839853
let summary = "Multiplication operation";
840854
let description = [{
@@ -1009,7 +1023,7 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
10091023
}];
10101024

10111025
let arguments = (ins EmitC_OpaqueOrTypedAttr:$value);
1012-
let results = (outs EmitCType);
1026+
let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>);
10131027

10141028
let hasVerifier = 1;
10151029
}
@@ -1079,7 +1093,7 @@ def EmitC_GetGlobalOp : EmitC_Op<"get_global",
10791093
}];
10801094

10811095
let arguments = (ins FlatSymbolRefAttr:$name);
1082-
let results = (outs EmitCType:$result);
1096+
let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>:$result);
10831097
let assemblyFormat = "$name `:` type($result) attr-dict";
10841098
}
10851099

@@ -1137,7 +1151,7 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
11371151
```
11381152
}];
11391153

1140-
let arguments = (ins EmitCType:$var, EmitCType:$value);
1154+
let arguments = (ins EmitC_LValueType:$var, EmitCType:$value);
11411155
let results = (outs);
11421156

11431157
let hasVerifier = 1;
@@ -1243,15 +1257,26 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
12431257
EmitC_PointerType]>,
12441258
"the value to subscript">:$value,
12451259
Variadic<EmitCType>:$indices);
1246-
let results = (outs EmitCType:$result);
1260+
let results = (outs EmitC_LValueType:$result);
12471261

12481262
let builders = [
12491263
OpBuilder<(ins "TypedValue<ArrayType>":$array, "ValueRange":$indices), [{
1250-
build($_builder, $_state, array.getType().getElementType(), array, indices);
1264+
build(
1265+
$_builder,
1266+
$_state,
1267+
emitc::LValueType::get(array.getType().getElementType()),
1268+
array,
1269+
indices
1270+
);
12511271
}]>,
12521272
OpBuilder<(ins "TypedValue<PointerType>":$pointer, "Value":$index), [{
1253-
build($_builder, $_state, pointer.getType().getPointee(), pointer,
1254-
ValueRange{index});
1273+
build(
1274+
$_builder,
1275+
$_state,
1276+
emitc::LValueType::get(pointer.getType().getPointee()),
1277+
pointer,
1278+
ValueRange{index}
1279+
);
12551280
}]>
12561281
];
12571282

mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,23 @@ def EmitC_ArrayType : EmitC_Type<"Array", "array", [ShapedTypeInterface]> {
8383
let hasCustomAssemblyFormat = 1;
8484
}
8585

86+
def EmitC_LValueType : EmitC_Type<"LValue", "lvalue"> {
87+
let summary = "EmitC lvalue type";
88+
89+
let description = [{
90+
Values of this type can be assigned to and their address can be taken.
91+
}];
92+
93+
let parameters = (ins "Type":$value);
94+
let builders = [
95+
TypeBuilderWithInferredContext<(ins "Type":$value), [{
96+
return $_get(value.getContext(), value);
97+
}]>
98+
];
99+
let assemblyFormat = "`<` qualified($value) `>`";
100+
let genVerifyDecl = 1;
101+
}
102+
86103
def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> {
87104
let summary = "EmitC opaque type";
88105

@@ -128,6 +145,7 @@ def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> {
128145
}]>
129146
];
130147
let assemblyFormat = "`<` qualified($pointee) `>`";
148+
let genVerifyDecl = 1;
131149
}
132150

133151
#endif // MLIR_DIALECT_EMITC_IR_EMITCTYPES

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,7 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
137137
auto subscript = rewriter.create<emitc::SubscriptOp>(
138138
op.getLoc(), arrayValue, operands.getIndices());
139139

140-
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
141-
auto var =
142-
rewriter.create<emitc::VariableOp>(op.getLoc(), resultTy, noInit);
143-
144-
rewriter.create<emitc::AssignOp>(op.getLoc(), var, subscript);
145-
rewriter.replaceOp(op, var);
140+
rewriter.replaceOpWithNewOp<emitc::LValueLoadOp>(op, resultTy, subscript);
146141
return success();
147142
}
148143
};

mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,10 @@ static SmallVector<Value> createVariablesForResults(T op,
6363

6464
for (OpResult result : op.getResults()) {
6565
Type resultType = result.getType();
66+
Type varType = emitc::LValueType::get(resultType);
6667
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
6768
emitc::VariableOp var =
68-
rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
69+
rewriter.create<emitc::VariableOp>(loc, varType, noInit);
6970
resultVariables.push_back(var);
7071
}
7172

@@ -100,8 +101,6 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
100101

101102
// Create an emitc::variable op for each result. These variables will be
102103
// assigned to by emitc::assign ops within the loop body.
103-
SmallVector<Value> resultVariables =
104-
createVariablesForResults(forOp, rewriter);
105104
SmallVector<Value> iterArgsVariables =
106105
createVariablesForResults(forOp, rewriter);
107106

@@ -115,18 +114,36 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
115114
// Erase the auto-generated terminator for the lowered for op.
116115
rewriter.eraseOp(loweredBody->getTerminator());
117116

117+
SmallVector<Value> iterArgsValues;
118+
{
119+
PatternRewriter::InsertionGuard guard(rewriter);
120+
rewriter.setInsertionPointToEnd(loweredBody);
121+
122+
for (auto &arg : iterArgsVariables) {
123+
Type type = cast<emitc::LValueType>(arg.getType()).getValue();
124+
iterArgsValues.push_back(
125+
rewriter.create<emitc::LValueLoadOp>(loc, type, arg));
126+
}
127+
}
128+
118129
SmallVector<Value> replacingValues;
119130
replacingValues.push_back(loweredFor.getInductionVar());
120-
replacingValues.append(iterArgsVariables.begin(), iterArgsVariables.end());
131+
replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
121132

122133
rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
123134
lowerYield(iterArgsVariables, rewriter,
124135
cast<scf::YieldOp>(loweredBody->getTerminator()));
125136

126137
// Copy iterArgs into results after the for loop.
127-
assignValues(iterArgsVariables, resultVariables, rewriter, loc);
138+
SmallVector<Value> resultValues;
128139

129-
rewriter.replaceOp(forOp, resultVariables);
140+
for (auto &arg : iterArgsVariables) {
141+
Type type = cast<emitc::LValueType>(arg.getType()).getValue();
142+
resultValues.push_back(
143+
rewriter.create<emitc::LValueLoadOp>(loc, type, arg));
144+
}
145+
146+
rewriter.replaceOp(forOp, resultValues);
130147
return success();
131148
}
132149

@@ -178,7 +195,14 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
178195
lowerRegion(elseRegion, loweredElseRegion);
179196
}
180197

181-
rewriter.replaceOp(ifOp, resultVariables);
198+
rewriter.setInsertionPointAfter(ifOp);
199+
SmallVector<Value> results;
200+
for (auto &resVar : resultVariables) {
201+
Type type = cast<emitc::LValueType>(resVar.getType()).getValue();
202+
results.push_back(rewriter.create<emitc::LValueLoadOp>(loc, type, resVar));
203+
}
204+
205+
rewriter.replaceOp(ifOp, results);
182206
return success();
183207
}
184208

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 83 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) {
6161
bool mlir::emitc::isSupportedEmitCType(Type type) {
6262
if (llvm::isa<emitc::OpaqueType>(type))
6363
return true;
64+
if (auto lType = llvm::dyn_cast<emitc::LValueType>(type))
65+
// lvalue types are only allowed in a few places.
66+
return false;
6467
if (auto ptrType = llvm::dyn_cast<emitc::PointerType>(type))
6568
return isSupportedEmitCType(ptrType.getPointee());
6669
if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(type)) {
@@ -140,6 +143,8 @@ static LogicalResult verifyInitializationAttribute(Operation *op,
140143
<< "string attributes are not supported, use #emitc.opaque instead";
141144

142145
Type resultType = op->getResult(0).getType();
146+
if (auto lType = dyn_cast<LValueType>(resultType))
147+
resultType = lType.getValue();
143148
Type attrType = cast<TypedAttr>(value).getType();
144149

145150
if (resultType != attrType)
@@ -188,9 +193,19 @@ LogicalResult ApplyOp::verify() {
188193
if (applicableOperatorStr != "&" && applicableOperatorStr != "*")
189194
return emitOpError("applicable operator is illegal");
190195

191-
Operation *op = getOperand().getDefiningOp();
192-
if (op && dyn_cast<ConstantOp>(op))
193-
return emitOpError("cannot apply to constant");
196+
Type operandType = getOperand().getType();
197+
Type resultType = getResult().getType();
198+
if (applicableOperatorStr == "&") {
199+
if (!llvm::isa<emitc::LValueType>(operandType))
200+
return emitOpError("operand type must be an lvalue when applying `&`");
201+
if (!llvm::isa<emitc::PointerType>(resultType))
202+
return emitOpError("result type must be a pointer when applying `&`");
203+
} else {
204+
if (!llvm::isa<emitc::PointerType>(operandType))
205+
return emitOpError("operand type must be a pointer when applying `*`");
206+
if (!llvm::isa<emitc::LValueType>(resultType))
207+
return emitOpError("result type must be an lvalue when applying `*`");
208+
}
194209

195210
return success();
196211
}
@@ -202,20 +217,18 @@ LogicalResult ApplyOp::verify() {
202217
/// The assign op requires that the assigned value's type matches the
203218
/// assigned-to variable type.
204219
LogicalResult emitc::AssignOp::verify() {
205-
Value variable = getVar();
206-
Operation *variableDef = variable.getDefiningOp();
207-
if (!variableDef ||
208-
!llvm::isa<emitc::VariableOp, emitc::SubscriptOp>(variableDef))
209-
return emitOpError() << "requires first operand (" << variable
210-
<< ") to be a Variable or subscript";
211-
212-
Value value = getValue();
213-
if (variable.getType() != value.getType())
214-
return emitOpError() << "requires value's type (" << value.getType()
215-
<< ") to match variable's type (" << variable.getType()
216-
<< ")";
217-
if (isa<ArrayType>(variable.getType()))
218-
return emitOpError() << "cannot assign to array type";
220+
TypedValue<emitc::LValueType> variable = getVar();
221+
222+
if (!variable.getDefiningOp())
223+
return emitOpError() << "cannot assign to block argument";
224+
225+
Type valueType = getValue().getType();
226+
Type variableType = variable.getType().getValue();
227+
if (variableType != valueType)
228+
return emitOpError() << "requires value's type (" << valueType
229+
<< ") to match variable's type (" << variableType
230+
<< ")\n variable: " << variable
231+
<< "\n value: " << getValue() << "\n";
219232
return success();
220233
}
221234

@@ -842,9 +855,10 @@ LogicalResult emitc::SubscriptOp::verify() {
842855
}
843856
// Check element type.
844857
Type elementType = arrayType.getElementType();
845-
if (elementType != getType()) {
858+
Type resultType = getType().getValue();
859+
if (elementType != resultType) {
846860
return emitOpError() << "on array operand requires element type ("
847-
<< elementType << ") and result type (" << getType()
861+
<< elementType << ") and result type (" << resultType
848862
<< ") to match";
849863
}
850864
return success();
@@ -868,9 +882,10 @@ LogicalResult emitc::SubscriptOp::verify() {
868882
}
869883
// Check pointee type.
870884
Type pointeeType = pointerType.getPointee();
871-
if (pointeeType != getType()) {
885+
Type resultType = getType().getValue();
886+
if (pointeeType != resultType) {
872887
return emitOpError() << "on pointer operand requires pointee type ("
873-
<< pointeeType << ") and result type (" << getType()
888+
<< pointeeType << ") and result type (" << resultType
874889
<< ") to match";
875890
}
876891
return success();
@@ -964,6 +979,25 @@ emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
964979
return emitc::ArrayType::get(*shape, elementType);
965980
}
966981

982+
//===----------------------------------------------------------------------===//
983+
// LValueType
984+
//===----------------------------------------------------------------------===//
985+
986+
LogicalResult mlir::emitc::LValueType::verify(
987+
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
988+
mlir::Type value) {
989+
// Check that the wrapped type is valid. This especially forbids nested lvalue
990+
// types.
991+
if (!isSupportedEmitCType(value))
992+
return emitError()
993+
<< "!emitc.lvalue must wrap supported emitc type, but got " << value;
994+
995+
if (llvm::isa<emitc::ArrayType>(value))
996+
return emitError() << "!emitc.lvalue cannot wrap !emitc.array type";
997+
998+
return success();
999+
}
1000+
9671001
//===----------------------------------------------------------------------===//
9681002
// OpaqueType
9691003
//===----------------------------------------------------------------------===//
@@ -981,6 +1015,18 @@ LogicalResult mlir::emitc::OpaqueType::verify(
9811015
return success();
9821016
}
9831017

1018+
//===----------------------------------------------------------------------===//
1019+
// PointerType
1020+
//===----------------------------------------------------------------------===//
1021+
1022+
LogicalResult mlir::emitc::PointerType::verify(
1023+
llvm::function_ref<mlir::InFlightDiagnostic()> emitError, Type value) {
1024+
if (llvm::isa<emitc::LValueType>(value))
1025+
return emitError() << "pointers to lvalues are not allowed";
1026+
1027+
return success();
1028+
}
1029+
9841030
//===----------------------------------------------------------------------===//
9851031
// GlobalOp
9861032
//===----------------------------------------------------------------------===//
@@ -1078,9 +1124,22 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
10781124
<< getName() << "' does not reference a valid emitc.global";
10791125

10801126
Type resultType = getResult().getType();
1081-
if (global.getType() != resultType)
1082-
return emitOpError("result type ")
1083-
<< resultType << " does not match type " << global.getType()
1127+
Type globalType = global.getType();
1128+
1129+
// global has array type
1130+
if (llvm::isa<ArrayType>(globalType)) {
1131+
if (globalType != resultType)
1132+
return emitOpError("on array type expects result type ")
1133+
<< resultType << " to match type " << globalType
1134+
<< " of the global @" << getName();
1135+
return success();
1136+
}
1137+
1138+
// global has non-array type
1139+
auto lvalueType = dyn_cast<LValueType>(resultType);
1140+
if (!lvalueType || lvalueType.getValue() != globalType)
1141+
return emitOpError("on non-array type expects result inner type ")
1142+
<< lvalueType.getValue() << " to match type " << globalType
10841143
<< " of the global @" << getName();
10851144
return success();
10861145
}

mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ struct FormExpressionsPass
3838
auto matchFun = [&](Operation *op) {
3939
if (op->hasTrait<OpTrait::emitc::CExpression>() &&
4040
!op->getParentOfType<emitc::ExpressionOp>() &&
41-
op->getNumResults() == 1)
41+
op->getNumResults() == 1 &&
42+
isSupportedEmitCType(op->getResult(0).getType()))
4243
createExpression(op, builder);
4344
};
4445
rootOp->walk(matchFun);

0 commit comments

Comments
 (0)