Skip to content

[mlir][EmitC] Add support for pointer and opaque types to subscript op #86266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 3, 2024

Conversation

simon-camp
Copy link
Contributor

For pointer types the indices are restricted to one integer-like operand.
For opaque types no further restrictions are made.

For pointer types the indices are restricted to one integer-like operand.
For opaque types no further restrictions are made.
@llvmbot
Copy link
Member

llvmbot commented Mar 22, 2024

@llvm/pr-subscribers-mlir

Author: Simon Camphausen (simon-camp)

Changes

For pointer types the indices are restricted to one integer-like operand.
For opaque types no further restrictions are made.


Full diff: https://github.com/llvm/llvm-project/pull/86266.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitC.h (+6)
  • (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitC.td (+18-12)
  • (modified) mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp (+4-2)
  • (modified) mlir/lib/Dialect/EmitC/IR/EmitC.cpp (+51-5)
  • (modified) mlir/lib/Target/Cpp/TranslateToCpp.cpp (+1-1)
  • (modified) mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir (+2-2)
  • (modified) mlir/test/Dialect/EmitC/invalid_ops.mlir (+42-2)
  • (modified) mlir/test/Dialect/EmitC/ops.mlir (+7)
  • (modified) mlir/test/Target/Cpp/subscript.mlir (+26-6)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index 725a1bcb4e6cb1..d2f20b642b26b2 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -30,8 +30,14 @@
 namespace mlir {
 namespace emitc {
 void buildTerminatedBody(OpBuilder &builder, Location loc);
+
 /// Determines whether \p type is a valid integer type in EmitC.
 bool isSupportedIntegerType(mlir::Type type);
+
+/// Determines whether \p type is integer like, i.e. it's a supported integer,
+/// an index or opaque type.
+bool isIntegerLikeType(Type type);
+
 /// Determines whether \p type is a valid floating-point type in EmitC.
 bool isSupportedFloatType(mlir::Type type);
 } // namespace emitc
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index d746222ff37a4b..539a4f3e9805e1 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1155,35 +1155,41 @@ def EmitC_IfOp : EmitC_Op<"if",
   let hasCustomAssemblyFormat = 1;
 }
 
-def EmitC_SubscriptOp : EmitC_Op<"subscript",
-  [TypesMatchWith<"result type matches element type of 'array'",
-                  "array", "result",
-                  "::llvm::cast<ArrayType>($_self).getElementType()">]> {
-  let summary = "Array subscript operation";
+def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
+  let summary = "Subscript operation";
   let description = [{
     With the `subscript` operation the subscript operator `[]` can be applied
-    to variables or arguments of array type.
+    to variables or arguments of array, pointer and opaque type.
 
     Example:
 
     ```mlir
     %i = index.constant 1
     %j = index.constant 7
-    %0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, index, index
+    %0 = emitc.subscript %arg0[%i, %j] : !emitc.array<4x8xf32>, index, index
+    %1 = emitc.subscript %arg1[%i] : !emitc.ptr<i32>, index
     ```
   }];
-  let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
-                       Variadic<IntegerIndexOrOpaqueType>:$indices);
+  let arguments = (ins Arg<AnyTypeOf<[
+      EmitC_ArrayType,
+      EmitC_OpaqueType,
+      EmitC_PointerType]>,
+    "the reference to load from">:$ref,
+    Variadic<AnyType>:$indices);
   let results = (outs AnyType:$result);
 
   let builders = [
-    OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{
-      build($_builder, $_state, cast<ArrayType>(array.getType()).getElementType(), array, indices);
+    OpBuilder<(ins "TypedValue<ArrayType>":$array, "ValueRange":$indices), [{
+      build($_builder, $_state, array.getType().getElementType(), array, indices);
+    }]>,
+    OpBuilder<(ins "TypedValue<PointerType>":$pointer, "Value":$index), [{
+      build($_builder, $_state, pointer.getType().getPointee(), pointer,
+            ValueRange{index});
     }]>
   ];
 
   let hasVerifier = 1;
-  let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array) `,` type($indices)";
+  let assemblyFormat = "$ref `[` $indices `]` attr-dict `:` functional-type(operands, results)";
 }
 
 
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 0e3b6469212640..3a2405a6195437 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -63,7 +63,8 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
     }
 
     auto subscript = rewriter.create<emitc::SubscriptOp>(
-        op.getLoc(), operands.getMemref(), operands.getIndices());
+        op.getLoc(), cast<TypedValue<emitc::ArrayType>>(operands.getMemref()),
+        operands.getIndices());
 
     auto noInit = emitc::OpaqueAttr::get(getContext(), "");
     auto var =
@@ -83,7 +84,8 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
                   ConversionPatternRewriter &rewriter) const override {
 
     auto subscript = rewriter.create<emitc::SubscriptOp>(
-        op.getLoc(), operands.getMemref(), operands.getIndices());
+        op.getLoc(), cast<TypedValue<emitc::ArrayType>>(operands.getMemref()),
+        operands.getIndices());
     rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
                                                  operands.getValue());
     return success();
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index ab5c418e844fbf..f364573552fe97 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -70,6 +70,11 @@ bool mlir::emitc::isSupportedIntegerType(Type type) {
   return false;
 }
 
+bool mlir::emitc::isIntegerLikeType(Type type) {
+  return isSupportedIntegerType(type) ||
+         llvm::isa<IndexType, emitc::OpaqueType>(type);
+}
+
 bool mlir::emitc::isSupportedFloatType(Type type) {
   if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
     switch (floatType.getWidth()) {
@@ -781,11 +786,52 @@ LogicalResult emitc::YieldOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult emitc::SubscriptOp::verify() {
-  if (getIndices().size() != (size_t)getArray().getType().getRank()) {
-    return emitOpError() << "requires number of indices ("
-                         << getIndices().size()
-                         << ") to match the rank of the array type ("
-                         << getArray().getType().getRank() << ")";
+  if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getRef().getType())) {
+    // Check arity of indices.
+    if (getIndices().size() != (size_t)arrayType.getRank()) {
+      return emitOpError() << "requires number of indices ("
+                           << getIndices().size()
+                           << ") to match the rank of the array type ("
+                           << arrayType.getRank() << ")";
+    }
+    // Check types of index operands.
+    for (unsigned i = 0, e = getIndices().size(); i != e; ++i) {
+      Type type = getIndices()[i].getType();
+      if (!isIntegerLikeType(type)) {
+        return emitOpError() << "requires index operand " << i
+                             << " to be integer-like, but got " << type;
+      }
+    }
+    // Check element type.
+    Type elementType = arrayType.getElementType();
+    if (elementType != getType()) {
+      return emitOpError() << "requires element type (" << elementType
+                           << ") and result type (" << getType()
+                           << ") to match";
+    }
+  } else if (auto pointerType =
+                 llvm::dyn_cast<emitc::PointerType>(getRef().getType())) {
+    // Check arity of indices.
+    if (getIndices().size() != 1) {
+      return emitOpError() << "requires one index operand, but got "
+                           << getIndices().size();
+    }
+    // Check types of index operand.
+    Type type = getIndices()[0].getType();
+    if (!isIntegerLikeType(type)) {
+      return emitOpError()
+             << "requires index operand to be integer-like, but got " << type;
+    }
+    // Check pointee type.
+    Type pointeeType = pointerType.getPointee();
+    if (pointeeType != getType()) {
+      return emitOpError() << "requires pointee type (" << pointeeType
+                           << ") and result type (" << getType()
+                           << ") to match";
+    }
+  } else {
+    // The reference has opaque type, so we can't assume anything about arity or
+    // types of index operands.
   }
   return success();
 }
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 95c7af2f07be46..8fd04b7d1a51e0 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1105,7 +1105,7 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
 std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
   std::string out;
   llvm::raw_string_ostream ss(out);
-  ss << getOrCreateName(op.getArray());
+  ss << getOrCreateName(op.getRef());
   for (auto index : op.getIndices()) {
     ss << "[" << getOrCreateName(index) << "]";
   }
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index 9793b2d6d7832f..7aa2ba88843a2a 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -6,7 +6,7 @@ func.func @memref_store(%v : f32, %i: index, %j: index) {
   // CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
   %0 = memref.alloca() : memref<4x8xf32>
 
-  // CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
+  // CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32
   // CHECK: emitc.assign %[[v]] : f32 to %[[SUBSCRIPT:.*]] : f32
   memref.store %v, %0[%i, %j] : memref<4x8xf32>
   return
@@ -19,7 +19,7 @@ func.func @memref_load(%i: index, %j: index) -> f32 {
   // CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
   %0 = memref.alloca() : memref<4x8xf32>
 
-  // CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
+  // CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32
   // CHECK: %[[VAR:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
   // CHECK: emitc.assign %[[LOAD]] : f32 to %[[VAR]] : f32
   %1 = memref.load %0[%i, %j] : memref<4x8xf32>
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 22423cf61b5556..321e4c01110e82 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -390,8 +390,48 @@ func.func @logical_or_resulterror(%arg0: i32, %arg1: i32) {
 
 // -----
 
-func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: index) {
+func.func @test_subscript_array_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index) {
   // expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}}
-  %0 = emitc.subscript %arg0[%arg2] : <4x8xf32>, index
+  %0 = emitc.subscript %arg0[%arg1] : (!emitc.array<4x8xf32>, index) -> f32
+  return
+}
+
+// -----
+
+func.func @test_subscript_array_index_type_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index, %arg2: f32) {
+  // expected-error @+1 {{'emitc.subscript' op requires index operand 1 to be integer-like, but got 'f32'}}
+  %0 = emitc.subscript %arg0[%arg1, %arg2] : (!emitc.array<4x8xf32>, index, f32) -> f32
+  return
+}
+
+// -----
+
+func.func @test_subscript_array_type_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index, %arg2: index) {
+  // expected-error @+1 {{'emitc.subscript' op requires element type ('f32') and result type ('i32') to match}}
+  %0 = emitc.subscript %arg0[%arg1, %arg2] : (!emitc.array<4x8xf32>, index, index) -> i32
+  return
+}
+
+// -----
+
+func.func @test_subscript_ptr_indices_mismatch(%arg0: !emitc.ptr<f32>, %arg2: index) {
+  // expected-error @+1 {{'emitc.subscript' op requires one index operand, but got 2}}
+  %0 = emitc.subscript %arg0[%arg2, %arg2] : (!emitc.ptr<f32>, index, index) -> f32
+  return
+}
+
+// -----
+
+func.func @test_subscript_ptr_index_type_mismatch(%arg0: !emitc.ptr<f32>, %arg2: f64) {
+  // expected-error @+1 {{'emitc.subscript' op requires index operand to be integer-like, but got 'f64'}}
+  %0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, f64) -> f32
+  return
+}
+
+// -----
+
+func.func @test_subscript_ptr_type_mismatch(%arg0: !emitc.ptr<f32>, %arg2: index) {
+  // expected-error @+1 {{'emitc.subscript' op requires pointee type ('f32') and result type ('f64') to match}}
+  %0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, index) -> f64
   return
 }
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 5f00a295ed740e..ace3670426afa5 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -214,6 +214,13 @@ func.func @test_for_not_index_induction(%arg0 : i16, %arg1 : i16, %arg2 : i16) {
   return
 }
 
