|
7 | 7 | //
|
8 | 8 | //===----------------------------------------------------------------------===//
|
9 | 9 |
|
| 10 | +#include "mlir/Dialect/Affine/IR/AffineOps.h" |
10 | 11 | #include "mlir/Dialect/Arith/IR/Arith.h"
|
11 | 12 | #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
|
12 | 13 | #include "mlir/Dialect/Arith/Utils/Utils.h"
|
@@ -103,6 +104,172 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
|
103 | 104 | }
|
104 | 105 | };
|
105 | 106 |
|
| 107 | +//===----------------------------------------------------------------------===// |
| 108 | +// ConvertVectorMaskedLoad |
| 109 | +//===----------------------------------------------------------------------===// |
| 110 | + |
| 111 | +struct ConvertVectorMaskedLoad final |
| 112 | + : OpConversionPattern<vector::MaskedLoadOp> { |
| 113 | + using OpConversionPattern::OpConversionPattern; |
| 114 | + |
| 115 | + LogicalResult |
| 116 | + matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor, |
| 117 | + ConversionPatternRewriter &rewriter) const override { |
| 118 | + |
| 119 | + auto loc = op.getLoc(); |
| 120 | + auto convertedType = cast<MemRefType>(adaptor.getBase().getType()); |
| 121 | + Type oldElementType = op.getType().getElementType(); |
| 122 | + Type newElementType = convertedType.getElementType(); |
| 123 | + int srcBits = oldElementType.getIntOrFloatBitWidth(); |
| 124 | + int dstBits = newElementType.getIntOrFloatBitWidth(); |
| 125 | + |
| 126 | + if (dstBits % srcBits != 0) { |
| 127 | + return rewriter.notifyMatchFailure( |
| 128 | + op, "only dstBits % srcBits == 0 supported"); |
| 129 | + } |
| 130 | + int scale = dstBits / srcBits; |
| 131 | + |
| 132 | + // Adjust the number of elements to load when emulating narrow types, |
| 133 | + // and then cast back to the original type with vector.bitcast op. |
| 134 | + // For example, to emulate i4 to i8, the following op: |
| 135 | + // |
| 136 | + // %mask = vector.constant_mask [3] : vector<6xi1> |
| 137 | + // %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru : |
| 138 | + // memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4> |
| 139 | + // |
| 140 | + // can be replaced with |
| 141 | + // |
| 142 | + // %new_mask = vector.constant_mask [2] : vector<3xi1> |
| 143 | + // %new_pass_thru = vector.bitcast %pass_thru : |
| 144 | + // vector<6xi4> to vector<3xi8> |
| 145 | + // %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru : |
| 146 | + // memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8> |
| 147 | + // %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4> |
| 148 | + // |
| 149 | + // Since we are effectively loading 16 bits (2xi8) from the memref with the |
| 150 | + // new mask, while originally we only wanted to effectively load 12 bits |
| 151 | + // (3xi4) from the memref, we need to set the second half of the last i8 |
| 152 | + // that was effectively loaded (i.e. the second i8) to %pass_thru. |
| 153 | + // |
| 154 | + // %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4> |
| 155 | + // |
| 156 | + // Given these input values: |
| 157 | + // %mask = [1, 1, 1, 0, 0, 0] |
| 158 | + // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6] |
| 159 | + // %pass_thru = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC] |
| 160 | + // |
| 161 | + // we'll have: |
| 162 | + // |
| 163 | + // expected output: [0x1, 0x2, 0x3, 0xA, 0xB, 0xC] |
| 164 | + // |
| 165 | + // %new_mask = [1, 1, 0] |
| 166 | + // %new_pass_thru = [0x78, 0x9A, 0xBC] |
| 167 | + // %1 = [0x12, 0x34, 0xBC] |
| 168 | + // %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC] |
| 169 | + // %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC] |
| 170 | + // |
| 171 | + // TODO: Currently, only the even number of elements loading is supported. |
| 172 | + // To deal with the odd number of elements, one has to extract the |
| 173 | + // subvector at the proper offset after bit-casting. |
| 174 | + |
| 175 | + auto origType = op.getVectorType(); |
| 176 | + auto origElements = origType.getNumElements(); |
| 177 | + if (origElements % scale != 0) |
| 178 | + return failure(); |
| 179 | + |
| 180 | + auto stridedMetadata = |
| 181 | + rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase()); |
| 182 | + |
| 183 | + OpFoldResult linearizedIndices; |
| 184 | + std::tie(std::ignore, linearizedIndices) = |
| 185 | + memref::getLinearizedMemRefOffsetAndSize( |
| 186 | + rewriter, loc, srcBits, dstBits, |
| 187 | + stridedMetadata.getConstifiedMixedOffset(), |
| 188 | + stridedMetadata.getConstifiedMixedSizes(), |
| 189 | + stridedMetadata.getConstifiedMixedStrides(), |
| 190 | + getAsOpFoldResult(adaptor.getIndices())); |
| 191 | + |
| 192 | + auto numElements = (origElements + scale - 1) / scale; |
| 193 | + auto newType = VectorType::get(numElements, newElementType); |
| 194 | + |
| 195 | + auto maskOp = op.getMask().getDefiningOp(); |
| 196 | + SmallVector<vector::ExtractOp, 2> extractOps; |
| 197 | + // Finding the mask creation operation. |
| 198 | + while (maskOp && |
| 199 | + !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) { |
| 200 | + if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) { |
| 201 | + maskOp = extractOp.getVector().getDefiningOp(); |
| 202 | + extractOps.push_back(extractOp); |
| 203 | + } |
| 204 | + } |
| 205 | + auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp); |
| 206 | + auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp); |
| 207 | + if (!createMaskOp && !constantMaskOp) |
| 208 | + return failure(); |
| 209 | + |
| 210 | + // Computing the "compressed" mask. All the emulation logic (i.e. computing |
| 211 | + // new mask index) only happens on the last dimension of the vectors. |
| 212 | + Operation *newMask = nullptr; |
| 213 | + auto shape = llvm::to_vector( |
| 214 | + maskOp->getResultTypes()[0].cast<VectorType>().getShape().drop_back()); |
| 215 | + shape.push_back(numElements); |
| 216 | + auto newMaskType = VectorType::get(shape, rewriter.getI1Type()); |
| 217 | + if (createMaskOp) { |
| 218 | + auto maskOperands = createMaskOp.getOperands(); |
| 219 | + auto numMaskOperands = maskOperands.size(); |
| 220 | + AffineExpr s0; |
| 221 | + bindSymbols(rewriter.getContext(), s0); |
| 222 | + s0 = s0 + scale - 1; |
| 223 | + s0 = s0.floorDiv(scale); |
| 224 | + OpFoldResult origIndex = |
| 225 | + getAsOpFoldResult(maskOperands[numMaskOperands - 1]); |
| 226 | + OpFoldResult maskIndex = |
| 227 | + affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex); |
| 228 | + auto newMaskOperands = llvm::to_vector(maskOperands.drop_back()); |
| 229 | + newMaskOperands.push_back( |
| 230 | + getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex)); |
| 231 | + newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType, |
| 232 | + newMaskOperands); |
| 233 | + } else if (constantMaskOp) { |
| 234 | + auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue(); |
| 235 | + auto numMaskOperands = maskDimSizes.size(); |
| 236 | + auto origIndex = |
| 237 | + cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt(); |
| 238 | + auto maskIndex = |
| 239 | + rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale); |
| 240 | + auto newMaskDimSizes = llvm::to_vector(maskDimSizes.drop_back()); |
| 241 | + newMaskDimSizes.push_back(maskIndex); |
| 242 | + newMask = rewriter.create<vector::ConstantMaskOp>( |
| 243 | + loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes)); |
| 244 | + } |
| 245 | + |
| 246 | + while (!extractOps.empty()) { |
| 247 | + newMask = rewriter.create<vector::ExtractOp>( |
| 248 | + loc, newMask->getResults()[0], extractOps.back().getMixedPosition()); |
| 249 | + extractOps.pop_back(); |
| 250 | + } |
| 251 | + |
| 252 | + auto newPassThru = |
| 253 | + rewriter.create<vector::BitCastOp>(loc, newType, op.getPassThru()); |
| 254 | + |
| 255 | + // Generating the new masked load. |
| 256 | + auto newLoad = rewriter.create<vector::MaskedLoadOp>( |
| 257 | + loc, newType, adaptor.getBase(), |
| 258 | + getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices), |
| 259 | + newMask->getResult(0), newPassThru); |
| 260 | + |
| 261 | + // Setting the part that originally was not effectively loaded from memory |
| 262 | + // to pass through. |
| 263 | + auto bitCast = |
| 264 | + rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad); |
| 265 | + auto select = rewriter.create<arith::SelectOp>(loc, op.getMask(), bitCast, |
| 266 | + op.getPassThru()); |
| 267 | + rewriter.replaceOp(op, select->getResult(0)); |
| 268 | + |
| 269 | + return success(); |
| 270 | + } |
| 271 | +}; |
| 272 | + |
106 | 273 | //===----------------------------------------------------------------------===//
|
107 | 274 | // ConvertVectorTransferRead
|
108 | 275 | //===----------------------------------------------------------------------===//
|
@@ -588,8 +755,8 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
|
588 | 755 | RewritePatternSet &patterns) {
|
589 | 756 |
|
590 | 757 | // Populate `vector.*` conversion patterns.
|
591 |
| - patterns.add<ConvertVectorLoad, ConvertVectorTransferRead>( |
592 |
| - typeConverter, patterns.getContext()); |
| 758 | + patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, |
| 759 | + ConvertVectorTransferRead>(typeConverter, patterns.getContext()); |
593 | 760 | }
|
594 | 761 |
|
595 | 762 | void vector::populateVectorNarrowTypeRewritePatterns(
|
|
0 commit comments