@@ -489,35 +489,18 @@ void hlfir::ConcatOp::build(mlir::OpBuilder &builder,
489
489
}
490
490
491
491
// ===----------------------------------------------------------------------===//
492
- // SetLengthOp
492
+ // ReductionOp
493
493
// ===----------------------------------------------------------------------===//
494
494
495
- void hlfir::SetLengthOp::build (mlir::OpBuilder &builder,
496
- mlir::OperationState &result, mlir::Value string,
497
- mlir::Value len) {
498
- fir::CharacterType::LenType resultTypeLen = fir::CharacterType::unknownLen ();
499
- if (auto cstLen = fir::getIntIfConstant (len))
500
- resultTypeLen = *cstLen;
501
- unsigned kind = getCharacterKind (string.getType ());
502
- auto resultType = hlfir::ExprType::get (
503
- builder.getContext (), hlfir::ExprType::Shape{},
504
- fir::CharacterType::get (builder.getContext (), kind, resultTypeLen),
505
- false );
506
- build (builder, result, resultType, string, len);
507
- }
508
-
509
- // ===----------------------------------------------------------------------===//
510
- // SumOp
511
- // ===----------------------------------------------------------------------===//
512
-
513
- mlir::LogicalResult hlfir::SumOp::verify () {
514
- mlir::Operation *op = getOperation ();
495
+ template <typename ReductionOp>
496
+ static mlir::LogicalResult verifyReductionOp (ReductionOp reductionOp) {
497
+ mlir::Operation *op = reductionOp->getOperation ();
515
498
516
499
auto results = op->getResultTypes ();
517
500
assert (results.size () == 1 );
518
501
519
- mlir::Value array = getArray ();
520
- mlir::Value mask = getMask ();
502
+ mlir::Value array = reductionOp-> getArray ();
503
+ mlir::Value mask = reductionOp-> getMask ();
521
504
522
505
fir::SequenceType arrayTy =
523
506
hlfir::getFortranElementOrSequenceType (array.getType ())
@@ -537,7 +520,7 @@ mlir::LogicalResult hlfir::SumOp::verify() {
537
520
538
521
if (!maskShape.empty ()) {
539
522
if (maskShape.size () != arrayShape.size ())
540
- return emitWarning (" MASK must be conformable to ARRAY" );
523
+ return reductionOp-> emitWarning (" MASK must be conformable to ARRAY" );
541
524
static_assert (fir::SequenceType::getUnknownExtent () ==
542
525
hlfir::ExprType::getUnknownExtent ());
543
526
constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent ();
@@ -546,32 +529,67 @@ mlir::LogicalResult hlfir::SumOp::verify() {
546
529
int64_t maskExtent = maskShape[i];
547
530
if ((arrayExtent != maskExtent) && (arrayExtent != unknownExtent) &&
548
531
(maskExtent != unknownExtent))
549
- return emitWarning (" MASK must be conformable to ARRAY" );
532
+ return reductionOp-> emitWarning (" MASK must be conformable to ARRAY" );
550
533
}
551
534
}
552
535
}
553
536
554
537
if (resultTy.isArray ()) {
555
538
// Result is of the same type as ARRAY
556
539
if (resultTy.getEleTy () != numTy)
557
- return emitOpError (
540
+ return reductionOp-> emitOpError (
558
541
" result must have the same element type as ARRAY argument" );
559
542
560
543
llvm::ArrayRef<int64_t > resultShape = resultTy.getShape ();
561
544
562
545
// Result has rank n-1
563
546
if (resultShape.size () != (arrayShape.size () - 1 ))
564
- return emitOpError (" result rank must be one less than ARRAY" );
547
+ return reductionOp->emitOpError (
548
+ " result rank must be one less than ARRAY" );
565
549
} else {
566
550
// Result is of the same type as ARRAY
567
551
if (resultTy.getElementType () != numTy)
568
- return emitOpError (
552
+ return reductionOp-> emitOpError (
569
553
" result must have the same element type as ARRAY argument" );
570
554
}
571
555
572
556
return mlir::success ();
573
557
}
574
558
559
+ // ===----------------------------------------------------------------------===//
560
+ // ProductOp
561
+ // ===----------------------------------------------------------------------===//
562
+
563
+ mlir::LogicalResult hlfir::ProductOp::verify () {
564
+ return verifyReductionOp<hlfir::ProductOp *>(this );
565
+ }
566
+
567
+ // ===----------------------------------------------------------------------===//
568
+ // SetLengthOp
569
+ // ===----------------------------------------------------------------------===//
570
+
571
+ void hlfir::SetLengthOp::build (mlir::OpBuilder &builder,
572
+ mlir::OperationState &result, mlir::Value string,
573
+ mlir::Value len) {
574
+ fir::CharacterType::LenType resultTypeLen = fir::CharacterType::unknownLen ();
575
+ if (auto cstLen = fir::getIntIfConstant (len))
576
+ resultTypeLen = *cstLen;
577
+ unsigned kind = getCharacterKind (string.getType ());
578
+ auto resultType = hlfir::ExprType::get (
579
+ builder.getContext (), hlfir::ExprType::Shape{},
580
+ fir::CharacterType::get (builder.getContext (), kind, resultTypeLen),
581
+ false );
582
+ build (builder, result, resultType, string, len);
583
+ }
584
+
585
+ // ===----------------------------------------------------------------------===//
586
+ // SumOp
587
+ // ===----------------------------------------------------------------------===//
588
+
589
+ mlir::LogicalResult hlfir::SumOp::verify () {
590
+ return verifyReductionOp<hlfir::SumOp *>(this );
591
+ }
592
+
575
593
// ===----------------------------------------------------------------------===//
576
594
// MatmulOp
577
595
// ===----------------------------------------------------------------------===//
0 commit comments