|
11 | 11 | //===----------------------------------------------------------------------===//
|
12 | 12 |
|
13 | 13 | #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
|
| 14 | +#include "mlir/Dialect/Quant/IR/QuantTypes.h" |
14 | 15 | #include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
15 | 16 |
|
16 | 17 | using namespace mlir;
|
@@ -182,3 +183,51 @@ Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
|
182 | 183 | ImplicitLocOpBuilder builder(loc, rewriter);
|
183 | 184 | return getTosaConstShape(builder, shape);
|
184 | 185 | }
|
| 186 | + |
| 187 | +// AMD: Picked from torch-mlir 12250739bfe85b702f9503cad45c2e535ea8eb18 |
| 188 | +// Get accumulator type for TOSA convolution ops |
| 189 | +LogicalResult mlir::tosa ::getConvOpsAccType(PatternRewriter &rewriter, |
| 190 | + RankedTensorType inputTy, |
| 191 | + RankedTensorType weightTy, |
| 192 | + RankedTensorType outputTy, |
| 193 | + TypeAttr &accType) { |
| 194 | + auto inputElemTy = inputTy.getElementType(); |
| 195 | + auto weightElemTy = weightTy.getElementType(); |
| 196 | + auto outputElemTy = outputTy.getElementType(); |
| 197 | + |
| 198 | + auto quantTy = dyn_cast<quant::QuantizedType>(inputElemTy); |
| 199 | + if (quantTy) |
| 200 | + inputElemTy = quantTy.getStorageType(); |
| 201 | + |
| 202 | + // Get TOSA conv ops acc type based on input, weight, and output types |
| 203 | + // according to the spec: |
| 204 | + // https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d |
| 205 | + // https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d |
| 206 | + // https://www.mlplatform.org/tosa/tosa_spec.html#_conv3d |
| 207 | + // |
| 208 | + // For undefined dtypes in TOSA like I64 and F64, acc_type will be set to the |
| 209 | + // output type but does not offer any guarantee on the numerical precision |
| 210 | + // since such cases will fail TOSA validation. |
| 211 | + if ((inputElemTy.isF32() && weightElemTy.isF32() && outputElemTy.isF32()) || |
| 212 | + (inputElemTy.isF16() && weightElemTy.isF16() && outputElemTy.isF16()) || |
| 213 | + (inputElemTy.isBF16() && weightElemTy.isBF16() && |
| 214 | + outputElemTy.isBF16())) { |
| 215 | + accType = mlir::TypeAttr::get(rewriter.getF32Type()); |
| 216 | + } else if (inputElemTy.isInteger(8) && |
| 217 | + (weightElemTy.isInteger(8) || weightElemTy.isInteger(4)) && |
| 218 | + outputElemTy.isInteger(32)) { |
| 219 | + accType = mlir::TypeAttr::get(rewriter.getIntegerType(32)); |
| 220 | + } else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) && |
| 221 | + outputElemTy.isInteger(48)) { |
| 222 | + accType = mlir::TypeAttr::get(rewriter.getIntegerType(48)); |
| 223 | + } else if ((inputElemTy.isFloat8E4M3() && weightElemTy.isFloat8E4M3() && |
| 224 | + outputElemTy.isF16()) || |
| 225 | + (inputElemTy.isFloat8E5M2() && weightElemTy.isFloat8E5M2() && |
| 226 | + outputElemTy.isF16())) { |
| 227 | + accType = mlir::TypeAttr::get(rewriter.getF16Type()); |
| 228 | + } else { |
| 229 | + accType = mlir::TypeAttr::get(outputElemTy); |
| 230 | + } |
| 231 | + |
| 232 | + return success(); |
| 233 | +} |
0 commit comments