-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Linalg] Add more specialize patterns #91153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,6 +70,99 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) { | |
return llvm::hasSingleElement(linalgOp.getBlock()->getOperations()); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// FillOpInterface implementation | ||
//===----------------------------------------------------------------------===// | ||
std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) { | ||
// Structural. | ||
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() || | ||
genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1) | ||
return std::nullopt; | ||
|
||
// Input should be referenced and init should not. | ||
if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) || | ||
genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0))) | ||
return std::nullopt; | ||
|
||
OpOperand *value = genericOp.getDpsInputOperand(0); | ||
if (!genericOp.isScalar(value)) | ||
return std::nullopt; | ||
|
||
Block *body = genericOp.getBody(); | ||
if (body->getOperations().size() != 1) | ||
return std::nullopt; | ||
|
||
auto yieldOp = dyn_cast<linalg::YieldOp>(body->back()); | ||
if (!yieldOp || yieldOp.getNumOperands() != 1 || | ||
yieldOp->getOperand(0) != body->getArgument(0)) | ||
return std::nullopt; | ||
return value->get(); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Elementwise Single Unary/Binary-OpInterface implementation | ||
//===----------------------------------------------------------------------===// | ||
static bool | ||
isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp, | ||
unsigned arity) { | ||
// Check all loops are parallel, and have only tensor semantics. | ||
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() || | ||
genericOp.getNumLoops() < 1 || !genericOp.hasPureTensorSemantics()) | ||
return false; | ||
|
||
// Check there are arity-inputs, 1-output and all are identity-maps. | ||
if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 || | ||
!llvm::all_of(genericOp.getIndexingMapsArray(), | ||
[](AffineMap map) { return map.isIdentity(); })) | ||
return false; | ||
|
||
// Init should not be referenced for elementwise operations. | ||
if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0))) | ||
return false; | ||
|
||
// A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such | ||
// as resulting from producer-consumer fusion. Here, we restrict to two ops in | ||
// the body, where the first is the elementwise single op and the second a | ||
// yield. | ||
Block *body = genericOp.getBody(); | ||
if (body->getOperations().size() != 2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like an unnecessary restriction. You could have an "elementwise operation" that cannot be a single instruction, but a sequence. SHouldnt matter. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right, a truly isaElementwiseUnaryOp could be a sequence. Changed the API name to be more specific to context (isaElemwiseSingleUnaryOrBinaryOpInterface). As the objective here is raising to a single named op e.g. linalg.addrather than series of it. Actually come to think of it, probably un-fuse followed by generic->named is the way rather than unthreading it all here. Not so much for this diff, but for binary-op the elementwise semantics is more interesting - There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, thanks! |
||
return false; | ||
|
||
Operation *op = &body->front(); | ||
if (op->getNumOperands() != arity || op->getNumResults() != 1) | ||
return false; | ||
|
||
auto yieldOp = dyn_cast<linalg::YieldOp>(body->back()); | ||
if (!yieldOp || yieldOp.getNumOperands() != 1 || | ||
yieldOp->getOperand(0).getDefiningOp() != op) | ||
return false; | ||
return true; | ||
} | ||
|
||
bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp genericOp) { | ||
// All basic elemwise checks. | ||
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 1)) | ||
return false; | ||
|
||
// Check input is actully used. | ||
if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0))) | ||
return false; | ||
return true; | ||
} | ||
|
||
bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp genericOp) { | ||
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 2)) | ||
return false; | ||
|
||
// Check both inputs are used (elementwise). | ||
OpOperand *inputOpOperand0 = genericOp.getDpsInputOperand(0); | ||
OpOperand *inputOpOperand1 = genericOp.getDpsInputOperand(1); | ||
if (!genericOp.payloadUsesValueFromOperand(inputOpOperand0) || | ||
!genericOp.payloadUsesValueFromOperand(inputOpOperand1)) | ||
return false; | ||
return true; | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ContractionOpInterface implementation | ||
//===----------------------------------------------------------------------===// | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s | ||
|
||
#map = affine_map<(d0, d1) -> (d0, d1)> | ||
func.func @specialize_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { | ||
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) { | ||
^bb0(%in: f32, %in_0: f32, %out: f32): | ||
%1 = arith.addf %in, %in_0 : f32 | ||
linalg.yield %1 : f32 | ||
} -> tensor<?x?xf32> | ||
return %0 : tensor<?x?xf32> | ||
} | ||
// CHECK-LABEL: specialize_add | ||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> | ||
// CHECK-NOT: linalg.generic | ||
// CHECK: linalg.add ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32> | ||
|
||
func.func @specialize_sub(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { | ||
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) { | ||
^bb0(%in: f32, %in_0: f32, %out: f32): | ||
%1 = arith.subf %in, %in_0 : f32 | ||
linalg.yield %1 : f32 | ||
} -> tensor<?x?xf32> | ||
return %0 : tensor<?x?xf32> | ||
} | ||
// CHECK-LABEL: specialize_sub | ||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> | ||
// CHECK-NOT: linalg.generic | ||
// CHECK: linalg.sub ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32> | ||
|
||
func.func @specialize_sub_swapped_operands(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { | ||
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) { | ||
^bb0(%in: f32, %in_0: f32, %out: f32): | ||
%1 = arith.subf %in_0, %in : f32 | ||
linalg.yield %1 : f32 | ||
} -> tensor<?x?xf32> | ||
return %0 : tensor<?x?xf32> | ||
} | ||
// CHECK-LABEL: specialize_sub | ||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> | ||
// CHECK-NOT: linalg.generic | ||
// CHECK: linalg.sub ins(%[[ARG1]], %[[ARG0]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32> | ||
|
||
func.func @specialize_mul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { | ||
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) { | ||
^bb0(%in: f32, %in_0: f32, %out: f32): | ||
%1 = arith.mulf %in, %in_0 : f32 | ||
linalg.yield %1 : f32 | ||
} -> tensor<?x?xf32> | ||
return %0 : tensor<?x?xf32> | ||
} | ||
// CHECK-LABEL: specialize_mul | ||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> | ||
// CHECK-NOT: linalg.generic | ||
// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32> | ||
|
||
func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { | ||
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) { | ||
^bb0(%in: f32, %in_0: f32, %out: f32): | ||
%1 = arith.divf %in, %in_0 : f32 | ||
linalg.yield %1 : f32 | ||
} -> tensor<?x?xf32> | ||
return %0 : tensor<?x?xf32> | ||
} | ||
// CHECK-LABEL: specialize_div | ||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> | ||
// CHECK-NOT: linalg.generic | ||
// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32> | ||
|
||
|
||
module attributes {transform.with_named_sequence} { | ||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { | ||
%0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op | ||
%1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op | ||
transform.yield | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s | ||
|
||
#umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)> | ||
func.func @specialize_exp(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { | ||
%0 = linalg.generic | ||
{indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]} | ||
ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) { | ||
^bb0(%in: f32, %out: f32): | ||
%1 = math.exp %in : f32 | ||
linalg.yield %1 : f32 | ||
} -> tensor<?x?x?xf32> | ||
return %0 : tensor<?x?x?xf32> | ||
} | ||
// CHECK-LABEL: specialize_exp | ||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> | ||
// CHECK-NOT: linalg.generic | ||
// CHECK: linalg.exp ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> | ||
|
||
module attributes {transform.with_named_sequence} { | ||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { | ||
%0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op | ||
%1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op | ||
transform.yield | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is related to https://discourse.llvm.org/t/notes-from-the-mlir-upstream-round-table-eurollvm-2024/78374/11?u=maheshravishankar . Please correct me if I am wrong, but IMO this is too restrictive. It is perfectly reasonable for binary operations to have some "explicit broadcasting support". Is this already an assumption of these ops, or is this being added here?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MaheshRavishankar : Good point on broadcast. I hope I got your exact question right.
implicit broadcast is not supported by linalg.add implementation e.g.
= linalg.add ins(%arg0, %arg1 : tensor<10xf32>, tensor<10x100xf32>) outs(%arg2: tensor<10x100xf32>) -> tensor<10x100xf32>
error: 'linalg.add' op expected operand rank (1) to match the result rank of indexing_map #0 (2)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok thanks.