Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit d1d7420

Browse files
committed
support for flux controls
1 parent fb604b7 commit d1d7420

File tree

4 files changed

+55
-32
lines changed

4 files changed

+55
-32
lines changed

‎flux.hpp‎

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,6 +1032,14 @@ namespace Flux {
10321032
control = patchify(ctx, control, patch_size);
10331033

10341034
img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
1035+
} else if (version == VERSION_FLUX_CONTROLS) {
1036+
GGML_ASSERT(c_concat != NULL);
1037+
1038+
ggml_tensor* control = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0);
1039+
1040+
control = patchify(ctx, control, patch_size);
1041+
1042+
img = ggml_concat(ctx, img, control, 0);
10351043
}
10361044

10371045
if (ref_latents.size() > 0) {
@@ -1079,6 +1087,8 @@ namespace Flux {
10791087
flux_params.depth_single_blocks = 0;
10801088
if (version == VERSION_FLUX_FILL) {
10811089
flux_params.in_channels = 384;
1090+
} else if (version == VERSION_FLUX_CONTROLS) {
1091+
flux_params.in_channels = 128;
10821092
} else if (version == VERSION_FLEX_2) {
10831093
flux_params.in_channels = 196;
10841094
}

‎model.cpp‎

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,10 +1685,12 @@ SDVersion ModelLoader::get_sd_version() {
16851685
}
16861686

16871687
if (is_flux) {
1688-
is_inpaint = input_block_weight.ne[0] == 384;
1689-
if (is_inpaint) {
1688+
if (input_block_weight.ne[0] == 384) {
16901689
return VERSION_FLUX_FILL;
16911690
}
1691+
if (input_block_weight.ne[0] == 128) {
1692+
return VERSION_FLUX_CONTROLS;
1693+
}
16921694
if(input_block_weight.ne[0] == 196){
16931695
return VERSION_FLEX_2;
16941696
}

‎model.h‎

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@ enum SDVersion {
3131
VERSION_SD3,
3232
VERSION_FLUX,
3333
VERSION_FLUX_FILL,
34+
VERSION_FLUX_CONTROLS,
3435
VERSION_FLEX_2,
3536
VERSION_COUNT,
3637
};
3738

3839
static inline bool sd_version_is_flux(SDVersion version) {
39-
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2 ) {
40+
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2 ) {
4041
return true;
4142
}
4243
return false;
@@ -70,15 +71,16 @@ static inline bool sd_version_is_sdxl(SDVersion version) {
7071
return false;
7172
}
7273

73-
static inline bool sd_version_is_inpaint(SDVersion version) {
74-
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) {
74+
75+
static inline bool sd_version_is_dit(SDVersion version) {
76+
if (sd_version_is_flux(version) || sd_version_is_sd3(version)) {
7577
return true;
7678
}
7779
return false;
7880
}
7981

80-
static inline bool sd_version_is_dit(SDVersion version) {
81-
if (sd_version_is_flux(version) || sd_version_is_sd3(version)) {
82+
static inline bool sd_version_is_inpaint(SDVersion version) {
83+
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) {
8284
return true;
8385
}
8486
return false;
@@ -88,8 +90,12 @@ static inline bool sd_version_is_edit(SDVersion version) {
8890
return version == VERSION_SD1_PIX2PIX || version == VERSION_SDXL_PIX2PIX;
8991
}
9092

93+
static inline bool sd_version_is_control(SDVersion version) {
94+
return version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2;
95+
}
96+
9197
static bool sd_version_use_concat(SDVersion version) {
92-
return sd_version_is_edit(version) || sd_version_is_inpaint(version);
98+
return sd_version_is_edit(version) || sd_version_is_inpaint(version)|| sd_version_is_control(version);
9399
}
94100

95101
enum PMVersion {

‎stable-diffusion.cpp‎

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ class StableDiffusionGGML {
314314
// TODO: shift_factor
315315
}
316316

317-
if(version == VERSION_FLEX_2){
317+
if (sd_version_is_control(version)) {
318318
// Might need vae encode for control cond
319319
vae_decode_only = false;
320320
}
@@ -840,7 +840,7 @@ class StableDiffusionGGML {
840840
int start_merge_step,
841841
SDCondition id_cond,
842842
std::vector<ggml_tensor*> ref_latents = {},
843-
ggml_tensor* denoise_mask = nullptr) {
843+
ggml_tensor* denoise_mask = nullptr) {
844844
std::vector<int> skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count);
845845

846846
// TODO (Pix2Pix): separate image guidance params (right now it's reusing distilled guidance)
@@ -1512,6 +1512,17 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
15121512
int W = width / 8;
15131513
int H = height / 8;
15141514
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
1515+
1516+
struct ggml_tensor* control_latent = NULL;
1517+
if (sd_version_is_control(sd_ctx->sd->version) && image_hint != NULL) {
1518+
if (!sd_ctx->sd->use_tiny_autoencoder) {
1519+
struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1520+
control_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments);
1521+
} else {
1522+
control_latent = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1523+
}
1524+
}
1525+
15151526
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
15161527
int64_t mask_channels = 1;
15171528
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
@@ -1544,50 +1555,44 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
15441555
}
15451556
}
15461557
}
1547-
if (sd_ctx->sd->version == VERSION_FLEX_2 && image_hint != NULL && sd_ctx->sd->control_net == NULL) {
1558+
1559+
if (sd_ctx->sd->version == VERSION_FLEX_2 && control_latent != NULL && sd_ctx->sd->control_net == NULL) {
15481560
bool no_inpaint = concat_latent == NULL;
15491561
if (no_inpaint) {
15501562
concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
15511563
}
15521564
// fill in the control image here
1553-
struct ggml_tensor* control_latents = NULL;
1554-
if (!sd_ctx->sd->use_tiny_autoencoder) {
1555-
struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1556-
control_latents = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments);
1557-
} else {
1558-
control_latents = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1559-
}
1560-
for (int64_t x = 0; x < concat_latent->ne[0]; x++) {
1561-
for (int64_t y = 0; y < concat_latent->ne[1]; y++) {
1565+
for (int64_t x = 0; x < control_latent->ne[0]; x++) {
1566+
for (int64_t y = 0; y < control_latent->ne[1]; y++) {
15621567
if (no_inpaint) {
1563-
for (int64_t c = 0; c < concat_latent->ne[2] - control_latents->ne[2]; c++) {
1568+
for (int64_t c = 0; c < concat_latent->ne[2] - control_latent->ne[2]; c++) {
15641569
// 0x16,1x1,0x16
15651570
ggml_tensor_set_f32(concat_latent, c == init_latent->ne[2], x, y, c);
15661571
}
15671572
}
1568-
for (int64_t c = 0; c < control_latents->ne[2]; c++) {
1569-
float v = ggml_tensor_get_f32(control_latents, x, y, c);
1570-
ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latents->ne[2] + c);
1573+
for (int64_t c = 0; c < control_latent->ne[2]; c++) {
1574+
float v = ggml_tensor_get_f32(control_latent, x, y, c);
1575+
ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latent->ne[2] + c);
15711576
}
15721577
}
15731578
}
1574-
// Disable controlnet
1575-
image_hint = NULL;
15761579
} else if (concat_latent == NULL) {
15771580
concat_latent = empty_latent;
15781581
}
15791582
cond.c_concat = concat_latent;
15801583
uncond.c_concat = empty_latent;
15811584
denoise_mask = NULL;
1582-
} else if (sd_version_is_edit(sd_ctx->sd->version)) {
1585+
} else if (sd_version_is_edit(sd_ctx->sd->version) || sd_version_is_control(sd_ctx->sd->version)) {
15831586
auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], init_latent->ne[3]);
15841587
ggml_set_f32(empty_latent, 0);
15851588
uncond.c_concat = empty_latent;
1589+
if (sd_version_is_control(sd_ctx->sd->version) && control_latent != NULL && sd_ctx->sd->control_net == NULL) {
1590+
concat_latent = control_latent;
1591+
}
15861592
if (concat_latent == NULL) {
15871593
concat_latent = empty_latent;
15881594
}
1589-
cond.c_concat = concat_latent;
1590-
1595+
cond.c_concat = concat_latent;
15911596
}
15921597
for (int b = 0; b < batch_count; b++) {
15931598
int64_t sampling_start = ggml_time_ms();
@@ -1870,7 +1875,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
18701875
ggml_tensor* masked_latent = NULL;
18711876
if (!sd_ctx->sd->use_tiny_autoencoder) {
18721877
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
1873-
masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
1878+
masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
18741879
} else {
18751880
masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
18761881
}
@@ -1941,8 +1946,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
19411946
} else {
19421947
concat_latent = init_latent;
19431948
}
1944-
}
1945-
1949+
}
1950+
19461951
{
19471952
// LOG_WARN("Inpainting with a base model is not great");
19481953
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 1, 1);

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /