Skip to content

Commit 01eb0e3

Browse files
authored
Add docstrings build_utils.py, enforce consistent precision answers over a single run (#1331)
1 parent 9af34c1 commit 01eb0e3

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

torchchat/utils/build_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,16 +111,30 @@ def use_et_backend() -> bool:
111111
##########################################################################
112112
### set and get target precision for this model ###
113113

114-
precision = torch.float32
114+
precision = None
115115

116116

117117
def set_precision(dtype):
118+
"""set_precision() is a torchchat-internal API that records the dtype we're building the model for.
119+
The precision is recorded for future queries by get_precision(), so that when building a model,
120+
or performing optimizations, we can query the type the user is building the model for.
121+
This is an informational value that can be used when we want to know what type to build for (e.g., a kv cache).
122+
Changing the `precision` does not change the precision of the model.
123+
"""
124+
118125
global precision
126+
assert precision is None, "only set precision once to avoid inconsistent answers during different phases of model build and export"
119127
precision = dtype
120128

121129

122130
def get_precision():
131+
"""get_precision() is a torchchat-internal API that returns the dtype we're building the model for, as specified by the `--dtype` CLI option+,
132+
or the precision quantizer.
133+
"""
123134
global precision
135+
# if (and only if) precision has not been set, update it to the default value torch.float32
136+
if precision is None:
137+
precision = torch.float32
124138
return precision
125139

126140

0 commit comments

Comments
 (0)