Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 8d5d16a

Browse files
committed
support for flux controls
1 parent c457adf commit 8d5d16a

File tree

5 files changed

+72
-59
lines changed

5 files changed

+72
-59
lines changed

‎examples/cli/main.cpp‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -905,7 +905,8 @@ int main(int argc, const char* argv[]) {
905905
input_image_buffer};
906906

907907
sd_image_t* control_image = NULL;
908-
if (params.control_net_path.size() > 0 && params.control_image_path.size() > 0) {
908+
if (params.control_image_path.size() > 0) {
909+
printf("load image from '%s'\n", params.control_image_path.c_str());
909910
int c = 0;
910911
control_image_buffer = stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3);
911912
if (control_image_buffer == NULL) {

‎flux.hpp‎

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,6 +1032,14 @@ namespace Flux {
10321032
control = patchify(ctx, control, patch_size);
10331033

10341034
img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
1035+
} else if (version == VERSION_FLUX_CONTROLS) {
1036+
GGML_ASSERT(c_concat != NULL);
1037+
1038+
ggml_tensor* control = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0);
1039+
1040+
control = patchify(ctx, control, patch_size);
1041+
1042+
img = ggml_concat(ctx, img, control, 0);
10351043
}
10361044

10371045
if (ref_latents.size() > 0) {
@@ -1079,6 +1087,8 @@ namespace Flux {
10791087
flux_params.depth_single_blocks = 0;
10801088
if (version == VERSION_FLUX_FILL) {
10811089
flux_params.in_channels = 384;
1090+
} else if (version == VERSION_FLUX_CONTROLS) {
1091+
flux_params.in_channels = 128;
10821092
} else if (version == VERSION_FLEX_2) {
10831093
flux_params.in_channels = 196;
10841094
}

‎model.cpp‎

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,10 +1685,12 @@ SDVersion ModelLoader::get_sd_version() {
16851685
}
16861686

16871687
if (is_flux) {
1688-
is_inpaint = input_block_weight.ne[0] == 384;
1689-
if (is_inpaint) {
1688+
if (input_block_weight.ne[0] == 384) {
16901689
return VERSION_FLUX_FILL;
16911690
}
1691+
if (input_block_weight.ne[0] == 128) {
1692+
return VERSION_FLUX_CONTROLS;
1693+
}
16921694
if(input_block_weight.ne[0] == 196){
16931695
return VERSION_FLEX_2;
16941696
}

‎model.h‎

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@ enum SDVersion {
3131
VERSION_SD3,
3232
VERSION_FLUX,
3333
VERSION_FLUX_FILL,
34+
VERSION_FLUX_CONTROLS,
3435
VERSION_FLEX_2,
3536
VERSION_COUNT,
3637
};
3738

3839
static inline bool sd_version_is_flux(SDVersion version) {
39-
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2 ) {
40+
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2 ) {
4041
return true;
4142
}
4243
return false;
@@ -88,8 +89,12 @@ static inline bool sd_version_is_unet_edit(SDVersion version) {
8889
return version == VERSION_SD1_PIX2PIX || version == VERSION_SDXL_PIX2PIX;
8990
}
9091

92+
static inline bool sd_version_is_control(SDVersion version) {
93+
return version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2;
94+
}
95+
9196
static bool sd_version_is_inpaint_or_unet_edit(SDVersion version) {
92-
return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version);
97+
return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version)|| sd_version_is_control(version);
9398
}
9499

95100
enum PMVersion {

‎stable-diffusion.cpp‎

Lines changed: 49 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ class StableDiffusionGGML {
297297
// TODO: shift_factor
298298
}
299299

300-
if(version == VERSION_FLEX_2){
300+
if (sd_version_is_control(version)) {
301301
// Might need vae encode for control cond
302302
vae_decode_only = false;
303303
}
@@ -1722,6 +1722,17 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
17221722
int W = width / 8;
17231723
int H = height / 8;
17241724
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
1725+
1726+
struct ggml_tensor* control_latent = NULL;
1727+
if (sd_version_is_control(sd_ctx->sd->version) && image_hint != NULL) {
1728+
if (!sd_ctx->sd->use_tiny_autoencoder) {
1729+
struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1730+
control_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments);
1731+
} else {
1732+
control_latent = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1733+
}
1734+
}
1735+
17251736
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
17261737
int64_t mask_channels = 1;
17271738
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
@@ -1754,50 +1765,53 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
17541765
}
17551766
}
17561767
}
1757-
if (sd_ctx->sd->version == VERSION_FLEX_2 && image_hint != NULL && sd_ctx->sd->control_net == NULL) {
1768+
1769+
if (sd_ctx->sd->version == VERSION_FLEX_2 && control_latent != NULL && sd_ctx->sd->control_net == NULL) {
17581770
bool no_inpaint = concat_latent == NULL;
17591771
if (no_inpaint) {
17601772
concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
17611773
}
17621774
// fill in the control image here
1763-
struct ggml_tensor* control_latents = NULL;
1764-
if (!sd_ctx->sd->use_tiny_autoencoder) {
1765-
struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1766-
control_latents = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments);
1767-
} else {
1768-
control_latents = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1769-
}
1770-
for (int64_t x = 0; x < concat_latent->ne[0]; x++) {
1771-
for (int64_t y = 0; y < concat_latent->ne[1]; y++) {
1775+
for (int64_t x = 0; x < control_latent->ne[0]; x++) {
1776+
for (int64_t y = 0; y < control_latent->ne[1]; y++) {
17721777
if (no_inpaint) {
1773-
for (int64_t c = 0; c < concat_latent->ne[2] - control_latents->ne[2]; c++) {
1778+
for (int64_t c = 0; c < concat_latent->ne[2] - control_latent->ne[2]; c++) {
17741779
// 0x16,1x1,0x16
17751780
ggml_tensor_set_f32(concat_latent, c == init_latent->ne[2], x, y, c);
17761781
}
17771782
}
1778-
for (int64_t c = 0; c < control_latents->ne[2]; c++) {
1779-
float v = ggml_tensor_get_f32(control_latents, x, y, c);
1780-
ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latents->ne[2] + c);
1783+
for (int64_t c = 0; c < control_latent->ne[2]; c++) {
1784+
float v = ggml_tensor_get_f32(control_latent, x, y, c);
1785+
ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latent->ne[2] + c);
17811786
}
17821787
}
17831788
}
1784-
// Disable controlnet
1785-
image_hint = NULL;
17861789
} else if (concat_latent == NULL) {
17871790
concat_latent = empty_latent;
17881791
}
17891792
cond.c_concat = concat_latent;
17901793
uncond.c_concat = empty_latent;
17911794
denoise_mask = NULL;
1792-
} else if (sd_version_is_unet_edit(sd_ctx->sd->version)) {
17931795
} else if (sd_version_is_unet_edit(sd_ctx->sd->version)) {
17941796
auto empty_latent = ggml_dup_tensor(work_ctx, init_latent);
17951797
ggml_set_f32(empty_latent, 0);
17961798
uncond.c_concat = empty_latent;
1797-
if (concat_latent == NULL) {
1798-
concat_latent = empty_latent;
1799+
cond.c_concat = ref_latents[0];
1800+
if (cond.c_concat == NULL) {
1801+
cond.c_concat = empty_latent;
1802+
}
1803+
} else if (sd_version_is_control(sd_ctx->sd->version)) {
1804+
LOG_DEBUG("HERE");
1805+
auto empty_latent = ggml_dup_tensor(work_ctx, init_latent);
1806+
ggml_set_f32(empty_latent, 0);
1807+
uncond.c_concat = empty_latent;
1808+
if (sd_version_is_control(sd_ctx->sd->version) && control_latent != NULL && sd_ctx->sd->control_net == NULL) {
1809+
cond.c_concat = control_latent;
17991810
}
1800-
cond.c_concat = ref_latents[0];
1811+
if (cond.c_concat == NULL) {
1812+
cond.c_concat = empty_latent;
1813+
}
1814+
LOG_DEBUG("HERE");
18011815
}
18021816
SDCondition img_cond;
18031817
if (uncond.c_crossattn != NULL &&
@@ -1956,6 +1970,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
19561970
size_t t0 = ggml_time_ms();
19571971

19581972
ggml_tensor* init_latent = NULL;
1973+
ggml_tensor* init_moments = NULL;
19591974
ggml_tensor* concat_latent = NULL;
19601975
ggml_tensor* denoise_mask = NULL;
19611976
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sd_img_gen_params->sample_steps);
@@ -1978,8 +1993,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
19781993
sd_image_to_tensor(sd_img_gen_params->init_image.data, init_img);
19791994

19801995
if (!sd_ctx->sd->use_tiny_autoencoder) {
1981-
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
1982-
init_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
1996+
init_moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
1997+
init_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, init_moments);
19831998
} else {
19841999
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
19852000
}
@@ -1988,8 +2003,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
19882003
int64_t mask_channels = 1;
19892004
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
19902005
mask_channels = 8 * 8; // flatten the whole mask
1991-
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
1992-
mask_channels = 1 + init_latent->ne[2];
2006+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
2007+
mask_channels = 1 + init_latent->ne[2];
19932008
}
19942009
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
19952010
sd_apply_mask(init_img, mask_img, masked_img);
@@ -2024,38 +2039,18 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
20242039
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * 8 + y);
20252040
}
20262041
}
2027-
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
2028-
float m = ggml_tensor_get_f32(mask_img, mx, my);
2029-
// masked image
2030-
for (int k = 0; k < masked_latent->ne[2]; k++) {
2031-
float v = ggml_tensor_get_f32(masked_latent, ix, iy, k);
2032-
ggml_tensor_set_f32(concat_latent, v, ix, iy, k);
2033-
}
2034-
// downsampled mask
2035-
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]);
2036-
// control (todo: support this)
2037-
for (int k = 0; k < masked_latent->ne[2]; k++) {
2038-
ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k);
2039-
}
2040-
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
2041-
float m = ggml_tensor_get_f32(mask_img, mx, my);
2042-
// masked image
2043-
for (int k = 0; k < masked_latent->ne[2]; k++) {
2044-
float v = ggml_tensor_get_f32(masked_latent, ix, iy, k);
2045-
ggml_tensor_set_f32(concat_latent, v, ix, iy, k);
2046-
}
2047-
// downsampled mask
2048-
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]);
2049-
// control (todo: support this)
2050-
for (int k = 0; k < masked_latent->ne[2]; k++) {
2051-
ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k);
2052-
}
2053-
} else {
2042+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
20542043
float m = ggml_tensor_get_f32(mask_img, mx, my);
2055-
ggml_tensor_set_f32(concat_latent, m, ix, iy, 0);
2044+
// masked image
20562045
for (int k = 0; k < masked_latent->ne[2]; k++) {
20572046
float v = ggml_tensor_get_f32(masked_latent, ix, iy, k);
2058-
ggml_tensor_set_f32(concat_latent, v, ix, iy, k + mask_channels);
2047+
ggml_tensor_set_f32(concat_latent, v, ix, iy, k);
2048+
}
2049+
// downsampled mask
2050+
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]);
2051+
// control (todo: support this)
2052+
for (int k = 0; k < masked_latent->ne[2]; k++) {
2053+
ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k);
20592054
}
20602055
}
20612056
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /