Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Ensure DTensor wraps inner float8 Tensors #224

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ def forward(

@staticmethod
def backward(ctx, gradY):
fp8_tensor = to_fp8_no_autograd(gradY, torch.float8_e5m2, ctx.emulate)
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
fp8_tensor = to_fp8_no_autograd(
gradY, gradY_scale, torch.float8_e5m2, ctx.emulate
)
return fp8_tensor, None


Expand Down
7 changes: 6 additions & 1 deletion float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,14 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):

@implements([aten.mm.default])
def float8_mm(aten_op, args, kwargs=None):
assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)
a = args[0]
b = args[1]

assert isinstance(a, Float8Tensor) and isinstance(
b, Float8Tensor
), "Expecting both Float8Tensor for mm inputs but found {} and {}".format(
type(a), type(b)
)
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
output_dtype = a._orig_dtype
if a._emulate:
Expand Down
150 changes: 118 additions & 32 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict
from typing import Dict, Optional

import torch

Expand All @@ -13,9 +13,91 @@
to_fp8_saturated,
)

from torch.distributed._tensor import DTensor

aten = torch.ops.aten


def to_fp8_no_autograd(
x: torch.Tensor, x_scale: torch.Tensor, float8_dtype: torch.dtype, emulate: bool
) -> "Float8Tensor":
"""Convert a tensor to float8 without autograd
This is used in multiple places in the codebase to convert a tensor to float8

This function will apply the scaling, and then convert to a Float8Tensor

Note:
We will call this function with a DTensor subclass. Ideally this would be an aten OP
that DTensor could overload to ensure proper semantics. There are some techincal issues
with that composing with FakeTensor, so we special case here.

DTensor Invariant: DTensor must always be the outer most tensor subclass

Args:
x: the tensor to convert
scale: the scale to use to convert the tensor
float8_dtype: the float8 dtype to use
emulate: whether to emulate the matmuls in fp32
"""
x_scaled = x * x_scale
bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype)

if isinstance(bits_fp8, DTensor):
assert isinstance(
x, DTensor
), "Expected Float8 scale to be a DTensor if bits_fp8 is a DTensor"
bits_mesh = bits_fp8.device_mesh
bits_placements = bits_fp8.placements
local_bits = bits_fp8.to_local()
local_scale = x_scale.to_local()
inner_float8_tensor = Float8Tensor(
local_bits, local_scale, x.dtype, emulate=emulate
)
return DTensor.from_local(
inner_float8_tensor,
bits_mesh,
bits_placements,
run_check=False,
shape=bits_fp8.size(),
stride=bits_fp8.stride(),
)

return Float8Tensor(bits_fp8, x_scale, x.dtype, emulate=emulate)


def from_fp8_no_autograd(x: torch.Tensor) -> torch.Tensor:
"""Convert a tensor from float8 without autograd

This function will handle 3 cases:
1. If the tensor is a DTensor, it will convert the inner tensor to the original precision
2. If the tensor is a Float8Tensor, it will convert the tensor to the original precision
3. If the tensor is a regular tensor, it will pass through this tensor

Args:
x: the tensor to convert
"""

def to_original_precision(grad):
if isinstance(grad, Float8Tensor):
return grad.to_original_precision()
else:
return grad

if isinstance(x, DTensor):
local_grad = x.to_local()
original_precision_grad = to_original_precision(local_grad)
return DTensor.from_local(
original_precision_grad,
x.device_mesh,
x.placements,
run_check=False,
shape=x.size(),
stride=x.stride(),
)
else:
return to_original_precision(x)


@torch._dynamo.allow_in_graph
class ToFloat8ConstrFunc(torch.autograd.Function):
"""
Expand All @@ -25,25 +107,29 @@ class ToFloat8ConstrFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
tensor,
scale: float,
tensor: torch.Tensor,
scale: torch.Tensor,
float8_dtype=torch.float8_e4m3fn,
amax_buffer=None,
amax_buffer: Optional[torch.Tensor] = None,
emulate: bool = False,
):
"""Autograd enabled wrapper around to_fp8_no_autograd that will also populate the amax buffer.
Args
tensor: the tensor to convert
scale: the scale to use to convert the tensor
float8_dtype: the float8 dtype either, torch.float8_e4m3fn or torch.float8_e5m2fn
amax_buffer: an Optional buffer buffer to store the amax value in prior to conversion
emulate: whether to emulate the matmuls in fp32
"""
if amax_buffer is not None:
amax_buffer.fill_(tensor_to_amax(tensor))

tensor_scaled = tensor * scale
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
return Float8Tensor(bits_fp8, scale, tensor.dtype, emulate=emulate)
return to_fp8_no_autograd(tensor, scale, float8_dtype, emulate)

@staticmethod
def backward(ctx, g):
if isinstance(g, Float8Tensor):
return g.to_original_precision(), None, None, None, None
else:
return g, None, None, None, None
grad = from_fp8_no_autograd(g)
return grad, None, None, None, None


@torch._dynamo.allow_in_graph
Expand Down Expand Up @@ -95,7 +181,11 @@ def __new__(
orig_dtype: torch.dtype,
emulate=False,
):
assert scale.numel() == 1
assert (
scale.numel() == 1
), "Scale should contain a single value, but got: {} elements".format(
scale.numel()
)

self = torch.Tensor._make_wrapper_subclass(
cls,
Expand Down Expand Up @@ -138,7 +228,13 @@ def to_original_precision(self):

@staticmethod
@torch._dynamo.allow_in_graph
def to_float8(tensor, scale, float8_dtype, amax_buffer=None, emulate: bool = False):
def to_float8(
tensor: torch.Tensor,
scale: torch.Tensor,
float8_dtype: torch.dtype,
amax_buffer: Optional[torch.Tensor] = None,
emulate: bool = False,
):
"""Converts a higher precision tensor to float8 in a differentiable way.

Args:
Expand Down Expand Up @@ -168,28 +264,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
# Lazy import to avoid circular dependency
from float8_experimental.float8_ops import FLOAT8_OPS_TABLE

# All ops in the FLOAT8_OPS_TABLE expect Float8Tensor as inputs
# And don't support mixed tensor subclasses. This will trigger the handler for
# the next type in the dispatch list. torch._C._TensorMeta is for FakeTensor
def allowed_subclasses(type):
return issubclass(cls, type) or isinstance(type, torch._C._TensorMeta)

if not all(allowed_subclasses(t) for t in types):
return NotImplemented

if func in FLOAT8_OPS_TABLE:
return FLOAT8_OPS_TABLE[func](func, args, kwargs)
raise NotImplementedError(f"attempting to run {func}, this is not supported")

# Do not force the Float8Tensor type on the returned tensor
__torch_function__ = torch._C._disabled_torch_function_impl


def to_fp8_no_autograd(
x: torch.Tensor, float8_dtype: torch.dtype, emulate: bool
) -> Float8Tensor:
"""Convert a tensor to float8 without autograd
This is used in multiple places in the codebase to convert a tensor to float8

This function will calculate the scale, do the scaling, and then convert to a Float8Tensor
Args:
x: the tensor to convert
scale: the scale to use to convert the tensor
float8_dtype: the float8 dtype to use
emulate: whether to emulate the matmuls in fp32
"""
x_scale = tensor_to_scale(x, float8_dtype)
x_scaled = x * x_scale
bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype)
return Float8Tensor(bits_fp8, x_scale, x.dtype, emulate=emulate)
62 changes: 59 additions & 3 deletions test/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@
import torch
import torch.nn as nn

from float8_experimental.float8_dynamic_linear import NoopFwToFloat8E5M2Bw
from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_utils import tensor_to_scale
from torch.distributed._tensor import DTensor, Replicate, Shard
from torch.distributed import init_process_group
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.testing._internal.distributed.fake_pg import FakeStore
from tqdm import tqdm


def setup_distributed():
Expand Down Expand Up @@ -92,10 +96,62 @@ def test_fp8_redistribute(mesh: DeviceMesh, size=16):
assert out_local._data.dtype == fp8_dtype


def test_dtensor_cast_to_fp8(mesh: DeviceMesh, size=16):
device = mesh.device_type
fp8_dtype = torch.float8_e4m3fn

x_fp32 = torch.rand(size, size, device=device)
dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)])

dist_x_scale = tensor_to_scale(dist_x_fp32, fp8_dtype).float()
assert isinstance(dist_x_scale, DTensor)

dist_x_fp8 = Float8Tensor.to_float8(dist_x_fp32, dist_x_scale, fp8_dtype)
assert isinstance(dist_x_fp8, DTensor)


def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
device = mesh.device_type
fp8_dtype = torch.float8_e4m3fn

x_fp32 = torch.rand(size, size, device=device, requires_grad=True)
local_weight = torch.rand(2 * size, size, device=device, requires_grad=True)
target = torch.rand(size, 2 * size, device=device)

dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)])
dist_x_scale = tensor_to_scale(dist_x_fp32, fp8_dtype).float()

dist_wight_fp32 = distribute_tensor(local_weight, mesh, [Shard(0)])
dist_weight_scale = tensor_to_scale(dist_wight_fp32, fp8_dtype).float()
dist_target = distribute_tensor(target, mesh, [Shard(0)])

dist_x_fp8 = Float8Tensor.to_float8(dist_x_fp32, dist_x_scale, fp8_dtype)
dist_weight_fp8 = Float8Tensor.to_float8(
dist_wight_fp32, dist_weight_scale, fp8_dtype
)

out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8)
out = NoopFwToFloat8E5M2Bw.apply(out, False)
assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}"
loss = torch.sum(torch.abs(out - dist_target))
loss.backward()


if __name__ == "__main__":
# float8 only works on CUDA H100 so we only test cuda and we follow
# other test files to not use TestCase but instead just add the test
# cases in the main func.
device_mesh = setup_distributed()
test_scaled_mm(device_mesh)
test_fp8_redistribute(device_mesh)
tests = [
test_scaled_mm,
test_fp8_redistribute,
test_dtensor_cast_to_fp8,
test_dtensor_fp8_autograd,
]

for test in tqdm(tests, desc="Running tests"):
try:
test(device_mesh)
except Exception as e:
print(f"Test {test.__name__} failed with error: {e}")
raise e