+func.func @test_subscript(%arg0 : !emitc.array<2x3xf32>, %arg1 : !emitc.ptr<i32>, %arg2 : !emitc.opaque<"std::map<char, int>">, %idx0 : index, %idx1 : i32, %idx2 : !emitc.opaque<"char">) {
+  %0 = emitc.subscript %arg0[%idx0, %idx1] : (!emitc.array<2x3xf32>, index, i32) -> f32
+  %1 = emitc.subscript %arg1[%idx0] : (!emitc.ptr<i32>, index) -> i32
+  %2 = emitc.subscript %arg2[%idx2] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
+  return
+}
+
 emitc.verbatim "#ifdef __cplusplus"
 emitc.verbatim "extern \"C\" {"
 emitc.verbatim "#endif  // __cplusplus"
diff --git a/mlir/test/Target/Cpp/subscript.mlir b/mlir/test/Target/Cpp/subscript.mlir
index a6c82df9111a79..0b388953c80d37 100644
--- a/mlir/test/Target/Cpp/subscript.mlir
+++ b/mlir/test/Target/Cpp/subscript.mlir
@@ -1,24 +1,44 @@
 // RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
 // RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
 
-func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
-  %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>, index, index
-  %1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>, index, index
+func.func @load_store_array(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
+  %0 = emitc.subscript %arg0[%arg2, %arg3] : (!emitc.array<4x8xf32>, index, index) -> f32
+  %1 = emitc.subscript %arg1[%arg2, %arg3] : (!emitc.array<3x5xf32>, index, index) -> f32
   emitc.assign %0 : f32 to %1 : f32
   return
 }
