@@ -279,6 +279,121 @@ struct CastAwayTransferWriteLeadingOneDim
279
279
}
280
280
};
281
281
282
+ } // namespace
283
+
284
+ LogicalResult
285
+ mlir::vector::castAwayContractionLeadingOneDim (vector::ContractionOp contractOp,
286
+ RewriterBase &rewriter) {
287
+ VectorType oldAccType = contractOp.getAccType ().dyn_cast <VectorType>();
288
+ if (oldAccType == nullptr )
289
+ return failure ();
290
+ if (oldAccType.getRank () < 2 )
291
+ return failure ();
292
+ if (oldAccType.getShape ()[0 ] != 1 )
293
+ return failure ();
294
+ // currently we support only dropping one dim but the pattern can be applied
295
+ // greedily to drop more.
296
+ int64_t dropDim = 1 ;
297
+
298
+ auto oldIndexingMaps = contractOp.getIndexingMapsArray ();
299
+ SmallVector<AffineMap> newIndexingMaps;
300
+
301
+ auto oldIteratorTypes = contractOp.getIteratorTypes ();
302
+ SmallVector<Attribute> newIteratorTypes;
303
+
304
+ int64_t dimToDrop = oldIndexingMaps[2 ].getDimPosition (0 );
305
+
306
+ if (!isParallelIterator (oldIteratorTypes[dimToDrop]))
307
+ // only parallel type iterators can be dropped.
308
+ return failure ();
309
+
310
+ for (const auto &it : llvm::enumerate (oldIteratorTypes)) {
311
+ int64_t currDim = it.index ();
312
+ if (currDim == dimToDrop)
313
+ continue ;
314
+ newIteratorTypes.push_back (it.value ());
315
+ }
316
+
317
+ SmallVector<Value> operands = {contractOp.getLhs (), contractOp.getRhs (),
318
+ contractOp.getAcc ()};
319
+ SmallVector<Value> newOperands;
320
+
321
+ for (const auto &it : llvm::enumerate (oldIndexingMaps)) {
322
+ // Check if the dim to be dropped exists as a leading dim in the operand
323
+ // if it does then we use vector.extract to drop it.
324
+ bool validExtract = false ;
325
+ SmallVector<AffineExpr> results;
326
+ auto map = it.value ();
327
+ int64_t orginalZeroDim = it.value ().getDimPosition (0 );
328
+ if (orginalZeroDim != dimToDrop) {
329
+ // There are two reasons to be in this path, 1. We need to
330
+ // tranpose the operand to make the dim to be dropped
331
+ // leading. 2. The dim to be dropped does not exist and in
332
+ // that case we dont want to add a unit tranpose but we must
333
+ // check all the indices to make sure this is the case.
334
+ bool tranposeNeeded = false ;
335
+ SmallVector<int64_t > perm;
336
+ SmallVector<AffineExpr> transposeResults;
337
+
338
+ for (int64_t i = 0 , e = map.getNumResults (); i < e; ++i) {
339
+ int64_t currDim = map.getDimPosition (i);
340
+ if (currDim == dimToDrop) {
341
+ tranposeNeeded = true ;
342
+ perm.insert (perm.begin (), i);
343
+ auto targetExpr = rewriter.getAffineDimExpr (currDim);
344
+ transposeResults.insert (transposeResults.begin (), targetExpr);
345
+ } else {
346
+ perm.push_back (i);
347
+ auto targetExpr = rewriter.getAffineDimExpr (currDim);
348
+ transposeResults.push_back (targetExpr);
349
+ }
350
+ }
351
+ // Do the tranpose now if needed so that we can drop the
352
+ // correct dim using extract later.
353
+ if (tranposeNeeded) {
354
+ map = AffineMap::get (map.getNumDims (), 0 , transposeResults,
355
+ contractOp.getContext ());
356
+ operands[it.index ()] = rewriter.create <vector::TransposeOp>(
357
+ contractOp.getLoc (), operands[it.index ()], perm);
358
+ }
359
+ }
360
+ // We have taken care to have the dim to be dropped be
361
+ // the leading dim. If its still not leading that means it
362
+ // does not exist in this operand and hence we do not need
363
+ // an extract.
364
+ if (map.getDimPosition (0 ) == dimToDrop)
365
+ validExtract = true ;
366
+
367
+ for (int64_t i = 0 , e = map.getNumResults (); i < e; ++i) {
368
+ int64_t currDim = map.getDimPosition (i);
369
+ if (currDim == dimToDrop)
370
+ // This is the dim we are dropping.
371
+ continue ;
372
+ auto targetExpr = rewriter.getAffineDimExpr (
373
+ currDim < dimToDrop ? currDim : currDim - 1 );
374
+ results.push_back (targetExpr);
375
+ }
376
+ newIndexingMaps.push_back (AffineMap::get (map.getNumDims () - 1 , 0 , results,
377
+ contractOp.getContext ()));
378
+ // Extract if its a valid extraction, otherwise use the operand
379
+ // without extraction.
380
+ newOperands.push_back (
381
+ validExtract ? rewriter.create <vector::ExtractOp>(contractOp.getLoc (),
382
+ operands[it.index ()],
383
+ splatZero (dropDim))
384
+ : operands[it.index ()]);
385
+ }
386
+ auto newContractOp = rewriter.create <vector::ContractionOp>(
387
+ contractOp.getLoc (), newOperands[0 ], newOperands[1 ], newOperands[2 ],
388
+ rewriter.getAffineMapArrayAttr (newIndexingMaps),
389
+ rewriter.getArrayAttr (newIteratorTypes), contractOp.getKind ());
390
+ rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
391
+ contractOp, contractOp->getResultTypes ()[0 ], newContractOp);
392
+ return success ();
393
+ }
394
+
395
+ namespace {
396
+
282
397
// / Turns vector.contract on vector with leading 1 dimensions into
283
398
// / vector.extract followed by vector.contract on vector without leading
284
399
// / 1 dimensions. Also performs tranpose of lhs and rhs operands if required
@@ -289,112 +404,7 @@ struct CastAwayContractionLeadingOneDim
289
404
290
405
LogicalResult matchAndRewrite (vector::ContractionOp contractOp,
291
406
PatternRewriter &rewriter) const override {
292
- VectorType oldAccType = contractOp.getAccType ().dyn_cast <VectorType>();
293
- if (oldAccType == nullptr )
294
- return failure ();
295
- if (oldAccType.getRank () < 2 )
296
- return failure ();
297
- if (oldAccType.getShape ()[0 ] != 1 )
298
- return failure ();
299
- // currently we support only dropping one dim but the pattern can be applied
300
- // greedily to drop more.
301
- int64_t dropDim = 1 ;
302
-
303
- auto oldIndexingMaps = contractOp.getIndexingMapsArray ();
304
- SmallVector<AffineMap> newIndexingMaps;
305
-
306
- auto oldIteratorTypes = contractOp.getIteratorTypes ();
307
- SmallVector<Attribute> newIteratorTypes;
308
-
309
- int64_t dimToDrop = oldIndexingMaps[2 ].getDimPosition (0 );
310
-
311
- if (!isParallelIterator (oldIteratorTypes[dimToDrop]))
312
- // only parallel type iterators can be dropped.
313
- return failure ();
314
-
315
- for (const auto &it : llvm::enumerate (oldIteratorTypes)) {
316
- int64_t currDim = it.index ();
317
- if (currDim == dimToDrop)
318
- continue ;
319
- newIteratorTypes.push_back (it.value ());
320
- }
321
-
322
- SmallVector<Value> operands = {contractOp.getLhs (), contractOp.getRhs (),
323
- contractOp.getAcc ()};
324
- SmallVector<Value> newOperands;
325
-
326
- for (const auto &it : llvm::enumerate (oldIndexingMaps)) {
327
- // Check if the dim to be dropped exists as a leading dim in the operand
328
- // if it does then we use vector.extract to drop it.
329
- bool validExtract = false ;
330
- SmallVector<AffineExpr> results;
331
- auto map = it.value ();
332
- int64_t orginalZeroDim = it.value ().getDimPosition (0 );
333
- if (orginalZeroDim != dimToDrop) {
334
- // There are two reasons to be in this path, 1. We need to
335
- // tranpose the operand to make the dim to be dropped
336
- // leading. 2. The dim to be dropped does not exist and in
337
- // that case we dont want to add a unit tranpose but we must
338
- // check all the indices to make sure this is the case.
339
- bool tranposeNeeded = false ;
340
- SmallVector<int64_t > perm;
341
- SmallVector<AffineExpr> transposeResults;
342
-
343
- for (int64_t i = 0 , e = map.getNumResults (); i < e; ++i) {
344
- int64_t currDim = map.getDimPosition (i);
345
- if (currDim == dimToDrop) {
346
- tranposeNeeded = true ;
347
- perm.insert (perm.begin (), i);
348
- auto targetExpr = rewriter.getAffineDimExpr (currDim);
349
- transposeResults.insert (transposeResults.begin (), targetExpr);
350
- } else {
351
- perm.push_back (i);
352
- auto targetExpr = rewriter.getAffineDimExpr (currDim);
353
- transposeResults.push_back (targetExpr);
354
- }
355
- }
356
- // Do the tranpose now if needed so that we can drop the
357
- // correct dim using extract later.
358
- if (tranposeNeeded) {
359
- map = AffineMap::get (map.getNumDims (), 0 , transposeResults,
360
- contractOp.getContext ());
361
- operands[it.index ()] = rewriter.create <vector::TransposeOp>(
362
- contractOp.getLoc (), operands[it.index ()], perm);
363
- }
364
- }
365
- // We have taken care to have the dim to be dropped be
366
- // the leading dim. If its still not leading that means it
367
- // does not exist in this operand and hence we do not need
368
- // an extract.
369
- if (map.getDimPosition (0 ) == dimToDrop)
370
- validExtract = true ;
371
-
372
- for (int64_t i = 0 , e = map.getNumResults (); i < e; ++i) {
373
- int64_t currDim = map.getDimPosition (i);
374
- if (currDim == dimToDrop)
375
- // This is the dim we are dropping.
376
- continue ;
377
- auto targetExpr = rewriter.getAffineDimExpr (
378
- currDim < dimToDrop ? currDim : currDim - 1 );
379
- results.push_back (targetExpr);
380
- }
381
- newIndexingMaps.push_back (AffineMap::get (map.getNumDims () - 1 , 0 , results,
382
- contractOp.getContext ()));
383
- // Extract if its a valid extraction, otherwise use the operand
384
- // without extraction.
385
- newOperands.push_back (validExtract
386
- ? rewriter.create <vector::ExtractOp>(
387
- contractOp.getLoc (), operands[it.index ()],
388
- splatZero (dropDim))
389
- : operands[it.index ()]);
390
- }
391
- auto newContractOp = rewriter.create <vector::ContractionOp>(
392
- contractOp.getLoc (), newOperands[0 ], newOperands[1 ], newOperands[2 ],
393
- rewriter.getAffineMapArrayAttr (newIndexingMaps),
394
- rewriter.getArrayAttr (newIteratorTypes), contractOp.getKind ());
395
- rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
396
- contractOp, contractOp->getResultTypes ()[0 ], newContractOp);
397
- return success ();
407
+ return castAwayContractionLeadingOneDim (contractOp, rewriter);
398
408
}
399
409
};
400
410
0 commit comments