25
25
namespace mlir {
26
26
namespace xegpu {
27
27
28
- const int MAX_2D_BLOCK_WIDTH_IN_ELEMENTS = 64 ;
29
- const int MIN_2D_BLOCK_WIDTH_IN_ELEMENTS = 1 ;
30
- const int MAX_2D_BLOCK_HEIGHT_IN_ELEMENTS = 32 ;
31
- const int MIN_2D_BLOCK_HEIGHT_IN_ELEMENTS = 1 ;
32
- // TODO: Generalize shapes for different architecture.
33
- const int MAX_TM_SIZE = 8 ;
34
- const int TN_SIZE = 16 ;
35
- const int TK_SIZE_FOR_D16 = 16 ;
36
- const int TK_SIZE_FOR_D8 = 32 ;
37
-
38
28
extern bool printDefaultValues ();
39
29
40
30
static size_t getRankOf (Value value) {
@@ -70,15 +60,14 @@ static std::string makeString(T array, bool breakline = false) {
70
60
return buf;
71
61
}
72
62
73
-
74
63
template <typename CustomEnum, typename CustomEnumAttr>
75
64
static ParseResult parseCustomEnumAttr (OpAsmParser &parser,
76
65
OperationState &result,
77
66
llvm::StringRef attrKeyword) {
78
67
auto loc = parser.getCurrentLocation ();
79
68
auto attrOptional = FieldParser<CustomEnum, CustomEnum>::parse (parser);
80
69
if (failed (attrOptional))
81
- return parser.emitError (loc, " invalid " ) << " attribute specification" ;
70
+ return parser.emitError (loc, " invalid attribute specification" ) ;
82
71
auto attr =
83
72
CustomEnumAttr::get (parser.getBuilder ().getContext (), *attrOptional);
84
73
result.addAttribute (attrKeyword, attr);
@@ -94,13 +83,12 @@ static ParseResult parseBoolAndIntegerAttr(OpAsmParser &parser,
94
83
95
84
if (std::is_same<AttrType, BoolAttr>::value) {
96
85
ty = parser.getBuilder ().getIntegerType (1 );
97
-
98
86
} else if (std::is_same<AttrType, IntegerAttr>::value) {
99
87
ty = parser.getBuilder ().getIntegerType (32 );
100
88
} else if (std::is_same<AttrType, DenseI64ArrayAttr>::value) {
101
89
ty = Type{};
102
90
} else {
103
- assert ( 0 && " Unreachable. \n " );
91
+ llvm_unreachable ( " Unsupported Attribute Type. " );
104
92
}
105
93
106
94
if (parser.parseCustomAttributeWithFallback (attr, ty))
@@ -129,8 +117,7 @@ parseOptionalAttrDict(OpAsmParser &parser, OperationState &result,
129
117
auto loc = parser.getCurrentLocation ();
130
118
llvm::StringRef nameId;
131
119
if (parser.parseOptionalKeyword (&nameId, allowedKeywords))
132
- return parser.emitError (loc, " invalid" )
133
- << " attribute keyword: " << nameId << " .\n " ;
120
+ return parser.emitError (loc, " invalid attribute keyword: " ) << nameId << " .\n " ;
134
121
135
122
if (parser.parseEqual ())
136
123
return failure ();
@@ -155,10 +142,9 @@ parseOptionalAttrDict(OpAsmParser &parser, OperationState &result,
155
142
return parseBoolAndIntegerAttr<BoolAttr>(parser, result, nameId);
156
143
157
144
if (nameId == " transpose" )
158
- return parseBoolAndIntegerAttr<DenseI64ArrayAttr>(parser, result,
159
- nameId);
145
+ return parseBoolAndIntegerAttr<DenseI64ArrayAttr>(parser, result, nameId);
160
146
161
- assert ( 0 && " Unreachable! " );
147
+ llvm_unreachable ( " Unsupported attribute keyword. " );
162
148
};
163
149
164
150
if (parser.parseCommaSeparatedList (parseElt))
@@ -549,8 +535,7 @@ llvm::SmallVector<OpFoldResult> CreateNdDescOp::getStrides() {
549
535
}
550
536
return strides;
551
537
}
552
- emitOpError (" The strides information is missing." );
553
- llvm_unreachable (" Unexpected error in CreateNdDescOp.\n " );
538
+ llvm_unreachable (" Unexpected error in CreateNdDescOp. The strides information is missing.\n " );
554
539
}
555
540
556
541
// / Return the element type of the TensorDesc
@@ -808,9 +793,9 @@ LogicalResult LoadNDOp::verify() {
808
793
auto tdescTy = getTensorDescType ();
809
794
auto valueTy = getValueType ();
810
795
811
- if (tdescTy.getRank () > 2 )
796
+ if (tdescTy.getRank () != 2 )
812
797
return emitOpError (
813
- " The TensorDesc for LoadNDOp should be a 2D/1D TensorDesc." );
798
+ " The TensorDesc for LoadNDOp should be a 2D TensorDesc." );
814
799
815
800
if (!valueTy)
816
801
return emitOpError (" Invalid result, it should be a VectorType.\n " );
@@ -822,31 +807,6 @@ LogicalResult LoadNDOp::verify() {
822
807
return emitOpError (
823
808
" Value should have the same element type as TensorDesc." );
824
809
825
- if (tdescTy.getRank () == 2 ) {
826
- // TODO: The following logic are architecture
827
- // dependent, pending to be moved out
828
- auto width = tdescTy.getShape ()[1 ];
829
- auto height = tdescTy.getShape ()[0 ];
830
- auto elemTyByteWidth = tdescElemTy.getIntOrFloatBitWidth () / 8 ;
831
-
832
- if (width < MIN_2D_BLOCK_WIDTH_IN_ELEMENTS ||
833
- width > MAX_2D_BLOCK_WIDTH_IN_ELEMENTS ||
834
- (width * elemTyByteWidth) % 4 != 0 ) {
835
- return emitOpError (
836
- " Invalid width size for 2D block load. "
837
- " The specification expects the value to "
838
- " be in range [1, 64], and The the total "
839
- " data size (width * elemTyBytes) to be multiple of 4.\n " );
840
- }
841
-
842
- if (height < MIN_2D_BLOCK_HEIGHT_IN_ELEMENTS ||
843
- height > MAX_2D_BLOCK_HEIGHT_IN_ELEMENTS) {
844
- return emitOpError (" Invalid height size for 2D block load. The "
845
- " specification expects the "
846
- " value to be in range [1, 32].\n " );
847
- }
848
- }
849
-
850
810
auto mode = getMode ();
851
811
auto tdescShape = tdescTy.getShape ().vec ();
852
812
auto valueShape = valueTy.getShape ().vec ();
@@ -993,10 +953,10 @@ void StoreNDOp::print(OpAsmPrinter &printer) {
993
953
}
994
954
995
955
LogicalResult StoreNDOp::verify () {
996
- auto dstTy = getTensorDesc ().getType (); // Tile
956
+ auto dstTy = getTensorDesc ().getType (); // Tile
997
957
auto valTy = llvm::dyn_cast<VectorType>(getValue ().getType ()); // Vector
998
958
999
- if (dstTy.getRank () > 2 )
959
+ if (dstTy.getRank () != 2 )
1000
960
return emitOpError (
1001
961
" The TensorDesc for StoreNdOp should be a 2D TensorDesc." );
1002
962
@@ -1011,30 +971,6 @@ LogicalResult StoreNDOp::verify() {
1011
971
" the elem type of memory (dst) shape.\n " );
1012
972
}
1013
973
1014
- if (dstTy.getRank () == 2 ) { // TODO: The following logic are architecture
1015
- // dependent, pending to be moved
1016
- // out
1017
- auto width = dstTy.getShape ()[1 ];
1018
- auto height = dstTy.getShape ()[0 ];
1019
- auto elemTyByteWidth = dstElemTy.getIntOrFloatBitWidth () / 8 ;
1020
- if (width < MIN_2D_BLOCK_WIDTH_IN_ELEMENTS ||
1021
- width > MAX_2D_BLOCK_WIDTH_IN_ELEMENTS ||
1022
- (width * elemTyByteWidth) % 4 != 0 ) {
1023
- return emitOpError (
1024
- " Invalid width size for 2D block write. "
1025
- " The specification expects the value to "
1026
- " be in range [1, 64], and The the total "
1027
- " data size (width * elemTyBytes) to be multiple of 4.\n " );
1028
- }
1029
-
1030
- if (height < MIN_2D_BLOCK_HEIGHT_IN_ELEMENTS ||
1031
- height > MAX_2D_BLOCK_HEIGHT_IN_ELEMENTS) {
1032
- return emitOpError (
1033
- " Invalid height size for 2D block write. The specification"
1034
- " expects the value to be in range [1, 32].\n " );
1035
- }
1036
- }
1037
-
1038
974
auto mode = getMode ();
1039
975
1040
976
if (mode == Mode::VC) { // for VC mode, no attr attached
@@ -1285,7 +1221,7 @@ LogicalResult LoadGatherOp::verify() {
1285
1221
return llvm::dyn_cast<VectorType>(type).getElementType ();
1286
1222
else if (llvm::isa<TensorDescType>(type))
1287
1223
return llvm::dyn_cast<TensorDescType>(type).getElementType ();
1288
- assert ( 0 && " Unreachable !!! " );
1224
+ llvm_unreachable ( " Unsupported type. " );
1289
1225
return type;
1290
1226
};
1291
1227
@@ -1295,19 +1231,20 @@ LogicalResult LoadGatherOp::verify() {
1295
1231
return emitOpError (
1296
1232
" Value should have the same element type as TensorDesc." );
1297
1233
1298
- auto getShape = [&](Type type, std::vector<int64_t > &shape) -> void {
1234
+ auto getShape = [&](Type type) -> std::vector<int64_t > {
1235
+ std::vector<int64_t > shape;
1299
1236
if (type.isIntOrIndexOrFloat ())
1300
1237
shape.push_back (1 );
1301
1238
else if (llvm::isa<VectorType>(type))
1302
1239
shape = llvm::dyn_cast<VectorType>(type).getShape ().vec ();
1303
1240
else
1304
- assert (0 && " Unreachable !!!" );
1241
+ llvm_unreachable (" Unsupported type." );
1242
+ return shape;
1305
1243
};
1306
1244
1307
- std::vector<int64_t > maskShape, valueShape;
1308
- getShape (maskTy, maskShape);
1309
- getShape (valueTy, valueShape);
1310
- auto tdescShape = tdescTy.getShape ().vec ();
1245
+ std::vector<int64_t > maskShape = getShape (maskTy);
1246
+ std::vector<int64_t > valueShape = getShape (valueTy);
1247
+ std::vector<int64_t > tdescShape = tdescTy.getShape ().vec ();
1311
1248
1312
1249
if (tdescShape != maskShape)
1313
1250
return emitOpError (" Mask should have the same shape as TensorDesc." );
@@ -1508,41 +1445,39 @@ void StoreScatterOp::print(OpAsmPrinter &printer) {
1508
1445
}
1509
1446
1510
1447
LogicalResult StoreScatterOp::verify () {
1511
- auto valueTy = getValue ().getType ();
1512
1448
auto tdescTy = getTensorDesc ().getType ();
1449
+ auto valueTy = getValue ().getType ();
1513
1450
auto maskTy = getMask ().getType ();
1451
+ auto mode = getMode ();
1452
+ auto mapping = tdescTy.getMapping ();
1453
+
1454
+ if (mode != Mode::VC || mapping)
1455
+ return emitOpError (" StoreScatterOp only supports VC mode and mapping "
1456
+ " attribute of TensorDesc is not expected.\n " );
1514
1457
1515
1458
if (!tdescTy.getScattered ())
1516
1459
return emitOpError (" Invalid TensorDesc. StoreScatterOp only works on "
1517
1460
" TensorDescs with ScatteredAttr." );
1518
1461
1519
- std::vector<int64_t > valueShape, maskShape;
1520
- auto getShape = [&](Type type, std::vector<int64_t > & shape) -> void {
1462
+ auto getShape = [&](Type type) -> std::vector<int64_t > {
1463
+ std::vector<int64_t > shape;
1521
1464
if (type.isIntOrIndexOrFloat ())
1522
1465
shape.push_back (1 );
1523
1466
else if (llvm::isa<VectorType>(type))
1524
1467
shape = llvm::dyn_cast<VectorType>(type).getShape ().vec ();
1525
1468
else
1526
- assert (0 && " Unreachable !!!" );
1469
+ llvm_unreachable (" Unsupported type." );
1470
+ return shape;
1527
1471
};
1528
1472
1529
- getShape (valueTy, valueShape);
1530
- getShape (maskTy, maskShape);
1473
+ std::vector<int64_t > maskShape = getShape (maskTy);
1474
+ std::vector<int64_t > valueShape = getShape (valueTy);
1475
+ std::vector<int64_t > tdescShape = tdescTy.getShape ().vec ();
1531
1476
1532
1477
if (valueShape != maskShape) {
1533
1478
return emitOpError (" Mask and value should have the same shape/size" );
1534
1479
}
1535
1480
1536
- auto tdescShape = tdescTy.getShape ().vec ();
1537
-
1538
- auto mode = getMode ();
1539
- auto mapping = tdescTy.getMapping ();
1540
-
1541
- if (mode != Mode::VC || mapping) {
1542
- return emitOpError (" StoreScatterOp only supports VC mode and mapping "
1543
- " attribute of TensorDesc is not expected.\n " );
1544
- }
1545
-
1546
1481
if (tdescShape != valueShape) {
1547
1482
return emitOpError (" TensorDesc shape and value shape doesn't match. " )
1548
1483
<< " The expected/derived value shape is: " << makeString (tdescShape)
0 commit comments