-// CHECK: void load_store(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
+// CHECK: void load_store_array(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
 // CHECK-SAME:            size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
 // CHECK-NEXT: [[ARR2]][[[I]]][[[J]]] = [[ARR1]][[[I]]][[[J]]];
 
+func.func @load_store_pointer(%arg0: !emitc.ptr<f32>, %arg1: !emitc.ptr<f32>, %arg2: index, %arg3: index) {
+  %0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, index) -> f32
+  %1 = emitc.subscript %arg1[%arg3] : (!emitc.ptr<f32>, index) -> f32
+  emitc.assign %0 : f32 to %1 : f32
+  return
+}
+// CHECK: void load_store_pointer(float* [[PTR1:[^ ]*]], float* [[PTR2:[^ ]*]],
+// CHECK-SAME:            size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
+// CHECK-NEXT: [[PTR2]][[[J]]] = [[PTR1]][[[I]]];
+
+func.func @load_store_opaque(%arg0: !emitc.opaque<"std::map<char, int>">, %arg1: !emitc.opaque<"std::map<char, int>">, %arg2: !emitc.opaque<"char">, %arg3: !emitc.opaque<"char">) {
+  %0 = emitc.subscript %arg0[%arg2] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
+  %1 = emitc.subscript %arg1[%arg3] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
+  emitc.assign %0 : !emitc.opaque<"int"> to %1 : !emitc.opaque<"int">
+  return
+}
+// CHECK: void load_store_opaque(std::map<char, int> [[MAP1:[^ ]*]], std::map<char, int> [[MAP2:[^ ]*]],
+// CHECK-SAME:            char [[I:[^ ]*]], char [[J:[^ ]*]])
+// CHECK-NEXT: [[MAP2]][[[J]]] = [[MAP1]][[[I]]];
+
 emitc.func @func1(%arg0 : f32) {
   emitc.return
 }
 
 emitc.func @call_arg(%arg0: !emitc.array<4x8xf32>, %i: i32, %j: i16,
                      %k: i8) {
-  %0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, i32, i16
-  %1 = emitc.subscript %arg0[%j, %k] : <4x8xf32>, i16, i8
+  %0 = emitc.subscript %arg0[%i, %j] : (!emitc.array<4x8xf32>, i32, i16) -> f32
+  %1 = emitc.subscript %arg0[%j, %k] : (!emitc.array<4x8xf32>, i16, i8) -> f32
 
   emitc.call @func1 (%0) : (f32) -> ()
   emitc.call_opaque "func2" (%1) : (f32) -> ()

@llvmbot
Copy link
Member

llvmbot commented Mar 22, 2024

@llvm/pr-subscribers-mlir-emitc

Author: Simon Camphausen (simon-camp)

Changes

For pointer types the indices are restricted to one integer-like operand.
For opaque types no further restrictions are made.


Full diff: https://github.com/llvm/llvm-project/pull/86266.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitC.h (+6)
  • (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitC.td (+18-12)
  • (modified) mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp (+4-2)
  • (modified) mlir/lib/Dialect/EmitC/IR/EmitC.cpp (+51-5)
  • (modified) mlir/lib/Target/Cpp/TranslateToCpp.cpp (+1-1)
  • (modified) mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir (+2-2)
  • (modified) mlir/test/Dialect/EmitC/invalid_ops.mlir (+42-2)
  • (modified) mlir/test/Dialect/EmitC/ops.mlir (+7)
  • (modified) mlir/test/Target/Cpp/subscript.mlir (+26-6)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index 725a1bcb4e6cb1..d2f20b642b26b2 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -30,8 +30,14 @@
 namespace mlir {
 namespace emitc {
 void buildTerminatedBody(OpBuilder &builder, Location loc);
+
 /// Determines whether \p type is a valid integer type in EmitC.
 bool isSupportedIntegerType(mlir::Type type);
+
+/// Determines whether \p type is integer like, i.e. it's a supported integer,
+/// an index or opaque type.
+bool isIntegerLikeType(Type type);
+
 /// Determines whether \p type is a valid floating-point type in EmitC.
 bool isSupportedFloatType(mlir::Type type);
 } // namespace emitc
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index d746222ff37a4b..539a4f3e9805e1 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1155,35 +1155,41 @@ def EmitC_IfOp : EmitC_Op<"if",
   let hasCustomAssemblyFormat = 1;
 }
 
-def EmitC_SubscriptOp : EmitC_Op<"subscript",
-  [TypesMatchWith<"result type matches element type of 'array'",
-                  "array", "result",
-                  "::llvm::cast<ArrayType>($_self).getElementType()">]> {
-  let summary = "Array subscript operation";
+def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
+  let summary = "Subscript operation";
   let description = [{
     With the `subscript` operation the subscript operator `[]` can be applied
-    to variables or arguments of array type.
+    to variables or arguments of array, pointer and opaque type.
 
     Example:
 
     ```mlir
     %i = index.constant 1
     %j = index.constant 7
-    %0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, index, index
+    %0 = emitc.subscript %arg0[%i, %j] : !emitc.array<4x8xf32>, index, index
+    %1 = emitc.subscript %arg1[%i] : !emitc.ptr<i32>, index
     ```
   }];
-  let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
-                       Variadic<IntegerIndexOrOpaqueType>:$indices);
+  let arguments = (ins Arg<AnyTypeOf<[
+      EmitC_ArrayType,
+      EmitC_OpaqueType,
+      EmitC_PointerType]>,
+    "the reference to load from">:$ref,
+    Variadic<AnyType>:$indices);
   let results = (outs AnyType:$result);
 
   let builders = [
-    OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{
-      build($_builder, $_state, cast<ArrayType>(array.getType()).getElementType(), array, indices);
+    OpBuilder<(ins "TypedValue<ArrayType>":$array, "ValueRange":$indices), [{
+      build($_builder, $_state, array.getType().getElementType(), array, indices);
+    }]>,
+    OpBuilder<(ins "TypedValue<PointerType>":$pointer, "Value":$index), [{
+      build($_builder, $_state, pointer.getType().getPointee(), pointer,
+            ValueRange{index});
     }]>
   ];
 
   let hasVerifier = 1;
-  let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array) `,` type($indices)";
+  let assemblyFormat = "$ref `[` $indices `]` attr-dict `:` functional-type(operands, results)";
 }
 
 
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 0e3b6469212640..3a2405a6195437 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -63,7 +63,8 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
     }
 
     auto subscript = rewriter.create<emitc::SubscriptOp>(
-        op.getLoc(), operands.getMemref(), operands.getIndices());
+        op.getLoc(), cast<TypedValue<emitc::ArrayType>>(operands.getMemref()),
+        operands.getIndices());
 
     auto noInit = emitc::OpaqueAttr::get(getContext(), "");
     auto var =
@@ -83,7 +84,8 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
                   ConversionPatternRewriter &rewriter) const override {
 
     auto subscript = rewriter.create<emitc::SubscriptOp>(
-        op.getLoc(), operands.getMemref(), operands.getIndices());
+        op.getLoc(), cast<TypedValue<emitc::ArrayType>>(operands.getMemref()),
+        operands.getIndices());
     rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
                                                  operands.getValue());
     return success();
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index ab5c418e844fbf..f364573552fe97 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -70,6 +70,11 @@ bool mlir::emitc::isSupportedIntegerType(Type type) {
   return false;
 }
 
+bool mlir::emitc::isIntegerLikeType(Type type) {
+  return isSupportedIntegerType(type) ||
+         llvm::isa<IndexType, emitc::OpaqueType>(type);
+}
+
 bool mlir::emitc::isSupportedFloatType(Type type) {
   if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
     switch (floatType.getWidth()) {
@@ -781,11 +786,52 @@ LogicalResult emitc::YieldOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult emitc::SubscriptOp::verify() {
-  if (getIndices().size() != (size_t)getArray().getType().getRank()) {
-    return emitOpError() << "requires number of indices ("
-                         << getIndices().size()
-                         << ") to match the rank of the array type ("
-                         << getArray().getType().getRank() << ")";
+  if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getRef().getType())) {
+    // Check arity of indices.
+    if (getIndices().size() != (size_t)arrayType.getRank()) {
+      return emitOpError() << "requires number of indices ("
+                           << getIndices().size()
+                           << ") to match the rank of the array type ("
+                           << arrayType.getRank() << ")";
+    }
+    // Check types of index operands.
+    for (unsigned i = 0, e = getIndices().size(); i != e; ++i) {
+      Type type = getIndices()[i].getType();
+      if (!isIntegerLikeType(type)) {
+        return emitOpError() << "requires index operand " << i
+                             << " to be integer-like, but got " << type;
+      }
+    }
+    // Check element type.
+    Type elementType = arrayType.getElementType();
+    if (elementType != getType()) {
+      return emitOpError() << "requires element type (" << elementType
+                           << ") and result type (" << getType()
+                           << ") to match";
+    }
+  } else if (auto pointerType =
+                 llvm::dyn_cast<emitc::PointerType>(getRef().getType())) {
+    // Check arity of indices.
+    if (getIndices().size() != 1) {
+      return emitOpError() << "requires one index operand, but got "
+                           << getIndices().size();
+    }
+    // Check types of index operand.
+    Type type = getIndices()[0].getType();
+    if (!isIntegerLikeType(type)) {
+      return emitOpError()
+             << "requires index operand to be integer-like, but got " << type;
+    }
+    // Check pointee type.
+    Type pointeeType = pointerType.getPointee();
+    if (pointeeType != getType()) {
+      return emitOpError() << "requires pointee type (" << pointeeType
+                           << ") and result type (" << getType()
+                           << ") to match";
+    }
+  } else {
+    // The reference has opaque type, so we can't assume anything about arity or
+    // types of index operands.
   }
   return success();
 }
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 95c7af2f07be46..8fd04b7d1a51e0 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1105,7 +1105,7 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
 std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
   std::string out;
   llvm::raw_string_ostream ss(out);
-  ss << getOrCreateName(op.getArray());
+  ss << getOrCreateName(op.getRef());
   for (auto index : op.getIndices()) {
     ss << "[" << getOrCreateName(index) << "]";
   }
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index 9793b2d6d7832f..7aa2ba88843a2a 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -6,7 +6,7 @@ func.func @memref_store(%v : f32, %i: index, %j: index) {
   // CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
   %0 = memref.alloca() : memref<4x8xf32>
 
-  // CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
+  // CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32
   // CHECK: emitc.assign %[[v]] : f32 to %[[SUBSCRIPT:.*]] : f32
   memref.store %v, %0[%i, %j] : memref<4x8xf32>
   return
@@ -19,7 +19,7 @@ func.func @memref_load(%i: index, %j: index) -> f32 {
   // CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
   %0 = memref.alloca() : memref<4x8xf32>
 
-  // CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
+  // CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32
   // CHECK: %[[VAR:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
   // CHECK: emitc.assign %[[LOAD]] : f32 to %[[VAR]] : f32
   %1 = memref.load %0[%i, %j] : memref<4x8xf32>
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 22423cf61b5556..321e4c01110e82 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -390,8 +390,48 @@ func.func @logical_or_resulterror(%arg0: i32, %arg1: i32) {
 
 // -----
 
-func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: index) {
+func.func @test_subscript_array_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index) {
   // expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}}
-  %0 = emitc.subscript %arg0[%arg2] : <4x8xf32>, index
+  %0 = emitc.subscript %arg0[%arg1] : (!emitc.array<4x8xf32>, index) -> f32
+  return
+}
+
+// -----
+
+func.func @test_subscript_array_index_type_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index, %arg2: f32) {
+  // expected-error @+1 {{'emitc.subscript' op requires index operand 1 to be integer-like, but got 'f32'}}
+  %0 = emitc.subscript %arg0[%arg1, %arg2] : (!emitc.array<4x8xf32>, index, f32) -> f32
+  return
+}
+
+// -----
+
+func.func @test_subscript_array_type_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index, %arg2: index) {
+  // expected-error @+1 {{'emitc.subscript' op requires element type ('f32') and result type ('i32') to match}}
+  %0 = emitc.subscript %arg0[%arg1, %arg2] : (!emitc.array<4x8xf32>, index, index) -> i32
+  return
+}
+
+// -----
+
+func.func @test_subscript_ptr_indices_mismatch(%arg0: !emitc.ptr<f32>, %arg2: index) {
+  // expected-error @+1 {{'emitc.subscript' op requires one index operand, but got 2}}
+  %0 = emitc.subscript %arg0[%arg2, %arg2] : (!emitc.ptr<f32>, index, index) -> f32
+  return
+}
+
+// -----
+
+func.func @test_subscript_ptr_index_type_mismatch(%arg0: !emitc.ptr<f32>, %arg2: f64) {
+  // expected-error @+1 {{'emitc.subscript' op requires index operand to be integer-like, but got 'f64'}}
+  %0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, f64) -> f32
+  return
+}
+
+// -----
+
+func.func @test_subscript_ptr_type_mismatch(%arg0: !emitc.ptr<f32>, %arg2: index) {
+  // expected-error @+1 {{'emitc.subscript' op requires pointee type ('f32') and result type ('f64') to match}}
+  %0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, index) -> f64
   return
 }
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 5f00a295ed740e..ace3670426afa5 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -214,6 +214,13 @@ func.func @test_for_not_index_induction(%arg0 : i16, %arg1 : i16, %arg2 : i16) {
   return
 }
 
