Skip to content

Commit ea55503

Browse files
clementvaljeanPerierschweitzpgi
committed
[fir] Add fir.extract_value and fir.insert_value conversion
This patch add the conversion pattern for fir.extract_value and fir.insert_value. fir.extract_value is lowered to llvm.extractvalue anf fir.insert_value is lowered to llvm.insertvalue. This patch also adds the type conversion for the BoxType and RecordType needed to have some comprehensive tests. This patch is part of the upstreaming effort from fir-dev branch. Reviewed By: awarzynski Differential Revision: https://reviews.llvm.org/D112961 Co-authored-by: Jean Perier <[email protected]> Co-authored-by: Eric Schweitz <[email protected]>
1 parent cf838eb commit ea55503

File tree

4 files changed

+337
-4
lines changed

4 files changed

+337
-4
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,105 @@ struct ZeroOpConversion : public FIROpConversion<fir::ZeroOp> {
296296
}
297297
};
298298

299+
// Code shared between insert_value and extract_value Ops.
300+
struct ValueOpCommon {
301+
// Translate the arguments pertaining to any multidimensional array to
302+
// row-major order for LLVM-IR.
303+
static void toRowMajor(SmallVectorImpl<mlir::Attribute> &attrs,
304+
mlir::Type ty) {
305+
assert(ty && "type is null");
306+
const auto end = attrs.size();
307+
for (std::remove_const_t<decltype(end)> i = 0; i < end; ++i) {
308+
if (auto seq = ty.dyn_cast<mlir::LLVM::LLVMArrayType>()) {
309+
const auto dim = getDimension(seq);
310+
if (dim > 1) {
311+
auto ub = std::min(i + dim, end);
312+
std::reverse(attrs.begin() + i, attrs.begin() + ub);
313+
i += dim - 1;
314+
}
315+
ty = getArrayElementType(seq);
316+
} else if (auto st = ty.dyn_cast<mlir::LLVM::LLVMStructType>()) {
317+
ty = st.getBody()[attrs[i].cast<mlir::IntegerAttr>().getInt()];
318+
} else {
319+
llvm_unreachable("index into invalid type");
320+
}
321+
}
322+
}
323+
324+
static llvm::SmallVector<mlir::Attribute>
325+
collectIndices(mlir::ConversionPatternRewriter &rewriter,
326+
mlir::ArrayAttr arrAttr) {
327+
llvm::SmallVector<mlir::Attribute> attrs;
328+
for (auto i = arrAttr.begin(), e = arrAttr.end(); i != e; ++i) {
329+
if (i->isa<mlir::IntegerAttr>()) {
330+
attrs.push_back(*i);
331+
} else {
332+
auto fieldName = i->cast<mlir::StringAttr>().getValue();
333+
++i;
334+
auto ty = i->cast<mlir::TypeAttr>().getValue();
335+
auto index = ty.cast<fir::RecordType>().getFieldIndex(fieldName);
336+
attrs.push_back(mlir::IntegerAttr::get(rewriter.getI32Type(), index));
337+
}
338+
}
339+
return attrs;
340+
}
341+
342+
private:
343+
static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) {
344+
unsigned result = 1;
345+
for (auto eleTy = ty.getElementType().dyn_cast<mlir::LLVM::LLVMArrayType>();
346+
eleTy;
347+
eleTy = eleTy.getElementType().dyn_cast<mlir::LLVM::LLVMArrayType>())
348+
++result;
349+
return result;
350+
}
351+
352+
static mlir::Type getArrayElementType(mlir::LLVM::LLVMArrayType ty) {
353+
auto eleTy = ty.getElementType();
354+
while (auto arrTy = eleTy.dyn_cast<mlir::LLVM::LLVMArrayType>())
355+
eleTy = arrTy.getElementType();
356+
return eleTy;
357+
}
358+
};
359+
360+
/// Extract a subobject value from an ssa-value of aggregate type
361+
struct ExtractValueOpConversion
362+
: public FIROpAndTypeConversion<fir::ExtractValueOp>,
363+
public ValueOpCommon {
364+
using FIROpAndTypeConversion::FIROpAndTypeConversion;
365+
366+
mlir::LogicalResult
367+
doRewrite(fir::ExtractValueOp extractVal, mlir::Type ty, OpAdaptor adaptor,
368+
mlir::ConversionPatternRewriter &rewriter) const override {
369+
auto attrs = collectIndices(rewriter, extractVal.coor());
370+
toRowMajor(attrs, adaptor.getOperands()[0].getType());
371+
auto position = mlir::ArrayAttr::get(extractVal.getContext(), attrs);
372+
rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
373+
extractVal, ty, adaptor.getOperands()[0], position);
374+
return success();
375+
}
376+
};
377+
378+
/// InsertValue is the generalized instruction for the composition of new
379+
/// aggregate type values.
380+
struct InsertValueOpConversion
381+
: public FIROpAndTypeConversion<fir::InsertValueOp>,
382+
public ValueOpCommon {
383+
using FIROpAndTypeConversion::FIROpAndTypeConversion;
384+
385+
mlir::LogicalResult
386+
doRewrite(fir::InsertValueOp insertVal, mlir::Type ty, OpAdaptor adaptor,
387+
mlir::ConversionPatternRewriter &rewriter) const override {
388+
auto attrs = collectIndices(rewriter, insertVal.coor());
389+
toRowMajor(attrs, adaptor.getOperands()[0].getType());
390+
auto position = mlir::ArrayAttr::get(insertVal.getContext(), attrs);
391+
rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(
392+
insertVal, ty, adaptor.getOperands()[0], adaptor.getOperands()[1],
393+
position);
394+
return success();
395+
}
396+
};
397+
299398
/// InsertOnRange inserts a value into a sequence over a range of offsets.
300399
struct InsertOnRangeOpConversion
301400
: public FIROpAndTypeConversion<fir::InsertOnRangeOp> {
@@ -389,10 +488,11 @@ class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
389488
auto *context = getModule().getContext();
390489
fir::LLVMTypeConverter typeConverter{getModule()};
391490
mlir::OwningRewritePatternList pattern(context);
392-
pattern.insert<AddrOfOpConversion, HasValueOpConversion, GlobalOpConversion,
393-
InsertOnRangeOpConversion, SelectOpConversion,
394-
SelectRankOpConversion, UnreachableOpConversion,
395-
ZeroOpConversion, UndefOpConversion>(typeConverter);
491+
pattern.insert<
492+
AddrOfOpConversion, ExtractValueOpConversion, HasValueOpConversion,
493+
GlobalOpConversion, InsertOnRangeOpConversion, InsertValueOpConversion,
494+
SelectOpConversion, SelectRankOpConversion, UndefOpConversion,
495+
UnreachableOpConversion, ZeroOpConversion>(typeConverter);
396496
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
397497
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
398498
pattern);
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
//===-- DescriptorModel.h -- model of descriptors for codegen ---*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// LLVM IR dialect models of C++ types.
9+
//
10+
// This supplies a set of model builders to decompose the C declaration of a
11+
// descriptor (as encoded in ISO_Fortran_binding.h and elsewhere) and
12+
// reconstruct that type in the LLVM IR dialect.
13+
//
14+
// TODO: It is understood that this is deeply incorrect as far as building a
15+
// portability layer for cross-compilation as these reflected types are those of
16+
// the build machine and not necessarily that of either the host or the target.
17+
// This assumption that build == host == target is actually pervasive across the
18+
// compiler (https://llvm.org/PR52418).
19+
//
20+
//===----------------------------------------------------------------------===//
21+
22+
#ifndef OPTIMIZER_DESCRIPTOR_MODEL_H
23+
#define OPTIMIZER_DESCRIPTOR_MODEL_H
24+
25+
#include "flang/ISO_Fortran_binding.h"
26+
#include "flang/Runtime/descriptor.h"
27+
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
28+
#include "llvm/Support/ErrorHandling.h"
29+
#include <tuple>
30+
31+
namespace fir {
32+
33+
using TypeBuilderFunc = mlir::Type (*)(mlir::MLIRContext *);
34+
35+
/// Get the LLVM IR dialect model for building a particular C++ type, `T`.
36+
template <typename T>
37+
TypeBuilderFunc getModel();
38+
39+
template <>
40+
TypeBuilderFunc getModel<void *>() {
41+
return [](mlir::MLIRContext *context) -> mlir::Type {
42+
return mlir::LLVM::LLVMPointerType::get(mlir::IntegerType::get(context, 8));
43+
};
44+
}
45+
template <>
46+
TypeBuilderFunc getModel<unsigned>() {
47+
return [](mlir::MLIRContext *context) -> mlir::Type {
48+
return mlir::IntegerType::get(context, sizeof(unsigned) * 8);
49+
};
50+
}
51+
template <>
52+
TypeBuilderFunc getModel<int>() {
53+
return [](mlir::MLIRContext *context) -> mlir::Type {
54+
return mlir::IntegerType::get(context, sizeof(int) * 8);
55+
};
56+
}
57+
template <>
58+
TypeBuilderFunc getModel<unsigned long>() {
59+
return [](mlir::MLIRContext *context) -> mlir::Type {
60+
return mlir::IntegerType::get(context, sizeof(unsigned long) * 8);
61+
};
62+
}
63+
template <>
64+
TypeBuilderFunc getModel<unsigned long long>() {
65+
return [](mlir::MLIRContext *context) -> mlir::Type {
66+
return mlir::IntegerType::get(context, sizeof(unsigned long long) * 8);
67+
};
68+
}
69+
template <>
70+
TypeBuilderFunc getModel<long long>() {
71+
return [](mlir::MLIRContext *context) -> mlir::Type {
72+
return mlir::IntegerType::get(context, sizeof(long long) * 8);
73+
};
74+
}
75+
template <>
76+
TypeBuilderFunc getModel<Fortran::ISO::CFI_rank_t>() {
77+
return [](mlir::MLIRContext *context) -> mlir::Type {
78+
return mlir::IntegerType::get(context,
79+
sizeof(Fortran::ISO::CFI_rank_t) * 8);
80+
};
81+
}
82+
template <>
83+
TypeBuilderFunc getModel<Fortran::ISO::CFI_type_t>() {
84+
return [](mlir::MLIRContext *context) -> mlir::Type {
85+
return mlir::IntegerType::get(context,
86+
sizeof(Fortran::ISO::CFI_type_t) * 8);
87+
};
88+
}
89+
template <>
90+
TypeBuilderFunc getModel<Fortran::ISO::CFI_index_t>() {
91+
return [](mlir::MLIRContext *context) -> mlir::Type {
92+
return mlir::IntegerType::get(context,
93+
sizeof(Fortran::ISO::CFI_index_t) * 8);
94+
};
95+
}
96+
template <>
97+
TypeBuilderFunc getModel<Fortran::ISO::CFI_dim_t>() {
98+
return [](mlir::MLIRContext *context) -> mlir::Type {
99+
auto indexTy = getModel<Fortran::ISO::CFI_index_t>()(context);
100+
return mlir::LLVM::LLVMArrayType::get(indexTy, 3);
101+
};
102+
}
103+
template <>
104+
TypeBuilderFunc
105+
getModel<Fortran::ISO::cfi_internal::FlexibleArray<Fortran::ISO::CFI_dim_t>>() {
106+
return getModel<Fortran::ISO::CFI_dim_t>();
107+
}
108+
109+
//===----------------------------------------------------------------------===//
110+
// Descriptor reflection
111+
//===----------------------------------------------------------------------===//
112+
113+
/// Get the type model of the field number `Field` in an ISO CFI descriptor.
114+
template <int Field>
115+
static constexpr TypeBuilderFunc getDescFieldTypeModel() {
116+
Fortran::ISO::Fortran_2018::CFI_cdesc_t dummyDesc{};
117+
// check that the descriptor is exactly 8 fields as specified in CFI_cdesc_t
118+
// in flang/include/flang/ISO_Fortran_binding.h.
119+
auto [a, b, c, d, e, f, g, h] = dummyDesc;
120+
auto tup = std::tie(a, b, c, d, e, f, g, h);
121+
auto field = std::get<Field>(tup);
122+
return getModel<decltype(field)>();
123+
}
124+
125+
/// An extended descriptor is defined by a class in runtime/descriptor.h. The
126+
/// three fields in the class are hard-coded here, unlike the reflection used on
127+
/// the ISO parts, which are a POD.
128+
template <int Field>
129+
static constexpr TypeBuilderFunc getExtendedDescFieldTypeModel() {
130+
if constexpr (Field == 8) {
131+
return getModel<void *>();
132+
} else if constexpr (Field == 9) {
133+
return getModel<Fortran::runtime::typeInfo::TypeParameterValue>();
134+
} else {
135+
llvm_unreachable("extended ISO descriptor only has 10 fields");
136+
}
137+
}
138+
139+
} // namespace fir
140+
141+
#endif // OPTIMIZER_DESCRIPTOR_MODEL_H

flang/lib/Optimizer/CodeGen/TypeConverter.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
#ifndef FORTRAN_OPTIMIZER_CODEGEN_TYPECONVERTER_H
1414
#define FORTRAN_OPTIMIZER_CODEGEN_TYPECONVERTER_H
1515

16+
#include "DescriptorModel.h"
17+
#include "flang/Lower/Todo.h" // remove when TODO's are done
18+
#include "llvm/ADT/StringMap.h"
1619
#include "llvm/Support/Debug.h"
1720

1821
namespace fir {
@@ -26,10 +29,35 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
2629
LLVM_DEBUG(llvm::dbgs() << "FIR type converter\n");
2730

2831
// Each conversion should return a value of type mlir::Type.
32+
addConversion(
33+
[&](fir::RecordType derived) { return convertRecordType(derived); });
2934
addConversion(
3035
[&](fir::ReferenceType ref) { return convertPointerLike(ref); });
3136
addConversion(
3237
[&](SequenceType sequence) { return convertSequenceType(sequence); });
38+
addConversion([&](mlir::TupleType tuple) {
39+
LLVM_DEBUG(llvm::dbgs() << "type convert: " << tuple << '\n');
40+
llvm::SmallVector<mlir::Type> inMembers;
41+
tuple.getFlattenedTypes(inMembers);
42+
llvm::SmallVector<mlir::Type> members;
43+
for (auto mem : inMembers)
44+
members.push_back(convertType(mem).cast<mlir::Type>());
45+
return mlir::LLVM::LLVMStructType::getLiteral(&getContext(), members,
46+
/*isPacked=*/false);
47+
});
48+
}
49+
50+
// fir.type<name(p : TY'...){f : TY...}> --> llvm<"%name = { ty... }">
51+
mlir::Type convertRecordType(fir::RecordType derived) {
52+
auto name = derived.getName();
53+
auto st = mlir::LLVM::LLVMStructType::getIdentified(&getContext(), name);
54+
llvm::SmallVector<mlir::Type> members;
55+
for (auto mem : derived.getTypeList()) {
56+
members.push_back(convertType(mem.second).cast<mlir::Type>());
57+
}
58+
if (mlir::succeeded(st.setBody(members, /*isPacked=*/false)))
59+
return st;
60+
return mlir::Type();
3361
}
3462

3563
template <typename A>

flang/test/Fir/convert-to-llvm.fir

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,67 @@ func @select_rank(%arg : i32, %arg2 : i32) -> i32 {
259259
// CHECK: 3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),
260260
// CHECK: 4: ^bb4(%[[C1]] : i32)
261261
// CHECK: ]
262+
263+
// -----
264+
265+
// Test fir.extract_value operation conversion with derived type.
266+
267+
func @extract_derived_type() -> f32 {
268+
%0 = fir.undefined !fir.type<derived{f:f32}>
269+
%1 = fir.extract_value %0, ["f", !fir.type<derived{f:f32}>] : (!fir.type<derived{f:f32}>) -> f32
270+
return %1 : f32
271+
}
272+
273+
// CHECK-LABEL: llvm.func @extract_derived_type
274+
// CHECK: %[[STRUCT:.*]] = llvm.mlir.undef : !llvm.struct<"derived", (f32)>
275+
// CHECK: %[[VALUE:.*]] = llvm.extractvalue %[[STRUCT]][0 : i32] : !llvm.struct<"derived", (f32)>
276+
// CHECK: llvm.return %[[VALUE]] : f32
277+
278+
// -----
279+
280+
// Test fir.extract_value operation conversion with a multi-dimensional array
281+
// of tuple.
282+
283+
func @extract_array(%a : !fir.array<10x10xtuple<i32, f32>>) -> f32 {
284+
%0 = fir.extract_value %a, [5 : index, 4 : index, 1 : index] : (!fir.array<10x10xtuple<i32, f32>>) -> f32
285+
return %0 : f32
286+
}
287+
288+
// CHECK-LABEL: llvm.func @extract_array(
289+
// CHECK-SAME: %[[ARR:.*]]: !llvm.array<10 x array<10 x struct<(i32, f32)>>>
290+
// CHECK: %[[VALUE:.*]] = llvm.extractvalue %[[ARR]][4 : index, 5 : index, 1 : index] : !llvm.array<10 x array<10 x struct<(i32, f32)>>>
291+
// CHECK: llvm.return %[[VALUE]] : f32
292+
293+
// -----
294+
295+
// Test fir.insert_value operation conversion with a multi-dimensional array
296+
// of tuple.
297+
298+
func @extract_array(%a : !fir.array<10x10xtuple<i32, f32>>) {
299+
%f = arith.constant 2.0 : f32
300+
%i = arith.constant 1 : i32
301+
%0 = fir.insert_value %a, %i, [5 : index, 4 : index, 0 : index] : (!fir.array<10x10xtuple<i32, f32>>, i32) -> !fir.array<10x10xtuple<i32, f32>>
302+
%1 = fir.insert_value %a, %f, [5 : index, 4 : index, 1 : index] : (!fir.array<10x10xtuple<i32, f32>>, f32) -> !fir.array<10x10xtuple<i32, f32>>
303+
return
304+
}
305+
306+
// CHECK-LABEL: llvm.func @extract_array(
307+
// CHECK-SAME: %[[ARR:.*]]: !llvm.array<10 x array<10 x struct<(i32, f32)>>>
308+
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %[[ARR]][4 : index, 5 : index, 0 : index] : !llvm.array<10 x array<10 x struct<(i32, f32)>>>
309+
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %[[ARR]][4 : index, 5 : index, 1 : index] : !llvm.array<10 x array<10 x struct<(i32, f32)>>>
310+
// CHECK: llvm.return
311+
312+
// -----
313+
314+
// Test fir.insert_value operation conversion with derived type.
315+
316+
func @insert_tuple(%a : tuple<i32, f32>) {
317+
%f = arith.constant 2.0 : f32
318+
%1 = fir.insert_value %a, %f, [1 : index] : (tuple<i32, f32>, f32) -> tuple<i32, f32>
319+
return
320+
}
321+
322+
// CHECK-LABEL: func @insert_tuple(
323+
// CHECK-SAME: %[[TUPLE:.*]]: !llvm.struct<(i32, f32)>
324+
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %[[TUPLE]][1 : index] : !llvm.struct<(i32, f32)>
325+
// CHECK: llvm.return

0 commit comments

Comments
 (0)