@@ -314,14 +314,14 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
314
314
auto convertedType = cast<MemRefType>(adaptor.getBase ().getType ());
315
315
Type oldElementType = op.getValueToStore ().getType ().getElementType ();
316
316
Type newElementType = convertedType.getElementType ();
317
- int srcBits = oldElementType.getIntOrFloatBitWidth ();
318
- int dstBits = newElementType.getIntOrFloatBitWidth ();
317
+ int oldBits = oldElementType.getIntOrFloatBitWidth ();
318
+ int newBits = newElementType.getIntOrFloatBitWidth ();
319
319
320
- if (dstBits % srcBits != 0 ) {
321
- return rewriter. notifyMatchFailure (
322
- op, " only dstBits % srcBits == 0 supported " );
320
+ // Check per-element alignment.
321
+ if (newBits % oldBits != 0 ) {
322
+ return rewriter. notifyMatchFailure ( op, " unalagined element types " );
323
323
}
324
- int scale = dstBits / srcBits ;
324
+ int scale = newBits / oldBits ;
325
325
326
326
// Adjust the number of elements to store when emulating narrow types.
327
327
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -346,7 +346,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
346
346
OpFoldResult linearizedIndices;
347
347
std::tie (std::ignore, linearizedIndices) =
348
348
memref::getLinearizedMemRefOffsetAndSize (
349
- rewriter, loc, srcBits, dstBits ,
349
+ rewriter, loc, oldBits, newBits ,
350
350
stridedMetadata.getConstifiedMixedOffset (),
351
351
stridedMetadata.getConstifiedMixedSizes (),
352
352
stridedMetadata.getConstifiedMixedStrides (),
@@ -385,15 +385,15 @@ struct ConvertVectorMaskedStore final
385
385
auto convertedType = cast<MemRefType>(adaptor.getBase ().getType ());
386
386
Type oldElementType = op.getValueToStore ().getType ().getElementType ();
387
387
Type newElementType = convertedType.getElementType ();
388
- int srcBits = oldElementType.getIntOrFloatBitWidth ();
389
- int dstBits = newElementType.getIntOrFloatBitWidth ();
388
+ int oldBits = oldElementType.getIntOrFloatBitWidth ();
389
+ int newBits = newElementType.getIntOrFloatBitWidth ();
390
390
391
- if (dstBits % srcBits != 0 ) {
392
- return rewriter. notifyMatchFailure (
393
- op, " only dstBits % srcBits == 0 supported " );
391
+ // Check per-element alignment.
392
+ if (newBits % oldBits != 0 ) {
393
+ return rewriter. notifyMatchFailure ( op, " unalagined element types " );
394
394
}
395
395
396
- int scale = dstBits / srcBits ;
396
+ int scale = newBits / oldBits ;
397
397
int origElements = op.getValueToStore ().getType ().getNumElements ();
398
398
if (origElements % scale != 0 )
399
399
return failure ();
@@ -404,7 +404,7 @@ struct ConvertVectorMaskedStore final
404
404
memref::LinearizedMemRefInfo linearizedInfo;
405
405
std::tie (linearizedInfo, linearizedIndicesOfr) =
406
406
memref::getLinearizedMemRefOffsetAndSize (
407
- rewriter, loc, srcBits, dstBits ,
407
+ rewriter, loc, oldBits, newBits ,
408
408
stridedMetadata.getConstifiedMixedOffset (),
409
409
stridedMetadata.getConstifiedMixedSizes (),
410
410
stridedMetadata.getConstifiedMixedStrides (),
@@ -493,14 +493,14 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
493
493
auto convertedType = cast<MemRefType>(adaptor.getBase ().getType ());
494
494
Type oldElementType = op.getType ().getElementType ();
495
495
Type newElementType = convertedType.getElementType ();
496
- int srcBits = oldElementType.getIntOrFloatBitWidth ();
497
- int dstBits = newElementType.getIntOrFloatBitWidth ();
496
+ int oldBits = oldElementType.getIntOrFloatBitWidth ();
497
+ int newBits = newElementType.getIntOrFloatBitWidth ();
498
498
499
- if (dstBits % srcBits != 0 ) {
500
- return rewriter. notifyMatchFailure (
501
- op, " only dstBits % srcBits == 0 supported " );
499
+ // Check per-element alignment.
500
+ if (newBits % oldBits != 0 ) {
501
+ return rewriter. notifyMatchFailure ( op, " unalagined element types " );
502
502
}
503
- int scale = dstBits / srcBits ;
503
+ int scale = newBits / oldBits ;
504
504
505
505
// Adjust the number of elements to load when emulating narrow types,
506
506
// and then cast back to the original type with vector.bitcast op.
@@ -541,7 +541,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
541
541
memref::LinearizedMemRefInfo linearizedInfo;
542
542
std::tie (linearizedInfo, linearizedIndices) =
543
543
memref::getLinearizedMemRefOffsetAndSize (
544
- rewriter, loc, srcBits, dstBits ,
544
+ rewriter, loc, oldBits, newBits ,
545
545
stridedMetadata.getConstifiedMixedOffset (),
546
546
stridedMetadata.getConstifiedMixedSizes (),
547
547
stridedMetadata.getConstifiedMixedStrides (),
@@ -596,14 +596,14 @@ struct ConvertVectorMaskedLoad final
596
596
auto convertedType = cast<MemRefType>(adaptor.getBase ().getType ());
597
597
Type oldElementType = op.getType ().getElementType ();
598
598
Type newElementType = convertedType.getElementType ();
599
- int srcBits = oldElementType.getIntOrFloatBitWidth ();
600
- int dstBits = newElementType.getIntOrFloatBitWidth ();
599
+ int oldBits = oldElementType.getIntOrFloatBitWidth ();
600
+ int newBits = newElementType.getIntOrFloatBitWidth ();
601
601
602
- if (dstBits % srcBits != 0 ) {
603
- return rewriter. notifyMatchFailure (
604
- op, " only dstBits % srcBits == 0 supported " );
602
+ // Check per-element alignment.
603
+ if (newBits % oldBits != 0 ) {
604
+ return rewriter. notifyMatchFailure ( op, " unalagined element types " );
605
605
}
606
- int scale = dstBits / srcBits ;
606
+ int scale = newBits / oldBits ;
607
607
608
608
// Adjust the number of elements to load when emulating narrow types,
609
609
// and then cast back to the original type with vector.bitcast op.
@@ -657,7 +657,7 @@ struct ConvertVectorMaskedLoad final
657
657
memref::LinearizedMemRefInfo linearizedInfo;
658
658
std::tie (linearizedInfo, linearizedIndices) =
659
659
memref::getLinearizedMemRefOffsetAndSize (
660
- rewriter, loc, srcBits, dstBits ,
660
+ rewriter, loc, oldBits, newBits ,
661
661
stridedMetadata.getConstifiedMixedOffset (),
662
662
stridedMetadata.getConstifiedMixedSizes (),
663
663
stridedMetadata.getConstifiedMixedStrides (),
@@ -758,14 +758,14 @@ struct ConvertVectorTransferRead final
758
758
auto convertedType = cast<MemRefType>(adaptor.getSource ().getType ());
759
759
Type oldElementType = op.getType ().getElementType ();
760
760
Type newElementType = convertedType.getElementType ();
761
- int srcBits = oldElementType.getIntOrFloatBitWidth ();
762
- int dstBits = newElementType.getIntOrFloatBitWidth ();
761
+ int oldBits = oldElementType.getIntOrFloatBitWidth ();
762
+ int newBits = newElementType.getIntOrFloatBitWidth ();
763
763
764
- if (dstBits % srcBits != 0 ) {
765
- return rewriter. notifyMatchFailure (
766
- op, " only dstBits % srcBits == 0 supported " );
764
+ // Check per-element alignment.
765
+ if (newBits % oldBits != 0 ) {
766
+ return rewriter. notifyMatchFailure ( op, " unalagined element types " );
767
767
}
768
- int scale = dstBits / srcBits ;
768
+ int scale = newBits / oldBits ;
769
769
770
770
auto origElements = op.getVectorType ().getNumElements ();
771
771
@@ -781,7 +781,7 @@ struct ConvertVectorTransferRead final
781
781
memref::LinearizedMemRefInfo linearizedInfo;
782
782
std::tie (linearizedInfo, linearizedIndices) =
783
783
memref::getLinearizedMemRefOffsetAndSize (
784
- rewriter, loc, srcBits, dstBits ,
784
+ rewriter, loc, oldBits, newBits ,
785
785
stridedMetadata.getConstifiedMixedOffset (),
786
786
stridedMetadata.getConstifiedMixedSizes (),
787
787
stridedMetadata.getConstifiedMixedStrides (),
0 commit comments