+func.func @test_subscript(%arg0 : !emitc.array<2x3xf32>, %arg1 : !emitc.ptr<i32>, %arg2 : !emitc.opaque<"std::map<char, int>">, %idx0 : index, %idx1 : i32, %idx2 : !emitc.opaque<"char">) {
+  %0 = emitc.subscript %arg0[%idx0, %idx1] : (!emitc.array<2x3xf32>, index, i32) -> f32
+  %1 = emitc.subscript %arg1[%idx0] : (!emitc.ptr<i32>, index) -> i32
+  %2 = emitc.subscript %arg2[%idx2] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
+  return
+}
+
 emitc.verbatim "#ifdef __cplusplus"
 emitc.verbatim "extern \"C\" {"
 emitc.verbatim "#endif  // __cplusplus"
diff --git a/mlir/test/Target/Cpp/subscript.mlir b/mlir/test/Target/Cpp/subscript.mlir
index a6c82df9111a79..0b388953c80d37 100644
--- a/mlir/test/Target/Cpp/subscript.mlir
+++ b/mlir/test/Target/Cpp/subscript.mlir
@@ -1,24 +1,44 @@
 // RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
 // RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
 
-func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
-  %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>, index, index
-  %1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>, index, index
+func.func @load_store_array(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
+  %0 = emitc.subscript %arg0[%arg2, %arg3] : (!emitc.array<4x8xf32>, index, index) -> f32
+  %1 = emitc.subscript %arg1[%arg2, %arg3] : (!emitc.array<3x5xf32>, index, index) -> f32
   emitc.assign %0 : f32 to %1 : f32
   return
 }
