@@ -819,9 +819,15 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
819
819
return isa<vector::TransferReadOp>(op) && op->hasOneUse ();
820
820
});
821
821
if (!operand)
822
- return failure ();
822
+ return rewriter.notifyMatchFailure (
823
+ warpOp, " warp result is not a vector.transfer_read op" );
823
824
auto read = operand->get ().getDefiningOp <vector::TransferReadOp>();
824
825
826
+ // Source must be defined outside of the region.
827
+ if (!warpOp.isDefinedOutsideOfRegion (read.getSource ()))
828
+ return rewriter.notifyMatchFailure (
829
+ read, " source must be defined outside of the region" );
830
+
825
831
unsigned operandIndex = operand->getOperandNumber ();
826
832
Value distributedVal = warpOp.getResult (operandIndex);
827
833
@@ -832,10 +838,25 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
832
838
AffineMap map = calculateImplicitMap (sequentialType, distributedType);
833
839
AffineMap indexMap = map.compose (read.getPermutationMap ());
834
840
835
- // Distribute the mask if present.
841
+ // Try to delinearize the lane ID to match the rank expected for
842
+ // distribution.
843
+ SmallVector<Value> delinearizedIds;
844
+ if (!delinearizeLaneId (rewriter, read.getLoc (), sequentialType.getShape (),
845
+ distributedType.getShape (), warpOp.getWarpSize (),
846
+ warpOp.getLaneid (), delinearizedIds)) {
847
+ return rewriter.notifyMatchFailure (
848
+ read, " cannot delinearize lane ID for distribution" );
849
+ }
850
+ assert (!delinearizedIds.empty () || map.getNumResults () == 0 );
851
+
852
+ // Distribute indices and the mask (if present).
836
853
OpBuilder::InsertionGuard g (rewriter);
837
- WarpExecuteOnLane0Op newWarpOp = warpOp;
838
- Value newMask = read.getMask ();
854
+ SmallVector<Value> additionalResults (indices.begin (), indices.end ());
855
+ SmallVector<Type> additionalResultTypes (indices.size (),
856
+ rewriter.getIndexType ());
857
+ additionalResults.push_back (read.getPadding ());
858
+ additionalResultTypes.push_back (read.getPadding ().getType ());
859
+
839
860
bool hasMask = false ;
840
861
if (read.getMask ()) {
841
862
hasMask = true ;
@@ -846,42 +867,26 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
846
867
// by shape information on the warp op, and thus requires materializing
847
868
// the permutation in IR.
848
869
if (!mlir::compressUnusedDims (read.getPermutationMap ()).isIdentity ())
849
- return failure ();
870
+ return rewriter.notifyMatchFailure (
871
+ read, " non-trivial permutation maps not supported" );
850
872
VectorType maskType =
851
873
getDistributedType (read.getMaskType (), map, warpOp.getWarpSize ());
852
- SmallVector<size_t > newRetIndices;
853
- newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
854
- rewriter, warpOp, ValueRange{read.getMask ()}, TypeRange{maskType},
855
- newRetIndices);
856
- newMask = newWarpOp.getResult (newRetIndices[0 ]);
857
- distributedVal = newWarpOp.getResult (operandIndex);
858
- } else {
859
- // This pattern does not actually change the warp op directly. Instead it
860
- // just rewrites a new transfer read (when not masked) outside of the warp
861
- // op and replaces the correponding result. There are then follow up
862
- // patterns to erase now dead results of the warp op. This erasure allows
863
- // propagation to continue, but this pattern on its own never actually
864
- // tells the pattern rewriter that the warp op "changed." Notify the
865
- // rewriter here that the warp op is changing. Similar situations are
866
- // noted in following patterns.
867
- rewriter.startRootUpdate (warpOp);
874
+ additionalResults.push_back (read.getMask ());
875
+ additionalResultTypes.push_back (maskType);
868
876
}
869
877
870
- rewriter.setInsertionPointAfter (newWarpOp);
878
+ SmallVector<size_t > newRetIndices;
879
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
880
+ rewriter, warpOp, additionalResults, additionalResultTypes,
881
+ newRetIndices);
882
+ distributedVal = newWarpOp.getResult (operandIndex);
871
883
872
- // Try to delinearize the lane ID to match the rank expected for
873
- // distribution.
874
- SmallVector<Value> delinearizedIds;
875
- if (!delinearizeLaneId (rewriter, read.getLoc (), sequentialType.getShape (),
876
- distributedType.getShape (), newWarpOp.getWarpSize (),
877
- newWarpOp.getLaneid (), delinearizedIds)) {
878
- if (!hasMask)
879
- rewriter.cancelRootUpdate (warpOp);
880
- return rewriter.notifyMatchFailure (
881
- read, " cannot delinearize lane ID for distribution" );
882
- }
883
- assert (!delinearizedIds.empty () || map.getNumResults () == 0 );
884
+ // Distributed indices were appended first.
885
+ SmallVector<Value> newIndices;
886
+ for (int64_t i = 0 , e = indices.size (); i < e; ++i)
887
+ newIndices.push_back (newWarpOp.getResult (newRetIndices[i]));
884
888
889
+ rewriter.setInsertionPointAfter (newWarpOp);
885
890
for (auto it : llvm::zip_equal (indexMap.getResults (), map.getResults ())) {
886
891
AffineExpr d0, d1;
887
892
bindDims (read.getContext (), d0, d1);
@@ -891,42 +896,23 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
891
896
unsigned indexPos = indexExpr.getPosition ();
892
897
unsigned vectorPos = cast<AffineDimExpr>(std::get<1 >(it)).getPosition ();
893
898
int64_t scale = distributedType.getDimSize (vectorPos);
894
- indices [indexPos] = affine::makeComposedAffineApply (
899
+ newIndices [indexPos] = affine::makeComposedAffineApply (
895
900
rewriter, read.getLoc (), d0 + scale * d1,
896
- {indices [indexPos], delinearizedIds[vectorPos]});
901
+ {newIndices [indexPos], delinearizedIds[vectorPos]});
897
902
}
903
+
904
+ // Distributed padding value was appended right after the indices.
905
+ Value newPadding = newWarpOp.getResult (newRetIndices[indices.size ()]);
906
+ // Distributed mask value was added at the end (if the op has a mask).
907
+ Value newMask =
908
+ hasMask ? newWarpOp.getResult (newRetIndices[newRetIndices.size () - 1 ])
909
+ : Value ();
898
910
auto newRead = rewriter.create <vector::TransferReadOp>(
899
- read.getLoc (), distributedVal.getType (), read.getSource (), indices ,
900
- read.getPermutationMapAttr (), read. getPadding () , newMask,
911
+ read.getLoc (), distributedVal.getType (), read.getSource (), newIndices ,
912
+ read.getPermutationMapAttr (), newPadding , newMask,
901
913
read.getInBoundsAttr ());
902
914
903
- // Check that the produced operation is legal.
904
- // The transfer op may be reading from values that are defined within
905
- // warpOp's body, which is illegal.
906
- // We do the check late because incdices may be changed by
907
- // makeComposeAffineApply. This rewrite may remove dependencies from
908
- // warpOp's body.
909
- // E.g., warpop {
910
- // %idx = affine.apply...[%outsideDef]
911
- // ... = transfer_read ...[%idx]
912
- // }
913
- // will be rewritten in:
914
- // warpop {
915
- // }
916
- // %new_idx = affine.apply...[%outsideDef]
917
- // ... = transfer_read ...[%new_idx]
918
- if (!llvm::all_of (newRead->getOperands (), [&](Value value) {
919
- return (newRead.getMask () && value == newRead.getMask ()) ||
920
- newWarpOp.isDefinedOutsideOfRegion (value);
921
- })) {
922
- if (!hasMask)
923
- rewriter.cancelRootUpdate (warpOp);
924
- return failure ();
925
- }
926
-
927
915
rewriter.replaceAllUsesWith (distributedVal, newRead);
928
- if (!hasMask)
929
- rewriter.finalizeRootUpdate (warpOp);
930
916
return success ();
931
917
}
932
918
};
@@ -1315,6 +1301,12 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1315
1301
unsigned int operandNumber = operand->getOperandNumber ();
1316
1302
auto extractOp = operand->get ().getDefiningOp <vector::ExtractElementOp>();
1317
1303
VectorType extractSrcType = extractOp.getSourceVectorType ();
1304
+ // TODO: Supported shuffle types should be parameterizable, similar to
1305
+ // `WarpShuffleFromIdxFn`.
1306
+ if (!extractSrcType.getElementType ().isF32 () &&
1307
+ !extractSrcType.getElementType ().isInteger (32 ))
1308
+ return rewriter.notifyMatchFailure (
1309
+ extractOp, " only f32/i32 element types are supported" );
1318
1310
bool is0dOrVec1Extract = extractSrcType.getNumElements () == 1 ;
1319
1311
Type elType = extractSrcType.getElementType ();
1320
1312
VectorType distributedVecType;
0 commit comments