@@ -600,30 +600,33 @@ static Value createLinalgBodyCalculationForElementwiseOp(
600
600
static Value expandRank (PatternRewriter &rewriter, Location loc, Value tensor,
601
601
int64_t rank) {
602
602
// No need to expand if we are already at the desired rank
603
- auto shapedType = dyn_cast<ShapedType>(tensor.getType ());
604
- assert (shapedType && shapedType.hasRank () && " expected a ranked shaped type" );
605
- int64_t numExtraDims = rank - shapedType.getRank ();
603
+ auto tensorType = dyn_cast<RankedTensorType>(tensor.getType ());
604
+ assert (tensorType && " expected a ranked tensor type" );
605
+ int64_t tensorRank = tensorType.getRank ();
606
+ int64_t numExtraDims = rank - tensorRank;
606
607
assert (numExtraDims >= 0 && " cannot expand tensor to a lower rank" );
607
608
if (!numExtraDims)
608
609
return tensor;
609
610
610
611
// Compute reassociation indices
611
- SmallVector<SmallVector<int64_t , 2 >> reassociationIndices (
612
- shapedType.getRank ());
612
+ SmallVector<ReassociationIndices> reassociationIndices (tensorRank);
613
613
int64_t index = 0 ;
614
- for (index = 0 ; index <= numExtraDims; index++)
615
- reassociationIndices[0 ].push_back (index);
616
- for (size_t position = 1 ; position < reassociationIndices.size (); position++)
617
- reassociationIndices[position].push_back (index++);
614
+ if (tensorRank != 0 ) {
615
+ for (index = 0 ; index <= numExtraDims; index++)
616
+ reassociationIndices[0 ].push_back (index);
617
+ for (size_t position = 1 ; position < reassociationIndices.size ();
618
+ position++)
619
+ reassociationIndices[position].push_back (index++);
620
+ }
618
621
619
622
// Compute result type
620
623
SmallVector<int64_t > resultShape;
621
624
for (index = 0 ; index < numExtraDims; index++)
622
625
resultShape.push_back (1 );
623
- for (auto size : shapedType .getShape ())
626
+ for (auto size : tensorType .getShape ())
624
627
resultShape.push_back (size);
625
628
auto resultType =
626
- RankedTensorType::get (resultShape, shapedType .getElementType ());
629
+ RankedTensorType::get (resultShape, tensorType .getElementType ());
627
630
628
631
// Emit 'tensor.expand_shape' op
629
632
return rewriter.create <tensor::ExpandShapeOp>(loc, resultType, tensor,
0 commit comments