-// CHECK: void load_store(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
+// CHECK: void load_store_array(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
 // CHECK-SAME:            size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
 // CHECK-NEXT: [[ARR2]][[[I]]][[[J]]] = [[ARR1]][[[I]]][[[J]]];
 
+func.func @load_store_pointer(%arg0: !emitc.ptr<f32>, %arg1: !emitc.ptr<f32>, %arg2: index, %arg3: index) {
+  %0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, index) -> f32
+  %1 = emitc.subscript %arg1[%arg3] : (!emitc.ptr<f32>, index) -> f32
+  emitc.assign %0 : f32 to %1 : f32
+  return
+}
+// CHECK: void load_store_pointer(float* [[PTR1:[^ ]*]], float* [[PTR2:[^ ]*]],
+// CHECK-SAME:            size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
+// CHECK-NEXT: [[PTR2]][[[J]]] = [[PTR1]][[[I]]];
+
+func.func @load_store_opaque(%arg0: !emitc.opaque<"std::map<char, int>">, %arg1: !emitc.opaque<"std::map<char, int>">, %arg2: !emitc.opaque<"char">, %arg3: !emitc.opaque<"char">) {
+  %0 = emitc.subscript %arg0[%arg2] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
+  %1 = emitc.subscript %arg1[%arg3] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
+  emitc.assign %0 : !emitc.opaque<"int"> to %1 : !emitc.opaque<"int">
+  return
+}
+// CHECK: void load_store_opaque(std::map<char, int> [[MAP1:[^ ]*]], std::map<char, int> [[MAP2:[^ ]*]],
+// CHECK-SAME:            char [[I:[^ ]*]], char [[J:[^ ]*]])
+// CHECK-NEXT: [[MAP2]][[[J]]] = [[MAP1]][[[I]]];
+
 emitc.func @func1(%arg0 : f32) {
   emitc.return
 }
 
 emitc.func @call_arg(%arg0: !emitc.array<4x8xf32>, %i: i32, %j: i16,
                      %k: i8) {
-  %0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, i32, i16
-  %1 = emitc.subscript %arg0[%j, %k] : <4x8xf32>, i16, i8
+  %0 = emitc.subscript %arg0[%i, %j] : (!emitc.array<4x8xf32>, i32, i16) -> f32
+  %1 = emitc.subscript %arg0[%j, %k] : (!emitc.array<4x8xf32>, i16, i8) -> f32
 
   emitc.call @func1 (%0) : (f32) -> ()
   emitc.call_opaque "func2" (%1) : (f32) -> ()

Copy link
Member

@marbre marbre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I really like to see progress here. Some first comments, need to take a closer look after lunch.

@@ -63,7 +63,8 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
}

