Skip to content

[MLIR][Linalg] Introduce SpecializeOp #70326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ FailureOr<ConvolutionDimensions> inferConvolutionDims(LinalgOp linalgOp);
// TODO: embed within `isa<ConvolutionOpInterface>` if possible / natural.
bool isaConvolutionOpInterface(LinalgOp linalgOp);

/// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
bool isaCopyOpInterface(LinalgOp linalgOp);

namespace detail {

/// Returns true if the block contains a contraction of the following form:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,43 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
}];
}

//===----------------------------------------------------------------------===//
// SpecializeOp
//===----------------------------------------------------------------------===//

def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Transforms a generic operation into the equivalent named form.

#### Return modes

This operation ignores non-Linalg ops and drops them in the return. If all
the operations referred to by the `target` handle specialize, the transform
succeeds; otherwise, the operation produces a silenceable failure. The return
handle points to only the subset of successfully produced equivalent named
operations, which can be empty or contain the original ops if they were already
in named form. The supported specialization to named Linalg operations are:
- linalg.copy of any rank.
}];

let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$transformed);
let assemblyFormat =
"$target attr-dict `:` "
"custom<SemiFunctionType>(type($target), type($transformed))";

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::linalg::LinalgOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}

//===----------------------------------------------------------------------===//
// InterchangeOp
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,11 @@ FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter,
FailureOr<GenericOp> generalizeNamedOp(RewriterBase &rewriter,
LinalgOp namedOp);

/// Create a namedOp from the given GenericOp and replace the GenericOp.
/// Currently we can specialize only trivial linalg copy operations.
FailureOr<LinalgOp> specializeGenericOp(RewriterBase &rewriter,
GenericOp genericOp);

/// Create a new buffer using the `allocationFn` provided. The size of this
/// buffer is the smallest constant bounding size along each dimension that
/// can be computed for the size of the result of `subView`. Returns the
Expand Down
22 changes: 22 additions & 0 deletions mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ using namespace mlir::linalg;
//===----------------------------------------------------------------------===//
// Interface utility functions
//===----------------------------------------------------------------------===//

bool linalg::detail::canOpOperandsBeDroppedImpl(
linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
SmallVector<AffineMap> indexingMaps;
Expand All @@ -48,6 +49,27 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
}

//===----------------------------------------------------------------------===//
// CopyOpInterface implementation
//===----------------------------------------------------------------------===//

bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
// Structural.
if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
return false;

// Operands and maps.
if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
return false;
auto mapRange = linalgOp.getIndexingMapsArray();
if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
!mapRange.back().isIdentity()) {
return false;
}
// Region.
return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
}

//===----------------------------------------------------------------------===//
// ContractionOpInterface implementation
//===----------------------------------------------------------------------===//
Expand Down
24 changes: 24 additions & 0 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,30 @@ transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
return emitDefaultSilenceableFailure(target);
}

//===----------------------------------------------------------------------===//
// SpecializeOp
//===----------------------------------------------------------------------===/

DiagnosedSilenceableFailure
transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
// Exit early if the operation is not a generic.
if (!isa<GenericOp>(target)) {
results.push_back(target);
return DiagnosedSilenceableFailure::success();
}
rewriter.setInsertionPoint(target);
FailureOr<LinalgOp> named =
specializeGenericOp(rewriter, cast<GenericOp>(target));
if (succeeded(named)) {
results.push_back(named->getOperation());
return DiagnosedSilenceableFailure::success();
}
return emitDefaultSilenceableFailure(target);
}

//===----------------------------------------------------------------------===//
// InterchangeOp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
NamedOpConversions.cpp
Padding.cpp
Promotion.cpp
Specialize.cpp
Split.cpp
SplitReduction.cpp
SubsetHoisting.cpp
Expand Down
32 changes: 32 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//===- Specialize.cpp - linalg generic ops to named ops ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a method to specialize generic operations to named
// operations. Conceptually it is the opposite of generalize.cpp.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "linalg-specialization"

using namespace mlir;
using namespace mlir::linalg;

FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
GenericOp genericOp) {
if (isaCopyOpInterface(genericOp)) {
LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
return namedOp;
}
return failure();
}
143 changes: 143 additions & 0 deletions mlir/test/Dialect/Linalg/transform-op-specialize.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
#map2 = affine_map<(d0, d1) -> (d1, d0)>

func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
// expected-note @below {{when applied to this op}}
linalg.generic {
indexing_maps = [#map1, #map],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : memref<?xf32>) outs(%arg1 : memref<?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
}
return
}

func.func @not_a_copy_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
// expected-note @below {{when applied to this op}}
linalg.generic {
indexing_maps = [#map, #map],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf32>) {
^bb0(%in: f32, %out: f32):
%0 = arith.addf %in, %out : f32
linalg.yield %0 : f32
}
return
}

func.func @transpose_op_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
// expected-note @below {{when applied to this op}}
linalg.generic {
indexing_maps = [#map, #map2],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
}
return
}

func.func @copy_with_up_cast(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
// expected-note @below {{when applied to this op}}
linalg.generic {
indexing_maps = [#map, #map],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : memref<?x?xf16>) outs(%arg1 : memref<?x?xf32>) {
^bb0(%in: f16, %out: f32):
%0 = arith.extf %in : f16 to f32
linalg.yield %0 : f32
}
return
}

func.func @copy_with_down_cast(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf16>) {
// expected-note @below {{when applied to this op}}
linalg.generic {
indexing_maps = [#map, #map],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf16>) {
^bb0(%in: f32, %out: f16):
%0 = arith.truncf %in : f32 to f16
linalg.yield %0 : f16
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @below {{failed to apply}}
%1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>

func.func @specialize_trivial_copy_memref(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
linalg.generic {
indexing_maps = [#map, #map],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
}
return
}

// CHECK-LABEL: specialize_trivial_copy_memref
// CHECK-SAME: %[[ARG0:.+]]: memref<?x?xf32>, %[[ARG1:.+]]: memref<?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: linalg.copy ins(%[[ARG0]] : memref<?x?xf32>) outs(%[[ARG1]] : memref<?x?xf32>)

#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

func.func @specialize_trivial_copy_tensor(%arg0: tensor<?x?x?xf32>,
%arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%0 = linalg.generic {
indexing_maps = [#map1, #map1],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}

// CHECK-LABEL: specialize_trivial_copy_tensor
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: %{{.+}} = linalg.copy ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>)

func.func @already_trivial_copy_memref(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
linalg.copy ins(%arg0: memref<?x?xf32>) outs(%arg1: memref<?x?xf32>)
return
}

// CHECK-LABEL: already_trivial_copy_memref
// CHECK-SAME: %[[ARG0:.+]]: memref<?x?xf32>, %[[ARG1:.+]]: memref<?x?xf32>
// CHECK: linalg.copy ins(%[[ARG0]] : memref<?x?xf32>) outs(%[[ARG1]] : memref<?x?xf32>)

func.func @already_trivial_copy_tensor(%arg0: tensor<?x?x?xf32>,
%arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%0 = linalg.copy ins(%arg0: tensor<?x?x?xf32>) outs(%arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}

// CHECK-LABEL: already_trivial_copy_tensor
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>
// CHECK: %{{.+}} = linalg.copy ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>)

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}