@@ -75,83 +75,132 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
75
75
int numSrcElemsPerDest,
76
76
int numFrontPadElems = 0 ) {
77
77
78
- assert (numFrontPadElems < numSrcElemsPerDest && " intraDataOffset must be less than scale" );
78
+ assert (numFrontPadElems < numSrcElemsPerDest &&
79
+ " numFrontPadElems must be less than numSrcElemsPerDest" );
79
80
80
- auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1 ) /
81
+ auto numDestElems = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1 ) /
81
82
numSrcElemsPerDest;
82
83
83
84
Operation *maskOp = mask.getDefiningOp ();
84
85
SmallVector<vector::ExtractOp, 2 > extractOps;
86
+ // TODO: add support to `vector.splat`.
85
87
// Finding the mask creation operation.
86
- while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
88
+ while (maskOp &&
89
+ !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
90
+ maskOp)) {
87
91
if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
88
92
maskOp = extractOp.getVector ().getDefiningOp ();
89
93
extractOps.push_back (extractOp);
90
94
}
91
95
}
92
- auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
93
- auto constantMaskOp = dyn_cast_or_null< vector::ConstantMaskOp>(maskOp);
94
- if (!createMaskOp && !constantMaskOp )
96
+
97
+ if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
98
+ maskOp) )
95
99
return failure ();
96
100
97
101
// Computing the "compressed" mask. All the emulation logic (i.e. computing
98
102
// new mask index) only happens on the last dimension of the vectors.
99
- Operation *newMask = nullptr ;
100
- SmallVector<int64_t > shape (
103
+ SmallVector<int64_t > maskShape (
101
104
cast<VectorType>(maskOp->getResultTypes ()[0 ]).getShape ());
102
- shape.back () = numElements;
103
- auto newMaskType = VectorType::get (shape, rewriter.getI1Type ());
104
- if (createMaskOp) {
105
- OperandRange maskOperands = createMaskOp.getOperands ();
106
- size_t numMaskOperands = maskOperands.size ();
107
- AffineExpr s0;
108
- bindSymbols (rewriter.getContext (), s0);
109
- s0 = s0 + numSrcElemsPerDest - 1 ;
110
- s0 = s0.floorDiv (numSrcElemsPerDest);
111
- OpFoldResult origIndex =
112
- getAsOpFoldResult (maskOperands[numMaskOperands - 1 ]);
113
- OpFoldResult maskIndex =
114
- affine::makeComposedFoldedAffineApply (rewriter, loc, s0, origIndex);
115
- SmallVector<Value> newMaskOperands (maskOperands.drop_back ());
116
- newMaskOperands.push_back (
117
- getValueOrCreateConstantIndexOp (rewriter, loc, maskIndex));
118
- newMask = rewriter.create <vector::CreateMaskOp>(loc, newMaskType,
119
- newMaskOperands);
120
- } else if (constantMaskOp) {
121
- ArrayRef<int64_t > maskDimSizes = constantMaskOp.getMaskDimSizes ();
122
- size_t numMaskOperands = maskDimSizes.size ();
123
- int64_t origIndex = maskDimSizes[numMaskOperands - 1 ];
124
- int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
125
- int64_t maskIndex =
126
- llvm::divideCeil (numFrontPadElems + origIndex, numSrcElemsPerDest);
127
-
128
- // TODO: we only want the mask between [startIndex, maskIndex] to be true,
129
- // the rest are false.
130
- if (numFrontPadElems != 0 && maskDimSizes.size () > 1 )
131
- return failure ();
132
-
133
- SmallVector<int64_t > newMaskDimSizes (maskDimSizes.drop_back ());
134
- newMaskDimSizes.push_back (maskIndex);
135
-
136
- if (numFrontPadElems == 0 ) {
137
- newMask = rewriter.create <vector::ConstantMaskOp>(loc, newMaskType,
138
- newMaskDimSizes);
139
- } else {
140
- SmallVector<bool > newMaskValues;
141
- for (int64_t i = 0 ; i < numElements; ++i)
142
- newMaskValues.push_back (i >= startIndex && i < maskIndex);
143
- auto denseAttr = DenseElementsAttr::get (newMaskType, newMaskValues);
144
- newMask = rewriter.create <arith::ConstantOp>(loc, newMaskType, denseAttr);
145
- }
146
- }
105
+ maskShape.back () = numDestElems;
106
+ auto newMaskType = VectorType::get (maskShape, rewriter.getI1Type ());
107
+ std::optional<Operation *> newMask =
108
+ TypeSwitch<Operation *, std::optional<Operation *>>(maskOp)
109
+ .Case <vector::CreateMaskOp>(
110
+ [&](auto createMaskOp) -> std::optional<Operation *> {
111
+ OperandRange maskOperands = createMaskOp.getOperands ();
112
+ size_t numMaskOperands = maskOperands.size ();
113
+ AffineExpr s0;
114
+ bindSymbols (rewriter.getContext (), s0);
115
+ s0 = s0 + numSrcElemsPerDest - 1 ;
116
+ s0 = s0.floorDiv (numSrcElemsPerDest);
117
+ OpFoldResult origIndex =
118
+ getAsOpFoldResult (maskOperands[numMaskOperands - 1 ]);
119
+ OpFoldResult maskIndex = affine::makeComposedFoldedAffineApply (
120
+ rewriter, loc, s0, origIndex);
121
+ SmallVector<Value> newMaskOperands (maskOperands.drop_back ());
122
+ newMaskOperands.push_back (
123
+ getValueOrCreateConstantIndexOp (rewriter, loc, maskIndex));
124
+ return rewriter.create <vector::CreateMaskOp>(loc, newMaskType,
125
+ newMaskOperands);
126
+ })
127
+ .Case <vector::ConstantMaskOp>([&](auto constantMaskOp)
128
+ -> std::optional<Operation *> {
129
+ ArrayRef<int64_t > maskDimSizes = constantMaskOp.getMaskDimSizes ();
130
+ size_t numMaskOperands = maskDimSizes.size ();
131
+ int64_t origIndex = maskDimSizes[numMaskOperands - 1 ];
132
+ int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
133
+ int64_t maskIndex = llvm::divideCeil (numFrontPadElems + origIndex,
134
+ numSrcElemsPerDest);
135
+
136
+ // TODO: we only want the mask between [startIndex, maskIndex]
137
+ // to be true, the rest are false.
138
+ if (numFrontPadElems != 0 && maskDimSizes.size () > 1 )
139
+ return std::nullopt;
140
+
141
+ SmallVector<int64_t > newMaskDimSizes (maskDimSizes.drop_back ());
142
+ newMaskDimSizes.push_back (maskIndex);
143
+
144
+ if (numFrontPadElems == 0 )
145
+ return rewriter.create <vector::ConstantMaskOp>(loc, newMaskType,
146
+ newMaskDimSizes);
147
+
148
+ SmallVector<bool > newMaskValues;
149
+ for (int64_t i = 0 ; i < numDestElems; ++i)
150
+ newMaskValues.push_back (i >= startIndex && i < maskIndex);
151
+ auto denseAttr = DenseElementsAttr::get (newMaskType, newMaskValues);
152
+ return rewriter.create <arith::ConstantOp>(loc, newMaskType,
153
+ denseAttr);
154
+ })
155
+ .Case <arith::ConstantOp>([&](auto constantOp)
156
+ -> std::optional<Operation *> {
157
+ // TODO: Support multiple dimensions.
158
+ if (maskShape.size () != 1 )
159
+ return std::nullopt;
160
+ // Rearrange the original mask values to cover the whole potential
161
+ // loading region. For example, in the case of using byte-size for
162
+ // emulation, given the following mask:
163
+ //
164
+ // %mask = [0, 1, 0, 1, 0, 0]
165
+ //
166
+ // With front offset of 1, the mask will be padded 0s in the front
167
+ // and back so that:
168
+ // 1. It is aligned with the effective loading bits
169
+ // 2. Its length is multiple of `numSrcElemPerDest` (and the total
170
+ // coverage size is mulitiple of bytes). The new mask will be like
171
+ // this before compressing:
172
+ //
173
+ // %new_mask = [0, 0, 1, 0, 1, 0, 0, 0]
174
+ auto denseAttr =
175
+ cast<DenseIntElementsAttr>(constantOp.getValue ());
176
+ SmallVector<bool > paddedMaskValues (numFrontPadElems, false );
177
+ paddedMaskValues.append (denseAttr.template value_begin <bool >(),
178
+ denseAttr.template value_end <bool >());
179
+ paddedMaskValues.resize (numDestElems * numSrcElemsPerDest, false );
180
+
181
+ // Compressing by combining every `numSrcElemsPerDest` elements:
182
+ SmallVector<bool > compressedMaskValues;
183
+ for (size_t i = 0 ; i < paddedMaskValues.size (); i += numSrcElemsPerDest) {
184
+ bool combinedValue = false ;
185
+ for (int j = 0 ; j < numSrcElemsPerDest; ++j) {
186
+ combinedValue |= paddedMaskValues[i + j];
187
+ }
188
+ compressedMaskValues.push_back (combinedValue);
189
+ }
190
+ return rewriter.create <arith::ConstantOp>(
191
+ loc, DenseElementsAttr::get (newMaskType, compressedMaskValues));
192
+ });
193
+
194
+ if (!newMask)
195
+ return failure ();
147
196
148
197
while (!extractOps.empty ()) {
149
198
newMask = rewriter.create <vector::ExtractOp>(
150
- loc, newMask->getResults ()[0 ], extractOps.back ().getMixedPosition ());
199
+ loc, (* newMask) ->getResults ()[0 ], extractOps.back ().getMixedPosition ());
151
200
extractOps.pop_back ();
152
201
}
153
202
154
- return newMask;
203
+ return * newMask;
155
204
}
156
205
157
206
// / Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
@@ -185,12 +234,10 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
185
234
// / `vector.insert_strided_slice`.
186
235
static Value staticallyInsertSubvector (OpBuilder &rewriter, Location loc,
187
236
Value src, Value dest, int64_t offset) {
188
- auto srcType = cast<VectorType>(src.getType ());
189
- auto destType = cast<VectorType>(dest.getType ());
237
+ [[maybe_unused]] auto srcType = cast<VectorType>(src.getType ());
238
+ [[maybe_unused]] auto destType = cast<VectorType>(dest.getType ());
190
239
assert (srcType.getRank () == 1 && destType.getRank () == 1 &&
191
240
" expected source and dest to be vector type" );
192
- (void )srcType;
193
- (void )destType;
194
241
auto offsets = rewriter.getI64ArrayAttr ({offset});
195
242
auto strides = rewriter.getI64ArrayAttr ({1 });
196
243
return rewriter.create <vector::InsertStridedSliceOp>(loc, dest.getType (), src,
0 commit comments