Skip to content

[MLIR][acc] Introduce varType to acc data clause operations #119007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 9, 2024

Conversation

razvanlupusoru
Copy link
Contributor

The acc data clause operations hold an operand named varPtr. This was intended to hold a pointer to a variable - where the element type of that pointer specifies the type of the variable. However, for both memref and llvm dialects, this assumption is not true. This is because memref element type for cases like memref<10xf32> is simply f32 and for LLVM, after opaque pointers, the variable type is no longer recoverable.

Thus, introduce varType to ensure that appropriate semantics are kept.

Both the parser and printer for this new type attribute allow it to not be specified in cases where a dialect's getElementType() applied to varPtr's type has a recoverable type. And more specifically, for FIR, no changes are needed in the MLIR unit tests.

The acc data clause operations hold an operand named `varPtr`. This
was intended to hold a pointer to a variable - where the element type
of that pointer specifies the type of the variable. However, for
both memref and llvm dialects, this assumption is not true. This is
because memref element type for cases like memref<10xf32> is simply
f32 and for LLVM, after opaque pointers, the variable type is no longer
recoverable.

Thus, introduce varType to ensure that appropriate semantics are kept.

Both the parser and printer for this new type attribute allow it to
not be specified in cases where a dialect's getElementType() applied
to `varPtr`'s type has a recoverable type. And more specifically, for
FIR, no changes are needed in the MLIR unit tests.
@llvmbot
Copy link
Member

llvmbot commented Dec 6, 2024

@llvm/pr-subscribers-mlir-openacc
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Razvan Lupusoru (razvanlupusoru)

Changes

The acc data clause operations hold an operand named varPtr. This was intended to hold a pointer to a variable - where the element type of that pointer specifies the type of the variable. However, for both memref and llvm dialects, this assumption is not true. This is because memref element type for cases like memref<10xf32> is simply f32 and for LLVM, after opaque pointers, the variable type is no longer recoverable.

Thus, introduce varType to ensure that appropriate semantics are kept.

Both the parser and printer for this new type attribute allow it to not be specified in cases where a dialect's getElementType() applied to varPtr's type has a recoverable type. And more specifically, for FIR, no changes are needed in the MLIR unit tests.


Patch is 71.67 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/119007.diff

9 Files Affected:

  • (modified) flang/lib/Lower/OpenACC.cpp (+5-3)
  • (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+65-66)
  • (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+39)
  • (modified) mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir (+9-9)
  • (modified) mlir/test/Dialect/OpenACC/canonicalize.mlir (+14-14)
  • (modified) mlir/test/Dialect/OpenACC/invalid.mlir (+7-7)
  • (modified) mlir/test/Dialect/OpenACC/legalize-data.mlir (+16-16)
  • (modified) mlir/test/Dialect/OpenACC/ops.mlir (+113-113)
  • (modified) mlir/test/Target/LLVMIR/openacc-llvm.mlir (+10-10)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 878dccc4ecbc4b..75dcf6ec3e1107 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -139,6 +139,8 @@ createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
   op.setStructured(structured);
   op.setImplicit(implicit);
   op.setDataClause(dataClause);
+  op.setVarType(mlir::cast<mlir::acc::PointerLikeType>(baseAddr.getType())
+                    .getElementType());
   op->setAttr(Op::getOperandSegmentSizeAttr(),
               builder.getDenseI32ArrayAttr(operandSegments));
   if (!asyncDeviceTypes.empty())
@@ -266,8 +268,8 @@ static void createDeclareDeallocFuncWithArg(
   if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
                 std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
     builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(),
-                           entryOp.getVarPtr(), entryOp.getBounds(),
-                           entryOp.getAsyncOperands(),
+                           entryOp.getVarPtr(), entryOp.getVarType(),
+                           entryOp.getBounds(), entryOp.getAsyncOperands(),
                            entryOp.getAsyncOperandsDeviceTypeAttr(),
                            entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(),
                            /*structured=*/false, /*implicit=*/false,
@@ -450,7 +452,7 @@ static void genDataExitOperations(fir::FirOpBuilder &builder,
                   std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
       builder.create<ExitOp>(
           entryOp.getLoc(), entryOp.getAccPtr(), entryOp.getVarPtr(),
-          entryOp.getBounds(), entryOp.getAsyncOperands(),
+          entryOp.getVarType(), entryOp.getBounds(), entryOp.getAsyncOperands(),
           entryOp.getAsyncOperandsDeviceTypeAttr(), entryOp.getAsyncOnlyAttr(),
           entryOp.getDataClause(), structured, entryOp.getImplicit(),
           builder.getStringAttr(*entryOp.getName()));
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 8d7e27405cfa46..d089519d7fd808 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -381,17 +381,18 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio
     OpenACC_Op<mnemonic, !listconcat(traits,
         [AttrSizedOperandSegments,
          MemoryEffects<[MemRead<OpenACC_CurrentDeviceIdResource>]>])> {
-  let arguments = !con(additionalArgs,
-                      (ins
-                       Optional<OpenACC_PointerLikeTypeInterface>:$varPtrPtr,
-                       Variadic<OpenACC_DataBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
-                       Variadic<IntOrIndex>:$asyncOperands,
-                       OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
-                       OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
-                       DefaultValuedAttr<OpenACC_DataClauseAttr,clause>:$dataClause,
-                       DefaultValuedAttr<BoolAttr, "true">:$structured,
-                       DefaultValuedAttr<BoolAttr, "false">:$implicit,
-                       OptionalAttr<StrAttr>:$name));
+  let arguments = !con(
+      additionalArgs,
+      (ins TypeAttr:$varType,
+          Optional<OpenACC_PointerLikeTypeInterface>:$varPtrPtr,
+          Variadic<OpenACC_DataBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
+          Variadic<IntOrIndex>:$asyncOperands,
+          OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
+          OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
+          DefaultValuedAttr<OpenACC_DataClauseAttr, clause>:$dataClause,
+          DefaultValuedAttr<BoolAttr, "true">:$structured,
+          DefaultValuedAttr<BoolAttr, "false">:$implicit,
+          OptionalAttr<StrAttr>:$name));
 
   let description = !strconcat(extraDescription, [{
     Description of arguments:
@@ -458,7 +459,7 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio
   }];
 
   let assemblyFormat = [{
-    `varPtr` `(` $varPtr `:` type($varPtr) `)`
+    `varPtr` `(` $varPtr `:` custom<varPtrTypes>(type($varPtr), $varType)  `)`
     oilist(
         `varPtrPtr` `(` $varPtrPtr `:` type($varPtrPtr) `)`
       | `bounds` `(` $bounds `)`
@@ -469,32 +470,35 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio
 
   let hasVerifier = 1;
 
-  let builders = [
-    OpBuilder<(ins "::mlir::Value":$varPtr,
-      "bool":$structured,
-      "bool":$implicit,
-      CArg<"::mlir::ValueRange", "{}">:$bounds), [{
-        build($_builder, $_state, varPtr.getType(), varPtr, /*varPtrPtr=*/{},
-          bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr,
+  let builders = [OpBuilder<(ins "::mlir::Value":$varPtr, "bool":$structured,
+                                "bool":$implicit,
+                                CArg<"::mlir::ValueRange", "{}">:$bounds),
+                            [{
+        build($_builder, $_state, varPtr.getType(), varPtr,
+          /*varType=*/::mlir::TypeAttr::get(
+            ::mlir::cast<::mlir::acc::PointerLikeType>(
+              varPtr.getType()).getElementType()),
+          /*varPtrPtr=*/{}, bounds, /*asyncOperands=*/{},
+          /*asyncOperandsDeviceType=*/nullptr,
           /*asyncOnly=*/nullptr, /*dataClause=*/nullptr,
           /*structured=*/$_builder.getBoolAttr(structured),
           /*implicit=*/$_builder.getBoolAttr(implicit), /*name=*/nullptr);
-      }]
-    >,
-    OpBuilder<(ins "::mlir::Value":$varPtr,
-      "bool":$structured,
-      "bool":$implicit,
-      "const ::llvm::Twine &":$name,
-      CArg<"::mlir::ValueRange", "{}">:$bounds), [{
-        build($_builder, $_state, varPtr.getType(), varPtr, /*varPtrPtr=*/{},
-          bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr,
+      }]>,
+                  OpBuilder<(ins "::mlir::Value":$varPtr, "bool":$structured,
+                                "bool":$implicit, "const ::llvm::Twine &":$name,
+                                CArg<"::mlir::ValueRange", "{}">:$bounds),
+                            [{
+        build($_builder, $_state, varPtr.getType(), varPtr,
+          /*varType=*/::mlir::TypeAttr::get(
+            ::mlir::cast<::mlir::acc::PointerLikeType>(
+              varPtr.getType()).getElementType()),
+          /*varPtrPtr=*/{}, bounds, /*asyncOperands=*/{},
+          /*asyncOperandsDeviceType=*/nullptr,
           /*asyncOnly=*/nullptr, /*dataClause=*/nullptr,
           /*structured=*/$_builder.getBoolAttr(structured),
           /*implicit=*/$_builder.getBoolAttr(implicit),
           /*name=*/$_builder.getStringAttr(name));
-      }]
-    >
-  ];
+      }]>];
 }
 
 //===----------------------------------------------------------------------===//
@@ -794,63 +798,58 @@ class OpenACC_DataExitOp<string mnemonic, string clause, string extraDescription
     }
   }];
 
-  let assemblyFormat = [{
-    `accPtr` `(` $accPtr `:` type($accPtr) `)`
-    oilist(
-        `bounds` `(` $bounds `)`
-      | `to` `varPtr` `(` $varPtr `:` type($varPtr) `)`
-      | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
-            type($asyncOperands), $asyncOperandsDeviceType) `)`
-    ) attr-dict
-  }];
-
   let hasVerifier = 1;
 }
 
-class OpenACC_DataExitOpWithVarPtr<string mnemonic, string clause> :
-    OpenACC_DataExitOp<mnemonic, clause,
-      "- `varPtr`: The address of variable to copy back to.",
-      [MemoryEffects<[MemRead<OpenACC_RuntimeCounters>,
-                      MemWrite<OpenACC_RuntimeCounters>]>],
-      (ins Arg<OpenACC_PointerLikeTypeInterface,"Address of device variable",[MemRead]>:$accPtr,
-           Arg<OpenACC_PointerLikeTypeInterface,"Address of variable",[MemWrite]>:$varPtr)> {
+class OpenACC_DataExitOpWithVarPtr<string mnemonic, string clause>
+    : OpenACC_DataExitOp<
+          mnemonic, clause,
+          "- `varPtr`: The address of variable to copy back to.",
+          [MemoryEffects<[MemRead<OpenACC_RuntimeCounters>,
+                          MemWrite<OpenACC_RuntimeCounters>]>],
+          (ins Arg<OpenACC_PointerLikeTypeInterface,
+                   "Address of device variable", [MemRead]>:$accPtr,
+              Arg<OpenACC_PointerLikeTypeInterface,
+                  "Address of variable", [MemWrite]>:$varPtr,
+              TypeAttr:$varType)> {
   let assemblyFormat = [{
     `accPtr` `(` $accPtr `:` type($accPtr) `)`
     (`bounds` `(` $bounds^ `)` )?
     (`async` `(` custom<DeviceTypeOperands>($asyncOperands,
             type($asyncOperands), $asyncOperandsDeviceType)^ `)`)?
-    `to` `varPtr` `(` $varPtr `:` type($varPtr) `)`
+    `to` `varPtr` `(` $varPtr `:` custom<varPtrTypes>(type($varPtr), $varType) `)`
     attr-dict
   }];
 
-  let builders = [
-    OpBuilder<(ins "::mlir::Value":$accPtr,
-      "::mlir::Value":$varPtr,
-      "bool":$structured,
-      "bool":$implicit,
-      CArg<"::mlir::ValueRange", "{}">:$bounds), [{
+  let builders = [OpBuilder<(ins "::mlir::Value":$accPtr,
+                                "::mlir::Value":$varPtr, "bool":$structured,
+                                "bool":$implicit,
+                                CArg<"::mlir::ValueRange", "{}">:$bounds),
+                            [{
         build($_builder, $_state, accPtr, varPtr,
+          /*varType=*/::mlir::TypeAttr::get(
+          ::mlir::cast<::mlir::acc::PointerLikeType>(
+              varPtr.getType()).getElementType()),
           bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr,
           /*asyncOnly=*/nullptr, /*dataClause=*/nullptr,
           /*structured=*/$_builder.getBoolAttr(structured),
           /*implicit=*/$_builder.getBoolAttr(implicit), /*name=*/nullptr);
-      }]
-    >,
-    OpBuilder<(ins "::mlir::Value":$accPtr,
-      "::mlir::Value":$varPtr,
-      "bool":$structured,
-      "bool":$implicit,
-      "const ::llvm::Twine &":$name,
-      CArg<"::mlir::ValueRange", "{}">:$bounds), [{
+      }]>,
+                  OpBuilder<(ins "::mlir::Value":$accPtr,
+                                "::mlir::Value":$varPtr, "bool":$structured,
+                                "bool":$implicit, "const ::llvm::Twine &":$name,
+                                CArg<"::mlir::ValueRange", "{}">:$bounds),
+                            [{
         build($_builder, $_state, accPtr, varPtr,
+          /*varType=*/::mlir::TypeAttr::get(
+          ::mlir::cast<::mlir::acc::PointerLikeType>(
+              varPtr.getType()).getElementType()),
           bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr,
           /*asyncOnly=*/nullptr, /*dataClause=*/nullptr,
           /*structured=*/$_builder.getBoolAttr(structured),
           /*implicit=*/$_builder.getBoolAttr(implicit),
           /*name=*/$_builder.getStringAttr(name));
-      }]
-    >
-  ];
+      }]>];
 }
 
 class OpenACC_DataExitOpNoVarPtr<string mnemonic, string clause> :
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 280260e0485bb5..4daba2679bd91c 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Matchers.h"
@@ -18,6 +19,7 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/LogicalResult.h"
 
 using namespace mlir;
 using namespace acc;
@@ -190,6 +192,43 @@ static LogicalResult checkWaitAndAsyncConflict(Op op) {
   return success();
 }
 
+static ParseResult parsevarPtrTypes(mlir::OpAsmParser &parser,
+                                    mlir::Type &varPtrRawType,
+                                    mlir::TypeAttr &varTypeAttr) {
+  if (failed(parser.parseType(varPtrRawType))) {
+    return failure();
+  }
+
+  // If there is no comma, it means that the varType is implied from the
+  // element type of varPtr.
+  if (succeeded(parser.parseOptionalComma())) {
+    mlir::Type varType;
+    if (failed(parser.parseType(varType)))
+      return failure();
+    varTypeAttr = mlir::TypeAttr::get(varType);
+  } else {
+    varTypeAttr = mlir::TypeAttr::get(
+        mlir::cast<mlir::acc::PointerLikeType>(varPtrRawType).getElementType());
+  }
+
+  return success();
+}
+
+static void printvarPtrTypes(mlir::OpAsmPrinter &p, mlir::Operation *op,
+                             mlir::Type varPtrType,
+                             mlir::TypeAttr varTypeAttr) {
+  p.printType(varPtrType);
+  mlir::Type varType = varTypeAttr.getValue();
+
+  // Avoid printing the varType if it is already captured as the element type
+  // of varPtr's type.
+  if (mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType() !=
+      varType) {
+    p << ", ";
+    p.printType(varType);
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // DataBoundsOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir b/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir
index d8e89f64f8bc04..d83baf3df114bf 100644
--- a/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir
+++ b/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -convert-openacc-to-scf -split-input-file | FileCheck %s
 
 func.func @testenterdataop(%a: memref<f32>, %ifCond: i1) -> () {
-  %0 = acc.create varPtr(%a : memref<f32>) -> memref<f32>
+  %0 = acc.create varPtr(%a : memref<f32>, f32) -> memref<f32>
   acc.enter_data if(%ifCond) dataOperands(%0 : memref<f32>)
   return
 }
@@ -14,7 +14,7 @@ func.func @testenterdataop(%a: memref<f32>, %ifCond: i1) -> () {
 // -----
 
 func.func @testexitdataop(%a: memref<f32>, %ifCond: i1) -> () {
-  %0 = acc.getdeviceptr varPtr(%a : memref<f32>) -> memref<f32>
+  %0 = acc.getdeviceptr varPtr(%a : memref<f32>, f32) -> memref<f32>
   acc.exit_data if(%ifCond) dataOperands(%0 : memref<f32>)
   acc.delete accPtr(%0 : memref<f32>)
   return
@@ -28,7 +28,7 @@ func.func @testexitdataop(%a: memref<f32>, %ifCond: i1) -> () {
 // -----
 
 func.func @testupdateop(%a: memref<f32>, %ifCond: i1) -> () {
-  %0 = acc.update_device varPtr(%a : memref<f32>) -> memref<f32>
+  %0 = acc.update_device varPtr(%a : memref<f32>, f32) -> memref<f32>
   acc.update if(%ifCond) dataOperands(%0 : memref<f32>)
   return
 }
@@ -42,7 +42,7 @@ func.func @testupdateop(%a: memref<f32>, %ifCond: i1) -> () {
 
 func.func @update_true(%arg0: memref<f32>) {
   %true = arith.constant true
-  %0 = acc.update_device varPtr(%arg0 : memref<f32>) -> memref<f32>
+  %0 = acc.update_device varPtr(%arg0 : memref<f32>, f32) -> memref<f32>
   acc.update if(%true) dataOperands(%0 : memref<f32>)
   return
 }
@@ -55,7 +55,7 @@ func.func @update_true(%arg0: memref<f32>) {
 
 func.func @update_false(%arg0: memref<f32>) {
   %false = arith.constant false
-  %0 = acc.update_device varPtr(%arg0 : memref<f32>) -> memref<f32>
+  %0 = acc.update_device varPtr(%arg0 : memref<f32>, f32) -> memref<f32>
   acc.update if(%false) dataOperands(%0 : memref<f32>)
   return
 }
@@ -67,7 +67,7 @@ func.func @update_false(%arg0: memref<f32>) {
 
 func.func @enter_data_true(%d1 : memref<f32>) {
   %true = arith.constant true
-  %0 = acc.create varPtr(%d1 : memref<f32>) -> memref<f32>
+  %0 = acc.create varPtr(%d1 : memref<f32>, f32) -> memref<f32>
   acc.enter_data if(%true) dataOperands(%0 : memref<f32>) attributes {async}
   return
 }
@@ -80,7 +80,7 @@ func.func @enter_data_true(%d1 : memref<f32>) {
 
 func.func @enter_data_false(%d1 : memref<f32>) {
   %false = arith.constant false
-  %0 = acc.create varPtr(%d1 : memref<f32>) -> memref<f32>
+  %0 = acc.create varPtr(%d1 : memref<f32>, f32) -> memref<f32>
   acc.enter_data if(%false) dataOperands(%0 : memref<f32>) attributes {async}
   return
 }
@@ -92,7 +92,7 @@ func.func @enter_data_false(%d1 : memref<f32>) {
 
 func.func @exit_data_true(%d1 : memref<f32>) {
   %true = arith.constant true
-  %0 = acc.getdeviceptr varPtr(%d1 : memref<f32>) -> memref<f32>
+  %0 = acc.getdeviceptr varPtr(%d1 : memref<f32>, f32) -> memref<f32>
   acc.exit_data if(%true) dataOperands(%0 : memref<f32>) attributes {async}
   acc.delete accPtr(%0 : memref<f32>)
   return
@@ -106,7 +106,7 @@ func.func @exit_data_true(%d1 : memref<f32>) {
 
 func.func @exit_data_false(%d1 : memref<f32>) {
   %false = arith.constant false
-  %0 = acc.getdeviceptr varPtr(%d1 : memref<f32>) -> memref<f32>
+  %0 = acc.getdeviceptr varPtr(%d1 : memref<f32>, f32) -> memref<f32>
   acc.exit_data if(%false) dataOperands(%0 : memref<f32>) attributes {async}
   acc.delete accPtr(%0 : memref<f32>)
   return
diff --git a/mlir/test/Dialect/OpenACC/canonicalize.mlir b/mlir/test/Dialect/OpenACC/canonicalize.mlir
index e43a27f6b9e89a..c5272f579c1d23 100644
--- a/mlir/test/Dialect/OpenACC/canonicalize.mlir
+++ b/mlir/test/Dialect/OpenACC/canonicalize.mlir
@@ -2,7 +2,7 @@
 
 func.func @testenterdataop(%a: memref<f32>) -> () {
   %ifCond = arith.constant true
-  %0 = acc.create varPtr(%a : memref<f32>) -> memref<f32>
+  %0 = acc.create varPtr(%a : memref<f32>, f32) -> memref<f32>
   acc.enter_data if(%ifCond) dataOperands(%0 : memref<f32>)
   return
 }
@@ -13,7 +13,7 @@ func.func @testenterdataop(%a: memref<f32>) -> () {
 
 func.func @testenterdataop(%a: memref<f32>) -> () {
   %ifCond = arith.constant false
-  %0 = acc.create varPtr(%a : memref<f32>) -> memref<f32>
+  %0 = acc.create varPtr(%a : memref<f32>, f32) -> memref<f32>
   acc.enter_data if(%ifCond) dataOperands(%0 : memref<f32>)
   return
 }
@@ -25,7 +25,7 @@ func.func @testenterdataop(%a: memref<f32>) -> () {
 
 func.func @testexitdataop(%a: memref<f32>) -> () {
   %ifCond = arith.constant true
-  %0 = acc.getdeviceptr varPtr(%a : memref<f32>) -> memref<f32>
+  %0 = acc.getdeviceptr varPtr(%a : memref<f32>, f32) -> memref<f32>
   acc.exit_data if(%ifCond) dataOperands(%0 : memref<f32>)
   acc.delete accPtr(%0 : memref<f32>)
   return
@@ -37,7 +37,7 @@ func.func @testexitdataop(%a: memref<f32>) -> () {
 
 func.func @testexitdataop(%a: memref<f32>) -> () {
   %ifCond = arith.constant false
-  %0 = acc.getdeviceptr varPtr(%a : memref<f32>) -> memref<f32>
+  %0 = acc.getdeviceptr varPtr(%a : memref<f32>, f32) -> memref<f32>
   acc.exit_data if(%ifCond) dataOperands(%0 : memref<f32>)
   acc.delete accPtr(%0 : memref<f32>)
   return
@@ -49,8 +49,8 @@ func.func @testexitdataop(%a: memref<f32>) -> () {
 // -----
 
 func.func @testupdateop(%a: memref<f32>) -> () {
-  %0 = acc.getdeviceptr varPtr(%a : memref<f32>) -> memref<f32>
-  acc.update_host accPtr(%0 : memref<f32>) to varPtr(%a : memref<f32>)
+  %0 = acc.getdeviceptr varPtr(%a : memref<f32>, f32) -> memref<f32>
+  acc.update_host accPtr(%0 : memref<f32>) to varPtr(%a : memref<f32>, f32)
   %ifCond = arith.constant true
   acc.update if(%ifCond) dataOperands(%0: memref<f32>)
   return
@@ -61,8 +61,8 @@ func.func @testupdateop(%a: memref<f32>) -> () {
 // -----
 
 func.func @testupdateop(%a: memref<f32>) -> () {
-  %0 = acc.getdeviceptr varPtr(%a : memref<f32>) -> memref<f32>
-  acc.update_host accPtr(%0 : memref<f32>) to varPtr(%a : memref<f32>)
+  %0 = acc.getdeviceptr varPtr(%a : memref<f32>, f32) -> memref<f32>
+  acc.update_host accPtr(%0 : memref<f32>) to varPtr(%a : memref<f32>, f32)
   %ifCond = arith.constant false
   acc.update if(%ifCond) dataOperands(%0: memref<f32>)
   return
@@ -74,7 +74,7 @@ func.func @testupdateop(%a: memref<f32>) -> () {
 // -----
 
 func.func @testenterdataop(%a: memref<f32>, %ifCond: i1) -> () {
-  %0 = acc.create varPtr(%a : memref<f32>) -> memref<f32>
+  %0 = acc.create varPtr(%a : memref<f32>, f32) -> memref<f32>
   acc.enter_data if(%ifCond) dataOperands(%0 : memref<f32>)
   return
 }
@@ -85,7 +85,7 @@ func.func @testenterdataop(%a: memref<f32>, %ifCond: i1) -> () {
 // -----
 
 func.func @testexitdataop(%a: memref<f32>, %ifCond: i1) -> () {
-  %0 = acc.getdeviceptr varPtr(%a : memref<f32>) -> memref<f32>
+  %0 = acc.getdeviceptr varPtr(%a : memref<f32>, f32) -> memref<f32>
   acc.exit_data if(%ifCond) dataOperands(%0 : memref<f32>)
   acc.delete accPtr(%0 : memref<f32...
[truncated]

@razvanlupusoru
Copy link
Contributor Author

Note that the invasive changes in the .td file is because the latest git-clang-format now handles .td files. It formats the context around my new changes.

@clementval
Copy link
Contributor

clementval commented Dec 8, 2024

Mostly looks good for me Razvan and it makes sense. Just thinking if we can have a syntax that is more obvious. Maybe a keyword would help? Or just keep it printed as an attribute? I'm fine going with this and refine if we think it needs to be.

@razvanlupusoru
Copy link
Contributor Author

Mostly looks good for me Razvan and it makes sense. Just thinking if we can have a syntax that is more obvious. Maybe a keyword would help? Or just keep it printed as an attribute? I'm fine going with this and refine if we think it needs to be.

What syntax do you have in mind?

@clementval
Copy link
Contributor

Mostly looks good for me Razvan and it makes sense. Just thinking if we can have a syntax that is more obvious. Maybe a keyword would help? Or just keep it printed as an attribute? I'm fine going with this and refine if we think it needs to be.

What syntax do you have in mind?

  • Maybe just leave the standard syntax with the attribute would be clearer but not a strong opinion on that. As I didn't really think about this the syntax proposed didn't feel really obvious but maybe another one would also be ambiguous. Adding a keyword might help but as I mentioned that is just over thinking and could be dealt with in follow up patches.

@Dinistro Dinistro changed the title [acc] Introduce varType to acc data clause operations [MLIR][acc] Introduce varType to acc data clause operations Dec 9, 2024
@razvanlupusoru
Copy link
Contributor Author

* Maybe just leave the standard syntax with the attribute would be clearer but not a strong opinion on that. As I didn't really think about this the syntax proposed didn't feel really obvious but maybe another one would also be ambiguous. Adding a keyword might help but as I mentioned that is just over thinking and could be dealt with in follow up patches.

I attempted to address your concern by adding the "varType" keyword to the printing. It is still optional (if the element's type is encoded in the type of varPtr's target), but at least this fixes the ambiguity when reading the IR. What do you think?

Copy link
Contributor

@clementval clementval left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Thanks for addressing my comment.

@razvanlupusoru razvanlupusoru merged commit a0eb794 into llvm:main Dec 9, 2024
6 of 7 checks passed
@razvanlupusoru razvanlupusoru deleted the accvartype1 branch December 9, 2024 23:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category mlir:llvm mlir:openacc mlir openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants