Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 607ff7b

Browse files
drisspgfacebook-github-bot
authored andcommitted
Add in pre-commit config and some more CI/CD (#232)
Summary: Some quality of life changes Pull Request resolved: #232 Reviewed By: wanchaol Differential Revision: D54437609 Pulled By: drisspg fbshipit-source-id: 31c27a98695ee5c092b52d59c8520c844ad2a700
1 parent b9b37f8 commit 607ff7b

20 files changed

+198
-53
lines changed

.github/workflows/python-app.yml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Basic flak8 + pytest workflow for Python 3.10
2+
3+
name: Python Lint and Test
4+
5+
on:
6+
push:
7+
branches: [ "main" ]
8+
pull_request:
9+
branches: [ "main" ]
10+
11+
permissions:
12+
contents: read
13+
14+
jobs:
15+
build:
16+
17+
runs-on: ubuntu-latest
18+
19+
steps:
20+
- uses: actions/checkout@v3
21+
- name: Set up Python 3.10
22+
uses: actions/setup-python@v3
23+
with:
24+
python-version: "3.10"
25+
- name: Install dependencies
26+
run: |
27+
python -m pip install --upgrade pip
28+
pip install -e .
29+
pip install -e .'[dev]'
30+
pip install -e .'[test]'
31+
- name: Lint with ruff
32+
run: |
33+
ruff check .
34+
- name: Running Tests
35+
run: |
36+
./test/test_everything.sh

.pre-commit-config.yaml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
exclude: 'build'
2+
3+
default_language_version:
4+
python: python3
5+
6+
repos:
7+
- repo: https://github.com/pre-commit/pre-commit-hooks
8+
rev: 6306a48f7dae5861702d573c9c247e4e9498e867
9+
hooks:
10+
- id: trailing-whitespace
11+
- id: check-ast
12+
- id: check-merge-conflict
13+
- id: no-commit-to-branch
14+
args: ['--branch=main']
15+
- id: check-added-large-files
16+
args: ['--maxkb=500']
17+
- id: end-of-file-fixer
18+
exclude: '^(.*\.svg)$'
19+
20+
- repo: https://github.com/astral-sh/ruff-pre-commit
21+
# Ruff version.
22+
rev: v0.3.0
23+
hooks:
24+
# Run the linter.
25+
- id: ruff
26+
27+
- repo: https://github.com/omnilib/ufmt
28+
rev: v2.3.0
29+
hooks:
30+
- id: ufmt
31+
additional_dependencies:
32+
- black == 23.3.0
33+
- usort == 1.0.6

benchmarks/bench_linear_float8.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,7 @@ def wrapper(*args, **kwargs):
222222
print(data_pd_simple)
223223

224224
sweep_path = sweep_path.with_suffix(".csv")
225-
with open(sweep_path, mode="w") as file:
226-
data_pd.to_csv(sweep_path)
225+
data_pd.to_csv(sweep_path)
227226

228227

229228
def invoke_main() -> None:

benchmarks/bench_matmul.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def run(n_limit: Optional[int] = None):
6666
results = []
6767

6868
name_to_shapes = name_to_shapes_70b
69-
bsz_and_seq_len = ((4, 4096),)
7069
dtypes = torch.bfloat16, torch.float16
7170

7271
for idx, (dtype, (name, (K, N))) in enumerate(

benchmarks/bench_multi_gpu.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def fsdp_main(rank, world_size, args):
7676
base_dtype, input_global, compile = args
7777

7878
# basic distributed data sampling
79-
bsz_global = input_global.shape[0]
8079
assert B % world_size == 0
8180
bsz_local_start = int(rank / world_size * B)
8281
bsz_local_end = int((rank + 1) / world_size * B)

float8_experimental/distributed_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def _transform(t):
5858

5959
def _reduce_scatter(ctx: Any, input_: torch.Tensor):
6060
group = get_model_parallel_group()
61-
rank = torch.distributed.get_rank(group)
6261
world_size = torch.distributed.get_world_size(group)
6362

6463
assert input_.shape[0] % world_size == 0

float8_experimental/float8_linear_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
from enum import auto, Enum
99
from typing import List, Optional, Type
1010

11-
import float8_experimental.config as fp8_config
12-
1311
import torch
1412
import torch.distributed as dist
1513
import torch.nn as nn

float8_experimental/float8_ops.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
aten = torch.ops.aten
1616
c10d_functional = torch.ops.c10d_functional
17+
_c10d_functional = torch.ops._c10d_functional
1718
FLOAT8_OPS_TABLE: Dict[Any, Any] = {}
1819

1920

@@ -148,7 +149,12 @@ def autocast_to_copy(aten_op, args, kwargs=None):
148149
)
149150

150151

151-
@implements([c10d_functional.all_gather_into_tensor.default])
152+
@implements(
153+
[
154+
c10d_functional.all_gather_into_tensor.default,
155+
_c10d_functional.all_gather_into_tensor.default,
156+
]
157+
)
152158
def allgather_fp8(aten_op, args, kwargs=None):
153159
"""
154160
override funcol with FP8 handling
@@ -166,7 +172,7 @@ def allgather_fp8(aten_op, args, kwargs=None):
166172
return Float8Tensor(fp8_out, fp8_input._scale, fp8_input._orig_dtype)
167173

168174

169-
@implements([c10d_functional.wait_tensor.default])
175+
@implements([c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default])
170176
def wait_tensor_fp8(aten_op, args, kwargs=None):
171177
fp8_input = args[0]
172178
assert isinstance(fp8_input, Float8Tensor)

float8_experimental/float8_tensor.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,7 @@
77

88
import torch
99

10-
from float8_experimental.float8_utils import (
11-
tensor_to_amax,
12-
tensor_to_scale,
13-
to_fp8_saturated,
14-
)
10+
from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated
1511

1612
from torch.distributed._tensor import DTensor
1713

pyproject.toml

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ dependencies = [
2222
test = [
2323
"transformers==4.32.0",
2424
"pandas >= 2.0",
25-
"tqdm==4.66.1",
26-
"fire==0.5.0"
25+
"tqdm==4.66.2",
26+
"fire==0.5.0",
27+
"expecttest",
2728
]
2829
dev = [
2930
"black==23.3.0",
@@ -32,16 +33,62 @@ dev = [
3233
"libcst==1.0.1",
3334
"pytest==7.4.0",
3435
"bumpver",
35-
"pip-tools"
36+
"pip-tools",
37+
"ruff==0.3.0"
3638
]
37-
38-
# Since we have multiple top level folders we specify what we want to be included
39-
# in the package
40-
[tool.setuptools]
41-
packages = ["float8_experimental"]
42-
39+
# ---------- TOOL CONFIGURATIONS ------------
4340
[tool.usort]
4441
first_party_detection = false
4542

4643
[tool.black]
47-
target-version = ["py38"]
44+
target-version = ["py310"]
45+
46+
[tool.ruff]
47+
# Exclude a variety of commonly ignored directories.
48+
exclude = [
49+
".bzr",
50+
".direnv",
51+
".eggs",
52+
".git",
53+
".git-rewrite",
54+
".hg",
55+
".ipynb_checkpoints",
56+
".mypy_cache",
57+
".nox",
58+
".pants.d",
59+
".pyenv",
60+
".pytest_cache",
61+
".pytype",
62+
".ruff_cache",
63+
".svn",
64+
".tox",
65+
".venv",
66+
".vscode",
67+
"__pypackages__",
68+
"_build",
69+
"buck-out",
70+
"build",
71+
"dist",
72+
"node_modules",
73+
"site-packages",
74+
"venv",
75+
]
76+
77+
# Same as Black.
78+
line-length = 88
79+
indent-width = 4
80+
81+
# Assume Python 3.10
82+
target-version = "py310"
83+
84+
[tool.ruff.lint]
85+
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
86+
select = ["E4", "E7", "E9", "F"]
87+
ignore = ["E731"]
88+
89+
# Allow fix for all enabled rules (when `--fix`) is provided.
90+
fixable = ["ALL"]
91+
unfixable = []
92+
93+
# Allow unused variables when underscore-prefixed.
94+
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"

test/test_base.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
random.seed(0)
3636
torch.manual_seed(0)
3737

38+
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
39+
3840

3941
class TestFloat8Tensor(unittest.TestCase):
4042
def test_preserves_dtype(self) -> None:
@@ -114,13 +116,14 @@ def _test_linear_impl(
114116
), f"{buffer_name} not filled, current value {buffer_value}"
115117

116118
# verify initialization flags got updated
117-
assert m_fp8.is_amax_initialized == True
119+
assert m_fp8.is_amax_initialized, "Amax was not properly initialized"
118120

119-
@pytest.mark.parametrize("emulate", [True, False])
121+
@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
120122
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
121123
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
122124
@pytest.mark.parametrize("use_activation_hooks", [True, False])
123125
@pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed")
126+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
124127
def test_linear_nobias(
125128
self,
126129
x_shape,
@@ -142,14 +145,15 @@ def test_linear_nobias(
142145
m_ref = nn.Linear(16, 32, bias=False, device="cuda")
143146
self._test_linear_impl(x, m_ref, linear_type, emulate, use_activation_hooks)
144147

145-
@pytest.mark.parametrize("emulate", [True, False])
148+
@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
146149
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
147150
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
148151
@pytest.mark.parametrize(
149152
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
150153
)
151154
@pytest.mark.parametrize("use_activation_hooks", [True, False])
152155
@pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed")
156+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
153157
def test_linear_bias(
154158
self,
155159
x_shape,
@@ -172,13 +176,14 @@ def test_linear_bias(
172176
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype)
173177
self._test_linear_impl(x, m_ref, linear_type, emulate, use_activation_hooks)
174178

175-
@pytest.mark.parametrize("emulate", [True, False])
179+
@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
176180
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
177181
@pytest.mark.parametrize(
178182
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
179183
)
180184
@pytest.mark.parametrize("use_activation_hooks", [True, False])
181185
@pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed")
186+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
182187
def test_autocast_outputs(
183188
self,
184189
linear_type: LinearType,
@@ -225,31 +230,36 @@ def test_autocast_outputs(
225230
@pytest.mark.parametrize(
226231
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
227232
)
228-
def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype):
233+
@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
234+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
235+
def test_type_cast(
236+
self, linear_type: LinearType, linear_dtype: torch.dtype, emulate: bool
237+
):
229238
emulate = (
230239
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0)
231240
)
232241

233242
m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
234-
m = Float8Linear.from_float(m, emulate)
243+
m = get_float8_linear(linear_type, m, emulate, False)
235244

236245
# Cast the module to dtype
237246
m = m.to(dtype=linear_dtype)
238-
# Check amax buffer types
239-
for key in [
240-
"fp8_amax_x",
241-
"fp8_amax_history_x",
242-
"fp8_scale_x",
243-
"fp8_amax_w",
244-
"fp8_amax_history_w",
245-
"fp8_scale_w",
246-
"fp8_amax_dL_dY",
247-
"fp8_amax_history_dL_dY",
248-
"fp8_scale_dL_dY",
249-
]:
250-
assert (
251-
m._buffers[key].dtype == torch.float32
252-
), f"{key}.dtype is {m._buffers[key].dtype}, expected torch.float32"
247+
if linear_requires_sync(linear_type):
248+
# Check amax buffer types
249+
for key in [
250+
"fp8_amax_x",
251+
"fp8_amax_history_x",
252+
"fp8_scale_x",
253+
"fp8_amax_w",
254+
"fp8_amax_history_w",
255+
"fp8_scale_w",
256+
"fp8_amax_dL_dY",
257+
"fp8_amax_history_dL_dY",
258+
"fp8_scale_dL_dY",
259+
]:
260+
assert (
261+
m._buffers[key].dtype == torch.float32
262+
), f"{key}.dtype is {m._buffers[key].dtype}, expected torch.float32"
253263

254264
# autocast off
255265
x = torch.randn(16, 32, device="cuda", dtype=linear_dtype)
@@ -273,7 +283,7 @@ def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype):
273283

274284
class TestScaledMM:
275285
@unittest.skipIf(
276-
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0),
286+
not is_H100,
277287
"CUDA not available",
278288
)
279289
@pytest.mark.parametrize(
@@ -321,6 +331,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype):
321331

322332
class TestNumerics:
323333
@pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
334+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
324335
def test_small_amax_float16(self, float8_dtype):
325336
# If we calculate scale naively with FP8_MAX_POS / amax,
326337
# the result may not be representable in fp16. Verify that

0 commit comments

Comments
 (0)