-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg] Introduce transpose semantic to 'linalg.matmul' ops. #104783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write If you have received no comments on your PR for a week, you can request a review If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir-linalg Author: Md Asghar Ahmad Shahid (shahidact) ChangesThe main goal of this patch is to extend the semantic of 'linalg.matmul' named op to include per operand transpose semantic while also laying out a way to move ops definition from OpDSL to tablegen. Hence, it is implemented in tablegen. Transpose semantic is as follows. By default 'linalg.matmul' behavior will remain as is. Transpose semantics can be appiled on per input operand by specifying the optional permutation attributes (namely 'permutationA' for 1st input and 'permutationB' for 2nd input) for each operand explicitly as needed. By default, no transpose is mandated for any of the input operand.
Patch is 23.55 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/104783.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 46b3ec0f60ebfa..8e2e827a12cc4e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1065,78 +1065,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: matmul
- cpp_class_name: MatmulOp
- doc: |-
- Performs a matrix multiplication of two 2D inputs.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- implements:
- - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: A
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
- - !LinalgOperandDefConfig
- name: B
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
- - !LinalgOperandDefConfig
- name: C
- kind: output_tensor
- type_var: U
- shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
- - !LinalgOperandDefConfig
- name: cast
- kind: type_fn_attr
- default_fn: cast_signed
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
- - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
- iterator_types:
- - parallel
- - parallel
- - reduction
- assignments:
- - !ScalarAssign
- arg: C
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: add
- operands:
- - !ScalarExpression
- scalar_arg: C
- - !ScalarExpression
- scalar_fn:
- kind: binary
- fn_name: mul
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: A
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: B
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: quantized_matmul
cpp_class_name: QuantizedMatmulOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index a9007c8db3078e..4ca7c5f0f1f676 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -271,4 +271,5 @@ def Linalg_WinogradOutputTransformOp :
let hasVerifier = 1;
}
+
#endif // LINALG_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index ac61117c3d6e36..5e6940b42db976 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -534,6 +534,106 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
let hasCanonicalizer = 1;
}
+//===----------------------------------------------------------------------===//
+// Op definition for MatmulOp
+//===----------------------------------------------------------------------===//
+
+def MatmulOp : LinalgStructuredBase_Op<"matmul", !listconcat([AttrSizedOperandSegments],
+ /*extraInterfaces=*/[LinalgContractionOpInterface])> {
+
+ let summary = [{Performs a matrix multiplication of two 2D inputs without transpose.}];
+ let description = [{Numeric casting is performed on the operands to the inner multiply,
+ promoting them to the same data type as the accumulator/output.
+
+ Per input operand transpose can be performed by specifying the required permutation
+ attributes (namely 'permutationA' for 1st input and 'permutationB' for 2nd input) for
+ each operand explicitly. By default, no transpose is mandated for each input operand.
+
+ Example:
+ ```
+ %val = linalg.matmul ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
+ outs(%arg2: memref<3x7xf32>)
+ permutationA = [1, 0]
+ permutationB = [0, 1]
+ ```
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ ConfinedAttr<DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{0, 1}">, [DenseArrayCount<2>]>:$permutationA,
+ ConfinedAttr<DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{0, 1}">, [DenseArrayCount<2>]>:$permutationB,
+ DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, MatmulOp::getRegionBuilder());
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildStructuredOp($_builder, $_state, resultTensorTypes,
+ inputs, outputs, attributes, MatmulOp::getRegionBuilder());
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addOperands(operands);
+ $_state.addAttributes(attributes);
+ $_state.addTypes(resultTensorTypes);
+ (void)$_state.addRegion();
+ }]>,
+ OpBuilder<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs, "DenseI64ArrayAttr":$permutationA, "DenseI64ArrayAttr":$permutationB, "Attribute":$cast,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ $_state.addAttribute("permutationA", permutationA);
+ $_state.addAttribute("permutationB", permutationB);
+ $_state.addAttribute("cast", cast);
+ buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+ attributes, MatmulOp::getRegionBuilder());
+ }]>
+
+ ];
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ // Auto-generated.
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+ ArrayAttr getIndexingMaps();
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
+
+ // Generic methods.
+ static unsigned getNumRegionArgs();
+ std::string getLibraryCallName();
+ bool hasDynamicIndexingMaps();
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 99b625d99fec2e..bef19e737ca6c7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -303,6 +303,26 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
return failure();
+ if (parser.parseOptionalKeyword("permutationA").succeeded()) {
+ if (parser.parseEqual())
+ return failure();
+
+ result.attributes.set("permutationA",
+ DenseI64ArrayAttr::parse(parser, Type{}));
+ }
+
+ if (parser.parseOptionalKeyword("permutationB").succeeded()) {
+ if (parser.parseEqual())
+ return failure();
+
+ result.attributes.set("permutationB",
+ DenseI64ArrayAttr::parse(parser, Type{}));
+ }
+
+ // Parse optional attributes.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
// TODO: consider merging results parsing into region parsing.
// Need to wait for declarative assembly resolution to decide.
SmallVector<Type, 1> outputTensorsTypes;
@@ -334,7 +354,8 @@ static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
/*elidedAttrs=*/{"operandSegmentSizes",
// See generated code in
// LinalgNamedStructuredOps.yamlgen.cpp.inc
- "linalg.memoized_indexing_maps"});
+ "linalg.memoized_indexing_maps", "permutationA",
+ "permutationB"});
// Printing is shared with generic ops, except for the region and
// attributes.
@@ -2980,3 +3001,132 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
Location loc) {
return arith::ConstantOp::materialize(builder, value, type, loc);
}
+
+namespace mlir {
+namespace linalg {
+//===----------------------------------------------------------------------===//
+// MatMulOp
+//===----------------------------------------------------------------------===//
+SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
+ return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
+ utils::IteratorType::parallel,
+ utils::IteratorType::reduction};
+}
+
+ArrayAttr MatmulOp::getIndexingMaps() {
+ static const char memoizeAttr[] = "linalg.memoized_indexing_maps";
+ ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr);
+ if (cached)
+ return cached;
+
+ MLIRContext *context = getContext();
+ SmallVector<AffineMap> maps;
+
+ unsigned numResults;
+ SmallVector<AffineExpr, 3> dimReplacements;
+ AffineMap originalMap =
+ llvm::cast<AffineMapAttr>(
+ mlir::parseAttribute("affine_map<(d0, d1, d2)->(d0, d2)>", context))
+ .getValue();
+ numResults = originalMap.getNumResults();
+ for (unsigned i = 0; i < numResults; i++) {
+ AffineExpr expr = originalMap.getResult(getPermutationA()[i]);
+ dimReplacements.push_back(expr);
+ }
+
+ AffineMap newMap =
+ AffineMap::get(originalMap.getNumDims(), originalMap.getNumSymbols(),
+ dimReplacements, context);
+ maps.push_back(newMap);
+ maps.back() =
+ simplifyAffineMap(maps.back().replaceDimsAndSymbols({}, {}, 3, 0));
+
+ originalMap =
+ llvm::cast<AffineMapAttr>(
+ mlir::parseAttribute("affine_map<(d0, d1, d2)->(d2, d1)>", context))
+ .getValue();
+ numResults = originalMap.getNumResults();
+ dimReplacements.clear();
+ for (unsigned i = 0; i < numResults; i++) {
+ AffineExpr expr = originalMap.getResult(getPermutationB()[i]);
+ dimReplacements.push_back(expr);
+ }
+
+ newMap = AffineMap::get(originalMap.getNumDims(), originalMap.getNumSymbols(),
+ dimReplacements, context);
+ maps.push_back(newMap);
+ maps.back() =
+ simplifyAffineMap(maps.back().replaceDimsAndSymbols({}, {}, 3, 0));
+
+ maps.push_back(
+ llvm::cast<AffineMapAttr>(
+ mlir::parseAttribute("affine_map<(d0, d1, d2)->(d0, d1)>", context))
+ .getValue());
+ maps.back() =
+ simplifyAffineMap(maps.back().replaceDimsAndSymbols({}, {}, 3, 0));
+ cached = Builder(context).getAffineMapArrayAttr(maps);
+ getOperation()->setAttr(memoizeAttr, cached);
+ return cached;
+}
+
+unsigned MatmulOp::getNumRegionArgs() { return 3; }
+
+std::string MatmulOp::getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+}
+
+bool MatmulOp::hasDynamicIndexingMaps() { return true; }
+
+void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ assert(3 > 0 && block.getNumArguments() == 3 &&
+ "MatmulOp regionBuilder expects 3 (>=0) args");
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+
+ TypeFn castVal = TypeFn::cast_signed;
+ auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+ return attr.getName() == "cast";
+ });
+ if (castIter != attrs.end()) {
+ if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+ castVal = attr.getValue();
+ }
+
+ Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+ block.getArgument(0));
+ Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
+ block.getArgument(1));
+ Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
+ Value value4 =
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
+ yields.push_back(value4);
+ helper.yieldOutputs(yields);
+}
+
+ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
+ MatmulOp::getRegionBuilder());
+}
+void MatmulOp::print(OpAsmPrinter &p) {
+ printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs());
+ if (!getPermutationA().empty())
+ printDenseI64ArrayAttr(p, getPermutationAAttrName(), getPermutationA());
+
+ if (!getPermutationB().empty())
+ printDenseI64ArrayAttr(p, getPermutationBAttrName(), getPermutationB());
+}
+
+LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+ return memref::foldMemRefCast(*this);
+}
+void MatmulOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (hasPureTensorSemantics())
+ return;
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+} // namespace linalg
+} // namespace mlir
\ No newline at end of file
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 67bde8f736ef46..7ef5de12de5ad3 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -383,23 +383,6 @@ def select(
O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None])
-@linalg_structured_op
-def matmul(
- A=TensorDef(T1, S.M, S.K),
- B=TensorDef(T2, S.K, S.N),
- C=TensorDef(U, S.M, S.N, output=True),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed),
-):
- """Performs a matrix multiplication of two 2D inputs.
-
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
-
-
@linalg_structured_op
def quantized_matmul(
A=TensorDef(T1, S.M, S.K),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 31fac9b4b41659..7c95d9592481e6 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -864,3 +864,65 @@ func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vec
return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @matmul_transpose_a_explicit(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK: linalg.generic
+// CHECK: arith.mulf
+// CHECK: arith.addf
+
+func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) permutationA = [1, 0]
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @matmul_transpose_b_explicit(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK: linalg.generic
+// CHECK: arith.mulf
+// CHECK: arith.addf
+
+func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>) permutationB = [1, 0]
+ return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @matmul_transpose_a_b_explicit(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK: linalg.generic
+// CHECK: arith.mulf
+// CHECK: arith.addf
+
+func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>) permutationA = [1, 0] permutationB = [1, 0]
+ return
+}
+
+// -----
+
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 02ecbed232c8b5..e702125667acc7 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1201,6 +1201,39 @@ func.func @matmul_transpose_a(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %a
// -----
+// CHECK-LABEL: func @matmul_transpose_a_explicit
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) permutationA = [1, 0]
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @matmul_transpose_b_explicit
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<7x5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>) permutationB = [1, 0]
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @matmul_transpose_a_b_explicit
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<7x5xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
+func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+ linalg.matmul ins...
[truncated]
|
Please, rebase branch on top of trunk. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we have discussed this before and I dont see a point in adding a matmul operation that is just folding the transpositions. Also this is just for 2D matmuls. What about batch matmuls? I have said previously, we should just have a single contraction
op that has inedxing maps to represent transpositions and broadcasts and batching. It should be failry simple to have utility methods that give you
- batching dimension(s)
- M dimension(s)
- N dimension(s)
- K dimension(s)
from the indexing maps and you could easily constraint things for the 3D specific cases. I dont really see the point of adding yet anothermatmul
operation that is only generalized along a specific dimension.
Second concern: I think we should try to limit changes to opdsl. Leave the opdsl ops as is, and have a set of named ops that dont conflict but build up towrads V2 of named ops to minimize downstream breakages. At some point in the future (if ever) when all uses have reached a concensus on what to use, then we can deprecate the opdsl path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add some negative tests? Showing wrong attributes, wrong values in attributes, etc?
Once we agree on the design, we'll do the remaining ones. No point in changing too much at the same time. The whole point of this is to have batch/reduce matmuls with transposes without explosion of ops. We will get there, just need one step at a time.
That's a different problem. We need to simplify matmuls first then work on contract. Holding this simple refactoring hostage to a much larger change makes no sense. Right now, Let's not have that discussion again, please.
This does nothing to
Once we're all using the new syntax, we can remove the rest from opDSL (if at all). Right now, we don't. But we had to remove
I don't care about opDSL. Those variations could be in there forever. I want to build a simple set of named ops in Linalg that isn't restricted by opDSL. This is what me, you and @stellaraccident agreed last time we talked. I'm not sure what changed since. |
For the record, this is what we have been proposing for almost 2 years, so we're very much aligned that we need this. Just nothing to do with this PR. |
Yeah, but this seems to transient in terms of dialect/op design. We have a permutation map that is not an indexing map. It changes all the assembly formatting, and then we have to change that again.
I understand the frustation, but I really dont see the point of adding this new operation. Its just combining transposes, but then what about fusing with broadcasts. We have discussed this previously, that pretty much any forward looking change has to account for broadcasting behavior of operands to matmuls. Without that I think this is serving a very narrow use case. Please let me know if I have ever indicated anything to the contrary. Every time this has come up I have mentioned that we might as well just add indexing maps to contraction operations and call it a day. All I am suggesting is that instead of adding a
We might have different understanding of what was agreed upon. I have always pushed back against permutation vectors or broadcast dimension lists. Those all seem very fragile and not future proof. I am suggesting just simple modifications to this PR
|
I agree the permutation map isn't the best way to represent this. But the indexing map is also a "change in syntax". So, perhaps we should just pause and talk about the final
This is not a new operation. It's just We're not combining transposes, we're starting the discussion on how to represent both transpose and broadcast. Perhaps the attribute has thrown it off the course, so @shahidact I recommend we stick to affine maps from now on and get both transpose and broadcast in one PR. However, I'd prefer to have this discussion over
We're not adding a new op, there is no name collision. All users of
That's a non-starter. Contraction is way more complicated than matmul and not viable in the short term. We'll be stuck in bike-shedding for months before any of this becomes useful and we'll end up with multiple variations of syntax. We will do that too very soon! But these must be orthogonal changes for us to make forward progress on the existing models.
We'll update this PR with affine maps for both transpose and broadcast (that was work in progress), but on the existing |
This PR is closed, but just rounding out the discussion.
Ok, my only concern here is that current op definition for
We should discuss that more. I dont think I understand all the issues you anticipate, but thats for later. |
I'm having trouble keeping pace with which PR we're going forward with, but I think I wouldn't insist on such a strict approach. If we agree on the op definition, including the ASM and C++ API, then whether implemented via opdsl or directly in ODS/C++ is immaterial, given prior discussions (where we say we don't care about some of the theoretical things you give up without OpDSL). And doing it purely in ODS/C++ is a code cleanup for the project. I expect the first one of these we do will likely have some bumps -- might be better to try that with one op and work the kinks out. I think the hard part is a) agreeing on the op definition, b) deciding that yes, the pure c++ version is a good drop in replacement. (a) Seems to have plenty of traction, and (b) just requires careful review and some collaboration on testing. Let me know when this PR is respun, and I'm happy to help with (b). |
+1 on this, let's be sure we provide a gentle path to deprecating opdsl, it feels premature to just replace the existing right now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! First round of comments.
2ba79c9
to
121cc0b
Compare
I'll take a look tomorrow. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This is looking much cleaner. I think I have just one comment remaining. I can do a quick turn around of review for just that comment. Everything else looks good.
ab94960
to
c1caf74
Compare
…nd transpose semantic. Goals: 1. To add syntax to matmul without changing any of the existing syntax expectations for current usage. matmul is still just matmul. 2. To expose broadcast and transpose semantics on the three matmul variations: matmul, batch_matmul and batch_reduce_matmul. Scope of this patch: To expose broadcast and transpose semantics on the 'matmul'. The broadcast and transpose semantic is as follows: By default 'linalg.matmul' behavior will remain as is.Broadcast and Transpose semantics can be appiled by specifying the explicit attribute 'indexing_maps' as shown below.This is a list attribute, so the list must include all the maps if specified. Example Transpose: linalg.matmul indexing_maps = [ affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)> ] ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) Example Broadcast: linalg.matmul indexing_maps = [ affine_map<(d0, d1, d2) -> (d2)>, // broadcast affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)> ] ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>)
c1caf74
to
f7d61d4
Compare
@MaheshRavishankar Thanks for the review. I have addressed all your points. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Thanks a lot!
@shahidact Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR. Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues. How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
GCC issues being looked at #111869 |
Fixed in 99c8557 |
This seems to have broken two tests, See https://lab.llvm.org/buildbot/#/builders/138/builds/4872 |
I vote for deleting those tests as the incompatibility with opdsl was a key point of the consensus and these tests are (hopelessly) tied to the linalg.matmul op being built with that mechanism. (whether we revert then delete or delete-forward, I leave to the discretion of those impacted) |
If this is now the intended behavior and the path forward is to delete the test, I'm fine with a delete-forward. I know nothing of this code, but it appears to me that the test in transform.py is trying to test |
…' ops. (#104783)" This reverts commit 0348373 and 99c8557, which is a fix-up on top of the former. I'm reverting because this commit broke two tests: mlir/test/python/integration/dialects/linalg/opsrun.py mlir/test/python/integration/dialects/transform.py See https://lab.llvm.org/buildbot/#/builders/138/builds/4872 I'm not familiar with the tests, so I'm leaving it to the original author to either remove or adapt the broken tests, as discussed here: #104783 (comment)
…lvm#104783) The main goal of this patch is to extend the semantic of 'linalg.matmul' named op to include per operand transpose semantic while also laying out a way to move ops definition from OpDSL to tablegen. Hence, it is implemented in tablegen. Transpose semantic is as follows. By default 'linalg.matmul' behavior will remain as is. Transpose semantics can be appiled on per input operand by specifying the optional permutation attributes (namely 'permutationA' for 1st input and 'permutationB' for 2nd input) for each operand explicitly as needed. By default, no transpose is mandated for any of the input operand. Example: ``` %val = linalg.matmul ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) permutationA = [1, 0] permutationB = [0, 1] ```
…' ops. (llvm#104783)" This reverts commit 0348373 and 99c8557, which is a fix-up on top of the former. I'm reverting because this commit broke two tests: mlir/test/python/integration/dialects/linalg/opsrun.py mlir/test/python/integration/dialects/transform.py See https://lab.llvm.org/buildbot/#/builders/138/builds/4872 I'm not familiar with the tests, so I'm leaving it to the original author to either remove or adapt the broken tests, as discussed here: llvm#104783 (comment)
…ul' ops. (llvm#104783) The main goal of this patch is to extend the semantic of 'linalg.matmul' named op to include per operand transpose semantic while also laying out a way to move ops definition from OpDSL to tablegen. Hence, it is implemented in tablegen. Transpose semantic is as follows. By default 'linalg.matmul' behavior will remain as is. Transpose/broadcast semantics can be appiled explicitly by specifying the optional indexing_map attribute. By default, no transpose/broadcast is mandated. Example: ``` linalg.matmul indexing_maps = [ affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)> ] ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) linalg.matmul indexing_maps = [ affine_map<(d0, d1, d2) -> (d2)>, // broadcast affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)> ] ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) ```
…atmul. The earlier PR(llvm#104783) was reverted due to two failing OpDSL test for linalg.matmul. Since linalg.matmul is now defined using TableGen ODS instead of Python-based OpDSL, these test started failing and needs to be removed/updated. This commit removes/updates the failing obsolete tests from below tests. "mlir/test/python/integration/dialects/linalg/opsrun.py" "mlir/test/python/integration/dialects/transform.py"
…ling obsolete OpDSL tests. (#115319) The earlier PR(#104783) which introduces transpose and broadcast semantic to linalg.matmul was reverted due to two failing OpDSL test for linalg.matmul. Since linalg.matmul is now defined using TableGen ODS instead of Python-based OpDSL, these test started failing and needs to be removed/updated. This commit removes/updates the failing obsolete tests from below files. All other files were part of earlier PR and just cherry picked. "mlir/test/python/integration/dialects/linalg/opsrun.py" "mlir/test/python/integration/dialects/transform.py" --------- Co-authored-by: Renato Golin <[email protected]>
…ling obsolete OpDSL tests. (llvm#115319) The earlier PR(llvm#104783) which introduces transpose and broadcast semantic to linalg.matmul was reverted due to two failing OpDSL test for linalg.matmul. Since linalg.matmul is now defined using TableGen ODS instead of Python-based OpDSL, these test started failing and needs to be removed/updated. This commit removes/updates the failing obsolete tests from below files. All other files were part of earlier PR and just cherry picked. "mlir/test/python/integration/dialects/linalg/opsrun.py" "mlir/test/python/integration/dialects/transform.py" --------- Co-authored-by: Renato Golin <[email protected]>
The main goal of this patch is to extend the semantic of 'linalg.matmul' named op to include per operand transpose semantic while also laying out a way to move ops definition from OpDSL to tablegen. Hence, it is implemented in tablegen. Transpose semantic is as follows.
By default 'linalg.matmul' behavior will remain as is. Transpose semantics can be appiled on per input operand by specifying the optional permutation attributes (namely 'permutationA' for 1st input and 'permutationB' for 2nd input) for each operand explicitly as needed. By default, no transpose is mandated for any of the input operand.