Skip to content

Commit 2199c2f

Browse files
committed
[mlir][emitc] Strengthen type and rank checks
1 parent 3852cc2 commit 2199c2f

File tree

2 files changed

+42
-11
lines changed

2 files changed

+42
-11
lines changed

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -247,12 +247,19 @@ LogicalResult emitc::AssignOp::verify() {
247247
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
248248
Type input = inputs.front(), output = outputs.front();
249249

250-
return ((emitc::isIntegerIndexOrOpaqueType(input) ||
251-
emitc::isSupportedFloatType(input) ||
252-
isa<emitc::PointerType>(input) || isa<emitc::ArrayType>(input)) &&
253-
(emitc::isIntegerIndexOrOpaqueType(output) ||
254-
emitc::isSupportedFloatType(output) ||
255-
isa<emitc::PointerType>(output)));
250+
if (auto arrayType = dyn_cast<emitc::ArrayType>(input)) {
251+
if (auto pointerType = dyn_cast<emitc::PointerType>(output)) {
252+
return (arrayType.getElementType() == pointerType.getPointee()) &&
253+
arrayType.getShape().size() == 1 && arrayType.getShape()[0] >= 1;
254+
}
255+
return false;
256+
}
257+
258+
return (
259+
(emitc::isIntegerIndexOrOpaqueType(input) ||
260+
emitc::isSupportedFloatType(input) || isa<emitc::PointerType>(input)) &&
261+
(emitc::isIntegerIndexOrOpaqueType(output) ||
262+
emitc::isSupportedFloatType(output) || isa<emitc::PointerType>(output)));
256263
}
257264

258265
//===----------------------------------------------------------------------===//
@@ -700,9 +707,9 @@ void IfOp::print(OpAsmPrinter &p) {
700707

701708
/// Given the region at `index`, or the parent operation if `index` is None,
702709
/// return the successor regions. These are the regions that may be selected
703-
/// during the flow of control. `operands` is a set of optional attributes that
704-
/// correspond to a constant value for each operand, or null if that operand is
705-
/// not a constant.
710+
/// during the flow of control. `operands` is a set of optional attributes
711+
/// that correspond to a constant value for each operand, or null if that
712+
/// operand is not a constant.
706713
void IfOp::getSuccessorRegions(RegionBranchPoint point,
707714
SmallVectorImpl<RegionSuccessor> &regions) {
708715
// The `then` and the `else` region branch back to the parent operation.
@@ -1000,8 +1007,8 @@ emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
10001007
LogicalResult mlir::emitc::LValueType::verify(
10011008
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
10021009
mlir::Type value) {
1003-
// Check that the wrapped type is valid. This especially forbids nested lvalue
1004-
// types.
1010+
// Check that the wrapped type is valid. This especially forbids nested
1011+
// lvalue types.
10051012
if (!isSupportedEmitCType(value))
10061013
return emitError()
10071014
<< "!emitc.lvalue must wrap supported emitc type, but got " << value;

mlir/test/Dialect/EmitC/invalid_ops.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,30 @@ func.func @cast_to_array(%arg : f32) {
138138

139139
// -----
140140

141+
func.func @cast_multidimensional_array(%arg : !emitc.array<1x2xi32>) {
142+
// expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<1x2xi32>' and result type '!emitc.ptr<i32>' are cast incompatible}}
143+
%1 = emitc.cast %arg: !emitc.array<1x2xi32> to !emitc.ptr<i32>
144+
return
145+
}
146+
147+
// -----
148+
149+
func.func @cast_array_zero_rank(%arg : !emitc.array<0xi32>) {
150+
// expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<0xi32>' and result type '!emitc.ptr<i32>' are cast incompatible}}
151+
%1 = emitc.cast %arg: !emitc.array<0xi32> to !emitc.ptr<i32>
152+
return
153+
}
154+
155+
// -----
156+
157+
func.func @cast_array_to_pointer_types_mismatch(%arg : !emitc.array<3xi32>) {
158+
// expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<3xi32>' and result type '!emitc.ptr<f16>' are cast incompatible}}
159+
%1 = emitc.cast %arg: !emitc.array<3xi32> to !emitc.ptr<f16>
160+
return
161+
}
162+
163+
// -----
164+
141165
func.func @cast_pointer_to_array(%arg : !emitc.ptr<i32>) {
142166
// expected-error @+1 {{'emitc.cast' op operand type '!emitc.ptr<i32>' and result type '!emitc.array<3xi32>' are cast incompatible}}
143167
%1 = emitc.cast %arg: !emitc.ptr<i32> to !emitc.array<3xi32>

0 commit comments

Comments
 (0)