Skip to content

Commit 4e0a2fc

Browse files
committed
Lowing to linalg ops
1 parent 0aaadb4 commit 4e0a2fc

File tree

2 files changed

+194
-84
lines changed

2 files changed

+194
-84
lines changed

mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp

Lines changed: 130 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ Type getScalarType(Type inputType) {
3838
return inputType;
3939
}
4040

41-
// Return the shape of an input value as a list of attributes (static dimensions)
42-
// and values (dynamic dimensions). If 'input' is a scalar, an empty list is
43-
// returned. If 'input' is a tensor, its shape is returned.
44-
SmallVector<OpFoldResult>
45-
getScalarOrTensorShape(OpBuilder &builder, Location loc, Value input) {
41+
// Return the shape of an input value as a list of attributes (static
42+
// dimensions) and values (dynamic dimensions). If 'input' is a scalar, an empty
43+
// list is returned. If 'input' is a tensor, its shape is returned.
44+
SmallVector<OpFoldResult> getScalarOrTensorShape(OpBuilder &builder,
45+
Location loc, Value input) {
4646
if (isa<TensorType>(input.getType()))
4747
return tensor::getMixedSizes(builder, loc, input);
4848
return {};
@@ -100,16 +100,16 @@ std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
100100

101101
// Turn input size into 1D tensor
102102
auto flatShapeType = shape::getExtentTensorType(context, 1);
103-
auto flatInputShape = builder.create<tensor::FromElementsOp>(
104-
loc, flatShapeType, inputSize);
103+
auto flatInputShape =
104+
builder.create<tensor::FromElementsOp>(loc, flatShapeType, inputSize);
105105

106106
// Reshape input tensor into 1D
107107
auto inputType = cast<UnrankedTensorType>(input.getType());
108108
auto elementType = inputType.getElementType();
109109
auto flatInputType =
110110
RankedTensorType::get({ShapedType::kDynamic}, elementType);
111-
auto flatInput = builder.create<tensor::ReshapeOp>(
112-
loc, flatInputType, input, flatInputShape);
111+
auto flatInput = builder.create<tensor::ReshapeOp>(loc, flatInputType, input,
112+
flatInputShape);
113113
return std::make_pair(flatInput, inputShape);
114114
}
115115

@@ -135,11 +135,9 @@ std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
135135
// - inputShape
136136
// 1D extent tensor containing the shape of the original unranked input.
137137
//
138-
std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
139-
Location loc,
140-
Value input,
141-
int64_t axis,
142-
int64_t axisSize) {
138+
std::pair<Value, Value>
139+
flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input,
140+
int64_t axis, int64_t axisSize) {
143141
// Get full tensor shape
144142
auto *context = builder.getContext();
145143
auto indexType = builder.getIndexType();
@@ -149,16 +147,20 @@ std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
149147
// Get shape and sizes on left and right of axis
150148
auto axisValue = builder.create<arith::ConstantIndexOp>(loc, axis);
151149
auto axisNextValue = builder.create<arith::ConstantIndexOp>(loc, axis + 1);
152-
auto shapeLeft = builder.create<shape::SplitAtOp>(
153-
loc, TypeRange{shapeType, shapeType}, inputShape, axisValue)
154-
.getResult(0);
155-
auto sizeLeft = builder.create<shape::NumElementsOp>(
156-
loc, indexType, shapeLeft);
157-
auto shapeRight = builder.create<shape::SplitAtOp>(
158-
loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue)
159-
.getResult(1);
160-
auto sizeRight = builder.create<shape::NumElementsOp>(
161-
loc, indexType, shapeRight);
150+
auto shapeLeft =
151+
builder
152+
.create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
153+
inputShape, axisValue)
154+
.getResult(0);
155+
auto sizeLeft =
156+
builder.create<shape::NumElementsOp>(loc, indexType, shapeLeft);
157+
auto shapeRight =
158+
builder
159+
.create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
160+
inputShape, axisNextValue)
161+
.getResult(1);
162+
auto sizeRight =
163+
builder.create<shape::NumElementsOp>(loc, indexType, shapeRight);
162164

163165
// Compute flat input shape as a 3-element 1D tensor
164166
auto axisSizeValue = builder.create<arith::ConstantIndexOp>(loc, axisSize);
@@ -171,8 +173,8 @@ std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
171173
auto elementType = inputType.getElementType();
172174
auto flatInputType = RankedTensorType::get(
173175
{ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
174-
auto flatInput = builder.create<tensor::ReshapeOp>(
175-
loc, flatInputType, input, flatInputShape);
176+
auto flatInput = builder.create<tensor::ReshapeOp>(loc, flatInputType, input,
177+
flatInputShape);
176178

177179
return std::make_pair(flatInput, inputShape);
178180
}
@@ -190,7 +192,8 @@ Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
190192
auto inputType = cast<RankedTensorType>(input.getType());
191193
auto elementType = inputType.getElementType();
192194
auto unrankedType = UnrankedTensorType::get(elementType);
193-
return builder.create<tensor::ReshapeOp>(loc, unrankedType, input, inputShape);
195+
return builder.create<tensor::ReshapeOp>(loc, unrankedType, input,
196+
inputShape);
194197
}
195198

196199
// Create a tensor constant containing all scales in a per-channel quantized
@@ -209,7 +212,8 @@ Value materializePerChannelScales(OpBuilder &builder, Location loc,
209212
auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute {
210213
return builder.getFloatAttr(expressedType, scale);
211214
});
212-
auto tensorType = RankedTensorType::get({(int64_t) scales.size()}, expressedType);
215+
auto tensorType =
216+
RankedTensorType::get({(int64_t)scales.size()}, expressedType);
213217
auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
214218
return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
215219
}
@@ -228,9 +232,8 @@ Value materializePerChannelZeroPoints(
228232
UniformQuantizedPerAxisType quantizedType) {
229233
auto zeroPoints = quantizedType.getZeroPoints();
230234
auto storageType = quantizedType.getStorageType();
231-
auto zeroPointAttrs = llvm::map_to_vector(
232-
zeroPoints,
233-
[&](int64_t zeroPoint) -> Attribute {
235+
auto zeroPointAttrs =
236+
llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute {
234237
return builder.getIntegerAttr(storageType, zeroPoint);
235238
});
236239
auto tensorType =
@@ -239,6 +242,54 @@ Value materializePerChannelZeroPoints(
239242
return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
240243
}
241244

245+
// Create a tensor constant containing all scales in a sub-channel quantized
246+
// type. Example:
247+
//
248+
// !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
249+
//
250+
// produces
251+
//
252+
// %cst = arith.constant dense<[[2.0, 3.0], [4.0, 5.0]]> : tensor<2x2xf32>
253+
//
254+
Value materializeSubChannelScales(
255+
OpBuilder &builder, Location loc,
256+
UniformQuantizedSubChannelType quantizedType) {
257+
auto scales = quantizedType.getScales();
258+
auto expressedType = quantizedType.getExpressedType();
259+
auto scaleAttrs = llvm::map_to_vector(
260+
scales.getValues<APFloat>(), [&](APFloat scale) -> Attribute {
261+
return builder.getFloatAttr(expressedType, scale);
262+
});
263+
auto tensorType =
264+
RankedTensorType::get(scales.getType().getShape(), expressedType);
265+
auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
266+
return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
267+
}
268+
269+
// Create a tensor constant containing all zero points in a sub-channel
270+
// quantized type. Example:
271+
//
272+
// !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
273+
//
274+
// produces
275+
//
276+
// %cst = arith.constant dense<[[10, 20], [30, 40]]> : tensor<2x2xi8>
277+
//
278+
Value materializeSubChannelZeroPoints(
279+
OpBuilder &builder, Location loc,
280+
UniformQuantizedSubChannelType quantizedType) {
281+
auto zeroPoints = quantizedType.getZeroPoints();
282+
auto storageType = quantizedType.getStorageType();
283+
auto zeroPointAttrs = llvm::map_to_vector(
284+
zeroPoints.getValues<APInt>(), [&](APInt zeroPoint) -> Attribute {
285+
return builder.getIntegerAttr(storageType, zeroPoint);
286+
});
287+
auto tensorType =
288+
RankedTensorType::get(zeroPoints.getType().getShape(), storageType);
289+
auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
290+
return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
291+
}
292+
242293
// Clamp the given scalar or tensor input using the storage bounds encoded in
243294
// the given quantized type, if present.
244295
//
@@ -299,7 +350,7 @@ Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
299350
return builder.create<arith::UIToFPOp>(loc, resultType, input);
300351
}
301352

302-
// Quantize a scalar or ranked tensor value. The stored value is clamped using
353+
// Quantize a scalar or ranked tensor value. The stored value is clamped using
303354
// the storage bounds encoded in the given quantized type.
304355
//
305356
// See function 'convertRanked()' below for a description of the arguments.
@@ -308,8 +359,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
308359
Value zeroPoint, QuantizedType quantizedType) {
309360
// Convert scale to tensor if necessary
310361
auto inputType = input.getType();
311-
scale = getScalarOrTensorConstant(
312-
builder, loc, scale, inputType, inputShape);
362+
scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
313363

314364
// Scale input
315365
auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
@@ -322,8 +372,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
322372
inputShape);
323373

324374
// Convert zero point from storage to expressed type
325-
zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
326-
scale.getType(),
375+
zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
327376
quantizedType.isSigned());
328377

329378
// Add zero point to stored value
@@ -334,9 +383,9 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
334383
// Convert stored value to storage type
335384
auto storageScalarOrTensorType =
336385
getScalarOrTensorType(quantizedType.getStorageType(), inputType);
337-
auto storedValueInt = convertFloatToInteger(
338-
builder, loc, storedValueFloat, storageScalarOrTensorType,
339-
quantizedType.isSigned());
386+
auto storedValueInt = convertFloatToInteger(builder, loc, storedValueFloat,
387+
storageScalarOrTensorType,
388+
quantizedType.isSigned());
340389

341390
// Clamp stored value it if the storage type is bound
342391
auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
@@ -352,12 +401,11 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
352401
Value zeroPoint, QuantizedType quantizedType) {
353402
// Convert scale to tensor if necessary
354403
auto inputType = input.getType();
355-
scale = getScalarOrTensorConstant(
356-
builder, loc, scale, inputType, inputShape);
404+
scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
357405

358406
// Convert stored value to float
359-
auto result = convertIntegerToFloat(
360-
builder, loc, input, scale.getType(), quantizedType.isSigned());
407+
auto result = convertIntegerToFloat(builder, loc, input, scale.getType(),
408+
quantizedType.isSigned());
361409

362410
// Skip unnecessary computations if no zero point is given
363411
if (!matchPattern(zeroPoint, m_Zero())) {
@@ -366,8 +414,7 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
366414
inputShape);
367415

368416
// Convert zero point from storage to expressed type
369-
zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
370-
scale.getType(),
417+
zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
371418
quantizedType.isSigned());
372419

373420
// Subtract zero point to stored value
@@ -501,35 +548,33 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
501548
auto initShape = tensor::getMixedSizes(builder, loc, input);
502549
Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
503550

504-
SmallVector<utils::IteratorType> iteratorTypes(
505-
inputRank, utils::IteratorType::parallel);
551+
SmallVector<utils::IteratorType> iteratorTypes(inputRank,
552+
utils::IteratorType::parallel);
506553
auto channelAxisAffineMap = AffineMap::get(
507554
inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
508555
SmallVector<AffineMap> indexingMaps{
509-
builder.getMultiDimIdentityMap(inputRank),
510-
channelAxisAffineMap,
511-
channelAxisAffineMap,
512-
builder.getMultiDimIdentityMap(inputRank)
513-
};
514-
auto result = builder.create<linalg::GenericOp>(
515-
loc,
516-
init.getType(), // resultType
517-
ValueRange{input, scales, zeroPoints}, // inputs
518-
ValueRange{init}, // outputs
519-
indexingMaps,
520-
iteratorTypes,
521-
[&](OpBuilder& builder, Location loc, ValueRange args) {
522-
assert(args.size() == 4);
523-
auto input = args[0];
524-
auto scale = args[1];
525-
auto zeroPoint = args[2];
526-
527-
auto result = convertRanked(builder, loc, op, input, {}, scale,
528-
zeroPoint, quantizedType);
529-
530-
builder.create<linalg::YieldOp>(loc, result);
531-
})
532-
.getResult(0);
556+
builder.getMultiDimIdentityMap(inputRank), channelAxisAffineMap,
557+
channelAxisAffineMap, builder.getMultiDimIdentityMap(inputRank)};
558+
auto result = builder
559+
.create<linalg::GenericOp>(
560+
loc,
561+
init.getType(), // resultType
562+
ValueRange{input, scales, zeroPoints}, // inputs
563+
ValueRange{init}, // outputs
564+
indexingMaps, iteratorTypes,
565+
[&](OpBuilder &builder, Location loc, ValueRange args) {
566+
assert(args.size() == 4);
567+
auto input = args[0];
568+
auto scale = args[1];
569+
auto zeroPoint = args[2];
570+
571+
auto result =
572+
convertRanked(builder, loc, op, input, {}, scale,
573+
zeroPoint, quantizedType);
574+
575+
builder.create<linalg::YieldOp>(loc, result);
576+
})
577+
.getResult(0);
533578

534579
return result;
535580
}
@@ -551,7 +596,7 @@ Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
551596
// Flatten unranked tensor into a 3D ranked tensor if necessary
552597
bool isUnranked = isa<UnrankedTensorType>(input.getType());
553598
int64_t channelAxis = quantizedType.getQuantizedDimension();
554-
int64_t channelAxisSize = (int64_t) quantizedType.getScales().size();
599+
int64_t channelAxisSize = (int64_t)quantizedType.getScales().size();
555600
Value inputShape;
556601
if (isUnranked) {
557602
std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
@@ -660,11 +705,17 @@ Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
660705
return convertPerChannel(builder, loc, op, input,
661706
uniformQuantizedPerAxisType);
662707

708+
if (auto uniformQuantizedSubChannelType =
709+
dyn_cast<UniformQuantizedSubChannelType>(quantizedType))
710+
return convertSubChannel(builder, loc, op, input,
711+
uniformQuantizedSubChannelType);
712+
663713
llvm_unreachable("unexpected quantized type");
664714
}
665715

666716
// Lowering pattern for 'quant.dcast'
667-
struct DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
717+
struct DequantizeCastOpConversion
718+
: public OpConversionPattern<quant::DequantizeCastOp> {
668719
using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
669720

670721
LogicalResult
@@ -689,7 +740,8 @@ struct DequantizeCastOpConversion : public OpConversionPattern<quant::Dequantize
689740
};
690741

691742
// Lowering pattern for 'quant.qcast'
692-
struct QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
743+
struct QuantizeCastOpConversion
744+
: public OpConversionPattern<quant::QuantizeCastOp> {
693745
using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
694746

695747
LogicalResult
@@ -717,12 +769,8 @@ struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
717769
ConversionTarget target(getContext());
718770
target.addLegalOp<quant::StorageCastOp>();
719771
target.addIllegalDialect<quant::QuantDialect>();
720-
target.addLegalDialect<
721-
arith::ArithDialect,
722-
linalg::LinalgDialect,
723-
shape::ShapeDialect,
724-
tensor::TensorDialect
725-
>();
772+
target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
773+
shape::ShapeDialect, tensor::TensorDialect>();
726774

727775
if (failed(applyPartialConversion(getOperation(), target,
728776
std::move(patterns))))
@@ -733,10 +781,8 @@ struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
733781
} // namespace
734782

735783
void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) {
736-
patterns.add<
737-
DequantizeCastOpConversion,
738-
QuantizeCastOpConversion
739-
>(patterns.getContext());
784+
patterns.add<DequantizeCastOpConversion, QuantizeCastOpConversion>(
785+
patterns.getContext());
740786
}
741787

742788
} // namespace quant

0 commit comments

Comments
 (0)