Skip to content

[mlir][Linalg] Fix Linalg behavior in the context of vector elemental… #71041

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 2, 2023

Conversation

nicolasvasilache
Copy link
Contributor

… types

@llvmbot
Copy link
Member

llvmbot commented Nov 2, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Nicolas Vasilache (nicolasvasilache)

Changes

… types


Full diff: https://github.com/llvm/llvm-project/pull/71041.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td (+22-6)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (+3-1)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+8-5)
  • (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+20)
  • (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+11)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 69ca888a8acdbe0..fbf3f19cde0e9b8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -344,7 +344,7 @@ def LinalgStructuredInterface
     >,
     InterfaceMethod<
       /*desc=*/[{
-        Return the `opOperand` rank or zero for scalars.
+        Return the `opOperand` rank or zero for scalars or vectors not wrapped within a tensor or a memref.
       }],
       /*retTy=*/"int64_t",
       /*methodName=*/"getRank",
@@ -352,9 +352,17 @@ def LinalgStructuredInterface
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         assert(opOperand->getOwner() == this->getOperation());
-        if (auto shapedType =
-              ::llvm::dyn_cast<ShapedType>(opOperand->get().getType()))
+        Type t = opOperand->get().getType();
+        // A VectorType is an elemental type, do not consider its rank for the operand.
+        if (isa<VectorType>(t))
+          return 0;
+        // Tensor and Memref container types have a rank.
+        if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
+          // Failsafe.
+          assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
+                 "expected a ranked tensor or memref in LinalgInterface::getRank");
           return shapedType.getRank();
+        }
         return 0;
       }]
     >,
@@ -384,7 +392,8 @@ def LinalgStructuredInterface
     >,
     InterfaceMethod<
       /*desc=*/[{
-        Return the `opOperand` shape or an empty vector for scalars.
+        Return the `opOperand` shape or an empty vector for scalars or vectors
+        not wrapped within a tensor or a memref.
       }],
       /*retTy=*/"ArrayRef<int64_t>",
       /*methodName=*/"getShape",
@@ -392,9 +401,16 @@ def LinalgStructuredInterface
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         assert(opOperand->getOwner() == this->getOperation());
-        if (auto shapedType =
-              ::llvm::dyn_cast<ShapedType>(opOperand->get().getType()))
+        Type t = opOperand->get().getType();
+        // A VectorType is an elemental type, do not consider its rank for the operand.
+        if (isa<VectorType>(t))
+          return {};
+        if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
+          // Failsafe.
+          assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
+                 "expected a ranked tensor or memref in LinalgInterface::getRank");
           return shapedType.getShape();
+        }
         return {};
       }]
     >,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index dfd6b991e7da159..08d46f236f8ab3b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -1130,7 +1130,9 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
                            "arguments as the number of input/output operands");
 
   for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
-    Type elementType = getElementTypeOrSelf(opOperand->get());
+    Type elementType = opOperand->get().getType();
+    if (isa<MemRefType, RankedTensorType>(elementType))
+      elementType = getElementTypeOrSelf(opOperand->get().getType());
     Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
     if (elementType != argType)
       return op->emitOpError("expected type of bb argument #")
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5457d51db1cc180..5a593fbb2b6024d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -122,13 +122,12 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
   assert(llvm::all_of(outputTypes,
                       [](Type t) { return llvm::isa<ShapedType>(t); }));
 
-  // TODO: atm all operands go through getElementTypeOrSelf,
-  // reconsider when we have evidence we need to.
   SmallVector<Type, 8> argTypes;
   SmallVector<Location, 8> argLocs;
   for (auto containers : {inputTypes, outputTypes}) {
     for (auto t : containers) {
-      argTypes.push_back(getElementTypeOrSelf(t));
+      argTypes.push_back(
+          isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
 
       // TODO: Pass in a proper location here.
       argLocs.push_back(opBuilder.getUnknownLoc());
@@ -826,7 +825,9 @@ static void buildGenericRegion(
   SmallVector<Location, 4> blockArgLocs;
   for (ValueRange container : {inputs, outputs}) {
     for (Value v : container) {
-      blockArgTypes.push_back(getElementTypeOrSelf(v));
+      Type t = v.getType();
+      blockArgTypes.push_back(
+          isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
       blockArgLocs.push_back(v.getLoc());
     }
   }
@@ -1927,7 +1928,9 @@ static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
   for (OpOperand &opOperand : op->getOpOperands()) {
     OpOperand *outputOperand =
         linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
-    Type elementType = getElementTypeOrSelf(outputOperand->get().getType());
+    Type elementType = outputOperand->get().getType();
+    if (isa<MemRefType, RankedTensorType>(elementType))
+      elementType = getElementTypeOrSelf(outputOperand->get().getType());
     if (opOperand.get().getType() != elementType)
       return op.emitOpError("type of yield operand ")
              << (opOperand.getOperandNumber() + 1) << " ("
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 2259d47eb2b2b0d..e852824cdb73675 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -587,3 +587,23 @@ func.func @generalize_max(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
 // CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
 // CHECK-NEXT:      %[[max:.+]] = arith.maximumf %[[BBARG0]], %[[BBARG1]] : f32
 // CHECK-NEXT:      linalg.yield %[[max]] : f32
+
+// -----
+
+
+// CHECK-LABEL: func @fill_tensor
+func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
+  %e0 = tensor.empty() : tensor<f32>
+  %0 = linalg.fill ins(%f : f32) outs(%e0 : tensor<f32>) -> tensor<f32>
+// CHECK: linalg.generic
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      linalg.yield %[[BBARG0]] : f32
+
+  %e1 = tensor.empty() : tensor<vector<2x4xf32>>
+  %1 = linalg.fill ins(%v : vector<2x4xf32>) outs(%e1 : tensor<vector<2x4xf32>>) -> tensor<vector<2x4xf32>>
+// CHECK: linalg.generic
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: vector<2x4xf32>, %[[BBARG1:.+]]: vector<2x4xf32>)
+// CHECK-NEXT:      linalg.yield %[[BBARG0]] : vector<2x4xf32>
+
+  return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
+}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 76146b17014ebb5..5ca35155854d332 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1585,3 +1585,14 @@ func.func @max_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> t
   %1 = linalg.max ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
   return %1 : tensor<4x8x16xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @fill_tensor
+func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
+  %e0 = tensor.empty() : tensor<f32>
+  %0 = linalg.fill ins(%f : f32) outs(%e0 : tensor<f32>) -> tensor<f32>
+  %e1 = tensor.empty() : tensor<vector<2x4xf32>>
+  %1 = linalg.fill ins(%v : vector<2x4xf32>) outs(%e1 : tensor<vector<2x4xf32>>) -> tensor<vector<2x4xf32>>
+  return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
+}

@nicolasvasilache nicolasvasilache merged commit a0989a7 into llvm:main Nov 2, 2023
@nicolasvasilache nicolasvasilache deleted the tensor-of-vector branch June 18, 2024 14:09
@nicolasvasilache nicolasvasilache restored the tensor-of-vector branch June 18, 2024 14:09
@nicolasvasilache nicolasvasilache deleted the tensor-of-vector branch June 18, 2024 14:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants