@@ -192,6 +192,109 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
192
192
return VecStart;
193
193
}
194
194
195
+ namespace {
196
+ struct ShapeInfo {
197
+ unsigned NumRows;
198
+ unsigned NumColumns;
199
+
200
+ bool IsColumnMajor;
201
+
202
+ ShapeInfo (unsigned NumRows = 0 , unsigned NumColumns = 0 )
203
+ : NumRows(NumRows), NumColumns(NumColumns),
204
+ IsColumnMajor (MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
205
+
206
+ ShapeInfo (Value *NumRows, Value *NumColumns)
207
+ : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
208
+ cast<ConstantInt>(NumColumns)->getZExtValue()) {}
209
+
210
+ bool operator ==(const ShapeInfo &other) {
211
+ return NumRows == other.NumRows && NumColumns == other.NumColumns ;
212
+ }
213
+ bool operator !=(const ShapeInfo &other) { return !(*this == other); }
214
+
215
+ // / Returns true if shape-information is defined, meaning both dimensions
216
+ // / are != 0.
217
+ operator bool () const {
218
+ assert (NumRows == 0 || NumColumns != 0 );
219
+ return NumRows != 0 ;
220
+ }
221
+
222
+ unsigned getStride () const {
223
+ if (IsColumnMajor)
224
+ return NumRows;
225
+ return NumColumns;
226
+ }
227
+
228
+ unsigned getNumVectors () const {
229
+ if (IsColumnMajor)
230
+ return NumColumns;
231
+ return NumRows;
232
+ }
233
+
234
+ // / Returns the transposed shape.
235
+ ShapeInfo t () const { return ShapeInfo (NumColumns, NumRows); }
236
+ };
237
+ } // namespace
238
+
239
+ static bool isUniformShape (Value *V) {
240
+ Instruction *I = dyn_cast<Instruction>(V);
241
+ if (!I)
242
+ return true ;
243
+
244
+ switch (I->getOpcode ()) {
245
+ case Instruction::FAdd:
246
+ case Instruction::FSub:
247
+ case Instruction::FMul: // Scalar multiply.
248
+ case Instruction::FNeg:
249
+ case Instruction::Add:
250
+ case Instruction::Mul:
251
+ case Instruction::Sub:
252
+ return true ;
253
+ default :
254
+ return false ;
255
+ }
256
+ }
257
+
258
+ // / Return the ShapeInfo for the result of \p I, it it can be determined.
259
+ static std::optional<ShapeInfo>
260
+ computeShapeInfoForInst (Instruction *I,
261
+ const ValueMap<Value *, ShapeInfo> &ShapeMap) {
262
+ Value *M;
263
+ Value *N;
264
+ Value *K;
265
+ if (match (I, m_Intrinsic<Intrinsic::matrix_multiply>(
266
+ m_Value (), m_Value (), m_Value (M), m_Value (N), m_Value (K))))
267
+ return ShapeInfo (M, K);
268
+ if (match (I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value (), m_Value (M),
269
+ m_Value (N)))) {
270
+ // Flip dimensions.
271
+ return ShapeInfo (N, M);
272
+ }
273
+ if (match (I, m_Intrinsic<Intrinsic::matrix_column_major_store>(
274
+ m_Value (), m_Value (), m_Value (), m_Value (), m_Value (M),
275
+ m_Value (N))))
276
+ return ShapeInfo (N, M);
277
+ if (match (I, m_Intrinsic<Intrinsic::matrix_column_major_load>(
278
+ m_Value (), m_Value (), m_Value (), m_Value (M), m_Value (N))))
279
+ return ShapeInfo (M, N);
280
+ Value *MatrixA;
281
+ if (match (I, m_Store (m_Value (MatrixA), m_Value ()))) {
282
+ auto OpShape = ShapeMap.find (MatrixA);
283
+ if (OpShape != ShapeMap.end ())
284
+ return OpShape->second ;
285
+ }
286
+
287
+ if (isUniformShape (I)) {
288
+ // Find the first operand that has a known shape and use that.
289
+ for (auto &Op : I->operands ()) {
290
+ auto OpShape = ShapeMap.find (Op.get ());
291
+ if (OpShape != ShapeMap.end ())
292
+ return OpShape->second ;
293
+ }
294
+ }
295
+ return std::nullopt;
296
+ }
297
+
195
298
// / LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
196
299
// /
197
300
// / Currently, the lowering for each matrix intrinsic is done as follows:
@@ -383,48 +486,6 @@ class LowerMatrixIntrinsics {
383
486
}
384
487
};
385
488
386
- struct ShapeInfo {
387
- unsigned NumRows;
388
- unsigned NumColumns;
389
-
390
- bool IsColumnMajor;
391
-
392
- ShapeInfo (unsigned NumRows = 0 , unsigned NumColumns = 0 )
393
- : NumRows(NumRows), NumColumns(NumColumns),
394
- IsColumnMajor (MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
395
-
396
- ShapeInfo (Value *NumRows, Value *NumColumns)
397
- : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
398
- cast<ConstantInt>(NumColumns)->getZExtValue()) {}
399
-
400
- bool operator ==(const ShapeInfo &other) {
401
- return NumRows == other.NumRows && NumColumns == other.NumColumns ;
402
- }
403
- bool operator !=(const ShapeInfo &other) { return !(*this == other); }
404
-
405
- // / Returns true if shape-information is defined, meaning both dimensions
406
- // / are != 0.
407
- operator bool () const {
408
- assert (NumRows == 0 || NumColumns != 0 );
409
- return NumRows != 0 ;
410
- }
411
-
412
- unsigned getStride () const {
413
- if (IsColumnMajor)
414
- return NumRows;
415
- return NumColumns;
416
- }
417
-
418
- unsigned getNumVectors () const {
419
- if (IsColumnMajor)
420
- return NumColumns;
421
- return NumRows;
422
- }
423
-
424
- // / Returns the transposed shape.
425
- ShapeInfo t () const { return ShapeInfo (NumColumns, NumRows); }
426
- };
427
-
428
489
// / Maps instructions to their shape information. The shape information
429
490
// / describes the shape to be used while lowering. This matches the shape of
430
491
// / the result value of the instruction, with the only exceptions being store
@@ -554,25 +615,6 @@ class LowerMatrixIntrinsics {
554
615
return true ;
555
616
}
556
617
557
- bool isUniformShape (Value *V) {
558
- Instruction *I = dyn_cast<Instruction>(V);
559
- if (!I)
560
- return true ;
561
-
562
- switch (I->getOpcode ()) {
563
- case Instruction::FAdd:
564
- case Instruction::FSub:
565
- case Instruction::FMul: // Scalar multiply.
566
- case Instruction::FNeg:
567
- case Instruction::Add:
568
- case Instruction::Mul:
569
- case Instruction::Sub:
570
- return true ;
571
- default :
572
- return false ;
573
- }
574
- }
575
-
576
618
// / Returns true if shape information can be used for \p V. The supported
577
619
// / instructions must match the instructions that can be lowered by this pass.
578
620
bool supportsShapeInfo (Value *V) {
@@ -610,43 +652,8 @@ class LowerMatrixIntrinsics {
610
652
611
653
// New entry, set the value and insert operands
612
654
bool Propagate = false ;
613
-
614
- Value *MatrixA;
615
- Value *MatrixB;
616
- Value *M;
617
- Value *N;
618
- Value *K;
619
- if (match (Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
620
- m_Value (MatrixA), m_Value (MatrixB), m_Value (M),
621
- m_Value (N), m_Value (K)))) {
622
- Propagate = setShapeInfo (Inst, {M, K});
623
- } else if (match (Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
624
- m_Value (MatrixA), m_Value (M), m_Value (N)))) {
625
- // Flip dimensions.
626
- Propagate = setShapeInfo (Inst, {N, M});
627
- } else if (match (Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
628
- m_Value (MatrixA), m_Value (), m_Value (),
629
- m_Value (), m_Value (M), m_Value (N)))) {
630
- Propagate = setShapeInfo (Inst, {N, M});
631
- } else if (match (Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
632
- m_Value (), m_Value (), m_Value (), m_Value (M),
633
- m_Value (N)))) {
634
- Propagate = setShapeInfo (Inst, {M, N});
635
- } else if (match (Inst, m_Store (m_Value (MatrixA), m_Value ()))) {
636
- auto OpShape = ShapeMap.find (MatrixA);
637
- if (OpShape != ShapeMap.end ())
638
- setShapeInfo (Inst, OpShape->second );
639
- continue ;
640
- } else if (isUniformShape (Inst)) {
641
- // Find the first operand that has a known shape and use that.
642
- for (auto &Op : Inst->operands ()) {
643
- auto OpShape = ShapeMap.find (Op.get ());
644
- if (OpShape != ShapeMap.end ()) {
645
- Propagate |= setShapeInfo (Inst, OpShape->second );
646
- break ;
647
- }
648
- }
649
- }
655
+ if (auto SI = computeShapeInfoForInst (Inst, ShapeMap))
656
+ Propagate = setShapeInfo (Inst, *SI);
650
657
651
658
if (Propagate) {
652
659
NewWorkList.push_back (Inst);
0 commit comments