Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Support for Flux Controls + Flex.2 #692

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
stduhpf wants to merge 6 commits into leejet:master
base: master
Choose a base branch
Loading
from stduhpf:concat-controls
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/cli/main.cpp
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,8 @@ int main(int argc, const char* argv[]) {
input_image_buffer};

sd_image_t* control_image = NULL;
if (params.control_net_path.size() > 0 && params.control_image_path.size() > 0) {
if (params.control_image_path.size() > 0) {
printf("load image from '%s'\n", params.control_image_path.c_str());
int c = 0;
control_image_buffer = stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3);
if (control_image_buffer == NULL) {
Expand Down
38 changes: 34 additions & 4 deletions flux.hpp
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,8 @@ namespace Flux {
struct ggml_tensor* pe,
struct ggml_tensor* mod_index_arange = NULL,
std::vector<ggml_tensor*> ref_latents = {},
std::vector<int> skip_layers = {}) {
std::vector<int> skip_layers = {},
SDVersion version = VERSION_FLUX) {
// Forward pass of DiT.
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
// timestep: (N,) tensor of diffusion timesteps
Expand All @@ -1007,14 +1008,38 @@ namespace Flux {
auto img = process_img(ctx, x);
uint64_t img_tokens = img->ne[1];

if (c_concat != NULL) {
if (version == VERSION_FLUX_FILL) {
GGML_ASSERT(c_concat != NULL);
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);

masked = process_img(ctx, masked);
mask = process_img(ctx, mask);

img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0);
} else if (version == VERSION_FLEX_2) {
GGML_ASSERT(c_concat != NULL);
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));

masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);
control = ggml_pad(ctx, control, pad_w, pad_h, 0, 0);

masked = patchify(ctx, masked, patch_size);
mask = patchify(ctx, mask, patch_size);
control = patchify(ctx, control, patch_size);

img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
} else if (version == VERSION_FLUX_CONTROLS) {
GGML_ASSERT(c_concat != NULL);

ggml_tensor* control = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0);

control = patchify(ctx, control, patch_size);

img = ggml_concat(ctx, img, control, 0);
}

if (ref_latents.size() > 0) {
Expand Down Expand Up @@ -1055,13 +1080,17 @@ namespace Flux {
SDVersion version = VERSION_FLUX,
bool flash_attn = false,
bool use_mask = false)
: GGMLRunner(backend), use_mask(use_mask) {
: GGMLRunner(backend), version(version), use_mask(use_mask) {
flux_params.flash_attn = flash_attn;
flux_params.guidance_embed = false;
flux_params.depth = 0;
flux_params.depth_single_blocks = 0;
if (version == VERSION_FLUX_FILL) {
flux_params.in_channels = 384;
} else if (version == VERSION_FLUX_CONTROLS) {
flux_params.in_channels = 128;
} else if (version == VERSION_FLEX_2) {
flux_params.in_channels = 196;
}
for (auto pair : tensor_types) {
std::string tensor_name = pair.first;
Expand Down Expand Up @@ -1171,7 +1200,8 @@ namespace Flux {
pe,
mod_index_arange,
ref_latents,
skip_layers);
skip_layers,
version);

ggml_build_forward_expand(gf, out);

Expand Down
14 changes: 10 additions & 4 deletions ggml_extend.hpp
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -380,18 +380,24 @@ __STATIC_INLINE__ void sd_mask_to_tensor(const uint8_t* image_data,

__STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
struct ggml_tensor* mask,
struct ggml_tensor* output) {
struct ggml_tensor* output,
float masked_value = 0.5f) {
int64_t width = output->ne[0];
int64_t height = output->ne[1];
int64_t channels = output->ne[2];
float rescale_mx = mask->ne[0]/output->ne[0];
float rescale_my = mask->ne[1]/output->ne[1];
GGML_ASSERT(output->type == GGML_TYPE_F32);
for (int ix = 0; ix < width; ix++) {
for (int iy = 0; iy < height; iy++) {
float m = ggml_tensor_get_f32(mask, ix, iy);
int mx = (int)(ix * rescale_mx);
int my = (int)(iy * rescale_my);
float m = ggml_tensor_get_f32(mask, mx, my);
m = round(m); // inpaint models need binary masks
ggml_tensor_set_f32(mask, m, ix, iy);
ggml_tensor_set_f32(mask, m, mx, my);
for (int k = 0; k < channels; k++) {
float value = (1 - m) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
float value = ggml_tensor_get_f32(image_data, ix, iy, k);
value = (1 - m) * (value - masked_value) + masked_value;
ggml_tensor_set_f32(output, value, ix, iy, k);
}
}
Expand Down
9 changes: 7 additions & 2 deletions model.cpp
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -1685,10 +1685,15 @@ SDVersion ModelLoader::get_sd_version() {
}

if (is_flux) {
is_inpaint = input_block_weight.ne[0] == 384;
if (is_inpaint) {
if (input_block_weight.ne[0] == 384) {
return VERSION_FLUX_FILL;
}
if (input_block_weight.ne[0] == 128) {
return VERSION_FLUX_CONTROLS;
}
if(input_block_weight.ne[0] == 196){
return VERSION_FLEX_2;
}
return VERSION_FLUX;
}

Expand Down
12 changes: 9 additions & 3 deletions model.h
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ enum SDVersion {
VERSION_SD3,
VERSION_FLUX,
VERSION_FLUX_FILL,
VERSION_FLUX_CONTROLS,
VERSION_FLEX_2,
VERSION_COUNT,
};

static inline bool sd_version_is_flux(SDVersion version) {
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) {
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2 ) {
return true;
}
return false;
Expand Down Expand Up @@ -70,7 +72,7 @@ static inline bool sd_version_is_sdxl(SDVersion version) {
}

static inline bool sd_version_is_inpaint(SDVersion version) {
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) {
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) {
return true;
}
return false;
Expand All @@ -87,8 +89,12 @@ static inline bool sd_version_is_unet_edit(SDVersion version) {
return version == VERSION_SD1_PIX2PIX || version == VERSION_SDXL_PIX2PIX;
}

static inline bool sd_version_is_control(SDVersion version) {
return version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2;
}

static bool sd_version_is_inpaint_or_unet_edit(SDVersion version) {
return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version);
return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version)|| sd_version_is_control(version);
}

enum PMVersion {
Expand Down
Loading
Loading

AltStyle によって変換されたページ (->オリジナル) /