Skip to content

Commit a0eb794

Browse files
[MLIR][acc] Introduce varType to acc data clause operations (#119007)
The acc data clause operations hold an operand named `varPtr`. This was intended to hold a pointer to a variable - where the element type of that pointer specifies the type of the variable. However, for both memref and llvm dialects, this assumption is not true. This is because memref element type for cases like memref<10xf32> is simply f32 and for LLVM, after opaque pointers, the variable type is no longer recoverable. Thus, introduce varType to ensure that appropriate semantics are kept. Both the parser and printer for this new type attribute allow it to not be specified in cases where a dialect's getElementType() applied to `varPtr`'s type has a recoverable type. And more specifically, for FIR, no changes are needed in the MLIR unit tests.
1 parent 0e70289 commit a0eb794

File tree

6 files changed

+253
-208
lines changed

6 files changed

+253
-208
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
139139
op.setStructured(structured);
140140
op.setImplicit(implicit);
141141
op.setDataClause(dataClause);
142+
op.setVarType(mlir::cast<mlir::acc::PointerLikeType>(baseAddr.getType())
143+
.getElementType());
142144
op->setAttr(Op::getOperandSegmentSizeAttr(),
143145
builder.getDenseI32ArrayAttr(operandSegments));
144146
if (!asyncDeviceTypes.empty())
@@ -266,8 +268,8 @@ static void createDeclareDeallocFuncWithArg(
266268
if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
267269
std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
268270
builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(),
269-
entryOp.getVarPtr(), entryOp.getBounds(),
270-
entryOp.getAsyncOperands(),
271+
entryOp.getVarPtr(), entryOp.getVarType(),
272+
entryOp.getBounds(), entryOp.getAsyncOperands(),
271273
entryOp.getAsyncOperandsDeviceTypeAttr(),
272274
entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(),
273275
/*structured=*/false, /*implicit=*/false,
@@ -450,7 +452,7 @@ static void genDataExitOperations(fir::FirOpBuilder &builder,
450452
std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
451453
builder.create<ExitOp>(
452454
entryOp.getLoc(), entryOp.getAccPtr(), entryOp.getVarPtr(),
453-
entryOp.getBounds(), entryOp.getAsyncOperands(),
455+
entryOp.getVarType(), entryOp.getBounds(), entryOp.getAsyncOperands(),
454456
entryOp.getAsyncOperandsDeviceTypeAttr(), entryOp.getAsyncOnlyAttr(),
455457
entryOp.getDataClause(), structured, entryOp.getImplicit(),
456458
builder.getStringAttr(*entryOp.getName()));

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td

Lines changed: 65 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -381,17 +381,18 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio
381381
OpenACC_Op<mnemonic, !listconcat(traits,
382382
[AttrSizedOperandSegments,
383383
MemoryEffects<[MemRead<OpenACC_CurrentDeviceIdResource>]>])> {
384-
let arguments = !con(additionalArgs,
385-
(ins
386-
Optional<OpenACC_PointerLikeTypeInterface>:$varPtrPtr,
387-
Variadic<OpenACC_DataBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
388-
Variadic<IntOrIndex>:$asyncOperands,
389-
OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
390-
OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
391-
DefaultValuedAttr<OpenACC_DataClauseAttr,clause>:$dataClause,
392-
DefaultValuedAttr<BoolAttr, "true">:$structured,
393-
DefaultValuedAttr<BoolAttr, "false">:$implicit,
394-
OptionalAttr<StrAttr>:$name));
384+
let arguments = !con(
385+
additionalArgs,
386+
(ins TypeAttr:$varType,
387+
Optional<OpenACC_PointerLikeTypeInterface>:$varPtrPtr,
388+
Variadic<OpenACC_DataBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
389+
Variadic<IntOrIndex>:$asyncOperands,
390+
OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
391+
OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
392+
DefaultValuedAttr<OpenACC_DataClauseAttr, clause>:$dataClause,
393+
DefaultValuedAttr<BoolAttr, "true">:$structured,
394+
DefaultValuedAttr<BoolAttr, "false">:$implicit,
395+
OptionalAttr<StrAttr>:$name));
395396

396397
let description = !strconcat(extraDescription, [{
397398
Description of arguments:
@@ -458,7 +459,7 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio
458459
}];
459460

460461
let assemblyFormat = [{
461-
`varPtr` `(` $varPtr `:` type($varPtr) `)`
462+
`varPtr` `(` $varPtr `:` custom<VarPtrType>(type($varPtr), $varType)
462463
oilist(
463464
`varPtrPtr` `(` $varPtrPtr `:` type($varPtrPtr) `)`
464465
| `bounds` `(` $bounds `)`
@@ -469,32 +470,35 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio
469470

470471
let hasVerifier = 1;
471472

472-
let builders = [
473-
OpBuilder<(ins "::mlir::Value":$varPtr,
474-
"bool":$structured,
475-
"bool":$implicit,
476-
CArg<"::mlir::ValueRange", "{}">:$bounds), [{
477-
build($_builder, $_state, varPtr.getType(), varPtr, /*varPtrPtr=*/{},
478-
bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr,
473+
let builders = [OpBuilder<(ins "::mlir::Value":$varPtr, "bool":$structured,
474+
"bool":$implicit,
475+
CArg<"::mlir::ValueRange", "{}">:$bounds),
476+
[{
477+
build($_builder, $_state, varPtr.getType(), varPtr,
478+
/*varType=*/::mlir::TypeAttr::get(
479+
::mlir::cast<::mlir::acc::PointerLikeType>(
480+
varPtr.getType()).getElementType()),
481+
/*varPtrPtr=*/{}, bounds, /*asyncOperands=*/{},
482+
/*asyncOperandsDeviceType=*/nullptr,
479483
/*asyncOnly=*/nullptr, /*dataClause=*/nullptr,
480484
/*structured=*/$_builder.getBoolAttr(structured),
481485
/*implicit=*/$_builder.getBoolAttr(implicit), /*name=*/nullptr);
482-
}]
483-
>,
484-
OpBuilder<(ins "::mlir::Value":$varPtr,
485-
"bool":$structured,
486-
"bool":$implicit,
487-
"const ::llvm::Twine &":$name,
488-
CArg<"::mlir::ValueRange", "{}">:$bounds), [{
489-
build($_builder, $_state, varPtr.getType(), varPtr, /*varPtrPtr=*/{},
490-
bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr,
486+
}]>,
487+
OpBuilder<(ins "::mlir::Value":$varPtr, "bool":$structured,
488+
"bool":$implicit, "const ::llvm::Twine &":$name,
489+
CArg<"::mlir::ValueRange", "{}">:$bounds),
490+
[{
491+
build($_builder, $_state, varPtr.getType(), varPtr,
492+
/*varType=*/::mlir::TypeAttr::get(
493+
::mlir::cast<::mlir::acc::PointerLikeType>(
494+
varPtr.getType()).getElementType()),
495+
/*varPtrPtr=*/{}, bounds, /*asyncOperands=*/{},
496+
/*asyncOperandsDeviceType=*/nullptr,
491497
/*asyncOnly=*/nullptr, /*dataClause=*/nullptr,
492498
/*structured=*/$_builder.getBoolAttr(structured),
493499
/*implicit=*/$_builder.getBoolAttr(implicit),
494500
/*name=*/$_builder.getStringAttr(name));
495-
}]
496-
>
497-
];
501+
}]>];
498502
}
499503

500504
//===----------------------------------------------------------------------===//
@@ -794,63 +798,58 @@ class OpenACC_DataExitOp<string mnemonic, string clause, string extraDescription
794798
}
795799
}];
796800

797-
let assemblyFormat = [{
798-
`accPtr` `(` $accPtr `:` type($accPtr) `)`
799-
oilist(
800-
`bounds` `(` $bounds `)`
801-
| `to` `varPtr` `(` $varPtr `:` type($varPtr) `)`
802-
| `async` `(` custom<DeviceTypeOperands>($asyncOperands,
803-
type($asyncOperands), $asyncOperandsDeviceType) `)`
804-
) attr-dict
805-
}];
806-
807801
let hasVerifier = 1;
808802
}
809803

810-
class OpenACC_DataExitOpWithVarPtr<string mnemonic, string clause> :
811-
OpenACC_DataExitOp<mnemonic, clause,
812-
"- `varPtr`: The address of variable to copy back to.",
813-
[MemoryEffects<[MemRead<OpenACC_RuntimeCounters>,
814-
MemWrite<OpenACC_RuntimeCounters>]>],
815-
(ins Arg<OpenACC_PointerLikeTypeInterface,"Address of device variable",[MemRead]>:$accPtr,
816-
Arg<OpenACC_PointerLikeTypeInterface,"Address of variable",[MemWrite]>:$varPtr)> {
804+
class OpenACC_DataExitOpWithVarPtr<string mnemonic, string clause>
805+
: OpenACC_DataExitOp<
806+
mnemonic, clause,
807+
"- `varPtr`: The address of variable to copy back to.",
808+
[MemoryEffects<[MemRead<OpenACC_RuntimeCounters>,
809+
MemWrite<OpenACC_RuntimeCounters>]>],
810+
(ins Arg<OpenACC_PointerLikeTypeInterface,
811+
"Address of device variable", [MemRead]>:$accPtr,
812+
Arg<OpenACC_PointerLikeTypeInterface,
813+
"Address of variable", [MemWrite]>:$varPtr,
814+
TypeAttr:$varType)> {
817815
let assemblyFormat = [{
818816
`accPtr` `(` $accPtr `:` type($accPtr) `)`
819817
(`bounds` `(` $bounds^ `)` )?
820818
(`async` `(` custom<DeviceTypeOperands>($asyncOperands,
821819
type($asyncOperands), $asyncOperandsDeviceType)^ `)`)?
822-
`to` `varPtr` `(` $varPtr `:` type($varPtr) `)`
820+
`to` `varPtr` `(` $varPtr `:` custom<VarPtrType>(type($varPtr), $varType)
823821
attr-dict
824822
}];
825823

826-
let builders = [
827-
OpBuilder<(ins "::mlir::Value":$accPtr,
828-
"::mlir::Value":$varPtr,
829-
"bool":$structured,
830-
"bool":$implicit,
831-
CArg<"::mlir::ValueRange", "{}">:$bounds), [{
824+
let builders = [OpBuilder<(ins "::mlir::Value":$accPtr,
825+
"::mlir::Value":$varPtr, "bool":$structured,
826+
"bool":$implicit,
827+
CArg<"::mlir::ValueRange", "{}">:$bounds),
828+
[{
832829
build($_builder, $_state, accPtr, varPtr,
830+
/*varType=*/::mlir::TypeAttr::get(
831+
::mlir::cast<::mlir::acc::PointerLikeType>(
832+
varPtr.getType()).getElementType()),
833833
bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr,
834834
/*asyncOnly=*/nullptr, /*dataClause=*/nullptr,
835835
/*structured=*/$_builder.getBoolAttr(structured),
836836
/*implicit=*/$_builder.getBoolAttr(implicit), /*name=*/nullptr);
837-
}]
838-
>,
839-
OpBuilder<(ins "::mlir::Value":$accPtr,
840-
"::mlir::Value":$varPtr,
841-
"bool":$structured,
842-
"bool":$implicit,
843-
"const ::llvm::Twine &":$name,
844-
CArg<"::mlir::ValueRange", "{}">:$bounds), [{
837+
}]>,
838+
OpBuilder<(ins "::mlir::Value":$accPtr,
839+
"::mlir::Value":$varPtr, "bool":$structured,
840+
"bool":$implicit, "const ::llvm::Twine &":$name,
841+
CArg<"::mlir::ValueRange", "{}">:$bounds),
842+
[{
845843
build($_builder, $_state, accPtr, varPtr,
844+
/*varType=*/::mlir::TypeAttr::get(
845+
::mlir::cast<::mlir::acc::PointerLikeType>(
846+
varPtr.getType()).getElementType()),
846847
bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr,
847848
/*asyncOnly=*/nullptr, /*dataClause=*/nullptr,
848849
/*structured=*/$_builder.getBoolAttr(structured),
849850
/*implicit=*/$_builder.getBoolAttr(implicit),
850851
/*name=*/$_builder.getStringAttr(name));
851-
}]
852-
>
853-
];
852+
}]>];
854853
}
855854

856855
class OpenACC_DataExitOpNoVarPtr<string mnemonic, string clause> :

mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1212
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1313
#include "mlir/IR/Builders.h"
14+
#include "mlir/IR/BuiltinAttributes.h"
1415
#include "mlir/IR/BuiltinTypes.h"
1516
#include "mlir/IR/DialectImplementation.h"
1617
#include "mlir/IR/Matchers.h"
1718
#include "mlir/IR/OpImplementation.h"
1819
#include "mlir/Transforms/DialectConversion.h"
1920
#include "llvm/ADT/SmallSet.h"
2021
#include "llvm/ADT/TypeSwitch.h"
22+
#include "llvm/Support/LogicalResult.h"
2123

2224
using namespace mlir;
2325
using namespace acc;
@@ -190,6 +192,48 @@ static LogicalResult checkWaitAndAsyncConflict(Op op) {
190192
return success();
191193
}
192194

195+
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
196+
mlir::Type &varPtrType,
197+
mlir::TypeAttr &varTypeAttr) {
198+
if (failed(parser.parseType(varPtrType)))
199+
return failure();
200+
if (failed(parser.parseRParen()))
201+
return failure();
202+
203+
if (succeeded(parser.parseOptionalKeyword("varType"))) {
204+
if (failed(parser.parseLParen()))
205+
return failure();
206+
mlir::Type varType;
207+
if (failed(parser.parseType(varType)))
208+
return failure();
209+
varTypeAttr = mlir::TypeAttr::get(varType);
210+
if (failed(parser.parseRParen()))
211+
return failure();
212+
} else {
213+
// Set `varType` from the element type of the type of `varPtr`.
214+
varTypeAttr = mlir::TypeAttr::get(
215+
mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType());
216+
}
217+
218+
return success();
219+
}
220+
221+
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op,
222+
mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
223+
p.printType(varPtrType);
224+
p << ")";
225+
226+
// Print the `varType` only if it differs from the element type of
227+
// `varPtr`'s type.
228+
mlir::Type varType = varTypeAttr.getValue();
229+
if (mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType() !=
230+
varType) {
231+
p << " varType(";
232+
p.printType(varType);
233+
p << ")";
234+
}
235+
}
236+
193237
//===----------------------------------------------------------------------===//
194238
// DataBoundsOp
195239
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)