Skip to content

Commit 5b5ef2e

Browse files
schweitzpgijeanPerierclementval
committed
[fir] Add fir.save_result op
Add the fir.save_result operation. It is use to save an array, box, or record function result SSA-value to a memory location Reviewed By: jeanPerier Differential Revision: https://reviews.llvm.org/D110407 Co-authored-by: Jean Perier <[email protected]> Co-authored-by: Valentin Clement <[email protected]>
1 parent 764d9aa commit 5b5ef2e

File tree

5 files changed

+199
-0
lines changed

5 files changed

+199
-0
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,55 @@ def fir_LoadOp : fir_OneResultOp<"load"> {
352352
}];
353353
}
354354

355+
def fir_SaveResultOp : fir_Op<"save_result", [AttrSizedOperandSegments]> {
356+
let summary = [{
357+
save an array, box, or record function result SSA-value to a memory location
358+
}];
359+
360+
let description = [{
361+
Save the result of a function returning an array, box, or record type value
362+
into a memory location given the shape and length parameters of the result.
363+
364+
Function results of type fir.box, fir.array, or fir.rec are abstract values
365+
that require a storage to be manipulated on the caller side. This operation
366+
allows associating such abstract result to a storage. In later lowering of
367+
the function interfaces, this storage might be used to pass the result in
368+
memory.
369+
370+
For arrays, result, it is required to provide the shape of the result. For
371+
character arrays and derived types with length parameters, the length
372+
parameter values must be provided.
373+
374+
The fir.save_result associated to a function call must immediately follow
375+
the call and be in the same block.
376+
377+
```mlir
378+
%buffer = fir.alloca fir.array<?xf32>, %c100
379+
%shape = fir.shape %c100
380+
%array_result = fir.call @foo() : () -> fir.array<?xf32>
381+
fir.save_result %array_result to %buffer(%shape)
382+
%coor = fir.array_coor %buffer%(%shape), %c5
383+
%fifth_element = fir.load %coor : f32
384+
```
385+
386+
The above fir.save_result allows saving a fir.array function result into
387+
a buffer to later access its 5th element.
388+
389+
}];
390+
391+
let arguments = (ins ArrayOrBoxOrRecord:$value,
392+
Arg<AnyReferenceLike, "", [MemWrite]>:$memref,
393+
Optional<AnyShapeType>:$shape,
394+
Variadic<AnyIntegerType>:$typeparams);
395+
396+
let assemblyFormat = [{
397+
$value `to` $memref (`(` $shape^ `)`)? (`typeparams` $typeparams^)?
398+
attr-dict `:` type(operands)
399+
}];
400+
401+
let verifier = [{ return ::verify(*this); }];
402+
}
403+
355404
def fir_StoreOp : fir_Op<"store", []> {
356405
let summary = "store an SSA-value to a memory location";
357406

flang/include/flang/Optimizer/Dialect/FIRTypes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,4 +551,9 @@ def AnyCoordinateType : Type<AnyCoordinateLike.predicate, "coordinate type">;
551551
def AnyAddressableLike : TypeConstraint<Or<[fir_ReferenceType.predicate,
552552
FunctionType.predicate]>, "any addressable">;
553553

554+
def ArrayOrBoxOrRecord : TypeConstraint<Or<[fir_SequenceType.predicate,
555+
fir_BoxType.predicate, fir_RecordType.predicate]>,
556+
"fir.box, fir.array or fir.type">;
557+
558+
554559
#endif // FIR_DIALECT_FIR_TYPES

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,6 +1361,63 @@ static mlir::LogicalResult verify(fir::ResultOp op) {
13611361
return success();
13621362
}
13631363

1364+
//===----------------------------------------------------------------------===//
1365+
// SaveResultOp
1366+
//===----------------------------------------------------------------------===//
1367+
1368+
static mlir::LogicalResult verify(fir::SaveResultOp op) {
1369+
auto resultType = op.value().getType();
1370+
if (resultType != fir::dyn_cast_ptrEleTy(op.memref().getType()))
1371+
return op.emitOpError("value type must match memory reference type");
1372+
if (fir::isa_unknown_size_box(resultType))
1373+
return op.emitOpError("cannot save !fir.box of unknown rank or type");
1374+
1375+
if (resultType.isa<fir::BoxType>()) {
1376+
if (op.shape() || !op.typeparams().empty())
1377+
return op.emitOpError(
1378+
"must not have shape or length operands if the value is a fir.box");
1379+
return mlir::success();
1380+
}
1381+
1382+
// fir.record or fir.array case.
1383+
unsigned shapeTyRank = 0;
1384+
if (auto shapeOp = op.shape()) {
1385+
auto shapeTy = shapeOp.getType();
1386+
if (auto s = shapeTy.dyn_cast<fir::ShapeType>())
1387+
shapeTyRank = s.getRank();
1388+
else
1389+
shapeTyRank = shapeTy.cast<fir::ShapeShiftType>().getRank();
1390+
}
1391+
1392+
auto eleTy = resultType;
1393+
if (auto seqTy = resultType.dyn_cast<fir::SequenceType>()) {
1394+
if (seqTy.getDimension() != shapeTyRank)
1395+
op.emitOpError("shape operand must be provided and have the value rank "
1396+
"when the value is a fir.array");
1397+
eleTy = seqTy.getEleTy();
1398+
} else {
1399+
if (shapeTyRank != 0)
1400+
op.emitOpError(
1401+
"shape operand should only be provided if the value is a fir.array");
1402+
}
1403+
1404+
if (auto recTy = eleTy.dyn_cast<fir::RecordType>()) {
1405+
if (recTy.getNumLenParams() != op.typeparams().size())
1406+
op.emitOpError("length parameters number must match with the value type "
1407+
"length parameters");
1408+
} else if (auto charTy = eleTy.dyn_cast<fir::CharacterType>()) {
1409+
if (op.typeparams().size() > 1)
1410+
op.emitOpError("no more than one length parameter must be provided for "
1411+
"character value");
1412+
} else {
1413+
if (!op.typeparams().empty())
1414+
op.emitOpError(
1415+
"length parameters must not be provided for this value type");
1416+
}
1417+
1418+
return mlir::success();
1419+
}
1420+
13641421
//===----------------------------------------------------------------------===//
13651422
// SelectOp
13661423
//===----------------------------------------------------------------------===//

flang/test/Fir/fir-ops.fir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,3 +671,14 @@ func @test_rebox(%arg0: !fir.box<!fir.array<?xf32>>) {
671671
fir.call @bar_rebox_test(%4) : (!fir.box<!fir.array<?x?xf32>>) -> ()
672672
return
673673
}
674+
675+
// CHECK-LABEL: @test_save_result(
676+
func @test_save_result(%buffer: !fir.ref<!fir.array<?x!fir.char<1,?>>>) {
677+
%c100 = constant 100 : index
678+
%c50 = constant 50 : index
679+
%shape = fir.shape %c100 : (index) -> !fir.shape<1>
680+
%res = fir.call @array_func() : () -> !fir.array<?x!fir.char<1,?>>
681+
// CHECK: fir.save_result %{{.*}} to %{{.*}}(%{{.*}}) typeparams %{{.*}} : !fir.array<?x!fir.char<1,?>>, !fir.ref<!fir.array<?x!fir.char<1,?>>>, !fir.shape<1>, index
682+
fir.save_result %res to %buffer(%shape) typeparams %c50 : !fir.array<?x!fir.char<1,?>>, !fir.ref<!fir.array<?x!fir.char<1,?>>>, !fir.shape<1>, index
683+
return
684+
}

flang/test/Fir/invalid.fir

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,3 +417,80 @@ fir.global internal @_QEmultiarray : !fir.array<32x32xi32> {
417417
%2 = fir.insert_on_range %0, %c0_i32, [10 : index, 9 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
418418
fir.has_value %2 : !fir.array<32x32xi32>
419419
}
420+
421+
// -----
422+
423+
func @bad_save_result(%buffer : !fir.ref<!fir.array<?xf64>>, %n :index) {
424+
%res = fir.call @array_func() : () -> !fir.array<?xf32>
425+
%shape = fir.shape %n : (index) -> !fir.shape<1>
426+
// expected-error@+1 {{'fir.save_result' op value type must match memory reference type}}
427+
fir.save_result %res to %buffer(%shape) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf64>>, !fir.shape<1>
428+
return
429+
}
430+
431+
// -----
432+
433+
func @bad_save_result(%buffer : !fir.ref<!fir.box<!fir.array<*:f32>>>) {
434+
%res = fir.call @array_func() : () -> !fir.box<!fir.array<*:f32>>
435+
// expected-error@+1 {{'fir.save_result' op cannot save !fir.box of unknown rank or type}}
436+
fir.save_result %res to %buffer : !fir.box<!fir.array<*:f32>>, !fir.ref<!fir.box<!fir.array<*:f32>>>
437+
return
438+
}
439+
440+
// -----
441+
442+
func @bad_save_result(%buffer : !fir.ref<f64>) {
443+
%res = fir.call @array_func() : () -> f64
444+
// expected-error@+1 {{'fir.save_result' op operand #0 must be fir.box, fir.array or fir.type, but got 'f64'}}
445+
fir.save_result %res to %buffer : f64, !fir.ref<f64>
446+
return
447+
}
448+
449+
// -----
450+
451+
func @bad_save_result(%buffer : !fir.ref<!fir.box<!fir.array<?xf32>>>, %n : index) {
452+
%res = fir.call @array_func() : () -> !fir.box<!fir.array<?xf32>>
453+
%shape = fir.shape %n : (index) -> !fir.shape<1>
454+
// expected-error@+1 {{'fir.save_result' op must not have shape or length operands if the value is a fir.box}}
455+
fir.save_result %res to %buffer(%shape) : !fir.box<!fir.array<?xf32>>, !fir.ref<!fir.box<!fir.array<?xf32>>>, !fir.shape<1>
456+
return
457+
}
458+
459+
// -----
460+
461+
func @bad_save_result(%buffer : !fir.ref<!fir.array<?xf32>>, %n :index) {
462+
%res = fir.call @array_func() : () -> !fir.array<?xf32>
463+
%shape = fir.shape %n, %n : (index, index) -> !fir.shape<2>
464+
// expected-error@+1 {{'fir.save_result' op shape operand must be provided and have the value rank when the value is a fir.array}}
465+
fir.save_result %res to %buffer(%shape) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<2>
466+
return
467+
}
468+
469+
// -----
470+
471+
func @bad_save_result(%buffer : !fir.ref<!fir.type<t{x:f32}>>, %n :index) {
472+
%res = fir.call @array_func() : () -> !fir.type<t{x:f32}>
473+
%shape = fir.shape %n : (index) -> !fir.shape<1>
474+
// expected-error@+1 {{'fir.save_result' op shape operand should only be provided if the value is a fir.array}}
475+
fir.save_result %res to %buffer(%shape) : !fir.type<t{x:f32}>, !fir.ref<!fir.type<t{x:f32}>>, !fir.shape<1>
476+
return
477+
}
478+
479+
// -----
480+
481+
func @bad_save_result(%buffer : !fir.ref<!fir.type<t{x:f32}>>, %n :index) {
482+
%res = fir.call @array_func() : () -> !fir.type<t{x:f32}>
483+
// expected-error@+1 {{'fir.save_result' op length parameters number must match with the value type length parameters}}
484+
fir.save_result %res to %buffer typeparams %n : !fir.type<t{x:f32}>, !fir.ref<!fir.type<t{x:f32}>>, index
485+
return
486+
}
487+
488+
// -----
489+
490+
func @bad_save_result(%buffer : !fir.ref<!fir.array<?xf32>>, %n :index) {
491+
%res = fir.call @array_func() : () -> !fir.array<?xf32>
492+
%shape = fir.shape %n : (index) -> !fir.shape<1>
493+
// expected-error@+1 {{'fir.save_result' op length parameters must not be provided for this value type}}
494+
fir.save_result %res to %buffer(%shape) typeparams %n : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<1>, index
495+
return
496+
}

0 commit comments

Comments
 (0)