Skip to content

Commit a0989a7

Browse files
[mlir][Linalg] Fix Linalg behavior in the context of vector elemental… (#71041)
… types
1 parent 835c885 commit a0989a7

File tree

5 files changed

+64
-12
lines changed

5 files changed

+64
-12
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -344,17 +344,25 @@ def LinalgStructuredInterface
344344
>,
345345
InterfaceMethod<
346346
/*desc=*/[{
347-
Return the `opOperand` rank or zero for scalars.
347+
Return the `opOperand` rank or zero for scalars or vectors not wrapped within a tensor or a memref.
348348
}],
349349
/*retTy=*/"int64_t",
350350
/*methodName=*/"getRank",
351351
/*args=*/(ins "OpOperand*":$opOperand),
352352
/*methodBody=*/"",
353353
/*defaultImplementation=*/[{
354354
assert(opOperand->getOwner() == this->getOperation());
355-
if (auto shapedType =
356-
::llvm::dyn_cast<ShapedType>(opOperand->get().getType()))
355+
Type t = opOperand->get().getType();
356+
// A VectorType is an elemental type, do not consider its rank for the operand.
357+
if (isa<VectorType>(t))
358+
return 0;
359+
// Tensor and Memref container types have a rank.
360+
if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
361+
// Failsafe.
362+
assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
363+
"expected a ranked tensor or memref in LinalgInterface::getRank");
357364
return shapedType.getRank();
365+
}
358366
return 0;
359367
}]
360368
>,
@@ -384,17 +392,25 @@ def LinalgStructuredInterface
384392
>,
385393
InterfaceMethod<
386394
/*desc=*/[{
387-
Return the `opOperand` shape or an empty vector for scalars.
395+
Return the `opOperand` shape or an empty vector for scalars or vectors
396+
not wrapped within a tensor or a memref.
388397
}],
389398
/*retTy=*/"ArrayRef<int64_t>",
390399
/*methodName=*/"getShape",
391400
/*args=*/(ins "OpOperand*":$opOperand),
392401
/*methodBody=*/"",
393402
/*defaultImplementation=*/[{
394403
assert(opOperand->getOwner() == this->getOperation());
395-
if (auto shapedType =
396-
::llvm::dyn_cast<ShapedType>(opOperand->get().getType()))
404+
Type t = opOperand->get().getType();
405+
// A VectorType is an elemental type, do not consider its rank for the operand.
406+
if (isa<VectorType>(t))
407+
return {};
408+
if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
409+
// Failsafe.
410+
assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
411+
"expected a ranked tensor or memref in LinalgInterface::getRank");
397412
return shapedType.getShape();
413+
}
398414
return {};
399415
}]
400416
>,

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1130,7 +1130,9 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
11301130
"arguments as the number of input/output operands");
11311131

11321132
for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1133-
Type elementType = getElementTypeOrSelf(opOperand->get());
1133+
Type elementType = opOperand->get().getType();
1134+
if (isa<MemRefType, RankedTensorType>(elementType))
1135+
elementType = getElementTypeOrSelf(opOperand->get().getType());
11341136
Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
11351137
if (elementType != argType)
11361138
return op->emitOpError("expected type of bb argument #")

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,12 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
122122
assert(llvm::all_of(outputTypes,
123123
[](Type t) { return llvm::isa<ShapedType>(t); }));
124124

125-
// TODO: atm all operands go through getElementTypeOrSelf,
126-
// reconsider when we have evidence we need to.
127125
SmallVector<Type, 8> argTypes;
128126
SmallVector<Location, 8> argLocs;
129127
for (auto containers : {inputTypes, outputTypes}) {
130128
for (auto t : containers) {
131-
argTypes.push_back(getElementTypeOrSelf(t));
129+
argTypes.push_back(
130+
isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
132131

133132
// TODO: Pass in a proper location here.
134133
argLocs.push_back(opBuilder.getUnknownLoc());
@@ -826,7 +825,9 @@ static void buildGenericRegion(
826825
SmallVector<Location, 4> blockArgLocs;
827826
for (ValueRange container : {inputs, outputs}) {
828827
for (Value v : container) {
829-
blockArgTypes.push_back(getElementTypeOrSelf(v));
828+
Type t = v.getType();
829+
blockArgTypes.push_back(
830+
isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
830831
blockArgLocs.push_back(v.getLoc());
831832
}
832833
}
@@ -1927,7 +1928,9 @@ static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
19271928
for (OpOperand &opOperand : op->getOpOperands()) {
19281929
OpOperand *outputOperand =
19291930
linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
1930-
Type elementType = getElementTypeOrSelf(outputOperand->get().getType());
1931+
Type elementType = outputOperand->get().getType();
1932+
if (isa<MemRefType, RankedTensorType>(elementType))
1933+
elementType = getElementTypeOrSelf(outputOperand->get().getType());
19311934
if (opOperand.get().getType() != elementType)
19321935
return op.emitOpError("type of yield operand ")
19331936
<< (opOperand.getOperandNumber() + 1) << " ("

mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,3 +587,23 @@ func.func @generalize_max(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
587587
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
588588
// CHECK-NEXT: %[[max:.+]] = arith.maximumf %[[BBARG0]], %[[BBARG1]] : f32
589589
// CHECK-NEXT: linalg.yield %[[max]] : f32
590+
591+
// -----
592+
593+
594+
// CHECK-LABEL: func @fill_tensor
595+
func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
596+
%e0 = tensor.empty() : tensor<f32>
597+
%0 = linalg.fill ins(%f : f32) outs(%e0 : tensor<f32>) -> tensor<f32>
598+
// CHECK: linalg.generic
599+
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
600+
// CHECK-NEXT: linalg.yield %[[BBARG0]] : f32
601+
602+
%e1 = tensor.empty() : tensor<vector<2x4xf32>>
603+
%1 = linalg.fill ins(%v : vector<2x4xf32>) outs(%e1 : tensor<vector<2x4xf32>>) -> tensor<vector<2x4xf32>>
604+
// CHECK: linalg.generic
605+
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: vector<2x4xf32>, %[[BBARG1:.+]]: vector<2x4xf32>)
606+
// CHECK-NEXT: linalg.yield %[[BBARG0]] : vector<2x4xf32>
607+
608+
return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
609+
}

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1585,3 +1585,14 @@ func.func @max_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> t
15851585
%1 = linalg.max ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
15861586
return %1 : tensor<4x8x16xf32>
15871587
}
1588+
1589+
// -----
1590+
1591+
// CHECK-LABEL: func @fill_tensor
1592+
func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
1593+
%e0 = tensor.empty() : tensor<f32>
1594+
%0 = linalg.fill ins(%f : f32) outs(%e0 : tensor<f32>) -> tensor<f32>
1595+
%e1 = tensor.empty() : tensor<vector<2x4xf32>>
1596+
%1 = linalg.fill ins(%v : vector<2x4xf32>) outs(%e1 : tensor<vector<2x4xf32>>) -> tensor<vector<2x4xf32>>
1597+
return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
1598+
}

0 commit comments

Comments
 (0)