FFmpeg: libavfilter/dnn/dnn_backend_torch.cpp Source File
Go to the documentation of this file. 1 /*
2 * Copyright (c) 2024
3 *
4 * This file is part of FFmpeg.
5 *
6 * FFmpeg is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU Lesser General Public
8 * License as published by the Free Software Foundation; either
9 * version 2.1 of the License, or (at your option) any later version.
10 *
11 * FFmpeg is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 * Lesser General Public License for more details.
15 *
16 * You should have received a copy of the GNU Lesser General Public
17 * License along with FFmpeg; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21 /**
22 * @file
23 * DNN Torch backend implementation.
24 */
25
26 #include <torch/torch.h>
27 #include <torch/script.h>
28
29 extern "C" {
30 #include "../internal.h"
36 }
37
42
47
56
61
67
68
69 #define OFFSET(x) offsetof(THContext, x)
70 #define FLAGS AV_OPT_FLAG_FILTERING_PARAM
75 };
76
78
80 {
84 if (!lltask) {
87 }
90 lltask->task = task;
95 }
96 return 0;
97 }
98
100 {
101 if (!request)
102 return;
106 }
110 }
111 return;
112 }
113
115 {
118 return;
119 }
126 }
127
129 {
131 if (!model || !*model)
132 return;
133
138 }
140
144 }
146
152 }
158 }
159
161 {
169 return 0;
170 }
171
173 {
175 }
176
178 {
184 int ret, width_idx, height_idx, channel_idx;
185
187 if (!lltask) {
189 goto err;
190 }
194
197 goto err;
198 }
205 input.dims[channel_idx] *
sizeof(
float));
209 infer_request->
output =
new torch::Tensor();
210
217 } else {
219 }
220 }
221 break;
222 default:
224 break;
225 }
227 {1, input.dims[channel_idx], input.dims[height_idx], input.dims[width_idx]},
229 return 0;
230
231 err:
234 }
235
237 {
244 std::vector<torch::jit::IValue>
inputs;
245 torch::NoGradGuard no_grad;
246
247 if (!request) {
250 }
256
257 if (
ctx->options.optimize)
258 torch::jit::setGraphExecutorOptimize(true);
259 else
260 torch::jit::setGraphExecutorOptimize(false);
261
265 }
267
269
270 return 0;
271 }
272
281
286 if (
sizes.size() == 4) {
287 // 4 dimensions: [batch_size, channel, height, width]
288 // this format of data is normally used for video frame SR
293 } else {
295 goto err;
296 }
297
305 } else {
307 }
308 } else {
311 }
312 break;
313 default:
315 goto err;
316 }
319 err:
321
324 av_log(&th_model->
ctx,
AV_LOG_ERROR,
"Unable to push back request_queue when failed to start inference.\n");
325 }
326 }
327
329 {
334
337 return 0;
338 }
339
341 if (lltask ==
NULL) {
344 goto err;
345 }
348
351 goto err;
352 }
355 } else {
358 goto err;
359 }
362 }
363
364 err:
368 }
370 }
371
372 static int get_output_th(
void *model,
const char *input_name,
int input_width,
int input_height,
373 const char *output_name, int *output_width, int *output_height)
374 {
382 .output_names = &output_name,
383 .nb_output = 1,
386 };
389 goto err;
390 }
391
395 goto err;
396 }
397
399 if (!request) {
402 goto err;
403 }
404
408
409 err:
413 }
414
416 {
418 if (!request) {
420 }
423 return request;
424 }
425
427 {
432
434 if (!model) {
436 }
437
439 if (!th_model) {
442 }
443 th_model->
model = model;
444 model->
model = th_model;
447 //parse options
452 }
453
454 c10::Device device = c10::Device(
ctx->options.device_name);
455 if (!device.is_cpu()) {
458 }
459
460 try {
461 th_model->
jit_model =
new torch::jit::Module;
462 (*th_model->
jit_model) = torch::jit::load(model_filename);
463 } catch (const c10::Error& e) {
466 }
467
471 }
472
474 if (!item) {
476 }
482 }
486
489 }
491
495 }
496
500 }
501
507 return model;
508
510 if (item) {
513 }
516 }
517
519 {
525
530 }
531
533 if (!task) {
536 }
537
543 }
544
550 }
551
556 }
557
559 if (!request) {
562 }
563
565 }
566
568 {
571 }
572
574 {
577
579 // no pending task need to flush
580 return 0;
581
583 if (!request) {
586 }
587
589 }
590
597 };
LastLevelTaskItem * lltask
THInferRequest * infer_request
void av_opt_set_defaults(void *s)
Set the values of all AVOption fields to their default values.
Filter the word "frame" indicates either a video frame or a group of audio as stored in an AVFrame structure Format for each input and each output the list of supported formats For video that means pixel format For audio that means channel sample they are references to shared objects When the negotiation mechanism computes the intersection of the formats supported at each end of a all references to both lists are replaced with a reference to the intersection And when a single format is eventually chosen for a link amongst the remaining all references to the list are updated That means that if a filter requires that its input and output have the same format amongst a supported all it has to do is use a reference to the same list of formats query_formats can leave some formats unset and return AVERROR(EAGAIN) to cause the negotiation mechanism toagain later. That can be used by filters with complex requirements to use the format negotiated on one link to set the formats supported on another. Frame references ownership and permissions
void * ff_safe_queue_pop_front(SafeQueue *sq)
Remove and free first element from the queue in SafeQueue.
static void deleter(void *arg)
Common Async Execution Mechanism for the DNN Backends.
filter_frame For filters that do not use the this method is called when a frame is pushed to the filter s input It can be called at any time except in a reentrant way If the input frame is enough to produce output
void * ff_queue_pop_front(Queue *q)
Remove and free first element from the Queue.
int ff_check_exec_params(void *ctx, DNNBackendType backend, DNNFunctionType func_type, DNNExecBaseParams *exec_params)
size_t ff_queue_size(Queue *q)
Return the length of the Queue.
#define DNN_GENERIC_ERROR
void av_frame_free(AVFrame **frame)
Free the frame and any dynamically allocated objects in it, e.g.
const DNNModule ff_dnn_backend_torch
This structure describes decoded (raw) audio or video data.
Double-ended queue with mutex locks ensuring data consistency while multithreading.
static int dnn_execute_model_th(const DNNModel *model, DNNExecBaseParams *exec_params)
int av_opt_set_from_string(void *ctx, const char *opts, const char *const *shorthand, const char *key_val_sep, const char *pairs_sep)
Parse the key-value pairs list in opts.
DNNModel *(* load_model)(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx)
FramePrePostProc frame_pre_proc
void(* callback)(void *args)
Completion Callback for the backend.
static DNNModel * dnn_load_model_th(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx)
static int get_input_th(void *model, DNNData *input, const char *input_name)
AVFilterContext * filter_ctx
Queue * ff_queue_create(void)
Create a Queue instance.
static int dnn_get_width_idx_by_layout(DNNLayout layout)
void av_opt_free(void *obj)
Free all allocated objects in obj.
static FilteringContext * filter_ctx
Linear double-ended data structure.
int ff_queue_push_back(Queue *q, void *v)
Add data to the tail of the queue.
torch::jit::Module * jit_model
#define AV_LOG_ERROR
Something went wrong and cannot losslessly be recovered.
static void destroy_request_item(THRequestItem **arg)
static THInferRequest * th_create_inference_request(void)
void ff_queue_destroy(Queue *q)
Destroy the Queue instance.
int ff_dnn_fill_gettingoutput_task(TaskItem *task, DNNExecBaseParams *exec_params, void *backend_model, int input_height, int input_width, void *ctx)
Allocate input and output frames and fill the Task with execution parameters.
size_t ff_safe_queue_size(SafeQueue *sq)
Return the length of the SafeQueue.
int ff_proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, void *log_ctx)
Describe the class of an AVClass context structure.
DNNAsyncExecModule exec_module
static const int sizes[][2]
SafeQueue * ff_safe_queue_create(void)
Create and initialize a SafeQueue instance.
FramePrePostProc frame_post_proc
int ff_dnn_async_module_cleanup(DNNAsyncExecModule *async_module)
Join the Async Execution thread and set module pointers to NULL.
static void infer_completion_callback(void *args)
static int extract_lltask_from_task(TaskItem *task, Queue *lltask_queue)
these buffered frames must be flushed immediately if a new input produces new the filter must not call request_frame to get more It must just process the frame or queue it The task of requesting more frames is left to the filter s request_frame method or the application If a filter has several inputs
const OptionDef options[]
DNNFunctionType func_type
void avpriv_report_missing_feature(void *avc, const char *msg,...) av_printf_format(2
Log a generic warning message about a missing feature.
void ff_safe_queue_destroy(SafeQueue *sq)
Destroy the SafeQueue instance.
static DNNAsyncStatusType dnn_get_result_th(const DNNModel *model, AVFrame **in, AVFrame **out)
int ff_dnn_fill_task(TaskItem *task, DNNExecBaseParams *exec_params, void *backend_model, int async, int do_ioproc)
Fill the Task for Backend Execution.
and forward the test the status of outputs and forward it to the corresponding return FFERROR_NOT_READY If the filters stores internally one or a few frame for some input
int ff_safe_queue_push_back(SafeQueue *sq, void *v)
Add data to the tail of queue in the SafeQueue after locking mutex.
static int th_start_inference(void *args)
torch::Tensor * input_tensor
int(* start_inference)(void *request)
Synchronous inference function for the backend with corresponding request item as the argument.
void * args
Argument for the execution functions.
void * av_mallocz(size_t size)
Allocate a memory block with alignment suitable for all memory accesses (including vectors if availab...
AVFILTER_DEFINE_CLASS(dnn_th)
static const AVFilterPad outputs[]
static int get_output_th(void *model, const char *input_name, int input_width, int input_height, const char *output_name, int *output_width, int *output_height)
int(* get_input)(void *model, DNNData *input, const char *input_name)
static const AVOption dnn_th_options[]
static int execute_model_th(THRequestItem *request, Queue *lltask_queue)
DNNAsyncStatusType ff_dnn_get_result_common(Queue *task_queue, AVFrame **in, AVFrame **out)
Extract input and output frame from the Task Queue after asynchronous inference.
void * ff_queue_peek_front(Queue *q)
Return a pointer to the data at the head of the queue.
static int dnn_get_height_idx_by_layout(DNNLayout layout)
static int dnn_flush_th(const DNNModel *model)
static int dnn_get_channel_idx_by_layout(DNNLayout layout)
static void dnn_free_model_th(DNNModel **model)
int(* get_output)(void *model, const char *input_name, int input_width, int input_height, const char *output_name, int *output_width, int *output_height)
static int fill_model_input_th(THModel *th_model, THRequestItem *request)
SafeQueue * request_queue
int ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx)
static void th_free_request(THInferRequest *request)
Generated on Thu Sep 26 2024 23:15:33 for FFmpeg by
doxygen
1.8.17