auto subscript = rewriter.create<emitc::SubscriptOp>(
op.getLoc(), operands.getMemref(), operands.getIndices());
op.getLoc(), cast<TypedValue<emitc::ArrayType>>(operands.getMemref()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we allow the builder to take Value and leave the verification to the verifier? This is currently more verbose to write, and when it fails it would assert instead of being a verifier error.

Copy link
Contributor Author

@simon-camp simon-camp Mar 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I switched to TypedValue to distinguish array and pointer values. These are the builders we have now:

static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, TypedValue<ArrayType> array, ValueRange indices);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, TypedValue<PointerType> pointer, Value index);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::Value value, ::mlir::ValueRange indices);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value value, ::mlir::ValueRange indices);
static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});

I will use the third one and pass the type explicitly.
Instead I now try a dyn_cast and fail gracefully.

@simon-camp simon-camp requested a review from mgehre-amd April 2, 2024 12:29
@simon-camp simon-camp merged commit 1f26809 into llvm:main Apr 3, 2024
@simon-camp simon-camp deleted the emitc.subscript-ptr branch April 3, 2024 11:06
mgehre-amd pushed a commit to Xilinx/llvm-project that referenced this pull request Apr 26, 2024
llvm#86266)

For pointer types the indices are restricted to one integer-like
operand.
For opaque types no further restrictions are made.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants