Skip to content

Commit 11ed38f

Browse files
committed
support reduce op with fast implementation
1 parent d72b1f1 commit 11ed38f

File tree

16 files changed

+553
-113
lines changed

16 files changed

+553
-113
lines changed

scripts/correctness.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@ python3 -m benchgc --verbose 0 --driver linalg --case matmul --md 0:32x128xbf16
1010

1111
# f32
1212

13+
# reduce
14+
15+
python3 -m benchgc --verbose 0 --driver linalg --case reduce.add --md 0:128x64x8xf32 --md 1:128xf32 --dimensions=1 --dimensions=2 || FAIL=1
16+
python3 -m benchgc --verbose 0 --driver linalg --case reduce.mul --md 0:128x8xf32 --md 1:128xf32 --dimensions=1 || FAIL=1
17+
python3 -m benchgc --verbose 0 --driver linalg --case reduce.max --md 0:128x64x8xf32 --md 1:128xf32 --dimensions=1 --dimensions=2 || FAIL=1
18+
python3 -m benchgc --verbose 0 --driver linalg --case reduce.min --md 0:128x64x8xf32 --md 1:128xf32 --dimensions=1 --dimensions=2 || FAIL=1
19+
python3 -m benchgc --verbose 0 --driver linalg --case reduce.l1 --md 0:128x64x8xf32 --md 1:128xf32 --dimensions=1 --dimensions=2 || FAIL=1
20+
python3 -m benchgc --verbose 0 --driver linalg --case reduce.l2_square --md 0:128x64x8xf32 --md 1:128xf32 --dimensions=1 --dimensions=2 || FAIL=1
21+
1322
# misc
1423
python3 -m benchgc --verbose 0 --driver linalg --case fill --md 0:f32 --md 1:32x4096xf32 --cmp 1:P:0:0 || FAIL=1
1524
python3 -m benchgc --verbose 0 --driver linalg --case copy --md 0:1024x1024xf32 --md 1:1024x1024xbf16 || FAIL=1

test/benchgc/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,4 @@ add_subdirectory("src/benchgc/mlir")
3939
add_subdirectory("src/benchgc/linalg")
4040
add_subdirectory("src/benchgc/tensor")
4141
add_subdirectory("src/benchgc/arith")
42+
add_subdirectory("src/benchgc/math")

test/benchgc/src/benchgc/__main__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,10 @@
189189

190190
from .linalg import mlir_op
191191

192-
mlir_func = mlir_op[flags.case]
192+
if flags.case.startswith("reduce."):
193+
mlir_func = mlir_op["reduce"]
194+
else:
195+
mlir_func = mlir_op[flags.case]
193196
module = mlir_func(flags, args)
194197
else:
195198
raise Exception(f"unsupported driver {flags.driver}")
@@ -207,7 +210,6 @@
207210
raise Exception("Wrong cmp format: %s", cmp)
208211
idx = int(cmp[:colon])
209212
args[idx].set_cmp(cmp[colon + 1 :])
210-
211213
entry = benchgc.mlir.util.get_entry(module)
212214

213215
for i, arg in enumerate(args):

test/benchgc/src/benchgc/arg/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import benchgc.arg.eltwise as eltwise
2424
import benchgc.arg.matmul as matmul
2525
import benchgc.arg.pool as pool
26+
import benchgc.arg.reduce as reduce
2627
import benchgc.arg.softmax as softmax
2728
import benchgc.util
2829
import torch
@@ -35,6 +36,7 @@
3536
"softmax": softmax,
3637
"conv": conv,
3738
"pool": pool,
39+
"reduce": reduce,
3840
}
3941

4042

test/benchgc/src/benchgc/arg/reduce.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,42 @@
1414
# limitations under the License.
1515
################################################################################
1616

17-
from typing import List, Tuple
17+
import argparse
18+
from typing import List, Set, Tuple
1819

1920
import benchgc.arg
2021
import benchgc.util
2122
import torch
22-
23+
from benchgc.arg.arg import Arg
24+
from benchgc.arg.compare import p2p
25+
26+
op: Set[str] = set(
27+
[
28+
"linalg.reduce.add",
29+
"linalg.reduce.mul",
30+
"linalg.reduce.max",
31+
"linalg.reduce.min",
32+
"linalg.reduce.l1",
33+
"linalg.reduce.l2_square",
34+
]
35+
)
36+
37+
38+
def default_fill(
39+
flags: argparse.Namespace,
40+
arg: Arg,
41+
arglist: List[Arg],
42+
):
43+
if arg.index > 0:
44+
raise Exception("reduce fill: dst filling is not allowed")
45+
arg.fill_param = [
46+
"reduce",
47+
flags.case,
48+
arglist[0].dtype,
49+
arglist[1].dtype,
50+
str(arglist[0].nelem() // arglist[1].nelem()),
51+
]
52+
arg.fill_type = "D"
2353

2454
def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tensor:
2555

@@ -30,22 +60,17 @@ def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tenso
3060

3161
safe_to_reduce_elems: int = benchgc.util.get_problem_bounds(op, sdtype)[0]
3262

33-
neutral_value: float = 1.0 if op == "mul" else 0.0
63+
neutral_value: float = 1.0 if op == "reduce.mul" else 0.0
3464

3565
shift: float = (
3666
1.0
37-
if (
38-
op == "mean"
39-
or op == "min"
40-
and not sdtype.is_signed
41-
and not ddtype.is_signed
42-
)
67+
if (op == "reduce.min" and not sdtype.is_signed and not ddtype.is_signed)
4368
else 0.0
4469
)
4570

4671
value_range: int = benchgc.util.get_problem_bounds(op, sdtype)[1]
4772

48-
is_mul_fp: bool = op == "mul" and sdtype.is_floating_point
73+
is_mul_fp: bool = op == "reduce.mul" and sdtype.is_floating_point
4974
min_range: int = -value_range if is_mul_fp else 1
5075

5176
index = torch.arange(benchgc.util.nelem(shape)).reshape(shape)
@@ -69,10 +94,18 @@ def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tenso
6994
return value.to(dtype)
7095

7196

97+
def default_compare(
98+
flags: argparse.Namespace,
99+
arg: Arg,
100+
arglist: List[Arg],
101+
):
102+
arg.cmp_type = "D"
103+
arg.cmp_param = ["reduce", arg.dtype, flags.case]
104+
72105
def compare(
73-
ref: torch.Tensor, res: torch.Tensor, verbose: int
106+
param: List[str], ref: torch.Tensor, res: torch.Tensor, verbose: int
74107
) -> Tuple[bool, bool | None]:
75108
dtype = ref.dtype
76109
ref = ref.to(torch.float)
77110
res = res.to(torch.float)
78-
return benchgc.arg.p2p(benchgc.util.get_eps(dtype), 30.0, ref, res, verbose)
111+
return p2p(benchgc.util.get_eps(dtype), 30.0, ref, res, verbose)

test/benchgc/src/benchgc/arith/basic.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,19 @@ def ref_constant(
4242
)
4343
else:
4444
raise Exception("only support splat value now")
45+
elif isinstance(value, gc_mlir._mlir_libs._mlir.ir.IntegerAttr):
46+
return (torch.full(size=tuple(), fill_value=value.__int__(), dtype=torch.int),)
47+
elif isinstance(value, gc_mlir._mlir_libs._mlir.ir.DenseIntElementsAttr):
48+
if value.is_splat:
49+
return (
50+
torch.full(
51+
size=tuple(value.type.shape),
52+
fill_value=value.get_splat_value().value,
53+
dtype=benchgc.util.get_dtype(str(value.get_splat_value().type)),
54+
),
55+
)
56+
else:
57+
raise Exception("only support splat value now")
4558
else:
4659
raise Exception("Not support constant type %s", type(value))
4760

@@ -56,3 +69,39 @@ def ref_addf(
5669
cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor]
5770
) -> Tuple[torch.Tensor, ...]:
5871
return (var[cache.opr[0]] + var[cache.opr[1]],)
72+
73+
74+
def ref_maxf(
75+
cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor]
76+
) -> Tuple[torch.Tensor, ...]:
77+
return (torch.max(var[cache.opr[0]], var[cache.opr[1]]),)
78+
79+
80+
def ref_minf(
81+
cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor]
82+
) -> Tuple[torch.Tensor, ...]:
83+
return (torch.min(var[cache.opr[0]], var[cache.opr[1]]),)
84+
85+
86+
def ref_muli(
87+
cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor]
88+
) -> Tuple[torch.Tensor, ...]:
89+
return (var[cache.opr[0]] * var[cache.opr[1]],)
90+
91+
92+
def ref_addi(
93+
cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor]
94+
) -> Tuple[torch.Tensor, ...]:
95+
return (var[cache.opr[0]] + var[cache.opr[1]],)
96+
97+
98+
def ref_maxsi(
99+
cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor]
100+
) -> Tuple[torch.Tensor, ...]:
101+
return (torch.max(var[cache.opr[0]], var[cache.opr[1]]),)
102+
103+
104+
def ref_minsi(
105+
cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor]
106+
) -> Tuple[torch.Tensor, ...]:
107+
return (torch.min(var[cache.opr[0]], var[cache.opr[1]]),)

