Skip to content

Commit 9d273c0

Browse files
committed
[mlir] Harden verifiers for DMA ops
DMA operation classes in the Standard dialect (`DmaStartOp` and `DmaWaitOp`) provide helper functions that make numerous assumptions about the number and order of operands, and about their types. However, these assumptions were not checked in the verifier, leading to assertion failures or crashes when helper functions were used on ill-formed ops. Some of the assuptions were checked in the custom parser (and thus could not check assumption violations in ops constructed programmatically, e.g., during rewrites) and others were not checked at all. Introduce the verifiers for all these assumptions and drop unnecessary checks in the parser that are now covered by the verifier. Addresses PR45560. Differential Revision: https://reviews.llvm.org/D79408
1 parent 0195b3a commit 9d273c0

File tree

3 files changed

+214
-37
lines changed

3 files changed

+214
-37
lines changed

mlir/include/mlir/Dialect/StandardOps/IR/Ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ class DmaWaitOp
286286
void print(OpAsmPrinter &p);
287287
LogicalResult fold(ArrayRef<Attribute> cstOperands,
288288
SmallVectorImpl<OpFoldResult> &results);
289+
LogicalResult verify();
289290
};
290291

291292
/// Prints dimension and symbol list.

mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Lines changed: 86 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,49 +1444,82 @@ ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
14441444
parser.resolveOperands(tagIndexInfos, indexType, result.operands))
14451445
return failure();
14461446

1447-
auto memrefType0 = types[0].dyn_cast<MemRefType>();
1448-
if (!memrefType0)
1449-
return parser.emitError(parser.getNameLoc(),
1450-
"expected source to be of memref type");
1451-
1452-
auto memrefType1 = types[1].dyn_cast<MemRefType>();
1453-
if (!memrefType1)
1454-
return parser.emitError(parser.getNameLoc(),
1455-
"expected destination to be of memref type");
1456-
1457-
auto memrefType2 = types[2].dyn_cast<MemRefType>();
1458-
if (!memrefType2)
1459-
return parser.emitError(parser.getNameLoc(),
1460-
"expected tag to be of memref type");
1461-
14621447
if (isStrided) {
14631448
if (parser.resolveOperands(strideInfo, indexType, result.operands))
14641449
return failure();
14651450
}
14661451

1467-
// Check that source/destination index list size matches associated rank.
1468-
if (static_cast<int64_t>(srcIndexInfos.size()) != memrefType0.getRank() ||
1469-
static_cast<int64_t>(dstIndexInfos.size()) != memrefType1.getRank())
1470-
return parser.emitError(parser.getNameLoc(),
1471-
"memref rank not equal to indices count");
1472-
if (static_cast<int64_t>(tagIndexInfos.size()) != memrefType2.getRank())
1473-
return parser.emitError(parser.getNameLoc(),
1474-
"tag memref rank not equal to indices count");
14751452

14761453
return success();
14771454
}
14781455

14791456
LogicalResult DmaStartOp::verify() {
1457+
unsigned numOperands = getNumOperands();
1458+
1459+
// Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1460+
// the number of elements.
1461+
if (numOperands < 4)
1462+
return emitOpError("expected at least 4 operands");
1463+
1464+
// Check types of operands. The order of these calls is important: the later
1465+
// calls rely on some type properties to compute the operand position.
1466+
// 1. Source memref.
1467+
if (!getSrcMemRef().getType().isa<MemRefType>())
1468+
return emitOpError("expected source to be of memref type");
1469+
if (numOperands < getSrcMemRefRank() + 4)
1470+
return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1471+
<< " operands";
1472+
if (!getSrcIndices().empty() &&
1473+
!llvm::all_of(getSrcIndices().getTypes(),
1474+
[](Type t) { return t.isIndex(); }))
1475+
return emitOpError("expected source indices to be of index type");
1476+
1477+
// 2. Destination memref.
1478+
if (!getDstMemRef().getType().isa<MemRefType>())
1479+
return emitOpError("expected destination to be of memref type");
1480+
unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1481+
if (numOperands < numExpectedOperands)
1482+
return emitOpError() << "expected at least " << numExpectedOperands
1483+
<< " operands";
1484+
if (!getDstIndices().empty() &&
1485+
!llvm::all_of(getDstIndices().getTypes(),
1486+
[](Type t) { return t.isIndex(); }))
1487+
return emitOpError("expected destination indices to be of index type");
1488+
1489+
// 3. Number of elements.
1490+
if (!getNumElements().getType().isIndex())
1491+
return emitOpError("expected num elements to be of index type");
1492+
1493+
// 4. Tag memref.
1494+
if (!getTagMemRef().getType().isa<MemRefType>())
1495+
return emitOpError("expected tag to be of memref type");
1496+
numExpectedOperands += getTagMemRefRank();
1497+
if (numOperands < numExpectedOperands)
1498+
return emitOpError() << "expected at least " << numExpectedOperands
1499+
<< " operands";
1500+
if (!getTagIndices().empty() &&
1501+
!llvm::all_of(getTagIndices().getTypes(),
1502+
[](Type t) { return t.isIndex(); }))
1503+
return emitOpError("expected tag indices to be of index type");
1504+
14801505
// DMAs from different memory spaces supported.
14811506
if (getSrcMemorySpace() == getDstMemorySpace())
14821507
return emitOpError("DMA should be between different memory spaces");
14831508

1484-
if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
1485-
getDstMemRefRank() + 3 + 1 &&
1486-
getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
1487-
getDstMemRefRank() + 3 + 1 + 2) {
1509+
// Optional stride-related operands must be either both present or both
1510+
// absent.
1511+
if (numOperands != numExpectedOperands &&
1512+
numOperands != numExpectedOperands + 2)
14881513
return emitOpError("incorrect number of operands");
1514+
1515+
// 5. Strides.
1516+
if (isStrided()) {
1517+
if (!getStride().getType().isIndex() ||
1518+
!getNumElementsPerStride().getType().isIndex())
1519+
return emitOpError(
1520+
"expected stride and num elements per stride to be of type index");
14891521
}
1522+
14901523
return success();
14911524
}
14921525

@@ -1536,15 +1569,6 @@ ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
15361569
parser.resolveOperand(numElementsInfo, indexType, result.operands))
15371570
return failure();
15381571

1539-
auto memrefType = type.dyn_cast<MemRefType>();
1540-
if (!memrefType)
1541-
return parser.emitError(parser.getNameLoc(),
1542-
"expected tag to be of memref type");
1543-
1544-
if (static_cast<int64_t>(tagIndexInfos.size()) != memrefType.getRank())
1545-
return parser.emitError(parser.getNameLoc(),
1546-
"tag memref rank not equal to indices count");
1547-
15481572
return success();
15491573
}
15501574

@@ -1554,6 +1578,32 @@ LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
15541578
return foldMemRefCast(*this);
15551579
}
15561580

1581+
LogicalResult DmaWaitOp::verify() {
1582+
// Mandatory non-variadic operands are tag and the number of elements.
1583+
if (getNumOperands() < 2)
1584+
return emitOpError() << "expected at least 2 operands";
1585+
1586+
// Check types of operands. The order of these calls is important: the later
1587+
// calls rely on some type properties to compute the operand position.
1588+
if (!getTagMemRef().getType().isa<MemRefType>())
1589+
return emitOpError() << "expected tag to be of memref type";
1590+
1591+
if (getNumOperands() != 2 + getTagMemRefRank())
1592+
return emitOpError() << "expected " << 2 + getTagMemRefRank()
1593+
<< " operands";
1594+
1595+
if (!getTagIndices().empty() &&
1596+
!llvm::all_of(getTagIndices().getTypes(),
1597+
[](Type t) { return t.isIndex(); }))
1598+
return emitOpError() << "expected tag indices to be of index type";
1599+
1600+
if (!getNumElements().getType().isIndex())
1601+
return emitOpError()
1602+
<< "expected the number of elements to be of index type";
1603+
1604+
return success();
1605+
}
1606+
15571607
//===----------------------------------------------------------------------===//
15581608
// ExtractElementOp
15591609
//===----------------------------------------------------------------------===//

mlir/test/IR/invalid-ops.mlir

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,13 +303,38 @@ func @invalid_cmp_shape(%idx : () -> ()) {
303303

304304
// -----
305305

306+
func @dma_start_not_enough_operands() {
307+
// expected-error@+1 {{expected at least 4 operands}}
308+
"std.dma_start"() : () -> ()
309+
}
310+
311+
// -----
312+
306313
func @dma_no_src_memref(%m : f32, %tag : f32, %c0 : index) {
307314
// expected-error@+1 {{expected source to be of memref type}}
308315
dma_start %m[%c0], %m[%c0], %c0, %tag[%c0] : f32, f32, f32
309316
}
310317

311318
// -----
312319

320+
func @dma_start_not_enough_operands_for_src(
321+
%src: memref<2x2x2xf32>, %idx: index) {
322+
// expected-error@+1 {{expected at least 7 operands}}
323+
"std.dma_start"(%src, %idx, %idx, %idx) : (memref<2x2x2xf32>, index, index, index) -> ()
324+
}
325+
326+
// -----
327+
328+
func @dma_start_src_index_wrong_type(
329+
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
330+
%tag: memref<i32,2>, %flt: f32) {
331+
// expected-error@+1 {{expected source indices to be of index type}}
332+
"std.dma_start"(%src, %idx, %flt, %dst, %idx, %tag, %idx)
333+
: (memref<2x2xf32>, index, f32, memref<2xf32,1>, index, memref<i32,2>, index) -> ()
334+
}
335+
336+
// -----
337+
313338
func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) {
314339
%mref = alloc() : memref<8 x f32>
315340
// expected-error@+1 {{expected destination to be of memref type}}
@@ -318,6 +343,36 @@ func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) {
318343

319344
// -----
320345

346+
func @dma_start_not_enough_operands_for_dst(
347+
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
348+
%tag: memref<i32,2>) {
349+
// expected-error@+1 {{expected at least 7 operands}}
350+
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx)
351+
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index) -> ()
352+
}
353+
354+
// -----
355+
356+
func @dma_start_dst_index_wrong_type(
357+
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
358+
%tag: memref<i32,2>, %flt: f32) {
359+
// expected-error@+1 {{expected destination indices to be of index type}}
360+
"std.dma_start"(%src, %idx, %idx, %dst, %flt, %tag, %idx)
361+
: (memref<2x2xf32>, index, index, memref<2xf32,1>, f32, memref<i32,2>, index) -> ()
362+
}
363+
364+
// -----
365+
366+
func @dma_start_dst_index_wrong_type(
367+
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
368+
%tag: memref<i32,2>, %flt: f32) {
369+
// expected-error@+1 {{expected num elements to be of index type}}
370+
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %flt, %tag)
371+
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, f32, memref<i32,2>) -> ()
372+
}
373+
374+
// -----
375+
321376
func @dma_no_tag_memref(%tag : f32, %c0 : index) {
322377
%mref = alloc() : memref<8 x f32>
323378
// expected-error@+1 {{expected tag to be of memref type}}
@@ -326,9 +381,80 @@ func @dma_no_tag_memref(%tag : f32, %c0 : index) {
326381

327382
// -----
328383

384+
func @dma_start_not_enough_operands_for_tag(
385+
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
386+
%tag: memref<2xi32,2>) {
387+
// expected-error@+1 {{expected at least 8 operands}}
388+
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag)
389+
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>) -> ()
390+
}
391+
392+
// -----
393+
394+
func @dma_start_dst_index_wrong_type(
395+
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
396+
%tag: memref<2xi32,2>, %flt: f32) {
397+
// expected-error@+1 {{expected tag indices to be of index type}}
398+
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %flt)
399+
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>, f32) -> ()
400+
}
401+
402+
// -----
403+
404+
func @dma_start_same_space(
405+
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32>,
406+
%tag: memref<i32,2>) {
407+
// expected-error@+1 {{DMA should be between different memory spaces}}
408+
dma_start %src[%idx, %idx], %dst[%idx], %idx, %tag[] : memref<2x2xf32>, memref<2xf32>, memref<i32,2>
409+
}
410+
411+
// -----
412+
413+
func @dma_start_too_many_operands(
414+
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
415+
%tag: memref<i32,2>) {
416+
// expected-error@+1 {{incorrect number of operands}}
417+
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %idx, %idx)
418+
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, index, index) -> ()
419+
}
420+
421+
422+
// -----
423+
424+
func @dma_start_wrong_stride_type(
425+
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
426+
%tag: memref<i32,2>, %flt: f32) {
427+
// expected-error@+1 {{expected stride and num elements per stride to be of type index}}
428+
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %flt)
429+
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, f32) -> ()
430+
}
431+
432+
// -----
433+
434+
func @dma_wait_not_enough_operands() {
435+
// expected-error@+1 {{expected at least 2 operands}}
436+
"std.dma_wait"() : () -> ()
437+
}
438+
439+
// -----
440+
329441
func @dma_wait_no_tag_memref(%tag : f32, %c0 : index) {
330442
// expected-error@+1 {{expected tag to be of memref type}}
331-
dma_wait %tag[%c0], %arg0 : f32
443+
"std.dma_wait"(%tag, %c0, %c0) : (f32, index, index) -> ()
444+
}
445+
446+
// -----
447+
448+
func @dma_wait_wrong_index_type(%tag : memref<2xi32>, %idx: index, %flt: f32) {
449+
// expected-error@+1 {{expected tag indices to be of index type}}
450+
"std.dma_wait"(%tag, %flt, %idx) : (memref<2xi32>, f32, index) -> ()
451+
}
452+
453+
// -----
454+
455+
func @dma_wait_wrong_num_elements_type(%tag : memref<2xi32>, %idx: index, %flt: f32) {
456+
// expected-error@+1 {{expected the number of elements to be of index type}}
457+
"std.dma_wait"(%tag, %idx, %flt) : (memref<2xi32>, index, f32) -> ()
332458
}
333459

334460
// -----

0 commit comments

Comments
 (0)