Skip to content

Commit 27046ba

Browse files
authored
[mlir][XeGPU] Add a builder for xegpu.create_nd_tdesc op. (#116472)
The builder is needed to support dynamic meref as source operand in xegpu.create_nd_tdesc op.
1 parent df13acf commit 27046ba

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,11 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
130130
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source,
131131
"llvm::ArrayRef<OpFoldResult>": $offsets)>,
132132

133+
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
134+
"llvm::ArrayRef<OpFoldResult>": $offsets,
135+
"llvm::ArrayRef<OpFoldResult>": $shape,
136+
"llvm::ArrayRef<OpFoldResult>": $strides)>,
137+
133138
OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
134139
"llvm::ArrayRef<OpFoldResult>": $offsets,
135140
"llvm::ArrayRef<OpFoldResult>": $shape,

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,33 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
9393
{} /* empty const strides*/);
9494
}
9595

96+
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
97+
Type tdesc, TypedValue<MemRefType> source,
98+
llvm::ArrayRef<OpFoldResult> offsets,
99+
llvm::ArrayRef<OpFoldResult> shape,
100+
llvm::ArrayRef<OpFoldResult> strides) {
101+
assert(shape.size() && offsets.size() && strides.size() &&
102+
shape.size() == strides.size() && shape.size() == offsets.size());
103+
104+
llvm::SmallVector<int64_t> staticOffsets;
105+
llvm::SmallVector<int64_t> staticShape;
106+
llvm::SmallVector<int64_t> staticStrides;
107+
llvm::SmallVector<Value> dynamicOffsets;
108+
llvm::SmallVector<Value> dynamicShape;
109+
llvm::SmallVector<Value> dynamicStrides;
110+
111+
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
112+
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
113+
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
114+
115+
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
116+
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
117+
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
118+
119+
build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
120+
dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
121+
}
122+
96123
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
97124
Type tdesc, TypedValue<IntegerType> source,
98125
llvm::ArrayRef<OpFoldResult> offsets,

0 commit comments

Comments
 (0)