-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][acc] Introduce varType to acc data clause operations #119007
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The acc data clause operations hold an operand named `varPtr`. This was intended to hold a pointer to a variable - where the element type of that pointer specifies the type of the variable. However, for both memref and llvm dialects, this assumption is not true. This is because memref element type for cases like memref<10xf32> is simply f32 and for LLVM, after opaque pointers, the variable type is no longer recoverable. Thus, introduce varType to ensure that appropriate semantics are kept. Both the parser and printer for this new type attribute allow it to not be specified in cases where a dialect's getElementType() applied to `varPtr`'s type has a recoverable type. And more specifically, for FIR, no changes are needed in the MLIR unit tests.
@llvm/pr-subscribers-mlir-openacc @llvm/pr-subscribers-mlir Author: Razvan Lupusoru (razvanlupusoru) ChangesThe acc data clause operations hold an operand named Thus, introduce varType to ensure that appropriate semantics are kept. Both the parser and printer for this new type attribute allow it to not be specified in cases where a dialect's getElementType() applied to Patch is 71.67 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/119007.diff 9 Files Affected:
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 878dccc4ecbc4b..75dcf6ec3e1107 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -139,6 +139,8 @@ createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
op.setStructured(structured);
op.setImplicit(implicit);
op.setDataClause(dataClause);
+ op.setVarType(mlir::cast<mlir::acc::PointerLikeType>(baseAddr.getType())
+ .getElementType());
op->setAttr(Op::getOperandSegmentSizeAttr(),
builder.getDenseI32ArrayAttr(operandSegments));
if (!asyncDeviceTypes.empty())
@@ -266,8 +268,8 @@ static void createDeclareDeallocFuncWithArg(
if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(),
- entryOp.getVarPtr(), entryOp.getBounds(),
- entryOp.getAsyncOperands(),
+ entryOp.getVarPtr(), entryOp.getVarType(),
+ entryOp.getBounds(), entryOp.getAsyncOperands(),
entryOp.getAsyncOperandsDeviceTypeAttr(),
entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(),
/*structured=*/false, /*implicit=*/false,
@@ -450,7 +452,7 @@ static void genDataExitOperations(fir::FirOpBuilder &builder,
std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
builder.create<ExitOp>(
entryOp.getLoc(), entryOp.getAccPtr(), entryOp.getVarPtr(),
- entryOp.getBounds(), entryOp.getAsyncOperands(),
+ entryOp.getVarType(), entryOp.getBounds(), entryOp.getAsyncOperands(),
entryOp.getAsyncOperandsDeviceTypeAttr(), entryOp.getAsyncOnlyAttr(),
entryOp.getDataClause(), structured, entryOp.getImplicit(),
builder.getStringAttr(*entryOp.getName()));
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 8d7e27405cfa46..d089519d7fd808 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -381,17 +381,18 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio
OpenACC_Op<mnemonic, !listconcat(traits,
[AttrSizedOperandSegments,
MemoryEffects<[MemRead<OpenACC_CurrentDeviceIdResource>]>])> {
- let arguments = !con(additionalArgs,
- (ins
- Optional<OpenACC_PointerLikeTypeInterface>:$varPtrPtr,
- Variadic<OpenACC_DataBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
- Variadic<IntOrIndex>:$asyncOperands,
- OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
- OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
- DefaultValuedAttr<OpenACC_DataClauseAttr,clause>:$dataClause,
- DefaultValuedAttr<BoolAttr, "true">:$structured,
- DefaultValuedAttr<BoolAttr, "false">:$implicit,
- OptionalAttr<StrAttr>:$name));
+ let arguments = !con(
+ additionalArgs,
+ (ins TypeAttr:$varType,
+ Optional<OpenACC_PointerLikeTypeInterface>:$varPtrPtr,
+ Variadic<OpenACC_DataBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
+ Variadic<IntOrIndex>:$asyncOperands,
+ OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
+ OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
+ DefaultValuedAttr<OpenACC_DataClauseAttr, clause>:$dataClause,
+ DefaultValuedAttr<BoolAttr, "true">:$structured,
+ DefaultValuedAttr<BoolAttr, "false">:$implicit,
+ OptionalAttr<StrAttr>:$name));
let description = !strconcat(extraDescription, [{
Description of arguments:
@@ -458,7 +459,7 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio
}];
let assemblyFormat = [{
- `varPtr` `(` $varPtr `:` type($varPtr) `)`
+ `varPtr` `(` $varPtr `:` custom<varPtrTypes>(type($varPtr), $varType) `)`
oilist(
`varPtrPtr` `(` $varPtrPtr `:` type($varPtrPtr) `)`
| `bounds` `(` $bounds `)`
@@ -469,32 +470,35 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio
let hasVerifier = 1;
- let builders = [
- OpBuilder<(ins "::mlir::Value":$varPtr,
- "bool":$structured,
- "bool":$implicit,
- CArg<"::mlir::ValueRange", "{}">:$bounds), [{
- build($_builder, $_state, varPtr.getType(), varPtr, /*varPtrPtr=*/{},
- bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr,
+ let builders = [OpBuilder<(ins "::mlir::Value":$varPtr, "bool":$structured,
+ "bool":$implicit,
+ CArg<"::mlir::ValueRange", "{}">:$bounds),
+ [{
+ build($_builder, $_state, varPtr.getType(), varPtr,
+ /*varType=*/::mlir::TypeAttr::get(
+ ::mlir::cast<::mlir::acc::PointerLikeType>(
+ varPtr.getType()).getElementType()),
+ /*varPtrPtr=*/{}, bounds, /*asyncOperands=*/{},
+ /*asyncOperandsDeviceType=*/nullptr,
/*asyncOnly=*/nullptr, /*dataClause=*/nullptr,
/*structured=*/$_builder.getBoolAttr(structured),
/*implicit=*/$_builder.getBoolAttr(implicit), /*name=*/nullptr);
- }]
- >,
- OpBuilder<(ins "::mlir::Value":$varPtr,
- "bool":$structured,
- "bool":$implicit,
- "const ::llvm::Twine &":$name,
- CArg<"::mlir::ValueRange", "{}">:$bounds), [{
- build($_builder, $_state, varPtr.getType(), varPtr, /*varPtrPtr=*/{},
- bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr,
+ }]>,
+ OpBuilder<(ins "::mlir::Value":$varPtr, "bool":$structured,
+ "bool":$implicit, "const ::llvm::Twine &":$name,
+ CArg<"::mlir::ValueRange", "{}">:$bounds),
+ [{
+ build($_builder, $_state, varPtr.getType(), varPtr,
+ /*varType=*/::mlir::TypeAttr::get(
+ ::mlir::cast<::mlir::acc::PointerLikeType>(
+ varPtr.getType()).getElementType()),
+ /*varPtrPtr=*/{}, bounds, /*asyncOperands=*/{},
+ /*asyncOperandsDeviceType=*/nullptr,
/*asyncOnly=*/nullptr, /*dataClause=*/nullptr,
/*structured=*/$_builder.getBoolAttr(structured),
/*implicit=*/$_builder.getBoolAttr(implicit),
/*name=*/$_builder.getStringAttr(name));
- }]
- >
- ];
+ }]>];
}
//===----------------------------------------------------------------------===//
@@ -794,63 +798,58 @@ class OpenACC_DataExitOp<string mnemonic, string clause, string extraDescription
}
}];
- let assemblyFormat = [{
- `accPtr` `(` $accPtr `:` type($accPtr) `)`
- oilist(
- `bounds` `(` $bounds `)`
- | `to` `varPtr` `(` $varPtr `:` type($varPtr) `)`
- | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
- type($asyncOperands), $asyncOperandsDeviceType) `)`
- ) attr-dict
- }];
-
let hasVerifier = 1;
}
-class OpenACC_DataExitOpWithVarPtr<string mnemonic, string clause> :
- OpenACC_DataExitOp<mnemonic, clause,
- "- `varPtr`: The address of variable to copy back to.",
- [MemoryEffects<[MemRead<OpenACC_RuntimeCounters>,
- MemWrite<OpenACC_RuntimeCounters>]>],
- (ins Arg<OpenACC_PointerLikeTypeInterface,"Address of device variable",[MemRead]>:$accPtr,
- Arg<OpenACC_PointerLikeTypeInterface,"Address of variable",[MemWrite]>:$varPtr)> {
+class OpenACC_DataExitOpWithVarPtr<string mnemonic, string clause>
+ : OpenACC_DataExitOp<
+ mnemonic, clause,
+ "- `varPtr`: The address of variable to copy back to.",
+ [MemoryEffects<[MemRead<OpenACC_RuntimeCounters>,
+ MemWrite<OpenACC_RuntimeCounters>]>],
+ (ins Arg<OpenACC_PointerLikeTypeInterface,
+ "Address of device variable", [MemRead]>:$accPtr,
+ Arg<OpenACC_PointerLikeTypeInterface,
+ "Address of variable", [MemWrite]>:$varPtr,
+ TypeAttr:$varType)> {
let assemblyFormat = [{
`accPtr` `(` $accPtr `:` type($accPtr) `)`
(`bounds` `(` $bounds^ `)` )?
(`async` `(` custom<DeviceTypeOperands>($asyncOperands,
type($asyncOperands), $asyncOperandsDeviceType)^ `)`)?
- `to` `varPtr` `(` $varPtr `:` type($varPtr) `)`
+ `to` `varPtr` `(` $varPtr `:` custom<varPtrTypes>(type($varPtr), $varType) `)`
attr-dict
}];
- let builders = [
- OpBuilder<(ins "::mlir::Value":$accPtr,
- "::mlir::Value":$varPtr,
- "bool":$structured,
- "bool":$implicit,
- CArg<"::mlir::ValueRange", "{}">:$bounds), [{
+ let builders = [OpBuilder<(ins "::mlir::Value":$accPtr,
+ "::mlir::Value":$varPtr, "bool":$structured,
+ "bool":$implicit,
+ CArg<"::mlir::ValueRange", "{}">:$bounds),
+ [{
build($_builder, $_state, accPtr, varPtr,
+ /*varType=*/::mlir::TypeAttr::get(
+ ::mlir::cast<::mlir::acc::PointerLikeType>(
+ varPtr.getType()).getElementType()),
bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr,
/*asyncOnly=*/nullptr, /*dataClause=*/nullptr,
/*structured=*/$_builder.getBoolAttr(structured),
/*implicit=*/$_builder.getBoolAttr(implicit), /*name=*/nullptr);
- }]
- >,
- OpBuilder<(ins "::mlir::Value":$accPtr,
- "::mlir::Value":$varPtr,
- "bool":$structured,
- "bool":$implicit,
- "const ::llvm::Twine &":$name,
- CArg<"::mlir::ValueRange", "{}">:$bounds), [{
+ }]>,
+ OpBuilder<(ins "::mlir::Value":$accPtr,
+ "::mlir::Value":$varPtr, "bool":$structured,
+ "bool":$implicit, "const ::llvm::Twine &":$name,
+ CArg<"::mlir::ValueRange", "{}">:$bounds),
+ [{
build($_builder, $_state, accPtr, varPtr,
+ /*varType=*/::mlir::TypeAttr::get(
+ ::mlir::cast<::mlir::acc::PointerLikeType>(
+ varPtr.getType()).getElementType()),
bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr,
/*asyncOnly=*/nullptr, /*dataClause=*/nullptr,
/*structured=*/$_builder.getBoolAttr(structured),
/*implicit=*/$_builder.getBoolAttr(implicit),
/*name=*/$_builder.getStringAttr(name));
- }]
- >
- ];
+ }]>];
}
class OpenACC_DataExitOpNoVarPtr<string mnemonic, string clause> :
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 280260e0485bb5..4daba2679bd91c 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
@@ -18,6 +19,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/LogicalResult.h"
using namespace mlir;
using namespace acc;
@@ -190,6 +192,43 @@ static LogicalResult checkWaitAndAsyncConflict(Op op) {
return success();
}
+static ParseResult parsevarPtrTypes(mlir::OpAsmParser &parser,
+ mlir::Type &varPtrRawType,
+ mlir::TypeAttr &varTypeAttr) {
+ if (failed(parser.parseType(varPtrRawType))) {
+ return failure();
+ }
+
+ // If there is no comma, it means that the varType is implied from the
+ // element type of varPtr.
+ if (succeeded(parser.parseOptionalComma())) {
+ mlir::Type varType;
+ if (failed(parser.parseType(varType)))
+ return failure();
+ varTypeAttr = mlir::TypeAttr::get(varType);
+ } else {
+ varTypeAttr = mlir::TypeAttr::get(
+ mlir::cast<mlir::acc::PointerLikeType>(varPtrRawType).getElementType());
+ }
+
+ return success();
+}
+
+static void printvarPtrTypes(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ mlir::Type varPtrType,
+ mlir::TypeAttr varTypeAttr) {
+ p.printType(varPtrType);
+ mlir::Type varType = varTypeAttr.getValue();
+
+ // Avoid printing the varType if it is already captured as the element type
+ // of varPtr's type.
+ if (mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType() !=
+ varType) {
+ p << ", ";
+ p.printType(varType);
+ }
+}
+
//===----------------------------------------------------------------------===//
// DataBoundsOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir b/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir
index d8e89f64f8bc04..d83baf3df114bf 100644
--- a/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir
+++ b/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -convert-openacc-to-scf -split-input-file | FileCheck %s
func.func @testenterdataop(%a: memref<f32>, %ifCond: i1) -> () {
- %0 = acc.create varPtr(%a : memref<f32>) -> memref<f32>
+ %0 = acc.create varPtr(%a : memref<f32>, f32) -> memref<f32>
acc.enter_data if(%ifCond) dataOperands(%0 : memref<f32>)
return
}
@@ -14,7 +14,7 @@ func.func @testenterdataop(%a: memref<f32>, %ifCond: i1) -> () {
// -----
func.func @testexitdataop(%a: memref<f32>, %ifCond: i1) -> () {
- %0 = acc.getdeviceptr varPtr(%a : memref<f32>) -> memref<f32>
+ %0 = acc.getdeviceptr varPtr(%a : memref<f32>, f32) -> memref<f32>
acc.exit_data if(%ifCond) dataOperands(%0 : memref<f32>)
acc.delete accPtr(%0 : memref<f32>)
return
@@ -28,7 +28,7 @@ func.func @testexitdataop(%a: memref<f32>, %ifCond: i1) -> () {
// -----
func.func @testupdateop(%a: memref<f32>, %ifCond: i1) -> () {
- %0 = acc.update_device varPtr(%a : memref<f32>) -> memref<f32>
+ %0 = acc.update_device varPtr(%a : memref<f32>, f32) -> memref<f32>
acc.update if(%ifCond) dataOperands(%0 : memref<f32>)
return
}
@@ -42,7 +42,7 @@ func.func @testupdateop(%a: memref<f32>, %ifCond: i1) -> () {
func.func @update_true(%arg0: memref<f32>) {
%true = arith.constant true
- %0 = acc.update_device varPtr(%arg0 : memref<f32>) -> memref<f32>
+ %0 = acc.update_device varPtr(%arg0 : memref<f32>, f32) -> memref<f32>
acc.update if(%true) dataOperands(%0 : memref<f32>)
return
}
@@ -55,7 +55,7 @@ func.func @update_true(%arg0: memref<f32>) {
func.func @update_false(%arg0: memref<f32>) {
%false = arith.constant false
- %0 = acc.update_device varPtr(%arg0 : memref<f32>) -> memref<f32>
+ %0 = acc.update_device varPtr(%arg0 : memref<f32>, f32) -> memref<f32>
acc.update if(%false) dataOperands(%0 : memref<f32>)
return
}
@@ -67,7 +67,7 @@ func.func @update_false(%arg0: memref<f32>) {
func.func @enter_data_true(%d1 : memref<f32>) {
%true = arith.constant true
- %0 = acc.create varPtr(%d1 : memref<f32>) -> memref<f32>
+ %0 = acc.create varPtr(%d1 : memref<f32>, f32) -> memref<f32>
acc.enter_data if(%true) dataOperands(%0 : memref<f32>) attributes {async}
return
}
@@ -80,7 +80,7 @@ func.func @enter_data_true(%d1 : memref<f32>) {
func.func @enter_data_false(%d1 : memref<f32>) {
%false = arith.constant false
- %0 = acc.create varPtr(%d1 : memref<f32>) -> memref<f32>
+ %0 = acc.create varPtr(%d1 : memref<f32>, f32) -> memref<f32>
acc.enter_data if(%false) dataOperands(%0 : memref<f32>) attributes {async}
return
}
@@ -92,7 +92,7 @@ func.func @enter_data_false(%d1 : memref<f32>) {
func.func @exit_data_true(%d1 : memref<f32>) {
%true = arith.constant true
- %0 = acc.getdeviceptr varPtr(%d1 : memref<f32>) -> memref<f32>
+ %0 = acc.getdeviceptr varPtr(%d1 : memref<f32>, f32) -> memref<f32>
acc.exit_data if(%true) dataOperands(%0 : memref<f32>) attributes {async}
acc.delete accPtr(%0 : memref<f32>)
return
@@ -106,7 +106,7 @@ func.func @exit_data_true(%d1 : memref<f32>) {
func.func @exit_data_false(%d1 : memref<f32>) {
%false = arith.constant false
- %0 = acc.getdeviceptr varPtr(%d1 : memref<f32>) -> memref<f32>
+ %0 = acc.getdeviceptr varPtr(%d1 : memref<f32>, f32) -> memref<f32>
acc.exit_data if(%false) dataOperands(%0 : memref<f32>) attributes {async}
acc.delete accPtr(%0 : memref<f32>)
return
diff --git a/mlir/test/Dialect/OpenACC/canonicalize.mlir b/mlir/test/Dialect/OpenACC/canonicalize.mlir
index e43a27f6b9e89a..c5272f579c1d23 100644
--- a/mlir/test/Dialect/OpenACC/canonicalize.mlir
+++ b/mlir/test/Dialect/OpenACC/canonicalize.mlir
@@ -2,7 +2,7 @@
func.func @testenterdataop(%a: memref<f32>) -> () {
%ifCond = arith.constant true
- %0 = acc.create varPtr(%a : memref<f32>) -> memref<f32>
+ %0 = acc.create varPtr(%a : memref<f32>, f32) -> memref<f32>
acc.enter_data if(%ifCond) dataOperands(%0 : memref<f32>)
return
}
@@ -13,7 +13,7 @@ func.func @testenterdataop(%a: memref<f32>) -> () {
func.func @testenterdataop(%a: memref<f32>) -> () {
%ifCond = arith.constant false
- %0 = acc.create varPtr(%a : memref<f32>) -> memref<f32>
+ %0 = acc.create varPtr(%a : memref<f32>, f32) -> memref<f32>
acc.enter_data if(%ifCond) dataOperands(%0 : memref<f32>)
return
}
@@ -25,7 +25,7 @@ func.func @testenterdataop(%a: memref<f32>) -> () {
func.func @testexitdataop(%a: memref<f32>) -> () {
%ifCond = arith.constant true
- %0 = acc.getdeviceptr varPtr(%a : memref<f32>) -> memref<f32>
+ %0 = acc.getdeviceptr varPtr(%a : memref<f32>, f32) -> memref<f32>
acc.exit_data if(%ifCond) dataOperands(%0 : memref<f32>)
acc.delete accPtr(%0 : memref<f32>)
return
@@ -37,7 +37,7 @@ func.func @testexitdataop(%a: memref<f32>) -> () {
func.func @testexitdataop(%a: memref<f32>) -> () {
%ifCond = arith.constant false
- %0 = acc.getdeviceptr varPtr(%a : memref<f32>) -> memref<f32>
+ %0 = acc.getdeviceptr varPtr(%a : memref<f32>, f32) -> memref<f32>
acc.exit_data if(%ifCond) dataOperands(%0 : memref<f32>)
acc.delete accPtr(%0 : memref<f32>)
return
@@ -49,8 +49,8 @@ func.func @testexitdataop(%a: memref<f32>) -> () {
// -----
func.func @testupdateop(%a: memref<f32>) -> () {
- %0 = acc.getdeviceptr varPtr(%a : memref<f32>) -> memref<f32>
- acc.update_host accPtr(%0 : memref<f32>) to varPtr(%a : memref<f32>)
+ %0 = acc.getdeviceptr varPtr(%a : memref<f32>, f32) -> memref<f32>
+ acc.update_host accPtr(%0 : memref<f32>) to varPtr(%a : memref<f32>, f32)
%ifCond = arith.constant true
acc.update if(%ifCond) dataOperands(%0: memref<f32>)
return
@@ -61,8 +61,8 @@ func.func @testupdateop(%a: memref<f32>) -> () {
// -----
func.func @testupdateop(%a: memref<f32>) -> () {
- %0 = acc.getdeviceptr varPtr(%a : memref<f32>) -> memref<f32>
- acc.update_host accPtr(%0 : memref<f32>) to varPtr(%a : memref<f32>)
+ %0 = acc.getdeviceptr varPtr(%a : memref<f32>, f32) -> memref<f32>
+ acc.update_host accPtr(%0 : memref<f32>) to varPtr(%a : memref<f32>, f32)
%ifCond = arith.constant false
acc.update if(%ifCond) dataOperands(%0: memref<f32>)
return
@@ -74,7 +74,7 @@ func.func @testupdateop(%a: memref<f32>) -> () {
// -----
func.func @testenterdataop(%a: memref<f32>, %ifCond: i1) -> () {
- %0 = acc.create varPtr(%a : memref<f32>) -> memref<f32>
+ %0 = acc.create varPtr(%a : memref<f32>, f32) -> memref<f32>
acc.enter_data if(%ifCond) dataOperands(%0 : memref<f32>)
return
}
@@ -85,7 +85,7 @@ func.func @testenterdataop(%a: memref<f32>, %ifCond: i1) -> () {
// -----
func.func @testexitdataop(%a: memref<f32>, %ifCond: i1) -> () {
- %0 = acc.getdeviceptr varPtr(%a : memref<f32>) -> memref<f32>
+ %0 = acc.getdeviceptr varPtr(%a : memref<f32>, f32) -> memref<f32>
acc.exit_data if(%ifCond) dataOperands(%0 : memref<f32>)
acc.delete accPtr(%0 : memref<f32...
[truncated]
|
Note that the invasive changes in the .td file is because the latest git-clang-format now handles .td files. It formats the context around my new changes. |
Mostly looks good for me Razvan and it makes sense. Just thinking if we can have a syntax that is more obvious. Maybe a keyword would help? Or just keep it printed as an attribute? I'm fine going with this and refine if we think it needs to be. |
What syntax do you have in mind? |
|
I attempted to address your concern by adding the "varType" keyword to the printing. It is still optional (if the element's type is encoded in the type of varPtr's target), but at least this fixes the ambiguity when reading the IR. What do you think? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Thanks for addressing my comment.
The acc data clause operations hold an operand named
varPtr
. This was intended to hold a pointer to a variable - where the element type of that pointer specifies the type of the variable. However, for both memref and llvm dialects, this assumption is not true. This is because memref element type for cases like memref<10xf32> is simply f32 and for LLVM, after opaque pointers, the variable type is no longer recoverable.Thus, introduce varType to ensure that appropriate semantics are kept.
Both the parser and printer for this new type attribute allow it to not be specified in cases where a dialect's getElementType() applied to
varPtr
's type has a recoverable type. And more specifically, for FIR, no changes are needed in the MLIR unit tests.