Skip to content

support reduce op with fast implementation #314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions scripts/correctness.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ python3 -m benchgc --verbose 0 --driver linalg --case matmul --md 0:32x128xbf16

# f32

# reduce

python3 -m benchgc --verbose 0 --driver linalg --case reduce.add --md 0:128x64x8xf32 --md 1:128xf32 --dimensions=1 --dimensions=2 || FAIL=1
python3 -m benchgc --verbose 0 --driver linalg --case reduce.mul --md 0:128x8xf32 --md 1:128xf32 --dimensions=1 || FAIL=1
python3 -m benchgc --verbose 0 --driver linalg --case reduce.max --md 0:128x64x8xf32 --md 1:128xf32 --dimensions=1 --dimensions=2 || FAIL=1
python3 -m benchgc --verbose 0 --driver linalg --case reduce.min --md 0:128x64x8xf32 --md 1:128xf32 --dimensions=1 --dimensions=2 || FAIL=1
python3 -m benchgc --verbose 0 --driver linalg --case reduce.l1 --md 0:128x64x8xf32 --md 1:128xf32 --dimensions=1 --dimensions=2 || FAIL=1
python3 -m benchgc --verbose 0 --driver linalg --case reduce.l2_square --md 0:128x64x8xf32 --md 1:128xf32 --dimensions=1 --dimensions=2 || FAIL=1

# misc
python3 -m benchgc --verbose 0 --driver linalg --case fill --md 0:f32 --md 1:32x4096xf32 --cmp 1:P:0:0 || FAIL=1
python3 -m benchgc --verbose 0 --driver linalg --case copy --md 0:1024x1024xf32 --md 1:1024x1024xbf16 || FAIL=1
Expand Down Expand Up @@ -92,9 +101,8 @@ python3 -m benchgc --verbose 0 --driver linalg --case pooling_nwc_max --md 0:4x3
python3 -m benchgc --verbose 0 --driver linalg --case pooling_nwc_sum --md 0:4x32x4xf32 --md 1:4xf32 --md 2:4x13x4xf32 --strides 2 --dilations 2 || FAIL=1
python3 -m benchgc --verbose 0 --driver linalg --case pooling_nwc_min --md 0:4x32x4xf32 --md 1:4xf32 --md 2:4x13x4xf32 --strides 2 --dilations 2 || FAIL=1

# generic / reduce
# generic
python3 -m benchgc --verbose 0 --driver mlir --case ${CASE_DIR}/generic.mlir || FAIL=1
python3 -m benchgc --verbose 0 --driver mlir --case ${CASE_DIR}/reduce.mlir || FAIL=1

# softmax
# python3 -m benchgc --verbose 0 --driver linalg --case softmax --md 0:32x4096xf32 --md 1:32x4096xf32 --dimension 1 || FAIL=1
Expand Down
1 change: 1 addition & 0 deletions test/benchgc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@ add_subdirectory("src/benchgc/linalg")
add_subdirectory("src/benchgc/tensor")
add_subdirectory("src/benchgc/arith")
add_subdirectory("src/benchgc/pattern")
add_subdirectory("src/benchgc/math")
12 changes: 0 additions & 12 deletions test/benchgc/cases/reduce.mlir

This file was deleted.

18 changes: 11 additions & 7 deletions test/benchgc/src/benchgc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def add_pattern_options(parser: argparse.ArgumentParser):
get_pattern_clz(pattern_name).add_args(parser)


def get_module_and_args(flags):
def get_module_and_args(flags: argparse.Namespace):
args: List[Arg] = []
if flags.driver in ["mlir", "pattern"]:
# we need to find all args by reading the entry function
Expand All @@ -203,6 +203,8 @@ def get_module_and_args(flags):
elif flags.driver == "pattern":
pattern_clz = get_pattern_clz(flags.case)
module = pattern_clz(ctx, flags).ir_module
else:
raise Exception("unexpected error")

entry = benchgc.mlir.util.get_kernel_func_from_module(module, flags.entry)
idx: int = 0
Expand Down Expand Up @@ -235,7 +237,10 @@ def get_module_and_args(flags):

from .linalg import mlir_op

