Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[1/N] add fp8 fp32 scale support for custom RL model #368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
yiakwy-xpu-ml-framework-team wants to merge 2 commits into antirez:main
base: main
Choose a base branch
Loading
from yiakwy-xpu-ml-framework-team:add_fp8_fp32_scale_support
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions download_model.sh
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
set -e

REPO="antirez/deepseek-v4-gguf"

Q2="DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2.gguf"

Q2_IMATRIX_FILE="DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.gguf"
Q4_IMATRIX_FILE="DeepSeek-V4-Flash-Q4KExperts-F16HC-F16Compressor-F16Indexer-Q8Attn-Q8Shared-Q8Out-chat-v2-imatrix.gguf"
Q2_Q4_IMATRIX_FILE="DeepSeek-V4-Flash-Layers37-42Q4KExperts-OtherExpertLayersIQ2XXSGateUp-Q2KDown-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix-fixed.gguf"
Expand Down Expand Up @@ -103,6 +106,7 @@ MODEL_FILES=
LINK_MODEL=1

case "$MODEL" in
q2) MODEL_FILE=$Q2 ;;
q2-imatrix) MODEL_FILE=$Q2_IMATRIX_FILE ;;
q2-q4-imatrix) MODEL_FILE=$Q2_Q4_IMATRIX_FILE ;;
q4-imatrix) MODEL_FILE=$Q4_IMATRIX_FILE ;;
Expand Down
54 changes: 48 additions & 6 deletions gguf-tools/deepseek4-quantize.c
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -681,20 +681,35 @@ static float *tensor_to_f32(const st_value *t, int64_t *n_out) {
}

static float *dequant_fp8_weight(const st_value *w, const st_value *scale, int64_t *n_out) {
if (strcmp(w->dtype, "F8_E4M3") != 0 || strcmp(scale->dtype, "F8_E8M0") != 0) die("bad FP8 weight/scale dtype");
if (strcmp(w->dtype, "F8_E4M3") != 0) die("bad FP8 weight dtype");
if (strcmp(scale->dtype, "F8_E8M0") != 0 && strcmp(scale->dtype, "F32") != 0) die("bad FP8 scale dtype, expected F8_E8M0 or F32");

if (w->n_dims != 2 || scale->n_dims != 2) die("FP8 tensor must be 2D");

const int64_t out_dim = w->shape[0];
const int64_t in_dim = w->shape[1];
const int64_t block_out = 128;
const int64_t block_in = 128;

if (out_dim % block_out || in_dim % block_in) die("FP8 dims are not divisible by 128");

const int64_t scale_rows = out_dim / block_out;
const int64_t scale_cols = in_dim / block_in;

if (scale->shape[0] != scale_rows || scale->shape[1] != scale_cols) die("FP8 scale shape mismatch");

float *out = xmalloc((size_t)out_dim * (size_t)in_dim * sizeof(float));
const int is_fp8_scale = (strcmp(scale->dtype, "F8_E8M0") == 0);
for (int64_t ob = 0; ob < scale_rows; ob++) {
for (int64_t ib = 0; ib < scale_cols; ib++) {
const float s = e8m0_to_f32(scale->data[(size_t)ob * (size_t)scale_cols + (size_t)ib]);
const size_t scale_idx = (size_t)ob * (size_t)scale_cols + (size_t)ib;
float s;
if (is_fp8_scale) {
s = e8m0_to_f32(scale->data[scale_idx]);
} else {
s = ((float *)scale->data)[scale_idx];
}

for (int64_t r = 0; r < block_out; r++) {
const int64_t row = ob * block_out + r;
const size_t base = (size_t)row * (size_t)in_dim + (size_t)ib * (size_t)block_in;
Expand Down Expand Up @@ -1131,6 +1146,13 @@ static byte_buf f32_to_type(const float *src, int64_t n, ds4q_type type, int64_t
}

static byte_buf i64_to_i32(const st_value *src) {
// TODO (yiakwy) : remove this redundant copy
if (strcmp(src->dtype, "I32") == 0) {
if (src->nbytes > SIZE_MAX) die("source too large for I32 conversion");
byte_buf out = { .size = src->nbytes, .data = xmalloc(src->nbytes) };
memcpy(out.data, src->data, src->nbytes);
return out;
};
if (strcmp(src->dtype, "I64") != 0) die("expected I64 source for I32 tensor");
const int64_t n = value_nelements(src);
if (src->nbytes != (size_t)n * sizeof(int64_t)) die("bad I64 byte size");
Expand Down Expand Up @@ -1179,11 +1201,21 @@ static byte_buf generate_regular(st_db *db, const char *gguf_name, const tensor_
if (!is_quantizable_target(target)) die("unsupported regular target type");
int64_t n = 0;
float *f32 = NULL;
if (strcmp(te->info.dtype, "F8_E4M3") == 0) {

bool should_dequant = true;
// NOTE (yiakwy) : for these patterns, we don't have fp8 scale
if (strstr(hf_name, "attn.indexer.weights_proj.weight") != NULL) {
should_dequant = false;
}

if (strcmp(te->info.dtype, "F8_E4M3") == 0 && should_dequant) {
if (!str_ends(hf_name, ".weight")) die("FP8 tensor without .weight suffix");
char *scale_name = xstrdup(hf_name);
strcpy(scale_name + strlen(scale_name) - strlen(".weight"), ".scale");
if (!db_has(db, scale_name)) die("missing FP8 scale tensor");
if (!db_has(db, scale_name)) {
fprintf(stderr, "missing fp8 scale %s for weight %s\n", scale_name, hf_name);
die("missing FP8 scale tensor");
}
st_value w = db_read(db, hf_name);
st_value s = db_read(db, scale_name);
f32 = dequant_fp8_weight(&w, &s, &n);
Expand Down Expand Up @@ -1230,9 +1262,19 @@ static void generate_one_expert(expert_job *j, int xid) {
snprintf(scale_name, sizeof(scale_name), "%s.scale", prefix);
st_value w = db_read(j->db, weight_name);
st_value s = db_read(j->db, scale_name);
if (w.n_dims != 2 || w.shape[0] != j->nrows || w.shape[1] * 2 != j->ncols) die("expert shape mismatch");
if (w.n_dims != 2 || w.shape[0] != j->nrows) die("expert shape mismatch");

int64_t n = 0;
float *f32 = dequant_fp4_weight(&w, &s, &n);
float *f32 = NULL;

if (w.shape[1] * 2 == j->ncols) {
dequant_fp4_weight(&w, &s, &n);
} else if (w.shape[1] == j->ncols) {
f32 = dequant_fp8_weight(&w, &s, &n);
} else {
die("expert shape mismatch");
}

const char *names[3] = { j->gguf_name, weight_name, NULL };
const float *imat = imatrix_find(j->imatrix, names, 2, j->ncols, xid, j->n_experts);
byte_buf q = f32_to_type(f32, n, j->target, j->ncols, imat);
Expand Down

AltStyle によって変換されたページ (->オリジナル) /