Skip to content

Misc fixes #35

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
d319b8c
[mlir][tosa] Fix constant folding of tosa.mul
mgehre-amd May 16, 2023
07d8cd0
Support lowering tosa.custom_op to another dialect operation.
ttjost Jun 1, 2023
e42a0b8
Merge pull request #37 from Xilinx/tiagot.tosa_custom_op_support
mgehre-amd Jun 1, 2023
88b3950
Adds lit_tests for tosa.custom_op lowering to LinAlg.
ttjost Jun 1, 2023
813e43e
Merge pull request #38 from Xilinx/tiagot.lit_test_tosa_custom_op
ljfitz Jun 1, 2023
4880bfc
Lowering for 'tosa.scatter'
rafaelubalmw May 30, 2023
5de799c
Generic support for legalizing tosa.custom_op into another dialect
ttjost Jun 1, 2023
115147d
Merge pull request #39 from Xilinx/tiagot.generic_support_tosa_custom_op
mgehre-amd Jun 2, 2023
9bccb5b
Merge pull request #40 from Xilinx/matthias.pick_upstream_tosa_scatter
mgehre-amd Jun 2, 2023
9b67e54
TOSA: Fold concat where one argument has zero elements (#41)
mgehre-amd Jun 12, 2023
0749db1
Some tosa verifiers (#42)
mgehre-amd Jun 12, 2023
20fa0e8
TorchToLinAlg: fix tosa.clamp legalization for integer types. (#43)
ttjost Jun 12, 2023
c2367d6
TosaToLinAlg: fix tosa.cast legalization of FP->Int for non FP32 type…
ttjost Jun 13, 2023
3663896
Merge commit '6cf7fe4a9a715bcdf3f4913753109e22dfc9940b' into HEAD
mgehre-amd Jun 13, 2023
bc0e73a
TOSA: Allow to transpose 7D tensors and higher
mgehre-amd Jun 14, 2023
9f46ca4
Merge pull request #49 from Xilinx/matthias.update_llvm_to_green-6cf7…
mgehre-amd Jun 14, 2023
053adf3
Merge pull request #50 from Xilinx/matthias.allow_all_tensor_transpose
mgehre-amd Jun 15, 2023
62ad0bd
Merge commit '2b4807ba' into matthias.bump_llvm_green-2b4807
mgehre-amd Jun 15, 2023
ae98bc3
Merge remote-tracking branch 'xlnx/misc_fixes' into matthias.bump_llv…
mgehre-amd Jun 16, 2023
791249e
Merge pull request #54 from Xilinx/matthias.bump_llvm_green-2b4807
mgehre-amd Jun 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1410,6 +1410,8 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [
Tosa_Tensor:$output
);

let hasFolder = 1;

let hasCanonicalizer = 1;
let hasFolder = 1;

Expand Down Expand Up @@ -1552,6 +1554,7 @@ def Tosa_SliceOp: Tosa_Op<"slice", [

let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1592,12 +1595,12 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [
}];

let arguments = (ins
Tosa_Tensor1Dto6D:$input1,
Tosa_Tensor:$input1,
Tosa_Int32Or64Tensor:$perms
);

let results = (
outs Tosa_Tensor1Dto6D:$output
outs Tosa_Tensor:$output
);

let extraClassDeclaration = [{
Expand Down
58 changes: 44 additions & 14 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,23 +388,23 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,

if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
auto intTy = cast<IntegerType>(elementTy);
int32_t min = static_cast<int32_t>(
cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue());
int32_t max = static_cast<int32_t>(
cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue());
int64_t min =
cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue();
int64_t max =
cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();

if (intTy.isUnsignedInteger()) {
min = std::max<int32_t>(min, 0);
max = std::min<int32_t>(
min = std::max(min, (int64_t)0);
max = std::min(
max,
APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
} else {
min = std::max<int32_t>(
min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
.getSExtValue());
max = std::min<int32_t>(
max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
.getSExtValue());
min =
std::max(min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
.getSExtValue());
max =
std::min(max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
.getSExtValue());
}

auto minVal = rewriter.create<arith::ConstantIntOp>(
Expand Down Expand Up @@ -478,16 +478,33 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}

if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
auto intMin = rewriter.create<arith::ConstantOp>(
Value intMin = rewriter.create<arith::ConstantOp>(
loc, rewriter.getF32FloatAttr(
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue()));

auto intMax = rewriter.create<arith::ConstantOp>(
Value intMax = rewriter.create<arith::ConstantOp>(
loc, rewriter.getF32FloatAttr(
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue()));

// Since F32 constants are created, we may still need to convert them to
// the correct type.
auto convertType = [&](Type ty, Value arg) {
auto argTy = arg.getType();
bool bitExtend =
argTy.getIntOrFloatBitWidth() < ty.getIntOrFloatBitWidth();
if (ty != argTy) {
if (!bitExtend)
arg = rewriter.create<arith::TruncFOp>(loc, ty, arg);
else
arg = rewriter.create<arith::ExtFOp>(loc, ty, arg);
}
return arg;
};
intMin = convertType(srcTy, intMin);
intMax = convertType(srcTy, intMax);

auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);

auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter);
Expand All @@ -513,6 +530,18 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
}

// tosa::CustomOp
if (auto customOp = dyn_cast<tosa::CustomOp>(op)) {
// Only legalize tosa.custom_op's that are marked as implementable with
// 'linalg.generic' by looking at the 'implementation_attrs' attribute
auto implementationAttr = customOp.getImplementationAttrs();
if (implementationAttr == "linalg.generic") {
OperationState state(loc, customOp.getIdentifierAttr(), args,
resultTypes);
return rewriter.create(state)->getResult(0);
}
}

(void)rewriter.notifyMatchFailure(
op, "unhandled op for linalg body calculation for elementwise op");
return nullptr;
Expand Down Expand Up @@ -2231,6 +2260,7 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
PointwiseConverter<tosa::FloorOp>,
PointwiseConverter<tosa::ClampOp>,
PointwiseConverter<tosa::SigmoidOp>,
PointwiseConverter<tosa::CustomOp>,
IdentityNConverter<tosa::IdentityOp>,
ReduceConverter<tosa::ReduceAllOp>,
ReduceConverter<tosa::ReduceAnyOp>,
Expand Down
21 changes: 21 additions & 0 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,28 @@ OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
return {};
}

static bool hasZeroSize(Type ty) {
auto ranked = dyn_cast<RankedTensorType>(ty);
if (!ranked)
return false;
return any_of(ranked.getShape(), [](auto d) { return d == 0; });
}

OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
/// Remove operands that have zero elements.
bool changed = false;
for (size_t i = 0; i < getInput1().size(); ) {
auto input = getInput1()[i];
if (hasZeroSize(input.getType())) {
getInput1Mutable().erase(i);
changed = true;
} else {
++i;
}
}
if (changed)
return getResult();

// Fold consecutive concats on the same axis into a single op.
// Keep track of the operands so we are able to construct a new concat
// later. Conservatively assume that we double the number of operands when
Expand Down
70 changes: 70 additions & 0 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,76 @@ mlir::LogicalResult tosa::ReshapeOp::verify() {
return emitOpError() << "Cannot reshape " << inputElementsNum
<< " elements into " << outputElementsNum;
}

if ((int64_t)getNewShape().size() != outputType.getRank()) {
return emitOpError() << "rank of newShape (" << getNewShape().size()
<< ") and output ("
<< outputType.getRank()
<< ") must match";
}

for (int64_t dim=0; dim < outputType.getRank(); ++dim) {
if (getNewShape()[dim] != -1 && getNewShape()[dim] != outputType.getShape()[dim]) {
return emitOpError() << "newShape attribute (" << getNewShape()[dim]
<< ") does not match output type ("
<< outputType.getShape()[dim]
<< ") in dimension " << dim;
}
}
}
return mlir::success();
}

mlir::LogicalResult tosa::SliceOp::verify() {
// TODO: Complete verification
ShapedType inputType = getInput().getType().cast<ShapedType>();
ShapedType outputType = getType().cast<ShapedType>();

if (inputType.getRank() != outputType.getRank()) {
return emitOpError() << "rank of input (" << inputType.getRank()
<< ") and output ("
<< outputType.getRank()
<< ") must match";
}

if ((int64_t)getSize().size() != outputType.getRank()) {
return emitOpError() << "rank of size (" << getSize().size()
<< ") and output ("
<< outputType.getRank()
<< ") must match";
}
for (int64_t dim=0; dim < outputType.getRank(); ++dim) {
if (getSize()[dim] != -1 && !outputType.isDynamicDim(dim) &&
getSize()[dim] != outputType.getShape()[dim]) {
return emitOpError() << "size attribute (" << getSize()[dim]
<< ") does not match output type ("
<< outputType.getShape()[dim] << ") in dimension "
<< dim;
}
}

if ((int64_t)getStart().size() != inputType.getRank()) {
return emitOpError() << "rank of start (" << getStart().size()
<< ") and input ("
<< inputType.getRank()
<< ") must match";
}
if ((int64_t)getSize().size() != inputType.getRank()) {
return emitOpError() << "rank of size (" << getSize().size()
<< ") and input ("
<< inputType.getRank()
<< ") must match";
}

for (int i = 0; i < outputType.getRank(); ++i) {
auto dimSize = inputType.getShape()[i];
if (getSize()[i] != -1 && dimSize != ShapedType::kDynamic &&
getStart()[i] + getSize()[i] > inputType.getShape()[i]) {
return emitOpError() << "start (" << getStart()[i] << ") plus size ("
<< getSize()[i]
<< ") goes out of bounds of input size ("
<< inputType.getShape()[i] << ") in dimension " << i;
}
}
return mlir::success();
}
Expand Down
39 changes: 39 additions & 0 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,17 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
// CHECK: arith.extf
%0 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xf32>

// CHECK: linalg.generic
// CHECK: %[[C_LOWEST:.+]] = arith.constant -2.14748365E+9
// CHECK: %[[C_MAX:.+]] = arith.constant 2.14748365E+9
// CHECK: arith.truncf %[[C_LOWEST]] : f32 to f16
// CHECK: arith.truncf %[[C_MAX]] : f32 to f16
// CHECK: math.roundeven
// CHECK: arith.minf
// CHECK: arith.maxf
// CHECK: arith.fptosi
%1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi32>

return
}

Expand Down Expand Up @@ -1414,6 +1425,34 @@ func.func @select_fp32(%arg0: tensor<1x1x5x5xi1>, %arg1: tensor<1x12x5x5xf32>, %

// -----

// CHECK-LABEL: @test_custom_ops
func.func @test_custom_ops(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> () {
// CHECK: linalg.generic
// CHECK: math.sin
// CHECK: linalg.generic
// CHECK: math.atan2
%2 = "tosa.custom"(%arg0) <{config = "UNDEF", identifier = "math.sin", implementation_attrs = "linalg.generic"}> : (tensor<1xf32>) -> tensor<1xf32>
%3 = "tosa.custom"(%arg0, %arg1) <{config = "UNDEF", identifier = "math.atan2", implementation_attrs = "linalg.generic"}> : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>

return
}


// -----

// CHECK-LABEL: @test_custom_ops_dyn
func.func @test_custom_ops_dyn(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> () {
// CHECK: linalg.generic
// CHECK: math.cos
// CHECK: linalg.generic
// CHECK: math.atan2
%2 = "tosa.custom"(%arg0) <{config = "UNDEF", identifier = "math.cos", implementation_attrs = "linalg.generic"}> : (tensor<?xf32>) -> tensor<?xf32>
%3 = "tosa.custom"(%arg0, %arg1) <{config = "UNDEF", identifier = "math.atan2", implementation_attrs = "linalg.generic"}> : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>

return
}
// -----

// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>

Expand Down
31 changes: 20 additions & 11 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
return %1 : tensor<4xi8>
}

// CHECK-LABEL: @concat_fold_zero
func.func @concat_fold_zero(%arg0: tensor<?x0xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x2xf32>) -> tensor<?x3xf32> {
// CHECK: "tosa.concat"(%arg1, %arg2) <{axis = 1 : i64}>
%0 = "tosa.concat"(%arg0, %arg1, %arg2) {axis = 1 : i64}: (tensor<?x0xf32>, tensor<?x1xf32>, tensor<?x2xf32>) -> tensor<?x3xf32>
return %0 : tensor<?x3xf32>
}

// CHECK-LABEL: @concat_fold
func.func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg0
Expand Down Expand Up @@ -507,17 +514,19 @@ func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 :

// -----

// CHECK-LABEL: @canonicalize_concat_slice_on_non_concat_axis
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.slice"(%[[VAL_0]]) <{size = array<i64: 1, 6, 12>, start = array<i64: 0, 0, 0>}> : (tensor<1x12x12xf32>) -> tensor<1x6x12xf32>
// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_1]]) <{size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 12>}> : (tensor<1x12x12xf32>) -> tensor<1x3x12xf32>
// CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x6x12xf32>, tensor<1x3x12xf32>
func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x6x12xf32>, tensor<1x3x12xf32>) {
%0 = "tosa.concat"(%arg0, %arg1) {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
%1 = "tosa.slice"(%0) {size = array<i64: 1, 6, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x6x12xf32>
%2 = "tosa.slice"(%0) {size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 12>} : (tensor<1x12x24xf32>) -> tensor<1x3x12xf32>
return %1, %2 : tensor<1x6x12xf32>, tensor<1x3x12xf32>
}

// xHECK-LABEL: @canonicalize_concat_slice_on_non_concat_axis
// xHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
// xHECK: %[[VAL_2:.*]] = "tosa.slice"(%[[VAL_0]]) <{size = array<i64: 1, 6, 12>, start = array<i64: 0, 0, 0>}> : (tensor<1x12x12xf32>) -> tensor<1x6x12xf32>
// TODO: This upstream test case seems broken because the start of the next line (12) is out of bounds with the input shape
// xHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_1]]) <{size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 12>}> : (tensor<1x12x12xf32>) -> tensor<1x3x12xf32>
// xHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x6x12xf32>, tensor<1x3x12xf32>
//func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x6x12xf32>, tensor<1x3x12xf32>) {
// %0 = "tosa.concat"(%arg0, %arg1) {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
// %1 = "tosa.slice"(%0) {size = array<i64: 1, 6, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x6x12xf32>
// %2 = "tosa.slice"(%0) {size = array<i64: 1, 3, 12>, start = array<i64: 0, 3, 12>} : (tensor<1x12x24xf32>) -> tensor<1x3x12xf32>
// return %1, %2 : tensor<1x6x12xf32>, tensor<1x3x12xf32>
//}

// -----

Expand Down