Skip to content

Commit 3953175

Browse files
mikekgfbmalfet
authored andcommitted
macos12 full build (x86) (#125)
* macos12 full build (x86) * add support for setting precision via --dtype
1 parent a245940 commit 3953175

File tree

6 files changed

+29
-6
lines changed

6 files changed

+29
-6
lines changed

.github/workflows/compile.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
run-tinystories:
1212
strategy:
1313
matrix:
14-
runner: [ubuntu-latest, macos-14]
14+
runner: [ubuntu-latest, macos-14, macos-12]
1515
runs-on: ${{matrix.runner}}
1616
steps:
1717
- name: Checkout repo

export.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import torch.nn as nn
1313
from torch.export import Dim, export
1414

15+
from quantize import quantize_model, name_to_dtype, set_precision, get_precision
16+
1517
try:
1618
executorch_export_available = True
1719
from export_et import export_model as export_model_et
@@ -62,8 +64,9 @@ def main(checkpoint_path, device, quantize = "{ }", args = None):
6264
assert checkpoint_path.is_file(), checkpoint_path
6365

6466
print(f"Using device={device}")
65-
precision = torch.float # bfloat16
66-
67+
precision = name_to_dtype(args.dtype) # torch.float # bfloat16
68+
set_precision(precision)
69+
6770
print("Loading model ...")
6871
t0 = time.time()
6972
model = _load_model(

generate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch._dynamo.config
1414
import torch._inductor.config
1515

16-
from quantize import quantize_model, name_to_dtype
16+
from quantize import quantize_model, name_to_dtype, set_precision, get_precision
1717

1818

1919
def device_sync(device):
@@ -344,7 +344,8 @@ def main(
344344
# print = lambda *args, **kwargs: None
345345

346346
print(f"Using device={device}")
347-
precision = torch.float # bfloat16
347+
precision = name_to_dtype(model_dtype)
348+
set_precision(precision)
348349
is_speculative = draft_checkpoint_path is not None
349350
is_chat = "chat" in str(checkpoint_path)
350351

model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch import Tensor
1212
from torch.nn import functional as F
1313

14+
from quantize import get_precision
1415

1516
def find_multiple(n: int, k: int) -> int:
1617
if n % k == 0:
@@ -99,8 +100,11 @@ def from_name(cls, name: str):
99100

100101
class KVCache(nn.Module):
101102
def __init__(
102-
self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.float): # bfloat16 ):
103+
self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=None):
104+
# torch.float): # bfloat16 ):
103105
super().__init__()
106+
if not dtype:
107+
dtype=get_precision()
104108
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
105109
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
106110
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
File renamed without changes.

quantize.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@
2323
##########################################################################
2424
### dtype name to torch.dtype mapping ###
2525

26+
precision = torch.float
27+
28+
def set_precision(dtype):
29+
global precision
30+
precision = dtype
31+
32+
def get_precision():
33+
global precision
34+
return precision
35+
2636
def name_to_dtype(name):
2737
if name in name_to_dtype_dict:
2838
return name_to_dtype_dict[name]
@@ -33,6 +43,11 @@ def name_to_dtype(name):
3343
"fp32" : torch.float,
3444
"fp16" : torch.float16,
3545
"bf16" : torch.bfloat16,
46+
"float" : torch.float,
47+
"half" : torch.float16,
48+
"float32" : torch.float,
49+
"float16" : torch.float16,
50+
"bfloat16" : torch.bfloat16,
3651
}
3752

3853
##########################################################################

0 commit comments

Comments
 (0)