@@ -446,16 +446,23 @@ static Type *getResolvedVectorElementType(Type *matrixType) {
446
446
return ty->getElementType ();
447
447
}
448
448
449
- static int getSliceSize (const JointMatrixTypeDescription *desc) {
449
+ static int getSliceSize (const JointMatrixTypeDescription *desc, Type *matTy) {
450
+ IGCLLVM::FixedVectorType *ty = dyn_cast<IGCLLVM::FixedVectorType>(matTy);
451
+ IGC_ASSERT_MESSAGE (ty, " Expecting vector type in calculating slice size" );
452
+
453
+ IntegerType *vecElemType = dyn_cast<IntegerType>(ty->getElementType ());
454
+ IGC_ASSERT_MESSAGE (vecElemType, " Expecting integer type for vector element." );
455
+
456
+ unsigned contribTypeWidth = vecElemType->getBitWidth ();
450
457
if (desc->layout == LayoutRowMajor) {
451
458
return desc->rows ;
452
459
}
453
460
if (desc->bitWidth != 0 ) {
454
461
if (desc->layout == LayoutPackedA) {
455
- return desc->rows * (32 / desc->bitWidth );
462
+ return desc->rows * (contribTypeWidth / desc->bitWidth );
456
463
}
457
464
if (desc->layout == LayoutPackedB) {
458
- return 8 * (32 / desc->bitWidth );
465
+ return 8 * (contribTypeWidth / desc->bitWidth );
459
466
}
460
467
}
461
468
IGC_ASSERT_MESSAGE (true , " Unexpected matrix layout." );
@@ -511,7 +518,7 @@ Value *JointMatrixFuncsResolutionPass::ResolveFill(CallInst *CI) {
511
518
Type *matTy = ResolveType (CI->getType (), &desc);
512
519
513
520
IRBuilder builder (CI);
514
- const int sliceSize = getSliceSize (&desc);
521
+ const int sliceSize = getSliceSize (&desc, matTy );
515
522
const int vectorSize = getResolvedVectorSize (matTy);
516
523
/* Case with packing: */
517
524
if (sliceSize > vectorSize) {
@@ -523,6 +530,14 @@ Value *JointMatrixFuncsResolutionPass::ResolveFill(CallInst *CI) {
523
530
IGC_ASSERT_MESSAGE (false , " Malformed matrix slice." );
524
531
}
525
532
533
+ if (fillValue->getType ()->isPointerTy ())
534
+ {
535
+ IntegerType *vectorElementType = dyn_cast<IntegerType>(getResolvedVectorElementType (matTy));
536
+ PointerType *PT = dyn_cast<PointerType>(fillValue->getType ());
537
+ fillValue = builder.CreateBitCast (fillValue, PointerType::get (vectorElementType, PT->getAddressSpace ()));
538
+ fillValue = builder.CreateLoad (vectorElementType, fillValue);
539
+ }
540
+
526
541
Value *slice = UndefValue::get (matTy);
527
542
for (int i = 0 ; i < vectorSize; i++) {
528
543
slice = builder.CreateInsertElement (slice, fillValue, i);
@@ -534,9 +549,9 @@ Value *JointMatrixFuncsResolutionPass::ResolveFill(CallInst *CI) {
534
549
535
550
Value *JointMatrixFuncsResolutionPass::ResolveWILength (CallInst *CI) {
536
551
JointMatrixTypeDescription desc;
537
- ResolveType (CI->getArgOperand (0 )->getType (), &desc);
552
+ Type *matTy = ResolveType (CI->getArgOperand (0 )->getType (), &desc);
538
553
539
- const int sliceSize = getSliceSize (&desc);
554
+ const int sliceSize = getSliceSize (&desc, matTy );
540
555
Value *lenght = ConstantInt::get (CI->getType (), sliceSize, " matrix.slice.size" );
541
556
542
557
CI->replaceAllUsesWith (lenght);
@@ -546,8 +561,8 @@ Value *JointMatrixFuncsResolutionPass::ResolveWILength(CallInst *CI) {
546
561
547
562
template <class BuilderT >
548
563
static Value *createSliceExtract
549
- (BuilderT *builder, Value *matrix, Value *index, const JointMatrixTypeDescription *desc) {
550
- const int sliceSize = getSliceSize (desc);
564
+ (BuilderT *builder, Value *matrix, Value *index, const JointMatrixTypeDescription *desc, Type *matTy ) {
565
+ const int sliceSize = getSliceSize (desc, matTy );
551
566
const int vectorSize = getResolvedVectorSize (matrix->getType ());
552
567
/* Unpacking: */
553
568
if (sliceSize > vectorSize) {
@@ -568,12 +583,12 @@ Value *JointMatrixFuncsResolutionPass::ResolveSliceInsert(CallInst *CI) {
568
583
IGCLLVM::FixedVectorType *matTy = dyn_cast<IGCLLVM::FixedVectorType>(rawMatTy);
569
584
570
585
IRBuilder builder (CI);
571
- const int sliceSize = getSliceSize (&desc);
586
+ const int sliceSize = getSliceSize (&desc, rawMatTy );
572
587
const int vectorSize = getResolvedVectorSize (matTy);
573
588
574
589
Value *slice = nullptr ;
575
590
if (sliceSize > vectorSize) {
576
- Value *element = createSliceExtract (&builder, matrix, index, &desc);
591
+ Value *element = createSliceExtract (&builder, matrix, index, &desc, rawMatTy );
577
592
if (!isa<IntegerType>(element->getType ())) {
578
593
unsigned vecElemSize = matTy->getElementType ()->getScalarSizeInBits ();
579
594
element = builder.CreateBitCast (element, Type::getIntNTy (builder.getContext (), vecElemSize));
@@ -603,6 +618,10 @@ Value *JointMatrixFuncsResolutionPass::ResolveSliceInsert(CallInst *CI) {
603
618
component = builder.CreateShl (component, offset);
604
619
component = builder.CreateOr (element, component);
605
620
}
621
+
622
+ IntegerType *vectorElementType = dyn_cast<IntegerType>(getResolvedVectorElementType (rawMatTy));
623
+ component = builder.CreateBitCast (component, vectorElementType);
624
+
606
625
slice = builder.CreateInsertElement (matrix, component, index);
607
626
608
627
InstsToErase.insert (CI);
@@ -617,9 +636,9 @@ Value *JointMatrixFuncsResolutionPass::ResolveSliceExtract(CallInst *CI) {
617
636
Type *matTy = ResolveType (CI->getArgOperand (0 )->getType (), &desc);
618
637
619
638
IRBuilder builder (CI);
620
- Value *element = createSliceExtract (&builder, matrix, index, &desc);
639
+ Value *element = createSliceExtract (&builder, matrix, index, &desc, matTy );
621
640
/* Unpacking: */
622
- const int sliceSize = getSliceSize (&desc);
641
+ const int sliceSize = getSliceSize (&desc, matTy );
623
642
const int vectorSize = getResolvedVectorSize (matTy);
624
643
if (sliceSize > vectorSize) {
625
644
index = builder.CreateTruncOrBitCast (index, element->getType ());
@@ -634,6 +653,10 @@ Value *JointMatrixFuncsResolutionPass::ResolveSliceExtract(CallInst *CI) {
634
653
element = builder.CreateBitCast (element, CI->getType ());
635
654
}
636
655
656
+ // We need the bitcast, especially for half, as the function call that is
657
+ // being replaces has a half return type and the vectorElementType is i16
658
+ element = builder.CreateBitCast (element, CI->getType ());
659
+
637
660
CI->replaceAllUsesWith (element);
638
661
InstsToErase.insert (CI);
639
662
return element;
0 commit comments