Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit c587a43

Browse files
stduhpfleejet
andauthored
feat: support incrementing ref image index (omni-kontext) (#755)
* kontext: support ref images indices * lora: support x_embedder * update help message * Support for negative indices * support for OmniControl (offsets at index 0) * c++11 compat * add --increase-ref-index option * simplify the logic and fix some issues * update README.md * remove unused variable --------- Co-authored-by: leejet <leejet714@gmail.com>
1 parent f8fe4e7 commit c587a43

File tree

8 files changed

+48
-12
lines changed

8 files changed

+48
-12
lines changed

‎README.md‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ arguments:
319319
-i, --end-img [IMAGE] path to the end image, required by flf2v
320320
--control-image [IMAGE] path to image condition, control net
321321
-r, --ref-image [PATH] reference image for Flux Kontext models (can be used multiple times)
322+
--increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1).
322323
-o, --output OUTPUT path to write result image to (default: ./output.png)
323324
-p, --prompt [PROMPT] the prompt to render
324325
-n, --negative-prompt PROMPT the negative prompt (default: "")

‎diffusion_model.hpp‎

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ struct DiffusionModel {
1616
struct ggml_tensor* y,
1717
struct ggml_tensor* guidance,
1818
std::vector<ggml_tensor*> ref_latents = {},
19+
bool increase_ref_index = false,
1920
int num_video_frames = -1,
2021
std::vector<struct ggml_tensor*> controls = {},
2122
float control_strength = 0.f,
@@ -77,6 +78,7 @@ struct UNetModel : public DiffusionModel {
7778
struct ggml_tensor* y,
7879
struct ggml_tensor* guidance,
7980
std::vector<ggml_tensor*> ref_latents = {},
81+
bool increase_ref_index = false,
8082
int num_video_frames = -1,
8183
std::vector<struct ggml_tensor*> controls = {},
8284
float control_strength = 0.f,
@@ -133,6 +135,7 @@ struct MMDiTModel : public DiffusionModel {
133135
struct ggml_tensor* y,
134136
struct ggml_tensor* guidance,
135137
std::vector<ggml_tensor*> ref_latents = {},
138+
bool increase_ref_index = false,
136139
int num_video_frames = -1,
137140
std::vector<struct ggml_tensor*> controls = {},
138141
float control_strength = 0.f,
@@ -191,13 +194,14 @@ struct FluxModel : public DiffusionModel {
191194
struct ggml_tensor* y,
192195
struct ggml_tensor* guidance,
193196
std::vector<ggml_tensor*> ref_latents = {},
197+
bool increase_ref_index = false,
194198
int num_video_frames = -1,
195199
std::vector<struct ggml_tensor*> controls = {},
196200
float control_strength = 0.f,
197201
struct ggml_tensor** output = NULL,
198202
struct ggml_context* output_ctx = NULL,
199203
std::vector<int> skip_layers = std::vector<int>()) {
200-
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, ref_latents, output, output_ctx, skip_layers);
204+
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, ref_latents, increase_ref_index, output, output_ctx, skip_layers);
201205
}
202206
};
203207

@@ -250,6 +254,7 @@ struct WanModel : public DiffusionModel {
250254
struct ggml_tensor* y,
251255
struct ggml_tensor* guidance,
252256
std::vector<ggml_tensor*> ref_latents = {},
257+
bool increase_ref_index = false,
253258
int num_video_frames = -1,
254259
std::vector<struct ggml_tensor*> controls = {},
255260
float control_strength = 0.f,

‎examples/cli/main.cpp‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ struct SDParams {
7474
std::string mask_image_path;
7575
std::string control_image_path;
7676
std::vector<std::string> ref_image_paths;
77+
bool increase_ref_index = false;
7778

7879
std::string prompt;
7980
std::string negative_prompt;
@@ -156,6 +157,7 @@ void print_params(SDParams params) {
156157
for (auto& path : params.ref_image_paths) {
157158
printf(" %s\n", path.c_str());
158159
};
160+
printf(" increase_ref_index: %s\n", params.increase_ref_index ? "true" : "false");
159161
printf(" offload_params_to_cpu: %s\n", params.offload_params_to_cpu ? "true" : "false");
160162
printf(" clip_on_cpu: %s\n", params.clip_on_cpu ? "true" : "false");
161163
printf(" control_net_cpu: %s\n", params.control_net_cpu ? "true" : "false");
@@ -222,6 +224,7 @@ void print_usage(int argc, const char* argv[]) {
222224
printf(" -i, --end-img [IMAGE] path to the end image, required by flf2v\n");
223225
printf(" --control-image [IMAGE] path to image condition, control net\n");
224226
printf(" -r, --ref-image [PATH] reference image for Flux Kontext models (can be used multiple times) \n");
227+
printf(" --increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1).\n");
225228
printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n");
226229
printf(" -p, --prompt [PROMPT] the prompt to render\n");
227230
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
@@ -536,6 +539,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
536539
{"", "--color", "", true, &params.color},
537540
{"", "--chroma-disable-dit-mask", "", false, &params.chroma_use_dit_mask},
538541
{"", "--chroma-enable-t5-mask", "", true, &params.chroma_use_t5_mask},
542+
{"", "--increase-ref-index", "", true, &params.increase_ref_index},
539543
};
540544

541545
auto on_mode_arg = [&](int argc, const char** argv, int index) {
@@ -1207,6 +1211,7 @@ int main(int argc, const char* argv[]) {
12071211
init_image,
12081212
ref_images.data(),
12091213
(int)ref_images.size(),
1214+
params.increase_ref_index,
12101215
mask_image,
12111216
params.width,
12121217
params.height,

‎flux.hpp‎

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,7 @@ namespace Flux {
960960
struct ggml_tensor* y,
961961
struct ggml_tensor* guidance,
962962
std::vector<ggml_tensor*> ref_latents = {},
963+
bool increase_ref_index = false,
963964
std::vector<int> skip_layers = {}) {
964965
GGML_ASSERT(x->ne[3] == 1);
965966
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);
@@ -999,6 +1000,7 @@ namespace Flux {
9991000
x->ne[3],
10001001
context->ne[1],
10011002
ref_latents,
1003+
increase_ref_index,
10021004
flux_params.theta,
10031005
flux_params.axes_dim);
10041006
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
@@ -1035,6 +1037,7 @@ namespace Flux {
10351037
struct ggml_tensor* y,
10361038
struct ggml_tensor* guidance,
10371039
std::vector<ggml_tensor*> ref_latents = {},
1040+
bool increase_ref_index = false,
10381041
struct ggml_tensor** output = NULL,
10391042
struct ggml_context* output_ctx = NULL,
10401043
std::vector<int> skip_layers = std::vector<int>()) {
@@ -1044,7 +1047,7 @@ namespace Flux {
10441047
// y: [N, adm_in_channels] or [1, adm_in_channels]
10451048
// guidance: [N, ]
10461049
auto get_graph = [&]() -> struct ggml_cgraph* {
1047-
return build_graph(x, timesteps, context, c_concat, y, guidance, ref_latents, skip_layers);
1050+
return build_graph(x, timesteps, context, c_concat, y, guidance, ref_latents, increase_ref_index, skip_layers);
10481051
};
10491052

10501053
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@@ -1084,7 +1087,7 @@ namespace Flux {
10841087
struct ggml_tensor* out = NULL;
10851088

10861089
int t0 = ggml_time_ms();
1087-
compute(8, x, timesteps, context, NULL, y, guidance, {}, &out, work_ctx);
1090+
compute(8, x, timesteps, context, NULL, y, guidance, {}, false, &out, work_ctx);
10881091
int t1 = ggml_time_ms();
10891092

10901093
print_ggml_tensor(out);

‎lora.hpp‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ struct LoraModel : public GGMLRunner {
5858
{"x_block.attn.proj", "attn.to_out.0"},
5959
{"x_block.attn2.proj", "attn2.to_out.0"},
6060
// flux
61+
{"img_in", "x_embedder"},
6162
// singlestream
6263
{"linear2", "proj_out"},
6364
{"modulation.lin", "norm.linear"},

‎rope.hpp‎

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,25 +156,33 @@ struct Rope {
156156
int patch_size,
157157
int bs,
158158
int context_len,
159-
std::vector<ggml_tensor*> ref_latents) {
159+
std::vector<ggml_tensor*> ref_latents,
160+
bool increase_ref_index) {
160161
auto txt_ids = gen_txt_ids(bs, context_len);
161162
auto img_ids = gen_img_ids(h, w, patch_size, bs);
162163

163164
auto ids = concat_ids(txt_ids, img_ids, bs);
164165
uint64_t curr_h_offset = 0;
165166
uint64_t curr_w_offset = 0;
167+
int index = 1;
166168
for (ggml_tensor* ref : ref_latents) {
167169
uint64_t h_offset = 0;
168170
uint64_t w_offset = 0;
169-
if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) {
170-
w_offset = curr_w_offset;
171-
} else {
172-
h_offset = curr_h_offset;
171+
if (!increase_ref_index) {
172+
if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) {
173+
w_offset = curr_w_offset;
174+
} else {
175+
h_offset = curr_h_offset;
176+
}
173177
}
174178

175-
auto ref_ids = gen_img_ids(ref->ne[1], ref->ne[0], patch_size, bs, 1, h_offset, w_offset);
179+
auto ref_ids = gen_img_ids(ref->ne[1], ref->ne[0], patch_size, bs, index, h_offset, w_offset);
176180
ids = concat_ids(ids, ref_ids, bs);
177181

182+
if (increase_ref_index) {
183+
index++;
184+
}
185+
178186
curr_h_offset = std::max(curr_h_offset, ref->ne[1] + h_offset);
179187
curr_w_offset = std::max(curr_w_offset, ref->ne[0] + w_offset);
180188
}
@@ -188,9 +196,10 @@ struct Rope {
188196
int bs,
189197
int context_len,
190198
std::vector<ggml_tensor*> ref_latents,
199+
bool increase_ref_index,
191200
int theta,
192201
const std::vector<int>& axes_dim) {
193-
std::vector<std::vector<float>> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents);
202+
std::vector<std::vector<float>> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
194203
return embed_nd(ids, bs, theta, axes_dim);
195204
}
196205

‎stable-diffusion.cpp‎

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,7 @@ class StableDiffusionGGML {
775775

776776
int64_t t0 = ggml_time_ms();
777777
struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t);
778-
diffusion_model->compute(n_threads, x_t, timesteps, c, concat, NULL, NULL, {}, -1, {}, 0.f, &out);
778+
diffusion_model->compute(n_threads, x_t, timesteps, c, concat, NULL, NULL, {}, false, -1, {}, 0.f, &out);
779779
diffusion_model->free_compute_buffer();
780780

781781
double result = 0.f;
@@ -1032,6 +1032,7 @@ class StableDiffusionGGML {
10321032
int start_merge_step,
10331033
SDCondition id_cond,
10341034
std::vector<ggml_tensor*> ref_latents = {},
1035+
bool increase_ref_index = false,
10351036
ggml_tensor* denoise_mask = nullptr) {
10361037
std::vector<int> skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count);
10371038

@@ -1126,6 +1127,7 @@ class StableDiffusionGGML {
11261127
cond.c_vector,
11271128
guidance_tensor,
11281129
ref_latents,
1130+
increase_ref_index,
11291131
-1,
11301132
controls,
11311133
control_strength,
@@ -1139,6 +1141,7 @@ class StableDiffusionGGML {
11391141
id_cond.c_vector,
11401142
guidance_tensor,
11411143
ref_latents,
1144+
increase_ref_index,
11421145
-1,
11431146
controls,
11441147
control_strength,
@@ -1160,6 +1163,7 @@ class StableDiffusionGGML {
11601163
uncond.c_vector,
11611164
guidance_tensor,
11621165
ref_latents,
1166+
increase_ref_index,
11631167
-1,
11641168
controls,
11651169
control_strength,
@@ -1177,6 +1181,7 @@ class StableDiffusionGGML {
11771181
img_cond.c_vector,
11781182
guidance_tensor,
11791183
ref_latents,
1184+
increase_ref_index,
11801185
-1,
11811186
controls,
11821187
control_strength,
@@ -1198,6 +1203,7 @@ class StableDiffusionGGML {
11981203
cond.c_vector,
11991204
guidance_tensor,
12001205
ref_latents,
1206+
increase_ref_index,
12011207
-1,
12021208
controls,
12031209
control_strength,
@@ -1710,6 +1716,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
17101716
"\n"
17111717
"batch_count: %d\n"
17121718
"ref_images_count: %d\n"
1719+
"increase_ref_index: %s\n"
17131720
"control_strength: %.2f\n"
17141721
"style_strength: %.2f\n"
17151722
"normalize_input: %s\n"
@@ -1724,6 +1731,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
17241731
sd_img_gen_params->seed,
17251732
sd_img_gen_params->batch_count,
17261733
sd_img_gen_params->ref_images_count,
1734+
BOOL_STR(sd_img_gen_params->increase_ref_index),
17271735
sd_img_gen_params->control_strength,
17281736
sd_img_gen_params->style_strength,
17291737
BOOL_STR(sd_img_gen_params->normalize_input),
@@ -1797,6 +1805,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
17971805
bool normalize_input,
17981806
std::string input_id_images_path,
17991807
std::vector<ggml_tensor*> ref_latents,
1808+
bool increase_ref_index,
18001809
ggml_tensor* concat_latent = NULL,
18011810
ggml_tensor* denoise_mask = NULL) {
18021811
if (seed < 0) {
@@ -2054,6 +2063,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
20542063
start_merge_step,
20552064
id_cond,
20562065
ref_latents,
2066+
increase_ref_index,
20572067
denoise_mask);
20582068
// print_ggml_tensor(x_0);
20592069
int64_t sampling_end = ggml_time_ms();
@@ -2304,7 +2314,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
23042314
LOG_INFO("EDIT mode");
23052315
}
23062316

2307-
std::vector<structggml_tensor*> ref_latents;
2317+
std::vector<ggml_tensor*> ref_latents;
23082318
for (int i = 0; i < sd_img_gen_params->ref_images_count; i++) {
23092319
ggml_tensor* img = ggml_new_tensor_4d(work_ctx,
23102320
GGML_TYPE_F32,
@@ -2359,6 +2369,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
23592369
sd_img_gen_params->normalize_input,
23602370
sd_img_gen_params->input_id_images_path,
23612371
ref_latents,
2372+
sd_img_gen_params->increase_ref_index,
23622373
concat_latent,
23632374
denoise_mask);
23642375

‎stable-diffusion.h‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ typedef struct {
182182
sd_image_t init_image;
183183
sd_image_t* ref_images;
184184
int ref_images_count;
185+
bool increase_ref_index;
185186
sd_image_t mask_image;
186187
int width;
187188
int height;

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /