同步操作将从 Gitee 极速下载/Halide 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
#include "AssociativeOpsTable.h"#include "IRPrinter.h"#include <mutex>namespace Halide {namespace Internal {using std::map;using std::string;using std::vector;namespace {enum class RootExpr {Add = 0,Mul = 1,Max = 2,Min = 3,Sub = 4,Select = 5,And = 6,Or = 7,Cast = 8,Unknown = 9, // Not supported IR type};enum class ValType {UInt1 = 0,UInt8 = 1,UInt16 = 2,UInt32 = 3,UInt64 = 4,Int8 = 5,Int16 = 6,Int32 = 7,Int64 = 8,Float16 = 9,Float32 = 10,Float64 = 11,All = 12, // General type (including all previous types)};ValType convert_halide_type_to_val_type(const Type &halide_t) {internal_assert(halide_t.is_scalar() && !halide_t.is_handle());ValType val_t;if (halide_t.is_uint()) {if (halide_t.bits() == 1) { // Boolval_t = ValType::UInt1;} else if (halide_t.bits() == 8) {val_t = ValType::UInt8;} else if (halide_t.bits() == 16) {val_t = ValType::UInt16;} else if (halide_t.bits() == 32) {val_t = ValType::UInt32;} else {internal_assert(halide_t.bits() == 64);val_t = ValType::UInt64;}} else if (halide_t.is_int()) {if (halide_t.bits() == 8) {val_t = ValType::Int8;} else if (halide_t.bits() == 16) {val_t = ValType::Int16;} else if (halide_t.bits() == 32) {val_t = ValType::Int32;} else {internal_assert(halide_t.bits() == 64);val_t = ValType::Int64;}} else {internal_assert(halide_t.is_float());if (halide_t.bits() == 16) {val_t = ValType::Float16;} else if (halide_t.bits() == 32) {val_t = ValType::Float32;} else {internal_assert(halide_t.bits() == 64);val_t = ValType::Float64;}}return val_t;}vector<ValType> convert_halide_types_to_val_types(const vector<Type> &halide_types) {vector<ValType> val_types(halide_types.size());for (size_t i = 0; i < halide_types.size(); ++i) {val_types[i] = convert_halide_type_to_val_type(halide_types[i]);}return val_types;}struct TableKey {vector<ValType> types;RootExpr root;size_t dim;TableKey(ValType t, RootExpr r, size_t d): types({t}), root(r), dim(d) {}TableKey(const vector<ValType> &t, RootExpr r, size_t d): types(t), root(r), dim(d) {}bool operator==(const TableKey &other) const {return (types == other.types) && (root == other.root) && (dim == other.dim);}bool operator<(const TableKey &other) const {if (types < other.types) {return true;} else if (types > other.types) {return false;}if (root < other.root) {return true;} else if (root > other.root) {return false;}return (dim < other.dim);}};map<TableKey, vector<AssociativePattern>> pattern_tables;#define declare_vars(t, index) \Expr x##index = Variable::make((t), "x" + std::to_string(index)); \Expr y##index = Variable::make((t), "y" + std::to_string(index)); \Expr k##index = Variable::make((t), "k" + std::to_string(index)); \Expr zero_##index = make_const((t), 0); \Expr one_##index = make_const((t), 1); \Expr neg_one_##index = make_const((t), -1); \Expr tmax_##index = (t).max(); \Expr tmin_##index = (t).min()#define declare_vars_single(types) \internal_assert((types).size() == 1); \declare_vars((types)[0], 0)#define declare_vars_double(types) \internal_assert((types).size() == 2); \declare_vars((types)[0], 0); \declare_vars((types)[1], 1)void populate_ops_table_single_general_add(const vector<Type> &types, vector<AssociativePattern> &table) {declare_vars_single(types);table.emplace_back(x0 + y0, zero_0, true);}void populate_ops_table_single_general_mul(const vector<Type> &types, vector<AssociativePattern> &table) {declare_vars_single(types);table.emplace_back(x0 * y0, one_0, true);}void populate_ops_table_single_general_max(const vector<Type> &types, vector<AssociativePattern> &table) {declare_vars_single(types);table.emplace_back(max(x0, y0), tmin_0, true);}void populate_ops_table_single_general_min(const vector<Type> &types, vector<AssociativePattern> &table) {declare_vars_single(types);table.emplace_back(min(x0, y0), tmax_0, true);}void populate_ops_table_single_general_sub(const vector<Type> &types, vector<AssociativePattern> &table) {declare_vars_single(types);}void populate_ops_table_single_general_select(const vector<Type> &types, vector<AssociativePattern> &table) {declare_vars_single(types);}void populate_ops_table_double_general_add(const vector<Type> &types, vector<AssociativePattern> &table) {declare_vars_double(types);if (types[0] == types[1]) {table.push_back({{x0 + y0, x0 + y1}, {zero_0, zero_1}, true});}}void populate_ops_table_double_general_mul(const vector<Type> &types, vector<AssociativePattern> &table) {declare_vars_double(types);}void populate_ops_table_double_general_max(const vector<Type> &types, vector<AssociativePattern> &table) {declare_vars_double(types);table.push_back({{max(x0, y0), select(y0 < x0, x1, y1)}, {tmin_0, zero_1}, true});}void populate_ops_table_double_general_min(const vector<Type> &types, vector<AssociativePattern> &table) {declare_vars_double(types);table.push_back({{min(x0, y0), select(x0 < y0, x1, y1)}, {tmax_0, zero_1}, true});}void populate_ops_table_double_general_sub(const vector<Type> &types, vector<AssociativePattern> &table) {declare_vars_double(types);if (types[0] == types[1]) {table.push_back({{x0 * y0 - x1 * y1, x1 * y0 + x0 * y1}, {one_0, zero_1}, true});table.push_back({{x0 * y0 - y1 * x1, x1 * y0 + y1 * x0}, {one_0, zero_1}, true});}}void populate_ops_table_double_general_select(const vector<Type> &types, vector<AssociativePattern> &table) {declare_vars_double(types);}void populate_ops_table_single_uint1_and(const vector<Type> &types, vector<AssociativePattern> &table) {declare_vars_single(types);table.emplace_back(x0 && y0, one_0, true);}void populate_ops_table_single_uint1_or(const vector<Type> &types, vector<AssociativePattern> &table) {declare_vars_single(types);table.emplace_back(x0 || y0, zero_0, true);}void populate_ops_table_single_uint8_cast(const vector<Type> &types, vector<AssociativePattern> &table) {declare_vars_single(types);Expr k0_uint16 = Variable::make(UInt(16), "k0");Expr k0_uint32 = Variable::make(UInt(32), "k0");Expr k0_uint64 = Variable::make(UInt(64), "k0");table.emplace_back(cast<uint8_t>(min(cast<uint16_t>(x0 + y0), k0_uint16)), zero_0, true);table.emplace_back(cast<uint8_t>(min(cast<uint32_t>(x0 + y0), k0_uint32)), zero_0, true);table.emplace_back(cast<uint8_t>(min(cast<uint64_t>(x0 + y0), k0_uint64)), zero_0, true);}void populate_ops_table_single_uint8_select(const vector<Type> &types, vector<AssociativePattern> &table) {declare_vars_single(types);table.emplace_back(select(x0 > tmax_0 - y0, tmax_0, y0), zero_0, true); // Saturating addtable.emplace_back(select(x0 < -y0, y0, tmax_0), zero_0, true); // Saturating add}void populate_ops_table_single_uint16_cast(const vector<Type> &types, vector<AssociativePattern> &table) {declare_vars_single(types);Expr k0_uint32 = Variable::make(UInt(32), "k0");Expr k0_uint64 = Variable::make(UInt(64), "k0");table.emplace_back(cast<uint16_t>(min(cast<uint32_t>(x0 + y0), k0_uint32)), zero_0, true);table.emplace_back(cast<uint16_t>(min(cast<uint64_t>(x0 + y0), k0_uint64)), zero_0, true);}void populate_ops_table_single_uint16_select(const vector<Type> &types, vector<AssociativePattern> &table) {declare_vars_single(types);table.emplace_back(select(x0 > tmax_0 - y0, tmax_0, y0), zero_0, true); // Saturating addtable.emplace_back(select(x0 < -y0, y0, tmax_0), zero_0, true); // Saturating add}void populate_ops_table_single_uint32_cast(const vector<Type> &types, vector<AssociativePattern> &table) {declare_vars_single(types);Expr k0_uint64 = Variable::make(UInt(64), "k0");table.emplace_back(cast<uint32_t>(min(cast<uint64_t>(x0 + y0), k0_uint64)), zero_0, true);}void populate_ops_table_single_uint32_select(const vector<Type> &types, vector<AssociativePattern> &table) {declare_vars_single(types);table.emplace_back(select(x0 > tmax_0 - y0, tmax_0, y0), zero_0, true); // Saturating addtable.emplace_back(select(x0 < -y0, y0, tmax_0), zero_0, true); // Saturating add}const map<TableKey, void (*)(const vector<Type> &types, vector<AssociativePattern> &)> val_type_to_populate_luts_fn = {{TableKey(ValType::All, RootExpr::Add, 1), &populate_ops_table_single_general_add},{TableKey(ValType::All, RootExpr::Mul, 1), &populate_ops_table_single_general_mul},{TableKey(ValType::All, RootExpr::Max, 1), &populate_ops_table_single_general_max},{TableKey(ValType::All, RootExpr::Min, 1), &populate_ops_table_single_general_min},{TableKey(ValType::All, RootExpr::Sub, 1), &populate_ops_table_single_general_sub},{TableKey(ValType::All, RootExpr::Select, 1), &populate_ops_table_single_general_select},{TableKey(ValType::All, RootExpr::Add, 2), &populate_ops_table_double_general_add},{TableKey(ValType::All, RootExpr::Mul, 2), &populate_ops_table_double_general_mul},{TableKey(ValType::All, RootExpr::Max, 2), &populate_ops_table_double_general_max},{TableKey(ValType::All, RootExpr::Min, 2), &populate_ops_table_double_general_min},{TableKey(ValType::All, RootExpr::Sub, 2), &populate_ops_table_double_general_sub},{TableKey(ValType::All, RootExpr::Select, 2), &populate_ops_table_double_general_select},{TableKey(ValType::UInt1, RootExpr::And, 1), &populate_ops_table_single_uint1_and},{TableKey(ValType::UInt1, RootExpr::Or, 1), &populate_ops_table_single_uint1_or},{TableKey(ValType::UInt8, RootExpr::Cast, 1), &populate_ops_table_single_uint8_cast},{TableKey(ValType::UInt8, RootExpr::Select, 1), &populate_ops_table_single_uint8_select},{TableKey(ValType::UInt16, RootExpr::Cast, 1), &populate_ops_table_single_uint16_cast},{TableKey(ValType::UInt16, RootExpr::Select, 1), &populate_ops_table_single_uint16_select},{TableKey(ValType::UInt32, RootExpr::Cast, 1), &populate_ops_table_single_uint32_cast},{TableKey(ValType::UInt32, RootExpr::Select, 1), &populate_ops_table_single_uint32_select},};const vector<AssociativePattern> &get_ops_table_helper(const vector<Type> &types, RootExpr root, size_t dim) {TableKey gen_key(ValType::All, root, dim);TableKey key(convert_halide_types_to_val_types(types), root, dim);const auto &table_it = pattern_tables.find(key);if (table_it == pattern_tables.end()) { // Populate the table if we haven't done so previouslyvector<AssociativePattern> &table = pattern_tables[key];// Populate the general associative op LUTconst auto &gen_iter = val_type_to_populate_luts_fn.find(gen_key);if (gen_iter != val_type_to_populate_luts_fn.end()) {gen_iter->second(types, table);}// Populate the type-specific associative op LUTconst auto &iter = val_type_to_populate_luts_fn.find(key);if (iter != val_type_to_populate_luts_fn.end()) {iter->second(types, table);}return table;}return table_it->second;}std::string print_types(const vector<Type> &types) {std::ostringstream stream;stream << "{";for (size_t i = 0; i < types.size(); ++i) {if (i > 0) {stream << ", ";}stream << types[i];}stream << "}";return stream.str();}} // anonymous namespaceconst vector<AssociativePattern> &get_ops_table(const vector<Expr> &exprs) {internal_assert(!exprs.empty());static const vector<AssociativePattern> empty;if (exprs.size() > 2) {debug(5) << "Returning empty table since tuple size is larger than 2\n";return empty;}vector<Type> types(exprs.size());for (size_t i = 0; i < exprs.size(); ++i) {types[i] = exprs[i].type();}RootExpr root = RootExpr::Unknown;if (exprs[0].as<Halide::Internal::Add>()) {debug(5) << "Returning Add root table for type " << print_types(types) << "\n";root = RootExpr::Add;} else if (exprs[0].as<Halide::Internal::Sub>()) {debug(5) << "Returning Sub root table for type " << print_types(types) << "\n";root = RootExpr::Sub;} else if (exprs[0].as<Halide::Internal::Mul>()) {debug(5) << "Returning Mul root table for type " << print_types(types) << "\n";root = RootExpr::Mul;} else if (exprs[0].as<Halide::Internal::Min>()) {debug(5) << "Returning Min root table for type " << print_types(types) << "\n";root = RootExpr::Min;} else if (exprs[0].as<Halide::Internal::Max>()) {debug(5) << "Returning Max root table for type " << print_types(types) << "\n";root = RootExpr::Max;} else if (exprs[0].as<Halide::Internal::Select>()) {debug(5) << "Returning Select root table for type " << print_types(types) << "\n";root = RootExpr::Select;} else if (exprs[0].as<Halide::Internal::And>()) {debug(5) << "Returning And root table for type " << print_types(types) << "\n";root = RootExpr::And;} else if (exprs[0].as<Halide::Internal::Or>()) {debug(5) << "Returning Or root table for type " << print_types(types) << "\n";root = RootExpr::Or;} else if (exprs[0].as<Halide::Internal::Cast>()) {debug(5) << "Returning Cast root table for type " << print_types(types) << "\n";root = RootExpr::Cast;}if (root != RootExpr::Unknown) {// get_ops_table_helper() lazily initializes the table, so ensure// that multiple threads can't try to do so at the same time.static std::mutex ops_table_lock;std::lock_guard<std::mutex> lock_guard(ops_table_lock);const vector<AssociativePattern> &table = get_ops_table_helper(types, root, exprs.size());debug(7) << "Table size: " << table.size() << "\n";for (const auto &p : table) {debug(7) << p;}return table;}debug(5) << "Returning empty table\n";return empty;}} // namespace Internal} // namespace Halide
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。