Skip to content

Commit b43d0bc

Browse files
committed
remove a few more PVC logics
1 parent c007496 commit b43d0bc

File tree

1 file changed

+32
-97
lines changed

1 file changed

+32
-97
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 32 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,6 @@
2525
namespace mlir {
2626
namespace xegpu {
2727

28-
const int MAX_2D_BLOCK_WIDTH_IN_ELEMENTS = 64;
29-
const int MIN_2D_BLOCK_WIDTH_IN_ELEMENTS = 1;
30-
const int MAX_2D_BLOCK_HEIGHT_IN_ELEMENTS = 32;
31-
const int MIN_2D_BLOCK_HEIGHT_IN_ELEMENTS = 1;
32-
// TODO: Generalize shapes for different architecture.
33-
const int MAX_TM_SIZE = 8;
34-
const int TN_SIZE = 16;
35-
const int TK_SIZE_FOR_D16 = 16;
36-
const int TK_SIZE_FOR_D8 = 32;
37-
3828
extern bool printDefaultValues();
3929

4030
static size_t getRankOf(Value value) {
@@ -70,15 +60,14 @@ static std::string makeString(T array, bool breakline = false) {
7060
return buf;
7161
}
7262

73-
7463
template <typename CustomEnum, typename CustomEnumAttr>
7564
static ParseResult parseCustomEnumAttr(OpAsmParser &parser,
7665
OperationState &result,
7766
llvm::StringRef attrKeyword) {
7867
auto loc = parser.getCurrentLocation();
7968
auto attrOptional = FieldParser<CustomEnum, CustomEnum>::parse(parser);
8069
if (failed(attrOptional))
81-
return parser.emitError(loc, "invalid ") << "attribute specification";
70+
return parser.emitError(loc, "invalid attribute specification");
8271
auto attr =
8372
CustomEnumAttr::get(parser.getBuilder().getContext(), *attrOptional);
8473
result.addAttribute(attrKeyword, attr);
@@ -94,13 +83,12 @@ static ParseResult parseBoolAndIntegerAttr(OpAsmParser &parser,
9483

9584
if (std::is_same<AttrType, BoolAttr>::value) {
9685
ty = parser.getBuilder().getIntegerType(1);
97-
9886
} else if (std::is_same<AttrType, IntegerAttr>::value) {
9987
ty = parser.getBuilder().getIntegerType(32);
10088
} else if (std::is_same<AttrType, DenseI64ArrayAttr>::value) {
10189
ty = Type{};
10290
} else {
103-
assert(0 && "Unreachable.\n");
91+
llvm_unreachable("Unsupported Attribute Type.");
10492
}
10593

10694
if (parser.parseCustomAttributeWithFallback(attr, ty))
@@ -129,8 +117,7 @@ parseOptionalAttrDict(OpAsmParser &parser, OperationState &result,
129117
auto loc = parser.getCurrentLocation();
130118
llvm::StringRef nameId;
131119
if (parser.parseOptionalKeyword(&nameId, allowedKeywords))
132-
return parser.emitError(loc, "invalid")
133-
<< "attribute keyword: " << nameId << ".\n";
120+
return parser.emitError(loc, "invalid attribute keyword: ") << nameId << ".\n";
134121

135122
if (parser.parseEqual())
136123
return failure();
@@ -155,10 +142,9 @@ parseOptionalAttrDict(OpAsmParser &parser, OperationState &result,
155142
return parseBoolAndIntegerAttr<BoolAttr>(parser, result, nameId);
156143

157144
if (nameId == "transpose")
158-
return parseBoolAndIntegerAttr<DenseI64ArrayAttr>(parser, result,
159-
nameId);
145+
return parseBoolAndIntegerAttr<DenseI64ArrayAttr>(parser, result, nameId);
160146

161-
assert(0 && "Unreachable!");
147+
llvm_unreachable("Unsupported attribute keyword.");
162148
};
163149

164150
if (parser.parseCommaSeparatedList(parseElt))
@@ -549,8 +535,7 @@ llvm::SmallVector<OpFoldResult> CreateNdDescOp::getStrides() {
549535
}
550536
return strides;
551537
}
552-
emitOpError("The strides information is missing.");
553-
llvm_unreachable("Unexpected error in CreateNdDescOp.\n");
538+
llvm_unreachable("Unexpected error in CreateNdDescOp. The strides information is missing.\n");
554539
}
555540

556541
/// Return the element type of the TensorDesc
@@ -808,9 +793,9 @@ LogicalResult LoadNDOp::verify() {
808793
auto tdescTy = getTensorDescType();
809794
auto valueTy = getValueType();
810795

811-
if (tdescTy.getRank() > 2)
796+
if (tdescTy.getRank() != 2)
812797
return emitOpError(
813-
"The TensorDesc for LoadNDOp should be a 2D/1D TensorDesc.");
798+
"The TensorDesc for LoadNDOp should be a 2D TensorDesc.");
814799

815800
if (!valueTy)
816801
return emitOpError("Invalid result, it should be a VectorType.\n");
@@ -822,31 +807,6 @@ LogicalResult LoadNDOp::verify() {
822807
return emitOpError(
823808
"Value should have the same element type as TensorDesc.");
824809

825-
if (tdescTy.getRank() == 2) {
826-
// TODO: The following logic are architecture
827-
// dependent, pending to be moved out
828-
auto width = tdescTy.getShape()[1];
829-
auto height = tdescTy.getShape()[0];
830-
auto elemTyByteWidth = tdescElemTy.getIntOrFloatBitWidth() / 8;
831-
832-
if (width < MIN_2D_BLOCK_WIDTH_IN_ELEMENTS ||
833-
width > MAX_2D_BLOCK_WIDTH_IN_ELEMENTS ||
834-
(width * elemTyByteWidth) % 4 != 0) {
835-
return emitOpError(
836-
"Invalid width size for 2D block load. "
837-
"The specification expects the value to "
838-
"be in range [1, 64], and The the total "
839-
"data size (width * elemTyBytes) to be multiple of 4.\n");
840-
}
841-
842-
if (height < MIN_2D_BLOCK_HEIGHT_IN_ELEMENTS ||
843-
height > MAX_2D_BLOCK_HEIGHT_IN_ELEMENTS) {
844-
return emitOpError("Invalid height size for 2D block load. The "
845-
"specification expects the "
846-
"value to be in range [1, 32].\n");
847-
}
848-
}
849-
850810
auto mode = getMode();
851811
auto tdescShape = tdescTy.getShape().vec();
852812
auto valueShape = valueTy.getShape().vec();
@@ -993,10 +953,10 @@ void StoreNDOp::print(OpAsmPrinter &printer) {
993953
}
994954

995955
LogicalResult StoreNDOp::verify() {
996-
auto dstTy = getTensorDesc().getType(); // Tile
956+
auto dstTy = getTensorDesc().getType(); // Tile
997957
auto valTy = llvm::dyn_cast<VectorType>(getValue().getType()); // Vector
998958

999-
if (dstTy.getRank() > 2)
959+
if (dstTy.getRank() != 2)
1000960
return emitOpError(
1001961
"The TensorDesc for StoreNdOp should be a 2D TensorDesc.");
1002962

@@ -1011,30 +971,6 @@ LogicalResult StoreNDOp::verify() {
1011971
"the elem type of memory (dst) shape.\n");
1012972
}
1013973

1014-
if (dstTy.getRank() == 2) { // TODO: The following logic are architecture
1015-
// dependent, pending to be moved
1016-
// out
1017-
auto width = dstTy.getShape()[1];
1018-
auto height = dstTy.getShape()[0];
1019-
auto elemTyByteWidth = dstElemTy.getIntOrFloatBitWidth() / 8;
1020-
if (width < MIN_2D_BLOCK_WIDTH_IN_ELEMENTS ||
1021-
width > MAX_2D_BLOCK_WIDTH_IN_ELEMENTS ||
1022-
(width * elemTyByteWidth) % 4 != 0) {
1023-
return emitOpError(
1024-
"Invalid width size for 2D block write. "
1025-
"The specification expects the value to "
1026-
"be in range [1, 64], and The the total "
1027-
"data size (width * elemTyBytes) to be multiple of 4.\n");
1028-
}
1029-
1030-
if (height < MIN_2D_BLOCK_HEIGHT_IN_ELEMENTS ||
1031-
height > MAX_2D_BLOCK_HEIGHT_IN_ELEMENTS) {
1032-
return emitOpError(
1033-
"Invalid height size for 2D block write. The specification"
1034-
"expects the value to be in range [1, 32].\n");
1035-
}
1036-
}
1037-
1038974
auto mode = getMode();
1039975

1040976
if (mode == Mode::VC) { // for VC mode, no attr attached
@@ -1285,7 +1221,7 @@ LogicalResult LoadGatherOp::verify() {
12851221
return llvm::dyn_cast<VectorType>(type).getElementType();
12861222
else if (llvm::isa<TensorDescType>(type))
12871223
return llvm::dyn_cast<TensorDescType>(type).getElementType();
1288-
assert(0 && "Unreachable !!!");
1224+
llvm_unreachable("Unsupported type.");
12891225
return type;
12901226
};
12911227

@@ -1295,19 +1231,20 @@ LogicalResult LoadGatherOp::verify() {
12951231
return emitOpError(
12961232
"Value should have the same element type as TensorDesc.");
12971233

1298-
auto getShape = [&](Type type, std::vector<int64_t> &shape) -> void {
1234+
auto getShape = [&](Type type) -> std::vector<int64_t> {
1235+
std::vector<int64_t> shape;
12991236
if (type.isIntOrIndexOrFloat())
13001237
shape.push_back(1);
13011238
else if (llvm::isa<VectorType>(type))
13021239
shape = llvm::dyn_cast<VectorType>(type).getShape().vec();
13031240
else
1304-
assert(0 && "Unreachable !!!");
1241+
llvm_unreachable("Unsupported type.");
1242+
return shape;
13051243
};
13061244

1307-
std::vector<int64_t> maskShape, valueShape;
1308-
getShape(maskTy, maskShape);
1309-
getShape(valueTy, valueShape);
1310-
auto tdescShape = tdescTy.getShape().vec();
1245+
std::vector<int64_t> maskShape = getShape(maskTy);
1246+
std::vector<int64_t> valueShape = getShape(valueTy);
1247+
std::vector<int64_t> tdescShape = tdescTy.getShape().vec();
13111248

13121249
if (tdescShape != maskShape)
13131250
return emitOpError("Mask should have the same shape as TensorDesc.");
@@ -1508,41 +1445,39 @@ void StoreScatterOp::print(OpAsmPrinter &printer) {
15081445
}
15091446

15101447
LogicalResult StoreScatterOp::verify() {
1511-
auto valueTy = getValue().getType();
15121448
auto tdescTy = getTensorDesc().getType();
1449+
auto valueTy = getValue().getType();
15131450
auto maskTy = getMask().getType();
1451+
auto mode = getMode();
1452+
auto mapping = tdescTy.getMapping();
1453+
1454+
if (mode != Mode::VC || mapping)
1455+
return emitOpError("StoreScatterOp only supports VC mode and mapping "
1456+
"attribute of TensorDesc is not expected.\n");
15141457

15151458
if (!tdescTy.getScattered())
15161459
return emitOpError("Invalid TensorDesc. StoreScatterOp only works on "
15171460
"TensorDescs with ScatteredAttr.");
15181461

1519-
std::vector<int64_t> valueShape, maskShape;
1520-
auto getShape = [&](Type type, std::vector<int64_t> &shape) -> void {
1462+
auto getShape = [&](Type type) -> std::vector<int64_t> {
1463+
std::vector<int64_t> shape;
15211464
if (type.isIntOrIndexOrFloat())
15221465
shape.push_back(1);
15231466
else if (llvm::isa<VectorType>(type))
15241467
shape = llvm::dyn_cast<VectorType>(type).getShape().vec();
15251468
else
1526-
assert(0 && "Unreachable !!!");
1469+
llvm_unreachable("Unsupported type.");
1470+
return shape;
15271471
};
15281472

1529-
getShape(valueTy, valueShape);
1530-
getShape(maskTy, maskShape);
1473+
std::vector<int64_t> maskShape = getShape(maskTy);
1474+
std::vector<int64_t> valueShape = getShape(valueTy);
1475+
std::vector<int64_t> tdescShape = tdescTy.getShape().vec();
15311476

15321477
if (valueShape != maskShape) {
15331478
return emitOpError("Mask and value should have the same shape/size");
15341479
}
15351480

1536-
auto tdescShape = tdescTy.getShape().vec();
1537-
1538-
auto mode = getMode();
1539-
auto mapping = tdescTy.getMapping();
1540-
1541-
if (mode != Mode::VC || mapping) {
1542-
return emitOpError("StoreScatterOp only supports VC mode and mapping "
1543-
"attribute of TensorDesc is not expected.\n");
1544-
}
1545-
15461481
if (tdescShape != valueShape) {
15471482
return emitOpError("TensorDesc shape and value shape doesn't match. ")
15481483
<< "The expected/derived value shape is: " << makeString(tdescShape)

0 commit comments

Comments
 (0)