Skip to content

Commit bd49e3a

Browse files
authored
benchgc: support transpose op (#339)
1 parent be1bcb9 commit bd49e3a

File tree

3 files changed

+29
-0
lines changed

3 files changed

+29
-0
lines changed

scripts/correctness.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ python3 -m benchgc --verbose 0 --driver linalg --case reduce.l2_square --md 0:12
2323
python3 -m benchgc --verbose 0 --driver linalg --case fill --md 0:f32 --md 1:32x4096xf32 --cmp 1:P:0:0 || FAIL=1
2424
python3 -m benchgc --verbose 0 --driver linalg --case copy --md 0:1024x1024xf32 --md 1:1024x1024xbf16 || FAIL=1
2525
python3 -m benchgc --verbose 0 --driver linalg --case broadcast --md 0:1024xf32 --md 1:2x32x1024xf32 --dimensions=0 --dimensions=1 || FAIL=1
26+
python3 -m benchgc --verbose 0 --driver linalg --case transpose --md 0:32x64x128xf32 --md 1:64x128x32xf32 --permutation=1 --permutation=2 --permutation=0 || FAIL=1
2627

2728
# matmul
2829
python3 -m benchgc --verbose 0 --driver linalg --case batch_matmul --md 0:16x512x64xf32 --md 1:16x64x32xf32 --md 2:16x512x32xf32 || FAIL=1

test/benchgc/src/benchgc/__main__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,14 @@ def add_common_options(parser: argparse.ArgumentParser):
187187
type=int,
188188
)
189189

190+
parser.add_argument(
191+
"--permutation",
192+
required=False,
193+
default=None,
194+
action="append",
195+
help="define the permutation attribute in linalg op",
196+
type=int,
197+
)
190198

191199
def add_bench_options(parser: argparse.ArgumentParser):
192200
"""add options for bench mode"""

test/benchgc/src/benchgc/linalg/misc.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,23 @@ def mlir_copy(flags: argparse.Namespace, args: List[Arg]) -> ir.Module:
9797
)
9898
],
9999
)
100+
101+
102+
def ref_transpose(
103+
cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor]
104+
) -> Tuple[torch.Tensor, ...]:
105+
permutation: List[int] = [int(d) for d in op.attributes["permutation"]]
106+
return (var[cache.opr[0]].permute(permutation).contiguous(),)
107+
108+
109+
def mlir_transpose(flags: argparse.Namespace, args: List[Arg]) -> ir.Module:
110+
return init_module(
111+
flags.entry,
112+
(args[0],),
113+
(args[1],),
114+
lambda ctx, arg0: [
115+
linalg.transpose(
116+
arg0, outs=[args[1].get_zero_op(ctx)], permutation=flags.permutation
117+
)
118+
],
119+
)

0 commit comments

Comments
 (0)