This action will force synchronization from Gitee 极速下载/Halide, which will overwrite any changes that you have made since you forked the repository, and can not be recovered!!!
Synchronous operation will process in the background and will refresh the page when finishing processing. Please be patient.
#include "SimplifyCorrelatedDifferences.h"#include "CSE.h"#include "CompilerLogger.h"#include "ExprUsesVar.h"#include "IRMatch.h"#include "IRMutator.h"#include "IROperator.h"#include "Monotonic.h"#include "Scope.h"#include "Simplify.h"#include "Solve.h"#include "Substitute.h"namespace Halide {namespace Internal {namespace {using std::string;using std::vector;class PartiallyCancelDifferences : public IRMutator {using IRMutator::visit;// Symbols used by rewrite rulesIRMatcher::Wild<0> x;IRMatcher::Wild<1> y;IRMatcher::Wild<2> z;IRMatcher::WildConst<0> c0;IRMatcher::WildConst<1> c1;Expr visit(const Sub *op) override {Expr a = mutate(op->a), b = mutate(op->b);// Partially cancel terms in correlated differences of// various kinds to get tighter bounds. We assume any// correlated term has already been pulled leftmost by// solve_expression.if (op->type == Int(32)) {auto rewrite = IRMatcher::rewriter(IRMatcher::sub(a, b), op->type);if (// Differences of quasi-affine functionsrewrite((x + y) / c0 - (x + z) / c0, ((x % c0) + y) / c0 - ((x % c0) + z) / c0) ||rewrite(x / c0 - (x + z) / c0, 0 - ((x % c0) + z) / c0) ||rewrite((x + y) / c0 - x / c0, ((x % c0) + y) / c0) ||// truncated cones have a constant upper or lower// bound that isn't apparent when expressed in the// form in the LHS belowrewrite(min(x, c0) - max(x, c1), min(min(c0 - x, x - c1), fold(min(0, c0 - c1)))) ||rewrite(max(x, c0) - min(x, c1), max(max(c0 - x, x - c1), fold(max(0, c0 - c1)))) ||rewrite(min(x, y) - max(x, z), min(min(x, y) - max(x, z), 0)) ||rewrite(max(x, y) - min(x, z), max(max(x, y) - min(x, z), 0)) ||rewrite(min(x + c0, y) - select(z, min(x, y) + c1, x), select(z, (max(min(y - x, c0), 0) - c1), min(y - x, c0)), c0 > 0) ||rewrite(min(y, x + c0) - select(z, min(y, x) + c1, x), select(z, (max(min(y - x, c0), 0) - c1), min(y - x, c0)), c0 > 0) ||false) {return rewrite.result;}}return a - b;}};class SimplifyCorrelatedDifferences : public IRMutator {using IRMutator::visit;string loop_var;Scope<ConstantInterval> monotonic;struct OuterLet {string name;Expr value;bool may_substitute;};vector<OuterLet> lets;template<typename LetStmtOrLet, typename StmtOrExpr>StmtOrExpr visit_let(const LetStmtOrLet *op) {// Visit an entire chain of lets in a single method to conserve stack space.struct Frame {const LetStmtOrLet *op;ScopedBinding<ConstantInterval> binding;Expr new_value;Frame(const LetStmtOrLet *op, const string &loop_var, Scope<ConstantInterval> &scope): op(op),binding(scope, op->name, derivative_bounds(op->value, loop_var, scope)) {}Frame(const LetStmtOrLet *op): op(op) {}};std::vector<Frame> frames;StmtOrExpr result;// Note that we must add *everything* that depends on the loop// var to the monotonic scope and the list of lets, even// things which we can never substitute in (e.g. impure// things). This is for two reasons. First this pass could be// used at a time when we still have nested lets under the// same name. If we decide not to add an inner let, but do add// the outer one, then later references to it will be// incorrect. Second, if we don't add something that happens// to be non-monotonic, then derivative_bounds finds a variable// that references it in a later let, it will think it's a// constant, not an unknown.do {result = op->body;if (loop_var.empty()) {frames.emplace_back(op);continue;}bool pure = is_pure(op->value);if (!pure || expr_uses_vars(op->value, monotonic) || monotonic.contains(op->name)) {frames.emplace_back(op, loop_var, monotonic);Expr new_value = mutate(op->value);bool may_substitute_in = new_value.type() == Int(32) && pure;lets.emplace_back(OuterLet{op->name, new_value, may_substitute_in});frames.back().new_value = std::move(new_value);} else {// Pure and constant w.r.t the loop var. Doesn't// shadow any outer thing already in the monotonic// scope.frames.emplace_back(op);}} while ((op = result.template as<LetStmtOrLet>()));result = mutate(result);for (auto it = frames.rbegin(); it != frames.rend(); it++) {if (it->new_value.defined()) {result = LetStmtOrLet::make(it->op->name, it->new_value, result);} else {result = LetStmtOrLet::make(it->op->name, it->op->value, result);}if (it->binding.bound()) {lets.pop_back();}}return result;}Expr visit(const Let *op) override {return visit_let<Let, Expr>(op);}Stmt visit(const LetStmt *op) override {return visit_let<LetStmt, Stmt>(op);}Stmt visit(const For *op) override {Stmt s = op;// This is unfortunately quadratic in maximum loop nesting depthif (loop_var.empty()) {decltype(monotonic) tmp_monotonic;decltype(lets) tmp_lets;tmp_monotonic.swap(monotonic);tmp_lets.swap(lets);loop_var = op->name;{ScopedBinding<ConstantInterval> bind(monotonic, loop_var, ConstantInterval::single_point(1));s = IRMutator::visit(op);}loop_var.clear();tmp_monotonic.swap(monotonic);tmp_lets.swap(lets);}s = IRMutator::visit(s.as<For>());return s;}Expr cancel_correlated_subexpression(Expr e, const Expr &a, const Expr &b, bool correlated) {auto ma = is_monotonic(a, loop_var, monotonic);auto mb = is_monotonic(b, loop_var, monotonic);if ((ma == Monotonic::Increasing && mb == Monotonic::Increasing && correlated) ||(ma == Monotonic::Decreasing && mb == Monotonic::Decreasing && correlated) ||(ma == Monotonic::Increasing && mb == Monotonic::Decreasing && !correlated) ||(ma == Monotonic::Decreasing && mb == Monotonic::Increasing && !correlated)) {for (auto it = lets.rbegin(); it != lets.rend(); it++) {if (expr_uses_var(e, it->name)) {if (!it->may_substitute) {// We have to stop here. Can't continue// because there might be an outer let with// the same name that we *can* substitute in,// and then inner uses will get the wrong// value.break;}}e = Let::make(it->name, it->value, e);}e = common_subexpression_elimination(e);e = solve_expression(e, loop_var).result;e = PartiallyCancelDifferences().mutate(e);e = simplify(e);const bool check_non_monotonic = debug::debug_level() > 0 || get_compiler_logger() != nullptr;if (check_non_monotonic &&is_monotonic(e, loop_var) == Monotonic::Unknown) {// Might be a missed simplification opportunity. Log to help improve the simplifier.if (get_compiler_logger()) {get_compiler_logger()->record_non_monotonic_loop_var(loop_var, e);}debug(1) << "Warning: expression is non-monotonic in loop variable "<< loop_var << ": " << e << "\n";}}return e;}template<typename T>Expr visit_binop(const T *op, bool correlated) {Expr e = IRMutator::visit(op);op = e.as<T>();if (op == nullptr ||op->a.type() != Int(32) ||loop_var.empty()) {return e;} else {// Bury the logic that doesn't depend on the template// parameter in a separate function to save code size and// reduce stack frame size in the recursive path.return cancel_correlated_subexpression(e, op->a, op->b, correlated);}}// Binary ops where it pays to cancel a correlated term on both// sides. E.g. consider the x in://// (x*3 + y)*2 - max(x*6, 0)))//// Both sides increase monotonically with x so interval arithmetic// will overestimate the bounds. If we subtract x*6 from both// sides we get://// y*2 - max(0, x*-6)//// Now only one side depends on x and interval arithmetic becomes// exact.Expr visit(const Sub *op) override {return visit_binop(op, true);}Expr visit(const LT *op) override {return visit_binop(op, true);}Expr visit(const LE *op) override {return visit_binop(op, true);}Expr visit(const GT *op) override {return visit_binop(op, true);}Expr visit(const GE *op) override {return visit_binop(op, true);}Expr visit(const EQ *op) override {return visit_binop(op, true);}Expr visit(const NE *op) override {return visit_binop(op, true);}// For add you actually want to cancel any anti-correlated term// (e.g. x in (x*3 + y)*2 + max(x*-6, 0))Expr visit(const Add *op) override {return visit_binop(op, false);}};} // namespaceStmt simplify_correlated_differences(const Stmt &stmt) {return SimplifyCorrelatedDifferences().mutate(stmt);}Expr bound_correlated_differences(const Expr &expr) {return PartiallyCancelDifferences().mutate(expr);}} // namespace Internal} // namespace Halide
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。