Skip to content

Commit 3852cc2

Browse files
committed
[mlir][emitc] Add an option to cast array type to ptr type
1 parent 564b9b7 commit 3852cc2

File tree

4 files changed

+23
-10
lines changed

4 files changed

+23
-10
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {
266266

267267
def EmitC_CastOp : EmitC_Op<"cast",
268268
[CExpression,
269-
DeclareOpInterfaceMethods<CastOpInterface>,
270-
SameOperandsAndResultShape]> {
269+
DeclareOpInterfaceMethods<CastOpInterface>]> {
271270
let summary = "Cast operation";
272271
let description = [{
273272
The `emitc.cast` operation performs an explicit type conversion and is emitted

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -247,11 +247,12 @@ LogicalResult emitc::AssignOp::verify() {
247247
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
248248
Type input = inputs.front(), output = outputs.front();
249249

250-
return (
251-
(emitc::isIntegerIndexOrOpaqueType(input) ||
252-
emitc::isSupportedFloatType(input) || isa<emitc::PointerType>(input)) &&
253-
(emitc::isIntegerIndexOrOpaqueType(output) ||
254-
emitc::isSupportedFloatType(output) || isa<emitc::PointerType>(output)));
250+
return ((emitc::isIntegerIndexOrOpaqueType(input) ||
251+
emitc::isSupportedFloatType(input) ||
252+
isa<emitc::PointerType>(input) || isa<emitc::ArrayType>(input)) &&
253+
(emitc::isIntegerIndexOrOpaqueType(output) ||
254+
emitc::isSupportedFloatType(output) ||
255+
isa<emitc::PointerType>(output)));
255256
}
256257

257258
//===----------------------------------------------------------------------===//

mlir/test/Dialect/EmitC/invalid_ops.mlir

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,17 @@ func.func @cast_tensor(%arg : tensor<f32>) {
130130

131131
// -----
132132

133-
func.func @cast_array(%arg : !emitc.array<4xf32>) {
134-
// expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<4xf32>' and result type '!emitc.array<4xf32>' are cast incompatible}}
135-
%1 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.array<4xf32>
133+
func.func @cast_to_array(%arg : f32) {
134+
// expected-error @+1 {{'emitc.cast' op operand type 'f32' and result type '!emitc.array<4xf32>' are cast incompatible}}
135+
%1 = emitc.cast %arg: f32 to !emitc.array<4xf32>
136+
return
137+
}
138+
139+
// -----
140+
141+
func.func @cast_pointer_to_array(%arg : !emitc.ptr<i32>) {
142+
// expected-error @+1 {{'emitc.cast' op operand type '!emitc.ptr<i32>' and result type '!emitc.array<3xi32>' are cast incompatible}}
143+
%1 = emitc.cast %arg: !emitc.ptr<i32> to !emitc.array<3xi32>
136144
return
137145
}
138146

mlir/test/Dialect/EmitC/ops.mlir

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ func.func @cast(%arg0: i32) {
3939
return
4040
}
4141

42+
func.func @cast_array_to_pointer(%arg0: !emitc.array<3xi32>) {
43+
%1 = emitc.cast %arg0: !emitc.array<3xi32> to !emitc.ptr<i32>
44+
return
45+
}
46+
4247
func.func @c() {
4348
%1 = "emitc.constant"(){value = 42 : i32} : () -> i32
4449
%2 = "emitc.constant"(){value = 42 : index} : () -> !emitc.size_t

0 commit comments

Comments
 (0)