Skip to content

[flang][openacc][NFC] Simplify lowering of recipe #68836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 16, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 74 additions & 101 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ bool isConstantBound(mlir::acc::DataBoundsOp &op) {
}

/// Return true iff all the bounds are expressed with constant values.
bool areAllBoundConstant(llvm::SmallVector<mlir::Value> &bounds) {
bool areAllBoundConstant(const llvm::SmallVector<mlir::Value> &bounds) {
for (auto bound : bounds) {
auto dataBound =
mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
Expand All @@ -474,27 +474,6 @@ bool areAllBoundConstant(llvm::SmallVector<mlir::Value> &bounds) {
return true;
}

static fir::ShapeOp
genShapeFromBounds(mlir::Location loc, fir::FirOpBuilder &builder,
const llvm::SmallVector<mlir::Value> &args) {
assert(args.size() % 3 == 0 && "Triplets must be a multiple of 3");
llvm::SmallVector<mlir::Value> extents;
mlir::Type idxTy = builder.getIndexType();
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
for (unsigned i = 0; i < args.size(); i += 3) {
mlir::Value s1 =
builder.create<mlir::arith::SubIOp>(loc, args[i + 1], args[0]);
mlir::Value s2 = builder.create<mlir::arith::AddIOp>(loc, s1, one);
mlir::Value s3 = builder.create<mlir::arith::DivSIOp>(loc, s2, args[i + 2]);
mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::sgt, s3, zero);
mlir::Value ext = builder.create<mlir::arith::SelectOp>(loc, cmp, s3, zero);
extents.push_back(ext);
}
return builder.create<fir::ShapeOp>(loc, extents);
}

static llvm::SmallVector<mlir::Value>
genConstantBounds(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::acc::DataBoundsOp &dataBound) {
Expand All @@ -520,6 +499,63 @@ genConstantBounds(fir::FirOpBuilder &builder, mlir::Location loc,
return {lb, ub, step};
}

static fir::ShapeOp genShapeFromBoundsOrArgs(
mlir::Location loc, fir::FirOpBuilder &builder, fir::SequenceType seqTy,
const llvm::SmallVector<mlir::Value> &bounds, mlir::ValueRange arguments) {
llvm::SmallVector<mlir::Value> args;
if (areAllBoundConstant(bounds)) {
for (auto bound : llvm::reverse(bounds)) {
auto dataBound =
mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
args.append(genConstantBounds(builder, loc, dataBound));
}
} else {
assert(((arguments.size() - 2) / 3 == seqTy.getDimension()) &&
"Expect 3 block arguments per dimension");
for (auto arg : arguments.drop_front(2))
args.push_back(arg);
}

assert(args.size() % 3 == 0 && "Triplets must be a multiple of 3");
llvm::SmallVector<mlir::Value> extents;
mlir::Type idxTy = builder.getIndexType();
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
for (unsigned i = 0; i < args.size(); i += 3) {
mlir::Value s1 =
builder.create<mlir::arith::SubIOp>(loc, args[i + 1], args[0]);
mlir::Value s2 = builder.create<mlir::arith::AddIOp>(loc, s1, one);
mlir::Value s3 = builder.create<mlir::arith::DivSIOp>(loc, s2, args[i + 2]);
mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::sgt, s3, zero);
mlir::Value ext = builder.create<mlir::arith::SelectOp>(loc, cmp, s3, zero);
extents.push_back(ext);
}
return builder.create<fir::ShapeOp>(loc, extents);
}

static hlfir::DesignateOp::Subscripts
getSubscriptsFromArgs(mlir::ValueRange args) {
hlfir::DesignateOp::Subscripts triplets;
for (unsigned i = 2; i < args.size(); i += 3)
triplets.emplace_back(
hlfir::DesignateOp::Triplet{args[i], args[i + 1], args[i + 2]});
return triplets;
}

static hlfir::Entity genDesignateWithTriplets(
fir::FirOpBuilder &builder, mlir::Location loc, hlfir::Entity &entity,
hlfir::DesignateOp::Subscripts &triplets, mlir::Value shape) {
llvm::SmallVector<mlir::Value> lenParams;
hlfir::genLengthParameters(loc, builder, entity, lenParams);
auto designate = builder.create<hlfir::DesignateOp>(
loc, entity.getBase().getType(), entity, /*component=*/"",
/*componentShape=*/mlir::Value{}, triplets,
/*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt, shape,
lenParams);
return hlfir::Entity{designate.getResult()};
}

mlir::acc::FirstprivateRecipeOp Fortran::lower::createOrGetFirstprivateRecipe(
mlir::OpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc,
mlir::Type ty, llvm::SmallVector<mlir::Value> &bounds) {
Expand Down Expand Up @@ -600,47 +636,16 @@ mlir::acc::FirstprivateRecipeOp Fortran::lower::createOrGetFirstprivateRecipe(
if (!seqTy)
TODO(loc, "Unsupported boxed type in OpenACC firstprivate");

if (allConstantBound) {
for (auto bound : llvm::reverse(bounds)) {
auto dataBound =
mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
tripletArgs.append(genConstantBounds(firBuilder, loc, dataBound));
}
} else {
assert(((recipe.getCopyRegion().getArguments().size() - 2) / 3 ==
seqTy.getDimension()) &&
"Expect 3 block arguments per dimension");
for (auto arg : recipe.getCopyRegion().getArguments().drop_front(2))
tripletArgs.push_back(arg);
}
auto shape = genShapeFromBounds(loc, firBuilder, tripletArgs);
hlfir::DesignateOp::Subscripts triplets;
for (unsigned i = 2; i < recipe.getCopyRegion().getArguments().size();
i += 3)
triplets.emplace_back(hlfir::DesignateOp::Triplet{
recipe.getCopyRegion().getArgument(i),
recipe.getCopyRegion().getArgument(i + 1),
recipe.getCopyRegion().getArgument(i + 2)});

llvm::SmallVector<mlir::Value> lenParamsLeft;
auto shape = genShapeFromBoundsOrArgs(
loc, firBuilder, seqTy, bounds, recipe.getCopyRegion().getArguments());
hlfir::DesignateOp::Subscripts triplets =
getSubscriptsFromArgs(recipe.getCopyRegion().getArguments());
auto leftEntity = hlfir::Entity{recipe.getCopyRegion().getArgument(0)};
hlfir::genLengthParameters(loc, firBuilder, leftEntity, lenParamsLeft);
auto leftDesignate = firBuilder.create<hlfir::DesignateOp>(
loc, leftEntity.getBase().getType(), leftEntity, /*component=*/"",
/*componentShape=*/mlir::Value{}, triplets,
/*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
shape, lenParamsLeft);
auto left = hlfir::Entity{leftDesignate.getResult()};

llvm::SmallVector<mlir::Value> lenParamsRight;
auto left =
genDesignateWithTriplets(firBuilder, loc, leftEntity, triplets, shape);
auto rightEntity = hlfir::Entity{recipe.getCopyRegion().getArgument(1)};
hlfir::genLengthParameters(loc, firBuilder, rightEntity, lenParamsRight);
auto rightDesignate = firBuilder.create<hlfir::DesignateOp>(
loc, rightEntity.getBase().getType(), rightEntity, /*component=*/"",
/*componentShape=*/mlir::Value{}, triplets,
/*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
shape, lenParamsRight);
auto right = hlfir::Entity{rightDesignate.getResult()};
auto right =
genDesignateWithTriplets(firBuilder, loc, rightEntity, triplets, shape);
firBuilder.create<hlfir::AssignOp>(loc, left, right);
}

Expand Down Expand Up @@ -1110,48 +1115,16 @@ static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
if (!seqTy)
TODO(loc, "Unsupported boxed type in OpenACC reduction");

if (allConstantBound) {
for (auto bound : llvm::reverse(bounds)) {
auto dataBound =
mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
tripletArgs.append(genConstantBounds(builder, loc, dataBound));
}
} else {
assert(((recipe.getCombinerRegion().getArguments().size() - 2) / 3 ==
seqTy.getDimension()) &&
"Expect 3 block arguments per dimension");
for (auto arg : recipe.getCombinerRegion().getArguments().drop_front(2))
tripletArgs.push_back(arg);
}
auto shape = genShapeFromBounds(loc, builder, tripletArgs);

hlfir::DesignateOp::Subscripts triplets;
for (unsigned i = 2; i < recipe.getCombinerRegion().getArguments().size();
i += 3)
triplets.emplace_back(hlfir::DesignateOp::Triplet{
recipe.getCombinerRegion().getArgument(i),
recipe.getCombinerRegion().getArgument(i + 1),
recipe.getCombinerRegion().getArgument(i + 2)});

llvm::SmallVector<mlir::Value> lenParamsLeft;
auto shape = genShapeFromBoundsOrArgs(
loc, builder, seqTy, bounds, recipe.getCombinerRegion().getArguments());
hlfir::DesignateOp::Subscripts triplets =
getSubscriptsFromArgs(recipe.getCombinerRegion().getArguments());
auto leftEntity = hlfir::Entity{value1};
hlfir::genLengthParameters(loc, builder, leftEntity, lenParamsLeft);
auto leftDesignate = builder.create<hlfir::DesignateOp>(
loc, value1.getType(), leftEntity, /*component=*/"",
/*componentShape=*/mlir::Value{}, triplets,
/*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
shape, lenParamsLeft);
auto left = hlfir::Entity{leftDesignate.getResult()};

llvm::SmallVector<mlir::Value> lenParamsRight;
auto left =
genDesignateWithTriplets(builder, loc, leftEntity, triplets, shape);
auto rightEntity = hlfir::Entity{value2};
hlfir::genLengthParameters(loc, builder, rightEntity, lenParamsRight);
auto rightDesignate = builder.create<hlfir::DesignateOp>(
loc, value2.getType(), rightEntity, /*component=*/"",
/*componentShape=*/mlir::Value{}, triplets,
/*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
shape, lenParamsRight);
auto right = hlfir::Entity{rightDesignate.getResult()};
auto right =
genDesignateWithTriplets(builder, loc, rightEntity, triplets, shape);

llvm::SmallVector<mlir::Value, 1> typeParams;
auto genKernel = [&builder, &loc, op, seqTy, &left, &right](
Expand Down