Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit c457adf

Browse files
committed
Support Flex-2
1 parent 1896b28 commit c457adf

File tree

6 files changed

+114
-16
lines changed

6 files changed

+114
-16
lines changed

‎flux.hpp‎

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,8 @@ namespace Flux {
984984
struct ggml_tensor* pe,
985985
struct ggml_tensor* mod_index_arange = NULL,
986986
std::vector<ggml_tensor*> ref_latents = {},
987-
std::vector<int> skip_layers = {}) {
987+
std::vector<int> skip_layers = {},
988+
SDVersion version = VERSION_FLUX) {
988989
// Forward pass of DiT.
989990
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
990991
// timestep: (N,) tensor of diffusion timesteps
@@ -1007,14 +1008,30 @@ namespace Flux {
10071008
auto img = process_img(ctx, x);
10081009
uint64_t img_tokens = img->ne[1];
10091010

1010-
if (c_concat != NULL) {
1011+
if (version == VERSION_FLUX_FILL) {
1012+
GGML_ASSERT(c_concat != NULL);
10111013
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
10121014
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
10131015

10141016
masked = process_img(ctx, masked);
10151017
mask = process_img(ctx, mask);
10161018

10171019
img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0);
1020+
} else if (version == VERSION_FLEX_2) {
1021+
GGML_ASSERT(c_concat != NULL);
1022+
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
1023+
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
1024+
ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
1025+
1026+
masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
1027+
mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);
1028+
control = ggml_pad(ctx, control, pad_w, pad_h, 0, 0);
1029+
1030+
masked = patchify(ctx, masked, patch_size);
1031+
mask = patchify(ctx, mask, patch_size);
1032+
control = patchify(ctx, control, patch_size);
1033+
1034+
img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
10181035
}
10191036

10201037
if (ref_latents.size() > 0) {
@@ -1055,13 +1072,15 @@ namespace Flux {
10551072
SDVersion version = VERSION_FLUX,
10561073
bool flash_attn = false,
10571074
bool use_mask = false)
1058-
: GGMLRunner(backend), use_mask(use_mask) {
1075+
: GGMLRunner(backend), version(version), use_mask(use_mask) {
10591076
flux_params.flash_attn = flash_attn;
10601077
flux_params.guidance_embed = false;
10611078
flux_params.depth = 0;
10621079
flux_params.depth_single_blocks = 0;
10631080
if (version == VERSION_FLUX_FILL) {
10641081
flux_params.in_channels = 384;
1082+
} else if (version == VERSION_FLEX_2) {
1083+
flux_params.in_channels = 196;
10651084
}
10661085
for (auto pair : tensor_types) {
10671086
std::string tensor_name = pair.first;
@@ -1171,7 +1190,8 @@ namespace Flux {
11711190
pe,
11721191
mod_index_arange,
11731192
ref_latents,
1174-
skip_layers);
1193+
skip_layers,
1194+
version);
11751195

11761196
ggml_build_forward_expand(gf, out);
11771197

‎ggml_extend.hpp‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,8 @@ __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
384384
int64_t width = output->ne[0];
385385
int64_t height = output->ne[1];
386386
int64_t channels = output->ne[2];
387+
float rescale_mx = mask->ne[0]/output->ne[0];
388+
float rescale_my = mask->ne[1]/output->ne[1];
387389
GGML_ASSERT(output->type == GGML_TYPE_F32);
388390
for (int ix = 0; ix < width; ix++) {
389391
for (int iy = 0; iy < height; iy++) {

‎model.cpp‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,6 +1689,9 @@ SDVersion ModelLoader::get_sd_version() {
16891689
if (is_inpaint) {
16901690
return VERSION_FLUX_FILL;
16911691
}
1692+
if(input_block_weight.ne[0] == 196){
1693+
return VERSION_FLEX_2;
1694+
}
16921695
return VERSION_FLUX;
16931696
}
16941697

‎model.h‎

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,12 @@ enum SDVersion {
3131
VERSION_SD3,
3232
VERSION_FLUX,
3333
VERSION_FLUX_FILL,
34+
VERSION_FLEX_2,
3435
VERSION_COUNT,
3536
};
3637

3738
static inline bool sd_version_is_flux(SDVersion version) {
38-
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) {
39+
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2 ) {
3940
return true;
4041
}
4142
return false;
@@ -70,7 +71,7 @@ static inline bool sd_version_is_sdxl(SDVersion version) {
7071
}
7172

7273
static inline bool sd_version_is_inpaint(SDVersion version) {
73-
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) {
74+
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) {
7475
return true;
7576
}
7677
return false;

‎stable-diffusion.cpp‎

Lines changed: 81 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class StableDiffusionGGML {
9595
std::shared_ptr<DiffusionModel> diffusion_model;
9696
std::shared_ptr<AutoEncoderKL> first_stage_model;
9797
std::shared_ptr<TinyAutoEncoder> tae_first_stage;
98-
std::shared_ptr<ControlNet> control_net;
98+
std::shared_ptr<ControlNet> control_net = NULL;
9999
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
100100
std::shared_ptr<LoraModel> pmid_lora;
101101
std::shared_ptr<PhotoMakerIDEmbed> pmid_id_embeds;
@@ -297,6 +297,11 @@ class StableDiffusionGGML {
297297
// TODO: shift_factor
298298
}
299299

300+
if(version == VERSION_FLEX_2){
301+
// Might need vae encode for control cond
302+
vae_decode_only = false;
303+
}
304+
300305
bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu;
301306

302307
if (version == VERSION_SVD) {
@@ -933,7 +938,7 @@ class StableDiffusionGGML {
933938

934939
std::vector<struct ggml_tensor*> controls;
935940

936-
if (control_hint != NULL) {
941+
if (control_hint != NULL && control_net != NULL) {
937942
control_net->compute(n_threads, noised_input, control_hint, timesteps, cond.c_crossattn, cond.c_vector);
938943
controls = control_net->controls;
939944
// print_ggml_tensor(controls[12]);
@@ -972,7 +977,7 @@ class StableDiffusionGGML {
972977
float* negative_data = NULL;
973978
if (has_unconditioned) {
974979
// uncond
975-
if (control_hint != NULL) {
980+
if (control_hint != NULL && control_net != NULL) {
976981
control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector);
977982
controls = control_net->controls;
978983
}
@@ -1721,6 +1726,8 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
17211726
int64_t mask_channels = 1;
17221727
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
17231728
mask_channels = 8 * 8; // flatten the whole mask
1729+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
1730+
mask_channels = 1 + init_latent->ne[2];
17241731
}
17251732
auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
17261733
// no mask, set the whole image as masked
@@ -1734,6 +1741,11 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
17341741
for (int64_t c = init_latent->ne[2]; c < empty_latent->ne[2]; c++) {
17351742
ggml_tensor_set_f32(empty_latent, 1, x, y, c);
17361743
}
1744+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
1745+
for (int64_t c = 0; c < empty_latent->ne[2]; c++) {
1746+
// 0x16,1x1,0x16
1747+
ggml_tensor_set_f32(empty_latent, c == init_latent->ne[2], x, y, c);
1748+
}
17371749
} else {
17381750
ggml_tensor_set_f32(empty_latent, 1, x, y, 0);
17391751
for (int64_t c = 1; c < empty_latent->ne[2]; c++) {
@@ -1742,12 +1754,42 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
17421754
}
17431755
}
17441756
}
1745-
if (concat_latent == NULL) {
1757+
if (sd_ctx->sd->version == VERSION_FLEX_2 && image_hint != NULL && sd_ctx->sd->control_net == NULL) {
1758+
bool no_inpaint = concat_latent == NULL;
1759+
if (no_inpaint) {
1760+
concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
1761+
}
1762+
// fill in the control image here
1763+
struct ggml_tensor* control_latents = NULL;
1764+
if (!sd_ctx->sd->use_tiny_autoencoder) {
1765+
struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1766+
control_latents = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments);
1767+
} else {
1768+
control_latents = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1769+
}
1770+
for (int64_t x = 0; x < concat_latent->ne[0]; x++) {
1771+
for (int64_t y = 0; y < concat_latent->ne[1]; y++) {
1772+
if (no_inpaint) {
1773+
for (int64_t c = 0; c < concat_latent->ne[2] - control_latents->ne[2]; c++) {
1774+
// 0x16,1x1,0x16
1775+
ggml_tensor_set_f32(concat_latent, c == init_latent->ne[2], x, y, c);
1776+
}
1777+
}
1778+
for (int64_t c = 0; c < control_latents->ne[2]; c++) {
1779+
float v = ggml_tensor_get_f32(control_latents, x, y, c);
1780+
ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latents->ne[2] + c);
1781+
}
1782+
}
1783+
}
1784+
// Disable controlnet
1785+
image_hint = NULL;
1786+
} else if (concat_latent == NULL) {
17461787
concat_latent = empty_latent;
17471788
}
17481789
cond.c_concat = concat_latent;
17491790
uncond.c_concat = empty_latent;
17501791
denoise_mask = NULL;
1792+
} else if (sd_version_is_unet_edit(sd_ctx->sd->version)) {
17511793
} else if (sd_version_is_unet_edit(sd_ctx->sd->version)) {
17521794
auto empty_latent = ggml_dup_tensor(work_ctx, init_latent);
17531795
ggml_set_f32(empty_latent, 0);
@@ -1935,10 +1977,19 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
19351977
sd_mask_to_tensor(sd_img_gen_params->mask_image.data, mask_img);
19361978
sd_image_to_tensor(sd_img_gen_params->init_image.data, init_img);
19371979

1980+
if (!sd_ctx->sd->use_tiny_autoencoder) {
1981+
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
1982+
init_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
1983+
} else {
1984+
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
1985+
}
1986+
19381987
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
19391988
int64_t mask_channels = 1;
19401989
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
19411990
mask_channels = 8 * 8; // flatten the whole mask
1991+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
1992+
mask_channels = 1 + init_latent->ne[2];
19421993
}
19431994
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
19441995
sd_apply_mask(init_img, mask_img, masked_img);
@@ -1973,6 +2024,32 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
19732024
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * 8 + y);
19742025
}
19752026
}
2027+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
2028+
float m = ggml_tensor_get_f32(mask_img, mx, my);
2029+
// masked image
2030+
for (int k = 0; k < masked_latent->ne[2]; k++) {
2031+
float v = ggml_tensor_get_f32(masked_latent, ix, iy, k);
2032+
ggml_tensor_set_f32(concat_latent, v, ix, iy, k);
2033+
}
2034+
// downsampled mask
2035+
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]);
2036+
// control (todo: support this)
2037+
for (int k = 0; k < masked_latent->ne[2]; k++) {
2038+
ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k);
2039+
}
2040+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
2041+
float m = ggml_tensor_get_f32(mask_img, mx, my);
2042+
// masked image
2043+
for (int k = 0; k < masked_latent->ne[2]; k++) {
2044+
float v = ggml_tensor_get_f32(masked_latent, ix, iy, k);
2045+
ggml_tensor_set_f32(concat_latent, v, ix, iy, k);
2046+
}
2047+
// downsampled mask
2048+
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]);
2049+
// control (todo: support this)
2050+
for (int k = 0; k < masked_latent->ne[2]; k++) {
2051+
ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k);
2052+
}
19762053
} else {
19772054
float m = ggml_tensor_get_f32(mask_img, mx, my);
19782055
ggml_tensor_set_f32(concat_latent, m, ix, iy, 0);
@@ -1998,12 +2075,6 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
19982075
}
19992076
}
20002077

2001-
if (!sd_ctx->sd->use_tiny_autoencoder) {
2002-
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
2003-
init_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
2004-
} else {
2005-
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
2006-
}
20072078
} else {
20082079
LOG_INFO("TXT2IMG");
20092080
if (sd_version_is_inpaint(sd_ctx->sd->version)) {

‎vae.hpp‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ struct AutoEncoderKL : public GGMLRunner {
559559
bool decode_graph,
560560
struct ggml_tensor** output,
561561
struct ggml_context* output_ctx = NULL) {
562+
GGML_ASSERT(!decode_only || decode_graph);
562563
auto get_graph = [&]() -> struct ggml_cgraph* {
563564
return build_graph(z, decode_graph);
564565
};

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /