Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 35843c7

Browse files
authored
fix: optimize the handling of embedding weight (#859)
1 parent 6ad46bb commit 35843c7

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

‎clip.hpp‎

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -553,10 +553,9 @@ class CLIPEmbeddings : public GGMLBlock {
553553
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
554554
enum ggml_type token_wtype = GGML_TYPE_F32;
555555
if (!force_clip_f32) {
556-
auto tensor_type = tensor_types.find(prefix + "token_embedding.weight");
557-
std::set<ggml_type> allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0};
558-
if (tensor_type != tensor_types.end() && allow_types.find(tensor_type->second) != allow_types.end()) {
559-
token_wtype = tensor_type->second;
556+
token_wtype = get_type(prefix + "token_embedding.weight", tensor_types, GGML_TYPE_F32);
557+
if (!support_get_rows(token_wtype)) {
558+
token_wtype = GGML_TYPE_F32;
560559
}
561560
}
562561
enum ggml_type position_wtype = GGML_TYPE_F32;

‎ggml_extend.hpp‎

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1967,13 +1967,24 @@ class Linear : public UnaryBlock {
19671967
}
19681968
};
19691969

1970+
__STATIC_INLINE__ bool support_get_rows(ggml_type wtype) {
1971+
std::set<ggml_type> allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0};
1972+
if (allow_types.find(wtype) != allow_types.end()) {
1973+
return true;
1974+
}
1975+
return false;
1976+
}
1977+
19701978
class Embedding : public UnaryBlock {
19711979
protected:
19721980
int64_t embedding_dim;
19731981
int64_t num_embeddings;
19741982
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
19751983
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
1976-
params["weight"] = ggml_new_tensor_2d(ctx, wtype, embedding_dim, num_embeddings);
1984+
if (!support_get_rows(wtype)) {
1985+
wtype = GGML_TYPE_F32;
1986+
}
1987+
params["weight"] = ggml_new_tensor_2d(ctx, wtype, embedding_dim, num_embeddings);
19771988
}
19781989

19791990
public:

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /