Skip to content

Commit 41b5268

Browse files
committed
[flang] add hlfir.product operation
Adds a HLFIR operation for the PRODUCT intrinsic according to the design set out in flang/doc/HighLevelFIR.md Since the PRODUCT intrinsic is essentially identical to SUM in terms of its arguments and result characteristics in the Fortran Standard, the operation definition and subsequent tests also take the same form. Differential Revision: https://reviews.llvm.org/D147624
1 parent b88023c commit 41b5268

File tree

4 files changed

+335
-28
lines changed

4 files changed

+335
-28
lines changed

flang/include/flang/Optimizer/HLFIR/HLFIROps.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,31 @@ def hlfir_ConcatOp : hlfir_Op<"concat", []> {
317317
let hasVerifier = 1;
318318
}
319319

320+
def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
321+
DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
322+
let summary = "PRODUCT transformational intrinsic";
323+
let description = [{
324+
Multiplies the elements of an array, optionally along a particular dimension,
325+
optionally if a mask is true.
326+
}];
327+
328+
let arguments = (ins
329+
AnyFortranNumericalArrayObject:$array,
330+
Optional<AnyIntegerType>:$dim,
331+
Optional<AnyFortranLogicalOrI1ArrayObject>:$mask,
332+
DefaultValuedAttr<Arith_FastMathAttr,
333+
"::mlir::arith::FastMathFlags::none">:$fastmath
334+
);
335+
336+
let results = (outs hlfir_ExprType);
337+
338+
let assemblyFormat = [{
339+
$array (`dim` $dim^)? (`mask` $mask^)? attr-dict `:` functional-type(operands, results)
340+
}];
341+
342+
let hasVerifier = 1;
343+
}
344+
320345
def hlfir_SetLengthOp : hlfir_Op<"set_length", []> {
321346
let summary = "change the length of a character entity";
322347
let description = [{

flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -489,35 +489,18 @@ void hlfir::ConcatOp::build(mlir::OpBuilder &builder,
489489
}
490490

491491
//===----------------------------------------------------------------------===//
492-
// SetLengthOp
492+
// ReductionOp
493493
//===----------------------------------------------------------------------===//
494494

495-
void hlfir::SetLengthOp::build(mlir::OpBuilder &builder,
496-
mlir::OperationState &result, mlir::Value string,
497-
mlir::Value len) {
498-
fir::CharacterType::LenType resultTypeLen = fir::CharacterType::unknownLen();
499-
if (auto cstLen = fir::getIntIfConstant(len))
500-
resultTypeLen = *cstLen;
501-
unsigned kind = getCharacterKind(string.getType());
502-
auto resultType = hlfir::ExprType::get(
503-
builder.getContext(), hlfir::ExprType::Shape{},
504-
fir::CharacterType::get(builder.getContext(), kind, resultTypeLen),
505-
false);
506-
build(builder, result, resultType, string, len);
507-
}
508-
509-
//===----------------------------------------------------------------------===//
510-
// SumOp
511-
//===----------------------------------------------------------------------===//
512-
513-
mlir::LogicalResult hlfir::SumOp::verify() {
514-
mlir::Operation *op = getOperation();
495+
template <typename ReductionOp>
496+
static mlir::LogicalResult verifyReductionOp(ReductionOp reductionOp) {
497+
mlir::Operation *op = reductionOp->getOperation();
515498

516499
auto results = op->getResultTypes();
517500
assert(results.size() == 1);
518501

519-
mlir::Value array = getArray();
520-
mlir::Value mask = getMask();
502+
mlir::Value array = reductionOp->getArray();
503+
mlir::Value mask = reductionOp->getMask();
521504

522505
fir::SequenceType arrayTy =
523506
hlfir::getFortranElementOrSequenceType(array.getType())
@@ -537,7 +520,7 @@ mlir::LogicalResult hlfir::SumOp::verify() {
537520

538521
if (!maskShape.empty()) {
539522
if (maskShape.size() != arrayShape.size())
540-
return emitWarning("MASK must be conformable to ARRAY");
523+
return reductionOp->emitWarning("MASK must be conformable to ARRAY");
541524
static_assert(fir::SequenceType::getUnknownExtent() ==
542525
hlfir::ExprType::getUnknownExtent());
543526
constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
@@ -546,32 +529,67 @@ mlir::LogicalResult hlfir::SumOp::verify() {
546529
int64_t maskExtent = maskShape[i];
547530
if ((arrayExtent != maskExtent) && (arrayExtent != unknownExtent) &&
548531
(maskExtent != unknownExtent))
549-
return emitWarning("MASK must be conformable to ARRAY");
532+
return reductionOp->emitWarning("MASK must be conformable to ARRAY");
550533
}
551534
}
552535
}
553536

554537
if (resultTy.isArray()) {
555538
// Result is of the same type as ARRAY
556539
if (resultTy.getEleTy() != numTy)
557-
return emitOpError(
540+
return reductionOp->emitOpError(
558541
"result must have the same element type as ARRAY argument");
559542

560543
llvm::ArrayRef<int64_t> resultShape = resultTy.getShape();
561544

562545
// Result has rank n-1
563546
if (resultShape.size() != (arrayShape.size() - 1))
564-
return emitOpError("result rank must be one less than ARRAY");
547+
return reductionOp->emitOpError(
548+
"result rank must be one less than ARRAY");
565549
} else {
566550
// Result is of the same type as ARRAY
567551
if (resultTy.getElementType() != numTy)
568-
return emitOpError(
552+
return reductionOp->emitOpError(
569553
"result must have the same element type as ARRAY argument");
570554
}
571555

572556
return mlir::success();
573557
}
574558

559+
//===----------------------------------------------------------------------===//
560+
// ProductOp
561+
//===----------------------------------------------------------------------===//
562+
563+
mlir::LogicalResult hlfir::ProductOp::verify() {
564+
return verifyReductionOp<hlfir::ProductOp *>(this);
565+
}
566+
567+
//===----------------------------------------------------------------------===//
568+
// SetLengthOp
569+
//===----------------------------------------------------------------------===//
570+
571+
void hlfir::SetLengthOp::build(mlir::OpBuilder &builder,
572+
mlir::OperationState &result, mlir::Value string,
573+
mlir::Value len) {
574+
fir::CharacterType::LenType resultTypeLen = fir::CharacterType::unknownLen();
575+
if (auto cstLen = fir::getIntIfConstant(len))
576+
resultTypeLen = *cstLen;
577+
unsigned kind = getCharacterKind(string.getType());
578+
auto resultType = hlfir::ExprType::get(
579+
builder.getContext(), hlfir::ExprType::Shape{},
580+
fir::CharacterType::get(builder.getContext(), kind, resultTypeLen),
581+
false);
582+
build(builder, result, resultType, string, len);
583+
}
584+
585+
//===----------------------------------------------------------------------===//
586+
// SumOp
587+
//===----------------------------------------------------------------------===//
588+
589+
mlir::LogicalResult hlfir::SumOp::verify() {
590+
return verifyReductionOp<hlfir::SumOp *>(this);
591+
}
592+
575593
//===----------------------------------------------------------------------===//
576594
// MatmulOp
577595
//===----------------------------------------------------------------------===//

flang/test/HLFIR/invalid.fir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,30 @@ func.func @bad_concat_4(%arg0: !fir.ref<!fir.char<1,30>>) {
296296
return
297297
}
298298

299+
// -----
300+
func.func @bad_product1(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
301+
// expected-error@+1 {{'hlfir.product' op result must have the same element type as ARRAY argument}}
302+
%0 = hlfir.product %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?xi32>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<f32>
303+
}
304+
305+
// -----
306+
func.func @bad_product2(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.array<?x?x?x?x?x!fir.logical<4>>>) {
307+
// expected-warning@+1 {{MASK must be conformable to ARRAY}}
308+
%0 = hlfir.product %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?xi32>, i32, !fir.box<!fir.array<?x?x?x?x?x!fir.logical<4>>>) -> !hlfir.expr<i32>
309+
}
310+
311+
// -----
312+
func.func @bad_product3(%arg0: !hlfir.expr<?x5x?xi32>, %arg1: i32, %arg2: !fir.box<!fir.array<2x6x?x!fir.logical<4>>>) {
313+
// expected-warning@+1 {{MASK must be conformable to ARRAY}}
314+
%0 = hlfir.product %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?x5x?xi32>, i32, !fir.box<!fir.array<2x6x?x!fir.logical<4>>>) -> !hlfir.expr<i32>
315+
}
316+
317+
// -----
318+
func.func @bad_product4(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
319+
// expected-error@+1 {{'hlfir.product' op result rank must be one less than ARRAY}}
320+
%0 = hlfir.product %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?xi32>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<?x?xi32>
321+
}
322+
299323
// -----
300324
func.func @bad_sum1(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
301325
// expected-error@+1 {{'hlfir.sum' op result must have the same element type as ARRAY argument}}

0 commit comments

Comments
 (0)