Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit f8fe4e7

Browse files
authored
fix: add flash attn support check (#803)
1 parent 1c07fb6 commit f8fe4e7

File tree

11 files changed

+191
-120
lines changed

11 files changed

+191
-120
lines changed

‎clip.hpp‎

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -488,14 +488,14 @@ struct CLIPLayer : public GGMLBlock {
488488
blocks["mlp"] = std::shared_ptr<GGMLBlock>(new CLIPMLP(d_model, intermediate_size));
489489
}
490490

491-
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, bool mask = true) {
491+
struct ggml_tensor* forward(struct ggml_context* ctx, ggml_backend_t backend, struct ggml_tensor* x, bool mask = true) {
492492
// x: [N, n_token, d_model]
493493
auto self_attn = std::dynamic_pointer_cast<MultiheadAttention>(blocks["self_attn"]);
494494
auto layer_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm1"]);
495495
auto layer_norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm2"]);
496496
auto mlp = std::dynamic_pointer_cast<CLIPMLP>(blocks["mlp"]);
497497

498-
x = ggml_add(ctx, x, self_attn->forward(ctx, layer_norm1->forward(ctx, x), mask));
498+
x = ggml_add(ctx, x, self_attn->forward(ctx, backend, layer_norm1->forward(ctx, x), mask));
499499
x = ggml_add(ctx, x, mlp->forward(ctx, layer_norm2->forward(ctx, x)));
500500
return x;
501501
}
@@ -517,7 +517,11 @@ struct CLIPEncoder : public GGMLBlock {
517517
}
518518
}
519519

520-
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, int clip_skip = -1, bool mask = true) {
520+
struct ggml_tensor* forward(struct ggml_context* ctx,
521+
ggml_backend_t backend,
522+
struct ggml_tensor* x,
523+
int clip_skip = -1,
524+
bool mask = true) {
521525
// x: [N, n_token, d_model]
522526
int layer_idx = n_layer - 1;
523527
// LOG_DEBUG("clip_skip %d", clip_skip);
@@ -532,7 +536,7 @@ struct CLIPEncoder : public GGMLBlock {
532536
}
533537
std::string name = "layers." + std::to_string(i);
534538
auto layer = std::dynamic_pointer_cast<CLIPLayer>(blocks[name]);
535-
x = layer->forward(ctx, x, mask); // [N, n_token, d_model]
539+
x = layer->forward(ctx, backend, x, mask); // [N, n_token, d_model]
536540
// LOG_DEBUG("layer %d", i);
537541
}
538542
return x;
@@ -712,6 +716,7 @@ class CLIPTextModel : public GGMLBlock {
712716
}
713717

714718
struct ggml_tensor* forward(struct ggml_context* ctx,
719+
ggml_backend_t backend,
715720
struct ggml_tensor* input_ids,
716721
struct ggml_tensor* tkn_embeddings,
717722
size_t max_token_idx = 0,
@@ -722,7 +727,7 @@ class CLIPTextModel : public GGMLBlock {
722727
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);
723728

724729
auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
725-
x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true);
730+
x = encoder->forward(ctx, backend, x, return_pooled ? -1 : clip_skip, true);
726731
if (return_pooled || with_final_ln) {
727732
x = final_layer_norm->forward(ctx, x);
728733
}
@@ -775,6 +780,7 @@ class CLIPVisionModel : public GGMLBlock {
775780
}
776781

777782
struct ggml_tensor* forward(struct ggml_context* ctx,
783+
ggml_backend_t backend,
778784
struct ggml_tensor* pixel_values,
779785
bool return_pooled = true,
780786
int clip_skip = -1) {
@@ -786,7 +792,7 @@ class CLIPVisionModel : public GGMLBlock {
786792

787793
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
788794
x = pre_layernorm->forward(ctx, x);
789-
x = encoder->forward(ctx, x, clip_skip, false);
795+
x = encoder->forward(ctx, backend, x, clip_skip, false);
790796
// print_ggml_tensor(x, true, "ClipVisionModel x: ");
791797
auto last_hidden_state = x;
792798
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
@@ -855,6 +861,7 @@ class CLIPVisionModelProjection : public GGMLBlock {
855861
}
856862

857863
struct ggml_tensor* forward(struct ggml_context* ctx,
864+
ggml_backend_t backend,
858865
struct ggml_tensor* pixel_values,
859866
bool return_pooled = true,
860867
int clip_skip = -1) {
@@ -863,7 +870,7 @@ class CLIPVisionModelProjection : public GGMLBlock {
863870
auto vision_model = std::dynamic_pointer_cast<CLIPVisionModel>(blocks["vision_model"]);
864871
auto visual_projection = std::dynamic_pointer_cast<CLIPProjection>(blocks["visual_projection"]);
865872

866-
auto x = vision_model->forward(ctx, pixel_values, return_pooled, clip_skip); // [N, hidden_size] or [N, n_token, hidden_size]
873+
auto x = vision_model->forward(ctx, backend, pixel_values, return_pooled, clip_skip); // [N, hidden_size] or [N, n_token, hidden_size]
867874

868875
if (return_pooled) {
869876
x = visual_projection->forward(ctx, x); // [N, projection_dim]
@@ -900,6 +907,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
900907
}
901908

902909
struct ggml_tensor* forward(struct ggml_context* ctx,
910+
ggml_backend_t backend,
903911
struct ggml_tensor* input_ids,
904912
struct ggml_tensor* embeddings,
905913
size_t max_token_idx = 0,
@@ -911,7 +919,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
911919
input_ids = ggml_reshape_2d(ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token);
912920
}
913921

914-
return model.forward(ctx, input_ids, embeddings, max_token_idx, return_pooled);
922+
return model.forward(ctx, backend, input_ids, embeddings, max_token_idx, return_pooled);
915923
}
916924

917925
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
@@ -937,7 +945,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
937945
embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1);
938946
}
939947

940-
struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, embeddings, max_token_idx, return_pooled);
948+
struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, embeddings, max_token_idx, return_pooled);
941949

942950
ggml_build_forward_expand(gf, hidden_states);
943951

‎common.hpp‎

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,10 @@ class CrossAttention : public GGMLBlock {
270270
// to_out_1 is nn.Dropout(), skip for inference
271271
}
272272

273-
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) {
273+
struct ggml_tensor* forward(struct ggml_context* ctx,
274+
ggml_backend_t backend,
275+
struct ggml_tensor* x,
276+
struct ggml_tensor* context) {
274277
// x: [N, n_token, query_dim]
275278
// context: [N, n_context, context_dim]
276279
// return: [N, n_token, query_dim]
@@ -288,7 +291,7 @@ class CrossAttention : public GGMLBlock {
288291
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
289292
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
290293

291-
x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false, false, flash_attn); // [N, n_token, inner_dim]
294+
x = ggml_nn_attention_ext(ctx, backend, q, k, v, n_head, NULL, false, false, flash_attn); // [N, n_token, inner_dim]
292295

293296
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
294297
return x;
@@ -327,7 +330,10 @@ class BasicTransformerBlock : public GGMLBlock {
327330
}
328331
}
329332

330-
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) {
333+
struct ggml_tensor* forward(struct ggml_context* ctx,
334+
ggml_backend_t backend,
335+
struct ggml_tensor* x,
336+
struct ggml_tensor* context) {
331337
// x: [N, n_token, query_dim]
332338
// context: [N, n_context, context_dim]
333339
// return: [N, n_token, query_dim]
@@ -352,11 +358,11 @@ class BasicTransformerBlock : public GGMLBlock {
352358

353359
auto r = x;
354360
x = norm1->forward(ctx, x);
355-
x = attn1->forward(ctx, x, x); // self-attention
361+
x = attn1->forward(ctx, backend, x, x); // self-attention
356362
x = ggml_add(ctx, x, r);
357363
r = x;
358364
x = norm2->forward(ctx, x);
359-
x = attn2->forward(ctx, x, context); // cross-attention
365+
x = attn2->forward(ctx, backend, x, context); // cross-attention
360366
x = ggml_add(ctx, x, r);
361367
r = x;
362368
x = norm3->forward(ctx, x);
@@ -401,7 +407,10 @@ class SpatialTransformer : public GGMLBlock {
401407
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
402408
}
403409

404-
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) {
410+
virtual struct ggml_tensor* forward(struct ggml_context* ctx,
411+
ggml_backend_t backend,
412+
struct ggml_tensor* x,
413+
struct ggml_tensor* context) {
405414
// x: [N, in_channels, h, w]
406415
// context: [N, max_position(aka n_token), hidden_size(aka context_dim)]
407416
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
@@ -424,7 +433,7 @@ class SpatialTransformer : public GGMLBlock {
424433
std::string name = "transformer_blocks." + std::to_string(i);
425434
auto transformer_block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[name]);
426435

427-
x = transformer_block->forward(ctx, x, context);
436+
x = transformer_block->forward(ctx, backend, x, context);
428437
}
429438

430439
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]

‎conditioner.hpp‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
639639

640640
pixel_values = to_backend(pixel_values);
641641

642-
struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, pixel_values, return_pooled, clip_skip);
642+
struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, runtime_backend, pixel_values, return_pooled, clip_skip);
643643

