@@ -95,7 +95,7 @@ class StableDiffusionGGML {
95
95
std::shared_ptr<DiffusionModel> diffusion_model;
96
96
std::shared_ptr<AutoEncoderKL> first_stage_model;
97
97
std::shared_ptr<TinyAutoEncoder> tae_first_stage;
98
- std::shared_ptr<ControlNet> control_net;
98
+ std::shared_ptr<ControlNet> control_net = NULL ;
99
99
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
100
100
std::shared_ptr<LoraModel> pmid_lora;
101
101
std::shared_ptr<PhotoMakerIDEmbed> pmid_id_embeds;
@@ -297,6 +297,11 @@ class StableDiffusionGGML {
297
297
// TODO: shift_factor
298
298
}
299
299
300
+ if (version == VERSION_FLEX_2){
301
+ // Might need vae encode for control cond
302
+ vae_decode_only = false ;
303
+ }
304
+
300
305
bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu ;
301
306
302
307
if (version == VERSION_SVD) {
@@ -933,7 +938,7 @@ class StableDiffusionGGML {
933
938
934
939
std::vector<struct ggml_tensor *> controls;
935
940
936
- if (control_hint != NULL ) {
941
+ if (control_hint != NULL && control_net != NULL ) {
937
942
control_net->compute (n_threads, noised_input, control_hint, timesteps, cond.c_crossattn , cond.c_vector );
938
943
controls = control_net->controls ;
939
944
// print_ggml_tensor(controls[12]);
@@ -972,7 +977,7 @@ class StableDiffusionGGML {
972
977
float * negative_data = NULL ;
973
978
if (has_unconditioned) {
974
979
// uncond
975
- if (control_hint != NULL ) {
980
+ if (control_hint != NULL && control_net != NULL ) {
976
981
control_net->compute (n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn , uncond.c_vector );
977
982
controls = control_net->controls ;
978
983
}
@@ -1721,6 +1726,8 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
1721
1726
int64_t mask_channels = 1 ;
1722
1727
if (sd_ctx->sd ->version == VERSION_FLUX_FILL) {
1723
1728
mask_channels = 8 * 8 ; // flatten the whole mask
1729
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
1730
+ mask_channels = 1 + init_latent->ne [2 ];
1724
1731
}
1725
1732
auto empty_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, init_latent->ne [0 ], init_latent->ne [1 ], mask_channels + init_latent->ne [2 ], 1 );
1726
1733
// no mask, set the whole image as masked
@@ -1734,6 +1741,11 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
1734
1741
for (int64_t c = init_latent->ne [2 ]; c < empty_latent->ne [2 ]; c++) {
1735
1742
ggml_tensor_set_f32 (empty_latent, 1 , x, y, c);
1736
1743
}
1744
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
1745
+ for (int64_t c = 0 ; c < empty_latent->ne [2 ]; c++) {
1746
+ // 0x16,1x1,0x16
1747
+ ggml_tensor_set_f32 (empty_latent, c == init_latent->ne [2 ], x, y, c);
1748
+ }
1737
1749
} else {
1738
1750
ggml_tensor_set_f32 (empty_latent, 1 , x, y, 0 );
1739
1751
for (int64_t c = 1 ; c < empty_latent->ne [2 ]; c++) {
@@ -1742,12 +1754,42 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
1742
1754
}
1743
1755
}
1744
1756
}
1745
- if (concat_latent == NULL ) {
1757
+ if (sd_ctx->sd ->version == VERSION_FLEX_2 && image_hint != NULL && sd_ctx->sd ->control_net == NULL ) {
1758
+ bool no_inpaint = concat_latent == NULL ;
1759
+ if (no_inpaint) {
1760
+ concat_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, init_latent->ne [0 ], init_latent->ne [1 ], mask_channels + init_latent->ne [2 ], 1 );
1761
+ }
1762
+ // fill in the control image here
1763
+ struct ggml_tensor * control_latents = NULL ;
1764
+ if (!sd_ctx->sd ->use_tiny_autoencoder ) {
1765
+ struct ggml_tensor * control_moments = sd_ctx->sd ->encode_first_stage (work_ctx, image_hint);
1766
+ control_latents = sd_ctx->sd ->get_first_stage_encoding (work_ctx, control_moments);
1767
+ } else {
1768
+ control_latents = sd_ctx->sd ->encode_first_stage (work_ctx, image_hint);
1769
+ }
1770
+ for (int64_t x = 0 ; x < concat_latent->ne [0 ]; x++) {
1771
+ for (int64_t y = 0 ; y < concat_latent->ne [1 ]; y++) {
1772
+ if (no_inpaint) {
1773
+ for (int64_t c = 0 ; c < concat_latent->ne [2 ] - control_latents->ne [2 ]; c++) {
1774
+ // 0x16,1x1,0x16
1775
+ ggml_tensor_set_f32 (concat_latent, c == init_latent->ne [2 ], x, y, c);
1776
+ }
1777
+ }
1778
+ for (int64_t c = 0 ; c < control_latents->ne [2 ]; c++) {
1779
+ float v = ggml_tensor_get_f32 (control_latents, x, y, c);
1780
+ ggml_tensor_set_f32 (concat_latent, v, x, y, concat_latent->ne [2 ] - control_latents->ne [2 ] + c);
1781
+ }
1782
+ }
1783
+ }
1784
+ // Disable controlnet
1785
+ image_hint = NULL ;
1786
+ } else if (concat_latent == NULL ) {
1746
1787
concat_latent = empty_latent;
1747
1788
}
1748
1789
cond.c_concat = concat_latent;
1749
1790
uncond.c_concat = empty_latent;
1750
1791
denoise_mask = NULL ;
1792
+ } else if (sd_version_is_unet_edit (sd_ctx->sd ->version )) {
1751
1793
} else if (sd_version_is_unet_edit (sd_ctx->sd ->version )) {
1752
1794
auto empty_latent = ggml_dup_tensor (work_ctx, init_latent);
1753
1795
ggml_set_f32 (empty_latent, 0 );
@@ -1935,10 +1977,19 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
1935
1977
sd_mask_to_tensor (sd_img_gen_params->mask_image .data , mask_img);
1936
1978
sd_image_to_tensor (sd_img_gen_params->init_image .data , init_img);
1937
1979
1980
+ if (!sd_ctx->sd ->use_tiny_autoencoder ) {
1981
+ ggml_tensor* moments = sd_ctx->sd ->encode_first_stage (work_ctx, init_img);
1982
+ init_latent = sd_ctx->sd ->get_first_stage_encoding (work_ctx, moments);
1983
+ } else {
1984
+ init_latent = sd_ctx->sd ->encode_first_stage (work_ctx, init_img);
1985
+ }
1986
+
1938
1987
if (sd_version_is_inpaint (sd_ctx->sd ->version )) {
1939
1988
int64_t mask_channels = 1 ;
1940
1989
if (sd_ctx->sd ->version == VERSION_FLUX_FILL) {
1941
1990
mask_channels = 8 * 8 ; // flatten the whole mask
1991
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
1992
+ mask_channels = 1 + init_latent->ne [2 ];
1942
1993
}
1943
1994
ggml_tensor* masked_img = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, width, height, 3 , 1 );
1944
1995
sd_apply_mask (init_img, mask_img, masked_img);
@@ -1973,6 +2024,32 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
1973
2024
ggml_tensor_set_f32 (concat_latent, m, ix, iy, masked_latent->ne [2 ] + x * 8 + y);
1974
2025
}
1975
2026
}
2027
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
2028
+ float m = ggml_tensor_get_f32 (mask_img, mx, my);
2029
+ // masked image
2030
+ for (int k = 0 ; k < masked_latent->ne [2 ]; k++) {
2031
+ float v = ggml_tensor_get_f32 (masked_latent, ix, iy, k);
2032
+ ggml_tensor_set_f32 (concat_latent, v, ix, iy, k);
2033
+ }
2034
+ // downsampled mask
2035
+ ggml_tensor_set_f32 (concat_latent, m, ix, iy, masked_latent->ne [2 ]);
2036
+ // control (todo: support this)
2037
+ for (int k = 0 ; k < masked_latent->ne [2 ]; k++) {
2038
+ ggml_tensor_set_f32 (concat_latent, 0 , ix, iy, masked_latent->ne [2 ] + 1 + k);
2039
+ }
2040
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
2041
+ float m = ggml_tensor_get_f32 (mask_img, mx, my);
2042
+ // masked image
2043
+ for (int k = 0 ; k < masked_latent->ne [2 ]; k++) {
2044
+ float v = ggml_tensor_get_f32 (masked_latent, ix, iy, k);
2045
+ ggml_tensor_set_f32 (concat_latent, v, ix, iy, k);
2046
+ }
2047
+ // downsampled mask
2048
+ ggml_tensor_set_f32 (concat_latent, m, ix, iy, masked_latent->ne [2 ]);
2049
+ // control (todo: support this)
2050
+ for (int k = 0 ; k < masked_latent->ne [2 ]; k++) {
2051
+ ggml_tensor_set_f32 (concat_latent, 0 , ix, iy, masked_latent->ne [2 ] + 1 + k);
2052
+ }
1976
2053
} else {
1977
2054
float m = ggml_tensor_get_f32 (mask_img, mx, my);
1978
2055
ggml_tensor_set_f32 (concat_latent, m, ix, iy, 0 );
@@ -1998,12 +2075,6 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
1998
2075
}
1999
2076
}
2000
2077
2001
- if (!sd_ctx->sd ->use_tiny_autoencoder ) {
2002
- ggml_tensor* moments = sd_ctx->sd ->encode_first_stage (work_ctx, init_img);
2003
- init_latent = sd_ctx->sd ->get_first_stage_encoding (work_ctx, moments);
2004
- } else {
2005
- init_latent = sd_ctx->sd ->encode_first_stage (work_ctx, init_img);
2006
- }
2007
2078
} else {
2008
2079
LOG_INFO (" TXT2IMG" );
2009
2080
if (sd_version_is_inpaint (sd_ctx->sd ->version )) {
0 commit comments