@@ -38,11 +38,11 @@ Type getScalarType(Type inputType) {
38
38
return inputType;
39
39
}
40
40
41
- // Return the shape of an input value as a list of attributes (static dimensions)
42
- // and values (dynamic dimensions). If 'input' is a scalar, an empty list is
43
- // returned. If 'input' is a tensor, its shape is returned.
44
- SmallVector<OpFoldResult>
45
- getScalarOrTensorShape (OpBuilder &builder, Location loc, Value input) {
41
+ // Return the shape of an input value as a list of attributes (static
42
+ // dimensions) and values (dynamic dimensions). If 'input' is a scalar, an empty
43
+ // list is returned. If 'input' is a tensor, its shape is returned.
44
+ SmallVector<OpFoldResult> getScalarOrTensorShape (OpBuilder &builder,
45
+ Location loc, Value input) {
46
46
if (isa<TensorType>(input.getType ()))
47
47
return tensor::getMixedSizes (builder, loc, input);
48
48
return {};
@@ -100,16 +100,16 @@ std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
100
100
101
101
// Turn input size into 1D tensor
102
102
auto flatShapeType = shape::getExtentTensorType (context, 1 );
103
- auto flatInputShape = builder. create <tensor::FromElementsOp>(
104
- loc, flatShapeType, inputSize);
103
+ auto flatInputShape =
104
+ builder. create <tensor::FromElementsOp>( loc, flatShapeType, inputSize);
105
105
106
106
// Reshape input tensor into 1D
107
107
auto inputType = cast<UnrankedTensorType>(input.getType ());
108
108
auto elementType = inputType.getElementType ();
109
109
auto flatInputType =
110
110
RankedTensorType::get ({ShapedType::kDynamic }, elementType);
111
- auto flatInput = builder.create <tensor::ReshapeOp>(
112
- loc, flatInputType, input, flatInputShape);
111
+ auto flatInput = builder.create <tensor::ReshapeOp>(loc, flatInputType, input,
112
+ flatInputShape);
113
113
return std::make_pair (flatInput, inputShape);
114
114
}
115
115
@@ -135,11 +135,9 @@ std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
135
135
// - inputShape
136
136
// 1D extent tensor containing the shape of the original unranked input.
137
137
//
138
- std::pair<Value, Value> flattenUnrankedTensorAroundAxis (OpBuilder &builder,
139
- Location loc,
140
- Value input,
141
- int64_t axis,
142
- int64_t axisSize) {
138
+ std::pair<Value, Value>
139
+ flattenUnrankedTensorAroundAxis (OpBuilder &builder, Location loc, Value input,
140
+ int64_t axis, int64_t axisSize) {
143
141
// Get full tensor shape
144
142
auto *context = builder.getContext ();
145
143
auto indexType = builder.getIndexType ();
@@ -149,16 +147,20 @@ std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
149
147
// Get shape and sizes on left and right of axis
150
148
auto axisValue = builder.create <arith::ConstantIndexOp>(loc, axis);
151
149
auto axisNextValue = builder.create <arith::ConstantIndexOp>(loc, axis + 1 );
152
- auto shapeLeft = builder.create <shape::SplitAtOp>(
153
- loc, TypeRange{shapeType, shapeType}, inputShape, axisValue)
154
- .getResult (0 );
155
- auto sizeLeft = builder.create <shape::NumElementsOp>(
156
- loc, indexType, shapeLeft);
157
- auto shapeRight = builder.create <shape::SplitAtOp>(
158
- loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue)
159
- .getResult (1 );
160
- auto sizeRight = builder.create <shape::NumElementsOp>(
161
- loc, indexType, shapeRight);
150
+ auto shapeLeft =
151
+ builder
152
+ .create <shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
153
+ inputShape, axisValue)
154
+ .getResult (0 );
155
+ auto sizeLeft =
156
+ builder.create <shape::NumElementsOp>(loc, indexType, shapeLeft);
157
+ auto shapeRight =
158
+ builder
159
+ .create <shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
160
+ inputShape, axisNextValue)
161
+ .getResult (1 );
162
+ auto sizeRight =
163
+ builder.create <shape::NumElementsOp>(loc, indexType, shapeRight);
162
164
163
165
// Compute flat input shape as a 3-element 1D tensor
164
166
auto axisSizeValue = builder.create <arith::ConstantIndexOp>(loc, axisSize);
@@ -171,8 +173,8 @@ std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
171
173
auto elementType = inputType.getElementType ();
172
174
auto flatInputType = RankedTensorType::get (
173
175
{ShapedType::kDynamic , axisSize, ShapedType::kDynamic }, elementType);
174
- auto flatInput = builder.create <tensor::ReshapeOp>(
175
- loc, flatInputType, input, flatInputShape);
176
+ auto flatInput = builder.create <tensor::ReshapeOp>(loc, flatInputType, input,
177
+ flatInputShape);
176
178
177
179
return std::make_pair (flatInput, inputShape);
178
180
}
@@ -190,7 +192,8 @@ Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
190
192
auto inputType = cast<RankedTensorType>(input.getType ());
191
193
auto elementType = inputType.getElementType ();
192
194
auto unrankedType = UnrankedTensorType::get (elementType);
193
- return builder.create <tensor::ReshapeOp>(loc, unrankedType, input, inputShape);
195
+ return builder.create <tensor::ReshapeOp>(loc, unrankedType, input,
196
+ inputShape);
194
197
}
195
198
196
199
// Create a tensor constant containing all scales in a per-channel quantized
@@ -209,7 +212,8 @@ Value materializePerChannelScales(OpBuilder &builder, Location loc,
209
212
auto scaleAttrs = llvm::map_to_vector (scales, [&](double scale) -> Attribute {
210
213
return builder.getFloatAttr (expressedType, scale);
211
214
});
212
- auto tensorType = RankedTensorType::get ({(int64_t ) scales.size ()}, expressedType);
215
+ auto tensorType =
216
+ RankedTensorType::get ({(int64_t )scales.size ()}, expressedType);
213
217
auto scalesAttr = DenseElementsAttr::get (tensorType, scaleAttrs);
214
218
return builder.create <arith::ConstantOp>(loc, tensorType, scalesAttr);
215
219
}
@@ -228,9 +232,8 @@ Value materializePerChannelZeroPoints(
228
232
UniformQuantizedPerAxisType quantizedType) {
229
233
auto zeroPoints = quantizedType.getZeroPoints ();
230
234
auto storageType = quantizedType.getStorageType ();
231
- auto zeroPointAttrs = llvm::map_to_vector (
232
- zeroPoints,
233
- [&](int64_t zeroPoint) -> Attribute {
235
+ auto zeroPointAttrs =
236
+ llvm::map_to_vector (zeroPoints, [&](int64_t zeroPoint) -> Attribute {
234
237
return builder.getIntegerAttr (storageType, zeroPoint);
235
238
});
236
239
auto tensorType =
@@ -239,6 +242,54 @@ Value materializePerChannelZeroPoints(
239
242
return builder.create <arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
240
243
}
241
244
245
+ // Create a tensor constant containing all scales in a sub-channel quantized
246
+ // type. Example:
247
+ //
248
+ // !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
249
+ //
250
+ // produces
251
+ //
252
+ // %cst = arith.constant dense<[[2.0, 3.0], [4.0, 5.0]]> : tensor<2x2xf32>
253
+ //
254
+ Value materializeSubChannelScales (
255
+ OpBuilder &builder, Location loc,
256
+ UniformQuantizedSubChannelType quantizedType) {
257
+ auto scales = quantizedType.getScales ();
258
+ auto expressedType = quantizedType.getExpressedType ();
259
+ auto scaleAttrs = llvm::map_to_vector (
260
+ scales.getValues <APFloat>(), [&](APFloat scale) -> Attribute {
261
+ return builder.getFloatAttr (expressedType, scale);
262
+ });
263
+ auto tensorType =
264
+ RankedTensorType::get (scales.getType ().getShape (), expressedType);
265
+ auto scalesAttr = DenseElementsAttr::get (tensorType, scaleAttrs);
266
+ return builder.create <arith::ConstantOp>(loc, tensorType, scalesAttr);
267
+ }
268
+
269
+ // Create a tensor constant containing all zero points in a sub-channel
270
+ // quantized type. Example:
271
+ //
272
+ // !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
273
+ //
274
+ // produces
275
+ //
276
+ // %cst = arith.constant dense<[[10, 20], [30, 40]]> : tensor<2x2xi8>
277
+ //
278
+ Value materializeSubChannelZeroPoints (
279
+ OpBuilder &builder, Location loc,
280
+ UniformQuantizedSubChannelType quantizedType) {
281
+ auto zeroPoints = quantizedType.getZeroPoints ();
282
+ auto storageType = quantizedType.getStorageType ();
283
+ auto zeroPointAttrs = llvm::map_to_vector (
284
+ zeroPoints.getValues <APInt>(), [&](APInt zeroPoint) -> Attribute {
285
+ return builder.getIntegerAttr (storageType, zeroPoint);
286
+ });
287
+ auto tensorType =
288
+ RankedTensorType::get (zeroPoints.getType ().getShape (), storageType);
289
+ auto zeroPointsAttr = DenseElementsAttr::get (tensorType, zeroPointAttrs);
290
+ return builder.create <arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
291
+ }
292
+
242
293
// Clamp the given scalar or tensor input using the storage bounds encoded in
243
294
// the given quantized type, if present.
244
295
//
@@ -299,7 +350,7 @@ Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
299
350
return builder.create <arith::UIToFPOp>(loc, resultType, input);
300
351
}
301
352
302
- // Quantize a scalar or ranked tensor value. The stored value is clamped using
353
+ // Quantize a scalar or ranked tensor value. The stored value is clamped using
303
354
// the storage bounds encoded in the given quantized type.
304
355
//
305
356
// See function 'convertRanked()' below for a description of the arguments.
@@ -308,8 +359,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
308
359
Value zeroPoint, QuantizedType quantizedType) {
309
360
// Convert scale to tensor if necessary
310
361
auto inputType = input.getType ();
311
- scale = getScalarOrTensorConstant (
312
- builder, loc, scale, inputType, inputShape);
362
+ scale = getScalarOrTensorConstant (builder, loc, scale, inputType, inputShape);
313
363
314
364
// Scale input
315
365
auto scaledValue = builder.create <arith::DivFOp>(loc, input, scale);
@@ -322,8 +372,7 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
322
372
inputShape);
323
373
324
374
// Convert zero point from storage to expressed type
325
- zeroPoint = convertIntegerToFloat (builder, loc, zeroPoint,
326
- scale.getType (),
375
+ zeroPoint = convertIntegerToFloat (builder, loc, zeroPoint, scale.getType (),
327
376
quantizedType.isSigned ());
328
377
329
378
// Add zero point to stored value
@@ -334,9 +383,9 @@ Value quantizeValue(OpBuilder &builder, Location loc, Value input,
334
383
// Convert stored value to storage type
335
384
auto storageScalarOrTensorType =
336
385
getScalarOrTensorType (quantizedType.getStorageType (), inputType);
337
- auto storedValueInt = convertFloatToInteger (
338
- builder, loc, storedValueFloat, storageScalarOrTensorType,
339
- quantizedType.isSigned ());
386
+ auto storedValueInt = convertFloatToInteger (builder, loc, storedValueFloat,
387
+ storageScalarOrTensorType,
388
+ quantizedType.isSigned ());
340
389
341
390
// Clamp stored value it if the storage type is bound
342
391
auto storedValueClamped = clampScalarOrTensor (builder, loc, storedValueInt,
@@ -352,12 +401,11 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
352
401
Value zeroPoint, QuantizedType quantizedType) {
353
402
// Convert scale to tensor if necessary
354
403
auto inputType = input.getType ();
355
- scale = getScalarOrTensorConstant (
356
- builder, loc, scale, inputType, inputShape);
404
+ scale = getScalarOrTensorConstant (builder, loc, scale, inputType, inputShape);
357
405
358
406
// Convert stored value to float
359
- auto result = convertIntegerToFloat (
360
- builder, loc, input, scale. getType (), quantizedType.isSigned ());
407
+ auto result = convertIntegerToFloat (builder, loc, input, scale. getType (),
408
+ quantizedType.isSigned ());
361
409
362
410
// Skip unnecessary computations if no zero point is given
363
411
if (!matchPattern (zeroPoint, m_Zero ())) {
@@ -366,8 +414,7 @@ Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
366
414
inputShape);
367
415
368
416
// Convert zero point from storage to expressed type
369
- zeroPoint = convertIntegerToFloat (builder, loc, zeroPoint,
370
- scale.getType (),
417
+ zeroPoint = convertIntegerToFloat (builder, loc, zeroPoint, scale.getType (),
371
418
quantizedType.isSigned ());
372
419
373
420
// Subtract zero point to stored value
@@ -501,35 +548,33 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
501
548
auto initShape = tensor::getMixedSizes (builder, loc, input);
502
549
Value init = builder.create <tensor::EmptyOp>(loc, initShape, elementType);
503
550
504
- SmallVector<utils::IteratorType> iteratorTypes (
505
- inputRank, utils::IteratorType::parallel);
551
+ SmallVector<utils::IteratorType> iteratorTypes (inputRank,
552
+ utils::IteratorType::parallel);
506
553
auto channelAxisAffineMap = AffineMap::get (
507
554
inputRank, 0 , builder.getAffineDimExpr (channelAxis), context);
508
555
SmallVector<AffineMap> indexingMaps{
509
- builder.getMultiDimIdentityMap (inputRank),
510
- channelAxisAffineMap,
511
- channelAxisAffineMap,
512
- builder.getMultiDimIdentityMap (inputRank)
513
- };
514
- auto result = builder.create <linalg::GenericOp>(
515
- loc,
516
- init.getType (), // resultType
517
- ValueRange{input, scales, zeroPoints}, // inputs
518
- ValueRange{init}, // outputs
519
- indexingMaps,
520
- iteratorTypes,
521
- [&](OpBuilder& builder, Location loc, ValueRange args) {
522
- assert (args.size () == 4 );
523
- auto input = args[0 ];
524
- auto scale = args[1 ];
525
- auto zeroPoint = args[2 ];
526
-
527
- auto result = convertRanked (builder, loc, op, input, {}, scale,
528
- zeroPoint, quantizedType);
529
-
530
- builder.create <linalg::YieldOp>(loc, result);
531
- })
532
- .getResult (0 );
556
+ builder.getMultiDimIdentityMap (inputRank), channelAxisAffineMap,
557
+ channelAxisAffineMap, builder.getMultiDimIdentityMap (inputRank)};
558
+ auto result = builder
559
+ .create <linalg::GenericOp>(
560
+ loc,
561
+ init.getType (), // resultType
562
+ ValueRange{input, scales, zeroPoints}, // inputs
563
+ ValueRange{init}, // outputs
564
+ indexingMaps, iteratorTypes,
565
+ [&](OpBuilder &builder, Location loc, ValueRange args) {
566
+ assert (args.size () == 4 );
567
+ auto input = args[0 ];
568
+ auto scale = args[1 ];
569
+ auto zeroPoint = args[2 ];
570
+
571
+ auto result =
572
+ convertRanked (builder, loc, op, input, {}, scale,
573
+ zeroPoint, quantizedType);
574
+
575
+ builder.create <linalg::YieldOp>(loc, result);
576
+ })
577
+ .getResult (0 );
533
578
534
579
return result;
535
580
}
@@ -551,7 +596,7 @@ Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
551
596
// Flatten unranked tensor into a 3D ranked tensor if necessary
552
597
bool isUnranked = isa<UnrankedTensorType>(input.getType ());
553
598
int64_t channelAxis = quantizedType.getQuantizedDimension ();
554
- int64_t channelAxisSize = (int64_t ) quantizedType.getScales ().size ();
599
+ int64_t channelAxisSize = (int64_t )quantizedType.getScales ().size ();
555
600
Value inputShape;
556
601
if (isUnranked) {
557
602
std::tie (input, inputShape) = flattenUnrankedTensorAroundAxis (
@@ -660,11 +705,17 @@ Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
660
705
return convertPerChannel (builder, loc, op, input,
661
706
uniformQuantizedPerAxisType);
662
707
708
+ if (auto uniformQuantizedSubChannelType =
709
+ dyn_cast<UniformQuantizedSubChannelType>(quantizedType))
710
+ return convertSubChannel (builder, loc, op, input,
711
+ uniformQuantizedSubChannelType);
712
+
663
713
llvm_unreachable (" unexpected quantized type" );
664
714
}
665
715
666
716
// Lowering pattern for 'quant.dcast'
667
- struct DequantizeCastOpConversion : public OpConversionPattern <quant::DequantizeCastOp> {
717
+ struct DequantizeCastOpConversion
718
+ : public OpConversionPattern<quant::DequantizeCastOp> {
668
719
using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
669
720
670
721
LogicalResult
@@ -689,7 +740,8 @@ struct DequantizeCastOpConversion : public OpConversionPattern<quant::Dequantize
689
740
};
690
741
691
742
// Lowering pattern for 'quant.qcast'
692
- struct QuantizeCastOpConversion : public OpConversionPattern <quant::QuantizeCastOp> {
743
+ struct QuantizeCastOpConversion
744
+ : public OpConversionPattern<quant::QuantizeCastOp> {
693
745
using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
694
746
695
747
LogicalResult
@@ -717,12 +769,8 @@ struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
717
769
ConversionTarget target (getContext ());
718
770
target.addLegalOp <quant::StorageCastOp>();
719
771
target.addIllegalDialect <quant::QuantDialect>();
720
- target.addLegalDialect <
721
- arith::ArithDialect,
722
- linalg::LinalgDialect,
723
- shape::ShapeDialect,
724
- tensor::TensorDialect
725
- >();
772
+ target.addLegalDialect <arith::ArithDialect, linalg::LinalgDialect,
773
+ shape::ShapeDialect, tensor::TensorDialect>();
726
774
727
775
if (failed (applyPartialConversion (getOperation (), target,
728
776
std::move (patterns))))
@@ -733,10 +781,8 @@ struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
733
781
} // namespace
734
782
735
783
void populateLowerQuantOpsPatterns (RewritePatternSet &patterns) {
736
- patterns.add <
737
- DequantizeCastOpConversion,
738
- QuantizeCastOpConversion
739
- >(patterns.getContext ());
784
+ patterns.add <DequantizeCastOpConversion, QuantizeCastOpConversion>(
785
+ patterns.getContext ());
740
786
}
741
787
742
788
} // namespace quant
0 commit comments