Skip to content

Commit d492166

Browse files
committed
Move getConvOpsAccType from torch-mlir 12250739bfe85b702f9503cad45c2e535ea8eb18 to LLVM
1 parent eb44789 commit d492166

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@ Value getTosaConstShape(ImplicitLocOpBuilder &builder,
9090
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
9191
llvm::ArrayRef<int64_t> shape);
9292

93+
// Get accumulator type for TOSA convolution ops
94+
LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
95+
RankedTensorType inputTy,
96+
RankedTensorType weightTy,
97+
RankedTensorType outputTy, TypeAttr &accType);
98+
9399
namespace {
94100

95101
// Creates a TOSA operation and performs shape inference on the individual

mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
14+
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
1415
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
1516

1617
using namespace mlir;
@@ -182,3 +183,51 @@ Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
182183
ImplicitLocOpBuilder builder(loc, rewriter);
183184
return getTosaConstShape(builder, shape);
184185
}
186+
187+
// AMD: Picked from torch-mlir 12250739bfe85b702f9503cad45c2e535ea8eb18
188+
// Get accumulator type for TOSA convolution ops
189+
LogicalResult mlir::tosa ::getConvOpsAccType(PatternRewriter &rewriter,
190+
RankedTensorType inputTy,
191+
RankedTensorType weightTy,
192+
RankedTensorType outputTy,
193+
TypeAttr &accType) {
194+
auto inputElemTy = inputTy.getElementType();
195+
auto weightElemTy = weightTy.getElementType();
196+
auto outputElemTy = outputTy.getElementType();
197+
198+
auto quantTy = dyn_cast<quant::QuantizedType>(inputElemTy);
199+
if (quantTy)
200+
inputElemTy = quantTy.getStorageType();
201+
202+
// Get TOSA conv ops acc type based on input, weight, and output types
203+
// according to the spec:
204+
// https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d
205+
// https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d
206+
// https://www.mlplatform.org/tosa/tosa_spec.html#_conv3d
207+
//
208+
// For undefined dtypes in TOSA like I64 and F64, acc_type will be set to the
209+
// output type but does not offer any guarantee on the numerical precision
210+
// since such cases will fail TOSA validation.
211+
if ((inputElemTy.isF32() && weightElemTy.isF32() && outputElemTy.isF32()) ||
212+
(inputElemTy.isF16() && weightElemTy.isF16() && outputElemTy.isF16()) ||
213+
(inputElemTy.isBF16() && weightElemTy.isBF16() &&
214+
outputElemTy.isBF16())) {
215+
accType = mlir::TypeAttr::get(rewriter.getF32Type());
216+
} else if (inputElemTy.isInteger(8) &&
217+
(weightElemTy.isInteger(8) || weightElemTy.isInteger(4)) &&
218+
outputElemTy.isInteger(32)) {
219+
accType = mlir::TypeAttr::get(rewriter.getIntegerType(32));
220+
} else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) &&
221+
outputElemTy.isInteger(48)) {
222+
accType = mlir::TypeAttr::get(rewriter.getIntegerType(48));
223+
} else if ((inputElemTy.isFloat8E4M3() && weightElemTy.isFloat8E4M3() &&
224+
outputElemTy.isF16()) ||
225+
(inputElemTy.isFloat8E5M2() && weightElemTy.isFloat8E5M2() &&
226+
outputElemTy.isF16())) {
227+
accType = mlir::TypeAttr::get(rewriter.getF16Type());
228+
} else {
229+
accType = mlir::TypeAttr::get(outputElemTy);
230+
}
231+
232+
return success();
233+
}

0 commit comments

Comments
 (0)