Skip to content

Commit 8c9a2ed

Browse files
committed
Use new TensorMapDescriptorType
1 parent 27f441b commit 8c9a2ed

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,18 @@ def __init__(self, swizzle, l2promo, oob, interleave, tma_box_shape, memref_ty):
4242
@property
4343
def tensormap_descriptor_ty(self):
4444
"""Returns a tensormap descriptor type."""
45-
memref_str = f"memref<{self.tma_box_shape[0]}x{self.tma_box_shape[1]}x{self.memref_ty.element_type}, 3>"
46-
parse_str = f"!nvgpu.tensormap.descriptor<tensor = {memref_str},\
47-
swizzle = {self.swizzle},\
48-
l2promo = {self.l2promo},\
49-
oob = {self.oob},\
50-
interleave = {self.interleave}>"
51-
return ir.Type.parse(parse_str)
45+
tensorMemrefType = ir.MemRefType.get(
46+
self.tma_box_shape,
47+
self.memref_ty.element_type,
48+
memory_space=ir.Attribute.parse("3"),
49+
)
50+
return nvgpu.TensorMapDescriptorType.get(
51+
tensorMemrefType,
52+
self.swizzle,
53+
self.l2promo,
54+
self.oob,
55+
self.interleave,
56+
)
5257

5358
def tma_descriptor_op(self, device_ptr):
5459
"""Returns a tensormap descriptor op."""

0 commit comments

Comments
 (0)