mlir_func = mlir_op[flags.case]
if flags.case.startswith("reduce."):
mlir_func = mlir_op["reduce"]
else:
mlir_func = mlir_op[flags.case]
module = mlir_func(flags, args)
else:
raise Exception(f"unsupported driver {flags.driver}")
Expand Down Expand Up @@ -269,7 +274,7 @@ def get_module_and_args(flags):
return module, args


def correctness_testing(flags, module, args):
def correctness_testing(flags: argparse.Namespace, module: ir.Module, args: List[Arg]):
ref_args: List[torch.Tensor] = []
gc_args: List[torch.Tensor | int] = []
ref_tensors: Dict[str, torch.Tensor] = {}
Expand All @@ -290,9 +295,8 @@ def correctness_testing(flags, module, args):
ref_out = runner.ref_run(entry, ref_tensors)

# we need to swap the result into the args if some arg is the return value
if ref_out is not None:
for i in range(len(ref_out)):
ref_args[0 - i - 1] = ref_out[0 - i - 1]
for i in range(len(ref_out)):
ref_args[0 - i - 1] = ref_out[0 - i - 1]

mlir_args = get_mlir_args(gc_args)
passes = "any(gc-cpu-pipeline)"
Expand Down Expand Up @@ -323,7 +327,7 @@ def correctness_testing(flags, module, args):
print(f"PASSED: {flags.driver}.{flags.case}")


def performance_testing(flags, module, args):
def performance_testing(flags: argparse.Namespace, module: ir.Module, args: List[Arg]):
gc_args: List[torch.Tensor | int] = []
gc_tensors: Dict[str, torch.Tensor] = {}
for i in range(len(args)):
Expand Down
2 changes: 2 additions & 0 deletions test/benchgc/src/benchgc/arg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import benchgc.arg.eltwise as eltwise
import benchgc.arg.matmul as matmul
import benchgc.arg.pool as pool
import benchgc.arg.reduce as reduce
import benchgc.arg.softmax as softmax
import benchgc.util
import torch
Expand All @@ -36,6 +37,7 @@
"softmax": softmax,
"conv": conv,
"pool": pool,
"reduce": reduce,
}


Expand Down
57 changes: 45 additions & 12 deletions test/benchgc/src/benchgc/arg/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,42 @@
# limitations under the License.
################################################################################

from typing import List, Tuple
import argparse
from typing import List, Set, Tuple

import benchgc.arg
import benchgc.util
import torch

from benchgc.arg.arg import Arg
from benchgc.arg.compare import p2p

op: Set[str] = set(
[
"linalg.reduce.add",
"linalg.reduce.mul",
"linalg.reduce.max",
"linalg.reduce.min",
"linalg.reduce.l1",
"linalg.reduce.l2_square",
]
)


def default_fill(
flags: argparse.Namespace,
arg: Arg,
arglist: List[Arg],
):
if arg.index > 0:
raise Exception("reduce fill: dst filling is not allowed")
arg.fill_param = [
"reduce",
flags.case,
arglist[0].dtype,
arglist[1].dtype,
str(arglist[0].nelem() // arglist[1].nelem()),
]
arg.fill_type = "D"

def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tensor:

Expand All @@ -30,22 +60,17 @@ def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tenso

safe_to_reduce_elems: int = benchgc.util.get_problem_bounds(op, sdtype)[0]

neutral_value: float = 1.0 if op == "mul" else 0.0
neutral_value: float = 1.0 if op == "reduce.mul" else 0.0

shift: float = (
1.0
if (
op == "mean"
or op == "min"
and not sdtype.is_signed
and not ddtype.is_signed
)
if (op == "reduce.min" and not sdtype.is_signed and not ddtype.is_signed)
else 0.0
)

value_range: int = benchgc.util.get_problem_bounds(op, sdtype)[1]

is_mul_fp: bool = op == "mul" and sdtype.is_floating_point
is_mul_fp: bool = op == "reduce.mul" and sdtype.is_floating_point
min_range: int = -value_range if is_mul_fp else 1

index = torch.arange(benchgc.util.nelem(shape)).reshape(shape)
Expand All @@ -69,10 +94,18 @@ def fill(shape: List[int], dtype: torch.dtype, params: List[str]) -> torch.Tenso
return value.to(dtype)


def default_compare(
flags: argparse.Namespace,
arg: Arg,
arglist: List[Arg],
):
arg.cmp_type = "D"
arg.cmp_param = ["reduce", arg.dtype, flags.case]

def compare(
ref: torch.Tensor, res: torch.Tensor, verbose: int
param: List[str], ref: torch.Tensor, res: torch.Tensor, verbose: int
) -> Tuple[bool, bool | None]:
dtype = ref.dtype
ref = ref.to(torch.float)
res = res.to(torch.float)
return benchgc.arg.p2p(benchgc.util.get_eps(dtype), 30.0, ref, res, verbose)
return p2p(benchgc.util.get_eps(dtype), 30.0, ref, res, verbose)
50 changes: 49 additions & 1 deletion test/benchgc/src/benchgc/arith/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Dict, Tuple

import benchgc.util
import gc_mlir._mlir_libs._mlir.ir
import torch
from benchgc.mlir.util import MLIRCache
from gc_mlir import ir
Expand All @@ -42,6 +41,19 @@ def ref_constant(
)
else:
raise Exception("only support splat value now")
elif isinstance(value, ir.IntegerAttr):
return (torch.full(size=tuple(), fill_value=value.__int__(), dtype=torch.int),)
elif isinstance(value, ir.DenseIntElementsAttr):
if value.is_splat:
return (
torch.full(
size=tuple(value.type.shape),
fill_value=value.get_splat_value().value,
dtype=benchgc.util.get_dtype(str(value.get_splat_value().type)),
),
)
else:
raise Exception("only support splat value now")
else:
raise Exception("Not support constant type %s", type(value))

Expand All @@ -56,3 +68,39 @@ def ref_addf(
cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, ...]:
return (var[cache.opr[0]] + var[cache.opr[1]],)


def ref_maxf(
cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, ...]:
return (torch.max(var[cache.opr[0]], var[cache.opr[1]]),)


def ref_minf(
cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, ...]:
return (torch.min(var[cache.opr[0]], var[cache.opr[1]]),)


def ref_muli(
cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, ...]:
return (var[cache.opr[0]] * var[cache.opr[1]],)


def ref_addi(
cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, ...]:
return (var[cache.opr[0]] + var[cache.opr[1]],)


def ref_maxsi(
cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, ...]:
return (torch.max(var[cache.opr[0]], var[cache.opr[1]]),)


def ref_minsi(
cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, ...]:
return (torch.min(var[cache.opr[0]], var[cache.opr[1]]),)
18 changes: 9 additions & 9 deletions test/benchgc/src/benchgc/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import ctypes
import random
import timeit
from typing import List, Tuple
from typing import Any, List, Tuple

import numpy as np
from benchgc.mlir.util import (
Expand All @@ -34,10 +34,10 @@ def py_timeit_bench(
ir_module: ir.Module,
entry_name: str,
pipeline: str,
mlir_args: list,
ir_printing=False,
repeat_time=100,
warm_up=20,
mlir_args: List[Any],
ir_printing: bool = False,
repeat_time: int = 100,
warm_up: int = 20,
) -> Tuple[float, float]:
"""benchmark mlir with python timeit."""
compiler = GraphCompiler(pipeline)
Expand All @@ -64,10 +64,10 @@ def mlir_wrapper_bench(
ir_module: ir.Module,
entry_name: str,
pipeline: str,
mlir_args: list,
ir_printing=False,
repeat_time=100,
warm_up=20,
mlir_args: List[Any],
ir_printing: bool = False,
repeat_time: int = 100,
warm_up: int = 20,
) -> Tuple[float, float]:
"""benchmark mlir with a wrapper func."""
kernel_func = get_kernel_func_from_module(ir_module, entry_name)
Expand Down
1 change: 1 addition & 0 deletions test/benchgc/src/benchgc/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"softmax",
"conv",
"pool",
"reduce",
]:
mod = importlib.import_module(f"benchgc.linalg.{dri}")
for key in mod.__dict__:
Expand Down
Loading