Skip to content

Commit 8c2b709

Browse files
committed
[mlir][xegpu] TensorDesc verifier
Adds XeGPU tensor descriptor type verifier. The checks focus on ensuring that provided subgroup map is valid with respect to the underlying data.
1 parent ac87d6b commit 8c2b709

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
179179
}];
180180

181181
let hasCustomAssemblyFormat = true;
182-
182+
let genVerifyDecl = 1;
183183
}
184184

185185

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,10 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
175175
if (parser.parseGreater())
176176
return {};
177177

178-
return TensorDescType::get(parser.getContext(), shape, elementType,
179-
encoding.value_or(mlir::Attribute()),
180-
sg_map.value_or(mlir::Attribute()));
178+
return TensorDescType::getChecked(
179+
[&]() { return parser.emitError(parser.getNameLoc()); },
180+
parser.getContext(), shape, elementType,
181+
encoding.value_or(mlir::Attribute()), sg_map.value_or(mlir::Attribute()));
181182
}
182183

183184
void TensorDescType::print(::mlir::AsmPrinter &printer) const {
@@ -223,6 +224,55 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
223224
return Base::get(context, shape, elementType, attr, sg_map);
224225
}
225226

227+
LogicalResult TensorDescType::verify(
228+
llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
229+
llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
230+
mlir::Attribute encoding, mlir::Attribute sg_map) {
231+
size_t rank = shape.size();
232+
if (rank > 2)
233+
return emitError() << "desc shape rank exceeds 2";
234+
235+
if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
236+
ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
237+
ArrayRef<uint32_t> wiData = sgMapAttr.getWiData();
238+
239+
if (rank == 1) {
240+
if (wiLayout[0] != 1 || wiData[0] != 1)
241+
return emitError() << "outer layout and data mapping must be 1 "
242+
"for 1D tensor";
243+
}
244+
245+
// For 1D tensor, pad the shape with an outer unit dimension to allow common
246+
// validation logic.
247+
SmallVector<int64_t> tensorShape(shape.begin(), shape.end());
248+
if (rank == 1)
249+
tensorShape = {1, tensorShape.back()};
250+
251+
size_t dims = tensorShape.size();
252+
for (size_t i = 0; i < dims; ++i) {
253+
uint32_t numElemPerWi = wiLayout[i] * wiData[i];
254+
if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0)
255+
return emitError() << "cannot map " << tensorShape[i]
256+
<< " elements into " << wiLayout[i] << " by "
257+
<< wiData[i] << " tiles";
258+
}
259+
260+
if (llvm::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
261+
auto scatterAttr = llvm::dyn_cast<ScatterTensorDescAttr>(encoding);
262+
if (wiData[0] != 1)
263+
return emitError()
264+
<< "cannot map over non-contiguous scattered elements";
265+
266+
unsigned chunkSize = scatterAttr.getChunkSize().getInt();
267+
if (wiData[1] > chunkSize)
268+
return emitError()
269+
<< "too few contiguous elements for work item mapping";
270+
}
271+
}
272+
273+
return success();
274+
}
275+
226276
} // namespace xegpu
227277
} // namespace mlir
228278

0 commit comments

Comments
 (0)