Skip to content

Commit 8dde3f4

Browse files
authored
feat(TosaToNamedLinalg): add FillOp conversion for Bias (#18)
1 parent fb2842e commit 8dde3f4

File tree

4 files changed

+159
-1
lines changed

4 files changed

+159
-1
lines changed

mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,10 @@ bool getConstShapeValues(Operation *op,
243243
// returns a small vector of int64_t values that attr contains
244244
SmallVector<int64_t> convertFromIntAttr(const DenseElementsAttr &attr,
245245
const int rank);
246+
247+
// Returns the attribute that stores the constant value of a ConstantLike
248+
// operation. Prerequisite is `op` to be a `ConstantLike` operation.
249+
Attribute getConstantAttribute(Operation *op);
246250
} // namespace tosa
247251
} // namespace mlir
248252

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
2424
#include "mlir/IR/Matchers.h"
2525
#include "mlir/IR/PatternMatch.h"
26+
#include "mlir/Interfaces/InferTypeOpInterface.h"
2627
#include "mlir/Transforms/DialectConversion.h"
2728
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2829

29-
#include "mlir/Interfaces/InferTypeOpInterface.h"
30+
#include "llvm/ADT/TypeSwitch.h"
3031

3132
#include <numeric>
3233
#include <type_traits>
@@ -118,6 +119,71 @@ static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source,
118119
/*symbolCount=*/0, sourceDims, rewriter.getContext());
119120
}
120121