test/benchgc/src/benchgc/linalg/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
"softmax",
4242
"conv",
4343
"pool",
44+
"reduce",
4445
]:
4546
mod = importlib.import_module(f"benchgc.linalg.{dri}")
4647
for key in mod.__dict__:

test/benchgc/src/benchgc/linalg/generic.py

Lines changed: 0 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -140,97 +140,3 @@ def ref_generic(
140140
return result_tensors
141141

142142

143-
def reduce_loop(
144-
cache: MLIRCache,
145-
op: gc_mlir.ir.OpView,
146-
depth: int,
147-
in_shape: List[int],
148-
var: Dict[str, torch.Tensor],
149-
in_idx: List[int],
150-
out_idx: List[int],
151-
reduced_axis: int,
152-
result_tensor: torch.Tensor,
153-
):
154-
if depth == len(in_shape):
155-
# we need to execute the block here
156-
# we will need to read the block argument name and save it into the cache
157-
158-
block: gc_mlir.ir.Block = op.regions[0].blocks[0]
159-
160-
if len(cache.next) == 0:
161-
# region cache
162-
cache.next.append(MLIRCache())
163-
if len(cache.next[0].next) == 0:
164-
# region->block cache
165-
cache.next[0].next.append(MLIRCache())
166-
for arg in block.arguments:
167-
cache.next[0].next[0].arg.append(arg.get_name())
168-
169-
block_arg: Dict[str, torch.Tensor] = {
170-
# set input
171-
cache.next[0].next[0].arg[0]: var[cache.opr[0]][tuple(in_idx)],
172-
# set output
173-
cache.next[0].next[0].arg[1]: result_tensor[tuple(out_idx)],
174-
}
175-
176-
res: Tuple[torch.Tensor, ...] = benchgc.runner.dfs_block(
177-
cache.next[0].next[0], op.regions[0].blocks[0], var | block_arg
178-
)
179-
180-
# perform the yield operation
181-
result_tensor[tuple(out_idx)] = res[0]
182-
else:
183-
dimensions: gc_mlir.ir.DenseI64ArrayAttr = op.attributes["dimensions"]
184-
reduce_axis: bool = depth in list(dimensions)
185-
186-
for i in range(in_shape[depth]):
187-
if reduce_axis:
188-
in_idx[depth] = i
189-
reduce_loop(
190-
cache,
191-
op,
192-
depth + 1,
193-
in_shape,
194-
var,
195-
in_idx,
196-
out_idx,
197-
reduced_axis + 1,
198-
result_tensor,
199-
)
200-
else:
201-
in_idx[depth] = i
202-
out_idx[depth - reduced_axis] = i
203-
reduce_loop(
204-
cache,
205-
op,
206-
depth + 1,
207-
in_shape,
208-
var,
209-
in_idx,
210-
out_idx,
211-
reduced_axis,
212-
result_tensor,
213-
)
214-
215-
216-
def ref_reduce(
217-
cache: MLIRCache, op: gc_mlir.ir.OpView, tensors: Dict[str, torch.Tensor]
218-
) -> Tuple[torch.Tensor, ...]:
219-
# create the buffer for result tensors
220-
tensors[cache.res[0]] = tensors[cache.opr[-1]].clone()
221-
in_shape: List[int] = list(op.operands[0].type.shape)
222-
out_shape: List[int] = list(op.result.type.shape)
223-
224-
result_tensor: torch.Tensor = tensors[cache.opr[-1]].clone()
225-
reduce_loop(
226-
cache,
227-
op,
228-
0,
229-
in_shape,
230-
tensors,
231-
[0] * len(in_shape),
232-
[0] * len(out_shape),
233-
0,
234-
result_tensor,
235-
)
236-
return (result_tensor,)

0 commit comments

Comments
 (0)