Skip to content

add fp8 e3m4 support #559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 79 additions & 72 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -609,87 +609,72 @@ float bf16_to_f32(uint16_t bfloat16) {
return *reinterpret_cast<float*>(&val_bits);
}

uint16_t f8_e4m3_to_f16(uint8_t f8) {
// do we need to support uz?

const uint32_t exponent_bias = 7;
if (f8 == 0xff) {
return ggml_fp32_to_fp16(-NAN);
} else if (f8 == 0x7f) {
return ggml_fp32_to_fp16(NAN);
uint16_t f8_e3m4_to_f16(uint8_t fp8) {
if ((fp8 & 0x7F) == 0 || (fp8 & 0x7F) == 0x7F) {
// +/- 0 or NaN
return static_cast<uint16_t>(fp8) << 8;
}
const uint8_t exponent_bias = 0x3; // 2^(3-1)-1
const uint8_t f16_bias = 0xF; // 2^(5-1)-1
const int mantissa_bits = 4;
const uint8_t mantissa_max = 0xF; // 2^4-1

uint32_t sign = f8 & 0x80;
uint32_t exponent = (f8 & 0x78) >> 3;
uint32_t mantissa = f8 & 0x07;
uint32_t result = sign << 24;
if (exponent == 0) {
if (mantissa > 0) {
exponent = 0x7f - exponent_bias;

// yes, 2 times
if ((mantissa & 0x04) == 0) {
mantissa &= 0x03;
mantissa <<= 1;
exponent -= 1;
}
if ((mantissa & 0x04) == 0) {
mantissa &= 0x03;
mantissa <<= 1;
exponent -= 1;
}
uint8_t sign = (fp8 >> 7) & 0x1;
uint8_t exponent = (fp8 >> mantissa_bits) & (0x7F >> mantissa_bits);
uint8_t mantissa = fp8 & mantissa_max;

result |= (mantissa & 0x03) << 21;
result |= exponent << 23;
uint16_t fp16_sign = sign << 15;
uint16_t fp16_exponent = (exponent + (f16_bias - exponent_bias));
if (exponent == 0) {
// subnormal numbers
fp16_exponent++;
// mantissa != 0 because (fp8 & 0x7F) != 0 && exponent == 0
while (!(mantissa >> mantissa_bits)) {
mantissa <<= 1;
fp16_exponent--;
}
} else {
result |= mantissa << 20;
exponent += 0x7f - exponent_bias;
result |= exponent << 23;
mantissa &= mantissa_max;
}
uint16_t fp16_mantissa = mantissa << 6;

return ggml_fp32_to_fp16(*reinterpret_cast<const float*>(&result));
return fp16_sign | fp16_exponent << 10 | fp16_mantissa;
}

uint16_t f8_e5m2_to_f16(uint8_t fp8) {
uint8_t sign = (fp8 >> 7) & 0x1;
uint8_t exponent = (fp8 >> 2) & 0x1F;
uint8_t mantissa = fp8 & 0x3;

uint16_t fp16_sign = sign << 15;
uint16_t fp16_exponent;
uint16_t fp16_mantissa;

if (exponent == 0 && mantissa == 0) { // zero
return fp16_sign;
uint16_t f8_e4m3_to_f16(uint8_t fp8) {
// do we need to support uz?
if ((fp8 & 0x7F) == 0 || (fp8 & 0x7F) == 0x7F) {
// +/- 0 or NaN
return static_cast<uint16_t>(fp8) << 8;
}
const uint8_t exponent_bias = 0x7; // 2^(4-1)-1
const uint8_t f16_bias = 0xF; // 2^(5-1)-1
const int mantissa_bits = 3;
const uint8_t mantissa_max = 0x7; // 2^3-1

if (exponent == 0x1F) { // NAN and INF
fp16_exponent = 0x1F;
fp16_mantissa = mantissa ? (mantissa << 8) : 0;
return fp16_sign | (fp16_exponent << 10) | fp16_mantissa;
}
uint8_t sign = (fp8 >> 7) & 0x1;
uint8_t exponent = (fp8 >> mantissa_bits) & (0x7F >> mantissa_bits);
uint8_t mantissa = fp8 & mantissa_max;

if (exponent == 0) { // subnormal numbers
fp16_exponent = 0;
fp16_mantissa = (mantissa << 8);
return fp16_sign | fp16_mantissa;
uint16_t fp16_sign = sign << 15;
uint16_t fp16_exponent = (exponent + (f16_bias - exponent_bias));
if (exponent == 0) {
// subnormal numbers
fp16_exponent++;
// mantissa != 0 because (fp8 & 0x7F) != 0 && exponent == 0
while (!(mantissa >> mantissa_bits)) {
mantissa <<= 1;
fp16_exponent--;
}
mantissa &= mantissa_max;
}
uint16_t fp16_mantissa = mantissa << 7;

// normal numbers
int16_t true_exponent = (int16_t)exponent - 15 + 15;
if (true_exponent <= 0) {
fp16_exponent = 0;
fp16_mantissa = (mantissa << 8);
} else if (true_exponent >= 0x1F) {
fp16_exponent = 0x1F;
fp16_mantissa = 0;
} else {
fp16_exponent = (uint16_t)true_exponent;
fp16_mantissa = mantissa << 8;
}
return fp16_sign | fp16_exponent << 10 | fp16_mantissa;
}

return fp16_sign | (fp16_exponent << 10) | fp16_mantissa;
uint16_t f8_e5m2_to_f16(uint8_t fp8) {
// do we need to support fnuz?
return static_cast<uint16_t>(fp8) << 8;
}

void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) {
Expand All @@ -699,6 +684,13 @@ void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) {
}
}

void f8_e3m4_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) {
// support inplace op
for (int64_t i = n - 1; i >= 0; i--) {
dst[i] = f8_e3m4_to_f16(src[i]);
}
}

void f8_e4m3_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) {
// support inplace op
for (int64_t i = n - 1; i >= 0; i--) {
Expand Down Expand Up @@ -946,6 +938,8 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
ttype = GGML_TYPE_F32;
} else if (dtype == "F32") {
ttype = GGML_TYPE_F32;
} else if (dtype == "F8_E3M4") {
ttype = GGML_TYPE_F16;
} else if (dtype == "F8_E4M3") {
ttype = GGML_TYPE_F16;
} else if (dtype == "F8_E5M2") {
Expand Down Expand Up @@ -1059,6 +1053,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
if (dtype == "BF16") {
tensor_storage.is_bf16 = true;
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
} else if (dtype == "F8_E3M4") {
tensor_storage.is_f8_e3m4 = true;
// f8 -> f16
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
} else if (dtype == "F8_E4M3") {
tensor_storage.is_f8_e4m3 = true;
// f8 -> f16
Expand Down Expand Up @@ -1461,10 +1459,10 @@ SDVersion ModelLoader::get_sd_version() {
TensorStorage token_embedding_weight, input_block_weight;
bool input_block_checked = false;

bool has_multiple_encoders = false;
bool is_unet = false;
bool has_multiple_encoders = false;
bool is_unet = false;

bool is_xl = false;
bool is_xl = false;
bool is_flux = false;

#define found_family (is_xl || is_flux)
Expand All @@ -1481,7 +1479,7 @@ SDVersion ModelLoader::get_sd_version() {
}
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) {
is_unet = true;
if(has_multiple_encoders){
if (has_multiple_encoders) {
is_xl = true;
if (input_block_checked) {
break;
Expand All @@ -1490,7 +1488,7 @@ SDVersion ModelLoader::get_sd_version() {
}
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos) {
has_multiple_encoders = true;
if(is_unet){
if (is_unet) {
is_xl = true;
if (input_block_checked) {
break;
Expand Down Expand Up @@ -1779,6 +1777,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
if (tensor_storage.is_bf16) {
// inplace op
bf16_to_f32_vec((uint16_t*)dst_tensor->data, (float*)dst_tensor->data, tensor_storage.nelements());
} else if (tensor_storage.is_f8_e3m4) {
// inplace op
f8_e3m4_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements());
} else if (tensor_storage.is_f8_e4m3) {
// inplace op
f8_e4m3_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements());
Expand All @@ -1793,6 +1794,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
if (tensor_storage.is_bf16) {
// inplace op
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
} else if (tensor_storage.is_f8_e3m4) {
// inplace op
f8_e3m4_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
} else if (tensor_storage.is_f8_e4m3) {
// inplace op
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
Expand All @@ -1811,6 +1815,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
if (tensor_storage.is_bf16) {
// inplace op
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
} else if (tensor_storage.is_f8_e3m4) {
// inplace op
f8_e3m4_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
} else if (tensor_storage.is_f8_e4m3) {
// inplace op
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
Expand Down
5 changes: 4 additions & 1 deletion model.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ struct TensorStorage {
std::string name;
ggml_type type = GGML_TYPE_F32;
bool is_bf16 = false;
bool is_f8_e3m4 = false;
bool is_f8_e4m3 = false;
bool is_f8_e5m2 = false;
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
Expand Down Expand Up @@ -120,7 +121,7 @@ struct TensorStorage {
}

int64_t nbytes_to_read() const {
if (is_bf16 || is_f8_e4m3 || is_f8_e5m2) {
if (is_bf16 || is_f8_e3m4 || is_f8_e4m3 || is_f8_e5m2) {
return nbytes() / 2;
} else {
return nbytes();
Expand Down Expand Up @@ -168,6 +169,8 @@ struct TensorStorage {
const char* type_name = ggml_type_name(type);
if (is_bf16) {
type_name = "bf16";
} else if (is_f8_e3m4) {
type_name = "f8_e3m4";
} else if (is_f8_e4m3) {
type_name = "f8_e4m3";
} else if (is_f8_e5m2) {
Expand Down
Loading