122+
static mlir::Value createScalarConstantFromTensor(PatternRewriter &rewriter,
123+
Operation *source,
124+
Value result) {
125+
// Get the constant as the attribute from the constant operation
126+
Attribute value = tosa::getConstantAttribute(source);
127+
auto attr = dyn_cast<SplatElementsAttr>(value);
128+
129+
// Ensure the constant is splat so we can convert to a scalar
130+
if (!attr) {
131+
return Value();
132+
}
133+
134+
// Filter for constants based on Ranked Tensors
135+
auto resultTy = dyn_cast<RankedTensorType>(result.getType());
136+
if (!resultTy) {
137+
return Value();
138+
}
139+
140+
// Create a scalar constant with the same type as the result tensor.
141+
// We assume the ResultType follows the TOSA spec, in that it can be an
142+
// accumulator type that is same as or larger in bitwidth than the splat
143+
// constant.
144+
Value scalarValue =
145+
llvm::TypeSwitch<Attribute, Value>(attr.getSplatValue<Attribute>())
146+
.Case([&](FloatAttr attr) {
147+
return rewriter
148+
// Create a float constant with the same type as the result
149+
// tensor and use the host systems double type as APFloat
150+
// checks bitwidths so in the case of different input -> output
151+
// types the conversion will fail.
152+
.create<arith::ConstantOp>(
153+
source->getLoc(),
154+
FloatAttr::get(resultTy.getElementType(),
155+
attr.getValue().convertToDouble()))
156+
.getResult();
157+
})
158+
.Case([&](IntegerAttr attr) {
159+
// At the moment all profiles are signed, so for the unsigned case
160+
// if it does happen bail out.
161+
if (resultTy.getElementType().isUnsignedInteger()) {
162+
return Value();
163+
}
164+
// Create a scalar that follows the result type. In the case of i8,
165+
// the result can be i32. So we perform the conversion at
166+
// compile-time.
167+
return rewriter
168+
.create<arith::ConstantOp>(
169+
source->getLoc(),
170+
IntegerAttr::get(resultTy.getElementType(),
171+
attr.getValue().getSExtValue()))
172+
.getResult();
173+
})
174+
.Default([](Attribute) { return Value(); });
175+
176+
// Could not create a scalar constant due to an unsupported type
177+
if (!scalarValue) {
178+
return Value();
179+
}
180+
181+
return rewriter
182+
.create<linalg::FillOp>(source->getLoc(), ValueRange{scalarValue},
183+
ValueRange{result})
184+
.getResult(0);
185+
}
186+
121187
// Broadcast the source value to all the outer dimensions of the result value.
122188
// If required, the element type is expanded using an arith.extsi or arith.extf
123189
// operation as appropriate.
@@ -126,6 +192,17 @@ static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter,
126192
Value result) {
127193
ShapedType resultTy = cast<ShapedType>(result.getType());
128194
const int64_t resultRank = resultTy.getRank();
195+
196+
// Attempt to create a FillOp in linalg if the constant is a splat value.
197+
if (source.getDefiningOp() &&
198+
matchPattern(source.getDefiningOp(), m_Constant())) {
199+
auto scalar = createScalarConstantFromTensor(
200+
rewriter, source.getDefiningOp(), result);
201+
if (scalar) {
202+
return scalar;
203+
}
204+
}
205+
129206
// Creating maps for the input and output of the broacast-like generic op.
130207
SmallVector<AffineMap, 2> indexingMaps;
131208
indexingMaps.push_back(getBroadcastingMap(rewriter, source, result));

mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,23 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) {
213213
}
214214
return {};
215215
}
216+
217+
Attribute mlir::tosa::getConstantAttribute(Operation *op) {
218+
219+
if (!op || !op->hasTrait<OpTrait::ConstantLike>())
220+
return Attribute();
221+
222+
if (auto constOp = dyn_cast<ConstOp>(op)) {
223+
return constOp.getValues();
224+
}
225+
226+
// TOSA names constants in the operation as "value" while linalg names them
227+
// with "values". Here we search for both and find the first.
228+
const SmallVector<const char *> possibleAttributes = {"value", "values"};
229+
for (llvm::StringRef name : possibleAttributes) {
230+
if (op->hasAttr(name)) {
231+
return op->getAttr(name);
232+
}
233+
}
234+
return Attribute();
235+
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,63 @@ func.func @conv2d_f16_f32_acc(%input: tensor<1x49x42x27xf16>, %weights: tensor<2
672672

673673
// -----
674674

675+
// CHECK-LABEL: @conv2d_bias_broadcast_f32
676+
func.func @conv2d_bias_broadcast_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>) -> () {
677+
%bias = "tosa.const"() <{values = dense<4.20> : tensor<28xf32>}> : () -> tensor<28xf32>
678+
// CHECK-DAG: %[[CST:.+]] = arith.constant 4.200000e+00 : f32
679+
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
680+
// CHECK: %[[BIAS:.+]] = linalg.fill
681+
// CHECK-SAME: ins(%[[CST]]
682+
// CHECK-SAME: outs(%[[EMPTY]]{{.+}} -> tensor<1x45x40x28xf32>
683+
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc
684+
// CHECK-SAME: outs(%[[BIAS]]
685+
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
686+
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
687+
%0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x45x40x28xf32>
688+
return
689+
}
690+
691+
// -----
692+
693+
// CHECK-LABEL: @conv2d_dynamic_batch_bias_broadcast_f32
694+
// CHECK-SAME: (%[[INPUT:.+]]: tensor<?x49x42x27xf32>
695+
func.func @conv2d_dynamic_batch_bias_broadcast_f32(%input: tensor<?x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>) -> () {
696+
%bias = "tosa.const"() <{values = dense<4.20> : tensor<28xf32>}> : () -> tensor<28xf32>
697+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
698+
// CHECK: %[[DIM:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x49x42x27xf32>
699+
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x45x40x28xf32>
700+
// CHECK: %[[CST:.+]] = arith.constant 4.200000e+00 : f32
701+
// CHECK: %[[BIAS:.+]] = linalg.fill
702+
// CHECK-SAME: ins(%[[CST]]
703+
// CHECK-SAME: outs(%[[EMPTY]]{{.+}} -> tensor<?x45x40x28xf32>
704+
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc
705+
// CHECK-SAME: outs(%[[BIAS]]
706+
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
707+
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
708+
%0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x45x40x28xf32>
709+
return
710+
}
711+
712+
// -----
713+
714+
// CHECK-LABEL: @conv2d_bias_broadcast_i8_acc_i32
715+
func.func @conv2d_bias_broadcast_i8_acc_i32(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x3x3x27xi8>) -> () {
716+
%bias = "tosa.const"() <{values = dense<42> : tensor<28xi8>}> : () -> tensor<28xi8>
717+
// CHECK-DAG: %[[CST:.+]] = arith.constant 42 : i32
718+
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
719+
// CHECK: %[[BIAS:.+]] = linalg.fill
720+
// CHECK-SAME: ins(%[[CST]]
721+
// CHECK-SAME: outs(%[[EMPTY]]{{.+}} -> tensor<1x45x40x28xi32>
722+
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc
723+
// CHECK-SAME: outs(%[[BIAS]]
724+
%input_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
725+
%weight_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
726+
%0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xi8>, tensor<28x3x3x27xi8>, tensor<28xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x45x40x28xi32>
727+
return
728+
}
729+
730+
// -----
731+
675732
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
676733
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
677734

0 commit comments

Comments
 (0)