644644
ggml_build_forward_expand(gf, hidden_states);
645645

‎control.hpp‎

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,11 @@ class ControlNetBlock : public GGMLBlock {
174174

175175
struct ggml_tensor* attention_layer_forward(std::string name,
176176
struct ggml_context* ctx,
177+
ggml_backend_t backend,
177178
struct ggml_tensor* x,
178179
struct ggml_tensor* context) {
179180
auto block = std::dynamic_pointer_cast<SpatialTransformer>(blocks[name]);
180-
return block->forward(ctx, x, context);
181+
return block->forward(ctx, backend, x, context);
181182
}
182183

183184
struct ggml_tensor* input_hint_block_forward(struct ggml_context* ctx,
@@ -199,6 +200,7 @@ class ControlNetBlock : public GGMLBlock {
199200
}
200201

201202
std::vector<struct ggml_tensor*> forward(struct ggml_context* ctx,
203+
ggml_backend_t backend,
202204
struct ggml_tensor* x,
203205
struct ggml_tensor* hint,
204206
struct ggml_tensor* guided_hint,
@@ -272,7 +274,7 @@ class ControlNetBlock : public GGMLBlock {
272274
h = resblock_forward(name, ctx, h, emb); // [N, mult*model_channels, h, w]
273275
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
274276
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1";
275-
h = attention_layer_forward(name, ctx, h, context); // [N, mult*model_channels, h, w]
277+
h = attention_layer_forward(name, ctx, backend, h, context); // [N, mult*model_channels, h, w]
276278
}
277279

278280
auto zero_conv = std::dynamic_pointer_cast<Conv2d>(blocks["zero_convs." + std::to_string(input_block_idx) + ".0"]);
@@ -296,9 +298,9 @@ class ControlNetBlock : public GGMLBlock {
296298
// [N, 4*model_channels, h/8, w/8]
297299

298300
// middle_block
299-
h = resblock_forward("middle_block.0", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
300-
h = attention_layer_forward("middle_block.1", ctx, h, context); // [N, 4*model_channels, h/8, w/8]
301-
h = resblock_forward("middle_block.2", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
301+
h = resblock_forward("middle_block.0", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
302+
h = attention_layer_forward("middle_block.1", ctx, backend, h, context); // [N, 4*model_channels, h/8, w/8]
303+
h = resblock_forward("middle_block.2", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
302304

303305
// out
304306
outs.push_back(middle_block_out->forward(ctx, h));
@@ -403,6 +405,7 @@ struct ControlNet : public GGMLRunner {
403405
timesteps = to_backend(timesteps);
404406

405407
auto outs = control_net.forward(compute_ctx,
408+
runtime_backend,
406409
x,
407410
hint,
408411
guided_hint_cached ? guided_hint : NULL,

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /