From d6e3e4878ee72347e53254b528d4583624c6fdac Mon Sep 17 00:00:00 2001 From: Bill Date: Mon, 28 Nov 2022 04:32:08 +0800 Subject: [PATCH] Optimized hashtable performance; Stored procedures --- Makefile | 2 +- README.md | 4 + prompt.py | 17 +- reconstruct/ast.py | 19 +- server/aggregations.h | 8 +- server/hasher.h | 39 +- server/libaquery.h | 7 +- server/server.cpp | 179 ++++- server/table.h | 17 +- server/unordered_dense.h | 1516 ++++++++++++++++++++++++++++++++++++++ server/vector_type.hpp | 32 +- 11 files changed, 1778 insertions(+), 62 deletions(-) create mode 100644 server/unordered_dense.h diff --git a/Makefile b/Makefile index 21b55bd..c438529 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ OS_SUPPORT = MonetDB_LIB = MonetDB_INC = Defines = -CXXFLAGS = --std=c++1z +CXXFLAGS = --std=c++2a ifeq ($(AQ_DEBUG), 1) OPTFLAGS = -g3 -fsanitize=address -fsanitize=leak LINKFLAGS = diff --git a/README.md b/README.md index 14abd61..7d72430 100644 --- a/README.md +++ b/README.md @@ -343,3 +343,7 @@ SELECT * FROM my_table WHERE c1 > 10 - [MonetDB](https://www.monetdb.org)
License (Mozilla Public License): https://github.com/MonetDB/MonetDB/blob/master/license.txt + +- [ankerl::unordered_dense](https://github.com/martinus/unordered_dense)
+ Author: Martin Ankerl
+ License (MIT): http://opensource.org/licenses/MIT
diff --git a/prompt.py b/prompt.py index 3afb22c..4e72ceb 100644 --- a/prompt.py +++ b/prompt.py @@ -613,16 +613,21 @@ def prompt(running = lambda:True, next = lambda:input('> '), state : Optional[Pr elif q.startswith('procedure'): qs = re.split(r'[ \t\r\n]', q) procedure_help = '''Usage: procedure [record|stop|run|remove|save|load]''' - send_to_server = lambda str: state.send(1, ctypes.c_char_p(bytes(str, 'utf-8'))) + def send_to_server(payload : str): + state.payload = (ctypes.c_char_p*1)(ctypes.c_char_p(bytes(payload, 'utf-8'))) + state.cfg.has_dll = 0 + state.send(1, state.payload) + state.set_ready() if len(qs) > 2: if qs[2].lower() =='record': - if state.current_procedure != qs[1]: + if state.current_procedure is not None and state.current_procedure != qs[1]: print(f'Cannot record 2 procedures at the same time. Stop recording {state.current_procedure} first.') - elif not state.current_procedure: + elif state.current_procedure is None: state.current_procedure = qs[1] - send_to_server(f'R\0{qs[1]}', 'utf-8') + send_to_server(f'R\0{qs[1]}') elif qs[2].lower() == 'stop': send_to_server(f'RT\0{qs[1]}') + state.current_procedure = None else: if state.current_procedure: print(f'Procedure manipulation commands are disallowed during procedure recording.') @@ -635,6 +640,10 @@ def prompt(running = lambda:True, next = lambda:input('> '), state : Optional[Pr send_to_server(f'RS\0{qs[1]}') elif qs[2].lower() == 'load': send_to_server(f'RL\0{qs[1]}') + + elif len(qs) > 1: + if qs[1].lower() == 'display': + send_to_server(f'Rd\0') else: print(procedure_help) continue diff --git a/reconstruct/ast.py b/reconstruct/ast.py index 07efd77..870df5b 100644 --- a/reconstruct/ast.py +++ b/reconstruct/ast.py @@ -383,12 +383,13 @@ class projection(ast_node): if self.group_node is not None and self.group_node.use_sp_gb: gb_vartable : Dict[str, Union[str, int]] = deepcopy(self.pyname2cname) gb_cexprs : List[str] = [] - + gb_colnames : List[str] = [] for key, val in proj_map.items(): col_name = 'col_' + base62uuid(6) self.context.emitc(f'decltype(auto) {col_name} = {self.out_table.contextname_cpp}->get_col<{key}>();') gb_cexprs.append((col_name, val[2])) - self.group_node.finalize(gb_cexprs, gb_vartable) + gb_colnames.append(col_name) + self.group_node.finalize(gb_cexprs, gb_vartable, gb_colnames) else: for i, (key, val) in enumerate(proj_map.items()): if type(val[1]) is int: @@ -536,7 +537,7 @@ class groupby_c(ast_node): def produce(self, node : List[Tuple[expr, Set[ColRef]]]): self.context.headers.add('"./server/hasher.h"') - self.context.headers.add('unordered_map') + # self.context.headers.add('unordered_map') self.group = 'g' + base62uuid(7) self.group_type = 'record_type' + base62uuid(7) self.datasource = self.proj.datasource @@ -565,8 +566,9 @@ class groupby_c(ast_node): [f'{c}[{scanner_itname}]' for c in g_contents_list] ) self.context.emitc(f'typedef record<{",".join(g_contents_decltype)}> {self.group_type};') - self.context.emitc(f'unordered_map<{self.group_type}, vector_type, ' + self.context.emitc(f'ankerl::unordered_dense::map<{self.group_type}, vector_type, ' f'transTypes<{self.group_type}, hasher>> {self.group};') + self.context.emitc(f'{self.group}.reserve({first_col}.size);') self.n_grps = len(self.glist) self.scanner = scan(self, first_col + '.size', it_name=scanner_itname) self.scanner.add(f'{self.group}[forward_as_tuple({g_contents})].emplace_back({self.scanner.it_var});') @@ -581,7 +583,10 @@ class groupby_c(ast_node): # gscanner.add(f'{self.datasource.cxt_name}->order_by<{assumption.result()}>(&{val_var});') # gscanner.finalize() - def finalize(self, cexprs : List[Tuple[str, expr]], var_table : Dict[str, Union[str, int]]): + def finalize(self, cexprs : List[Tuple[str, expr]], var_table : Dict[str, Union[str, int]], col_names : List[str]): + for c in col_names: + self.context.emitc(f'{c}.reserve({self.group}.size());') + gscanner = scan(self, self.group, loop_style = 'for_each') key_var = 'key_'+base62uuid(7) val_var = 'val_'+base62uuid(7) @@ -713,10 +718,10 @@ class groupby(ast_node): # self.parent.var_table. self.parent.col_ext.update(l[1]) - def finalize(self, cexprs : List[Tuple[str, expr]], var_table : Dict[str, Union[str, int]]): + def finalize(self, cexprs : List[Tuple[str, expr]], var_table : Dict[str, Union[str, int]], col_names : List[str]): if self.use_sp_gb: self.dedicated_gb = groupby_c(self.parent, self.dedicated_glist) - self.dedicated_gb.finalize(cexprs, var_table) + self.dedicated_gb.finalize(cexprs, var_table, col_names) class join(ast_node): diff --git a/server/aggregations.h b/server/aggregations.h index cb4bcbe..0f1d8f8 100644 --- a/server/aggregations.h +++ b/server/aggregations.h @@ -31,7 +31,7 @@ double avg(const VT& v) { template class VT> VT sqrt(const VT& v) { - VT ret{ v.size }; + VT ret(v.size); for (uint32_t i = 0; i < v.size; ++i) { ret[i] = sqrt(v[i]); } @@ -52,7 +52,7 @@ VT truncate(const VT& v, const uint32_t precision) { return v.subvec_memcpy(); auto multiplier = pow(10, precision); auto max_truncate = std::numeric_limits::max()/multiplier; - VT ret{ v.size }; + VT ret(v.size); for (uint32_t i = 0; i < v.size; ++i) { // round or trunc?? ret[i] = v[i] < max_truncate ? round(v[i] * multiplier)/multiplier : v[i]; } @@ -102,7 +102,7 @@ decayed_t maxs(const VT& arr) { template class VT> decayed_t minw(uint32_t w, const VT& arr) { const uint32_t& len = arr.size; - decayed_t ret{ len }; + decayed_t ret(len); std::deque> cache; for (int i = 0; i < len; ++i) { if (!cache.empty() && cache.front().second == i - w) cache.pop_front(); @@ -194,7 +194,7 @@ decayed_t>> avgw(uint32_t w, const VT uint32_t i = 0; types::GetLongType s{}; w = w > len ? len : w; - if (len) s = ret[i++] = arr[0]; + if (len) s = ret[i++] = arr[0]; for (; i < w; ++i) ret[i] = (s += arr[i]) / (FPType)(i + 1); for (; i < len; ++i) diff --git a/server/hasher.h b/server/hasher.h index 70a97e8..30330dd 100644 --- a/server/hasher.h +++ b/server/hasher.h @@ -3,7 +3,10 @@ #include #include #include +#include #include "types.h" +// #include "robin_hood.h" +#include "unordered_dense.h" // only works for 64 bit systems constexpr size_t _FNV_offset_basis = 14695981039346656037ULL; constexpr size_t _FNV_prime = 1099511628211ULL; @@ -14,7 +17,7 @@ inline size_t append_bytes(const unsigned char* _First) noexcept { _Val ^= static_cast(*_First); _Val *= _FNV_prime; } - + return _Val; } @@ -65,37 +68,44 @@ struct hasher { #else #define _current_type current_type #endif - return std::hash<_current_type>()(std::get(record)) ^ hashi(record); + return ankerl::unordered_dense::hash<_current_type>()(std::get(record)) ^ hashi(record); } size_t operator()(const std::tuple& record) const { return hashi(record); } }; +template +struct hasher{ + size_t operator()(const std::tuple& record) const { + return ankerl::unordered_dense::hash()(std::get<0>(record)); + } +}; - -namespace std{ - +namespace ankerl::unordered_dense{ template<> struct hash { size_t operator()(const astring_view& _Keyval) const noexcept { - return append_bytes(_Keyval.str); + + return ankerl::unordered_dense::hash()(_Keyval.rstr); + //return append_bytes(_Keyval.str); + } }; template<> struct hash { size_t operator() (const types::date_t& _Keyval) const noexcept { - return std::hash()(*(unsigned int*)(&_Keyval)); + return ankerl::unordered_dense::hash()(*(unsigned int*)(&_Keyval)); } }; template<> struct hash { size_t operator() (const types::time_t& _Keyval) const noexcept { - return std::hash()(_Keyval.ms) ^ - std::hash()(_Keyval.seconds) ^ - std::hash()(_Keyval.minutes) ^ - std::hash()(_Keyval.hours) + return ankerl::unordered_dense::hash()(_Keyval.ms) ^ + ankerl::unordered_dense::hash()(_Keyval.seconds) ^ + ankerl::unordered_dense::hash()(_Keyval.minutes) ^ + ankerl::unordered_dense::hash()(_Keyval.hours) ; } }; @@ -103,8 +113,8 @@ namespace std{ template<> struct hash{ size_t operator() (const types::timestamp_t& _Keyval) const noexcept { - return std::hash()(_Keyval.date) ^ - std::hash()(_Keyval.time); + return ankerl::unordered_dense::hash()(_Keyval.date) ^ + ankerl::unordered_dense::hash()(_Keyval.time); } }; #ifdef __SIZEOF_INT128__ @@ -112,12 +122,11 @@ namespace std{ template<> struct hash{ size_t operator() (const int128_struct& _Keyval) const noexcept { - return std::hash()(_Keyval.__struct.low) ^ std::hash()(_Keyval.__struct.high); + return ankerl::unordered_dense::hash()(_Keyval.__struct.low) ^ ankerl::unordered_dense::hash()(_Keyval.__struct.high); } }; #endif template struct hash> : public hasher{ }; - } diff --git a/server/libaquery.h b/server/libaquery.h index 6bfc98c..981cd57 100644 --- a/server/libaquery.h +++ b/server/libaquery.h @@ -109,10 +109,9 @@ struct Context{ void init_session(); void end_session(); void* get_module_function(const char*); - std::unordered_map tables; - std::unordered_map cols; - std::unordered_map loaded_modules; - std::unordered_map stored_proc; + std::unordered_map tables; + std::unordered_map cols; + std::unordered_map stored_proc; }; diff --git a/server/server.cpp b/server/server.cpp index dd55597..12f0aed 100644 --- a/server/server.cpp +++ b/server/server.cpp @@ -163,6 +163,20 @@ __AQEXPORT__(bool) have_hge(){ using prt_fn_t = char* (*)(void*, char*); +// This function contains heap allocations, free after use +template +char* to_lpstr(const String_T& str){ + auto ret = static_cast(malloc(str.size() + 1)); + memcpy(ret, str.c_str(), str.size()); + ret[str.size()] = '\0'; + return ret; +} +char* copy_lpstr(const char* str){ + auto len = strlen(str); + auto ret = static_cast(malloc(len + 1)); + memcpy(ret, str, len + 1); + return ret; +} constexpr prt_fn_t monetdbe_prtfns[] = { aq_to_chars, aq_to_chars, aq_to_chars, aq_to_chars, @@ -270,7 +284,18 @@ int dll_main(int argc, char** argv, Context* cxt){ aq_timer timer; Config *cfg = reinterpret_cast(argv[0]); std::unordered_map user_module_map; + std::string pwd = std::filesystem::current_path().c_str(); + auto sep = std::filesystem::path::preferred_separator; + pwd += sep; + std::string procedure_root = pwd + "procedures" + sep; std::string procedure_name = ""; + StoredProcedure current_procedure; + vector_type recorded_queries; + vector_type recorded_libraries; + bool procedure_recording = false, + procedure_replaying = false; + uint32_t procedure_module_cursor = 0; + if (cxt->module_function_maps == nullptr) cxt->module_function_maps = new std::unordered_map(); auto module_fn_map = @@ -291,12 +316,12 @@ int dll_main(int argc, char** argv, Context* cxt){ puts(*(const char**)(alt_server->getCol(0))); cxt->alt_server = alt_server; } - bool rec = false; while(cfg->running){ ENGINE_ACQUIRE(); if (cfg->new_query) { cfg->stats.postproc_time = 0; cfg->stats.monet_time = 0; +start: void *handle = nullptr; void *user_module_handle = nullptr; @@ -306,7 +331,28 @@ int dll_main(int argc, char** argv, Context* cxt){ Server* server = reinterpret_cast(cxt->alt_server); if(n_recv > 0){ if (cfg->backend_type == BACKEND_AQuery || cfg->has_dll) { - handle = dlopen("./dll.so", RTLD_NOW); + const char* proc_name = "./dll.so"; + std::string dll_path; + if (procedure_recording) { + dll_path = procedure_root + + procedure_name + std::to_string(recorded_libraries.size) + ".so"; + + try{ + if (std::filesystem::exists(dll_path)) + std::filesystem::remove(dll_path); + std::filesystem::copy_file(proc_name, dll_path); + } catch(std::filesystem::filesystem_error& e){ + puts(e.what()); + dll_path = proc_name; + } + proc_name = dll_path.c_str(); + if(recorded_libraries.size) + recorded_queries.emplace_back(copy_lpstr("N")); + } + handle = dlopen(proc_name, RTLD_NOW); + if (procedure_recording) { + recorded_libraries.emplace_back(handle); + } } for (const auto& module : user_module_map){ initialize_module(module.first.c_str(), module.second, cxt); @@ -314,18 +360,24 @@ int dll_main(int argc, char** argv, Context* cxt){ cxt->init_session(); for(int i = 0; i < n_recv; ++i) { - //printf("%s, %d\n", n_recvd[i], n_recvd[i][0] == 'Q'); + printf("%s, %d\n", n_recvd[i], n_recvd[i][0] == 'Q'); switch(n_recvd[i][0]){ case 'Q': // SQL query for monetdbe { + if(procedure_recording){ + recorded_queries.emplace_back(copy_lpstr(n_recvd[i])); + } timer.reset(); server->exec(n_recvd[i] + 1); cfg->stats.monet_time += timer.elapsed(); - // printf("Exec Q%d: %s", i, n_recvd[i]); + printf("Exec Q%d: %s", i, n_recvd[i]); } break; case 'P': // Postprocessing procedure if(handle && !server->haserror()) { + if (procedure_recording) { + recorded_queries.emplace_back(copy_lpstr(n_recvd[i])); + } code_snippet c = reinterpret_cast(dlsym(handle, n_recvd[i]+1)); timer.reset(); c(cxt); @@ -359,6 +411,12 @@ int dll_main(int argc, char** argv, Context* cxt){ case 'O': { if(!server->haserror()){ + if (procedure_recording){ + char* buf = (char*) malloc (sizeof(char) * 6); + memcpy(buf, n_recvd[i], 5); + buf[5] = '\0'; + recorded_queries.emplace_back(buf); + } uint32_t limit; memcpy(&limit, n_recvd[i] + 1, sizeof(uint32_t)); if (limit == 0) @@ -379,36 +437,115 @@ int dll_main(int argc, char** argv, Context* cxt){ user_module_map.erase(it); } break; + case 'N': + { + if(procedure_module_cursor < current_procedure.postproc_modules) + handle = current_procedure.__rt_loaded_modules[procedure_module_cursor++]; + } + break; case 'R': //recorded procedure { - auto proc_name = n_recvd[i] + 1; + auto proc_name = n_recvd[i] + 2; proc_name = *proc_name?proc_name : proc_name + 1; - const auto& load_modules = [](StoredProcedure &p){ + puts(proc_name); + const auto& load_modules = [&](StoredProcedure &p) { if (!p.__rt_loaded_modules){ p.__rt_loaded_modules = static_cast( malloc(sizeof(void*) * p.postproc_modules)); for(uint32_t j = 0; j < p.postproc_modules; ++j){ - p.__rt_loaded_modules[j] = dlopen(p.name, RTLD_NOW); + auto pj = dlopen(p.name, RTLD_NOW); + if (pj == nullptr){ + printf("Error: failed to load module %s\n", p.name); + return true; + } + p.__rt_loaded_modules[j] = pj; } } + return false; + }; + const auto& save_proc_tofile = [&](const StoredProcedure& p) { + auto config_name = procedure_root + procedure_name + ".aqp"; + auto fp = fopen(config_name.c_str(), "wb"); + if (fp == nullptr){ + printf("Error: failed to open file %s\n", config_name.c_str()); + return true; + } + fwrite(&p.cnt, sizeof(p.cnt), 1, fp); + fwrite(&p.postproc_modules, sizeof(p.postproc_modules), 1, fp); + for(uint32_t j = 0; j < p.cnt; ++j){ + auto current_query = p.queries[j]; + auto len_query = strlen(current_query); + fwrite(current_query, len_query + 1, 1, fp); + } + fclose(fp); + return false; + }; + const auto& load_proc_fromfile = [&](StoredProcedure& p) { + auto config_name = procedure_root + p.name + ".aqp"; + auto fp = fopen(config_name.c_str(), "rb"); + if(fp == nullptr){ + puts("ERROR: Procedure not found on disk."); + return false; + } + fread(&p.cnt, sizeof(p.cnt), 1, fp); + fread(&p.postproc_modules, sizeof(p.postproc_modules), 1, fp); + auto offset_now = ftell(fp); + fseek(fp, 0, SEEK_END); + auto queries_size = ftell(fp) - offset_now; + fseek(fp, offset_now, SEEK_SET); + + p.queries = static_cast(malloc(sizeof(char*) * p.cnt)); + p.queries[0] = static_cast(malloc(sizeof(char) * queries_size)); + fread(&p.queries[0], queries_size, 1, fp); + + for(uint32_t j = 1; j < p.cnt; ++j){ + p.queries[j] = p.queries[j-1]; + while(*p.queries[j] != '\0') + ++p.queries[j]; + } + fclose(fp); + return load_modules(p); }; switch(n_recvd[i][1]){ case '\0': + current_procedure.name = copy_lpstr(proc_name); + current_procedure.cnt = 0; + current_procedure.queries = nullptr; + current_procedure.postproc_modules = 0; + current_procedure.__rt_loaded_modules = nullptr; + procedure_recording = true; procedure_name = proc_name; break; case 'T': + current_procedure.queries = recorded_queries.container; + current_procedure.cnt = recorded_queries.size; + current_procedure.name = copy_lpstr(proc_name); + current_procedure.postproc_modules = recorded_libraries.size; + current_procedure.__rt_loaded_modules = recorded_libraries.container; + recorded_queries.size = recorded_queries.capacity = 0; + recorded_queries.container = nullptr; + recorded_libraries.size = recorded_libraries.capacity = 0; + recorded_libraries.container = nullptr; + procedure_recording = false; + save_proc_tofile(current_procedure); + cxt->stored_proc.insert_or_assign(procedure_name, current_procedure); procedure_name = ""; break; case 'E': // execute procedure { - auto _proc = cxt->stored_proc.find(procedure_name.c_str()); - if (_proc == cxt->stored_proc.end()) - printf("Procedure %s not found.\n", procedure_name.c_str()); + auto _proc = cxt->stored_proc.find(proc_name); + if (_proc == cxt->stored_proc.end()){ + printf("Procedure %s not found. Trying load from disk.\n", proc_name); + if (load_proc_fromfile(current_procedure)){ + cxt->stored_proc.insert_or_assign(proc_name, current_procedure); + } + } else{ - StoredProcedure &p = _proc->second; - n_recv = p.cnt; - n_recvd = p.queries; - load_modules(p); + current_procedure = _proc->second; + n_recv = current_procedure.cnt; + n_recvd = current_procedure.queries; + load_modules(current_procedure); + goto start; // yes, I know, refactor later!! } } break; @@ -418,12 +555,22 @@ int dll_main(int argc, char** argv, Context* cxt){ break; case 'L': //load procedure break; + case 'd': // display all procedures + for(const auto& p : cxt->stored_proc){ + printf("Procedure: %s, %d queries, %d modules:\n", p.first.c_str(), + p.second.cnt, p.second.postproc_modules); + for(uint32_t j = 0; j < p.second.cnt; ++j){ + printf("\tQuery %d: %s\n", j, p.second.queries[j]); + } + puts(""); + } + break; } } break; } } - if(handle) { + if(handle && procedure_replaying) { dlclose(handle); handle = nullptr; } @@ -486,7 +633,7 @@ extern "C" int __DLLEXPORT__ main(int argc, char** argv) { #endif // puts("running"); Context* cxt = new Context(); - cxt->aquery_root_path = std::filesystem::current_path().c_str(); + cxt->aquery_root_path = to_lpstr(std::filesystem::current_path().string()); // cxt->log("%d %s\n", argc, argv[1]); #ifdef THREADING diff --git a/server/table.h b/server/table.h index a32705c..8b038f0 100644 --- a/server/table.h +++ b/server/table.h @@ -145,9 +145,19 @@ public: ColRef<_Ty>& operator =(ColRef<_Ty>&& vt) { vector_type<_Ty>::operator=(std::move(vt)); return *this; + } - ColView<_Ty> operator [](const vector_type& idxs) const { - return ColView<_Ty>(*this, idxs); + // ColView<_Ty> operator [](vector_type& idxs) const { + // return ColView<_Ty>(*this, std::move(idxs)); + // } + // ColView<_Ty> operator [](const vector_type& idxs) const { + // return ColView<_Ty>(*this, idxs); + // } + vector_type<_Ty> operator[](vector_type& idxs) const { + vector_type<_Ty> ret(idxs.size); + for (uint32_t i = 0; i < idxs.size; ++i) + ret.container[i] = this->container[idxs[i]]; + return ret; } vector_type<_Ty> operator [](const std::vector& idxs) const { vector_type<_Ty> ret (this->size); @@ -226,7 +236,7 @@ class ColView : public vector_base<_Ty> { public: typedef ColRef<_Ty> Decayed_t; const uint32_t size; - const ColRef<_Ty> orig; + const ColRef<_Ty>& orig; vector_type idxs; ColView(const ColRef<_Ty>& orig, vector_type&& idxs) : orig(orig), size(idxs.size), idxs(std::move(idxs)) {} ColView(const ColRef<_Ty>& orig, const vector_type& idxs) : orig(orig), idxs(idxs), size(idxs.size) {} @@ -274,6 +284,7 @@ public: ret[i] = orig[idxs[i]]; return ret; } + ColView<_Ty> subvec(uint32_t start, uint32_t end) const { uint32_t len = end - start; return ColView<_Ty>(orig, idxs.subvec(start, end)); diff --git a/server/unordered_dense.h b/server/unordered_dense.h new file mode 100644 index 0000000..737d12b --- /dev/null +++ b/server/unordered_dense.h @@ -0,0 +1,1516 @@ +///////////////////////// ankerl::unordered_dense::{map, set} ///////////////////////// + +// A fast & densely stored hashmap and hashset based on robin-hood backward shift deletion. +// Version 2.0.1 +// https://github.com/martinus/unordered_dense +// +// Licensed under the MIT License . +// SPDX-License-Identifier: MIT +// Copyright (c) 2022 Martin Leitner-Ankerl +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef ANKERL_UNORDERED_DENSE_H +#define ANKERL_UNORDERED_DENSE_H + +// see https://semver.org/spec/v2.0.0.html +#define ANKERL_UNORDERED_DENSE_VERSION_MAJOR 2 // NOLINT(cppcoreguidelines-macro-usage) incompatible API changes +#define ANKERL_UNORDERED_DENSE_VERSION_MINOR 0 // NOLINT(cppcoreguidelines-macro-usage) backwards compatible functionality +#define ANKERL_UNORDERED_DENSE_VERSION_PATCH 1 // NOLINT(cppcoreguidelines-macro-usage) backwards compatible bug fixes + +// API versioning with inline namespace, see https://www.foonathan.net/2018/11/inline-namespaces/ +#define ANKERL_UNORDERED_DENSE_VERSION_CONCAT1(major, minor, patch) v##major##_##minor##_##patch +#define ANKERL_UNORDERED_DENSE_VERSION_CONCAT(major, minor, patch) ANKERL_UNORDERED_DENSE_VERSION_CONCAT1(major, minor, patch) +#define ANKERL_UNORDERED_DENSE_NAMESPACE \ + ANKERL_UNORDERED_DENSE_VERSION_CONCAT( \ + ANKERL_UNORDERED_DENSE_VERSION_MAJOR, ANKERL_UNORDERED_DENSE_VERSION_MINOR, ANKERL_UNORDERED_DENSE_VERSION_PATCH) + +#if defined(_MSVC_LANG) +# define ANKERL_UNORDERED_DENSE_CPP_VERSION _MSVC_LANG +#else +# define ANKERL_UNORDERED_DENSE_CPP_VERSION __cplusplus +#endif + +#if defined(__GNUC__) +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +# define ANKERL_UNORDERED_DENSE_PACK(decl) decl __attribute__((__packed__)) +#elif defined(_MSC_VER) +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +# define ANKERL_UNORDERED_DENSE_PACK(decl) __pragma(pack(push, 1)) decl __pragma(pack(pop)) +#endif + +#if ANKERL_UNORDERED_DENSE_CPP_VERSION < 201703L +# error ankerl::unordered_dense requires C++17 or higher +#else +# include // for array +# include // for uint64_t, uint32_t, uint8_t, UINT64_C +# include // for size_t, memcpy, memset +# include // for equal_to, hash +# include // for initializer_list +# include // for pair, distance +# include // for numeric_limits +# include // for allocator, allocator_traits, shared_ptr +# include // for out_of_range +# include // for basic_string +# include // for basic_string_view, hash +# include // for forward_as_tuple +# include // for enable_if_t, declval, conditional_t, ena... +# include // for forward, exchange, pair, as_const, piece... +# include // for vector + +# define ANKERL_UNORDERED_DENSE_PMR 0 // NOLINT(cppcoreguidelines-macro-usage) +# if defined(__has_include) +# if __has_include() +# undef ANKERL_UNORDERED_DENSE_PMR +# define ANKERL_UNORDERED_DENSE_PMR 1 // NOLINT(cppcoreguidelines-macro-usage) +# include // for polymorphic_allocator +# endif +# endif + +# if defined(_MSC_VER) && defined(_M_X64) +# include +# pragma intrinsic(_umul128) +# endif + +# if defined(__GNUC__) || defined(__INTEL_COMPILER) || defined(__clang__) +# define ANKERL_UNORDERED_DENSE_LIKELY(x) __builtin_expect(x, 1) // NOLINT(cppcoreguidelines-macro-usage) +# define ANKERL_UNORDERED_DENSE_UNLIKELY(x) __builtin_expect(x, 0) // NOLINT(cppcoreguidelines-macro-usage) +# else +# define ANKERL_UNORDERED_DENSE_LIKELY(x) (x) // NOLINT(cppcoreguidelines-macro-usage) +# define ANKERL_UNORDERED_DENSE_UNLIKELY(x) (x) // NOLINT(cppcoreguidelines-macro-usage) +# endif + +namespace ankerl::unordered_dense { +inline namespace ANKERL_UNORDERED_DENSE_NAMESPACE { + +// hash /////////////////////////////////////////////////////////////////////// + +// This is a stripped-down implementation of wyhash: https://github.com/wangyi-fudan/wyhash +// No big-endian support (because different values on different machines don't matter), +// hardcodes seed and the secret, reformattes the code, and clang-tidy fixes. +namespace detail::wyhash { + +static inline void mum(uint64_t* a, uint64_t* b) { +# if defined(__SIZEOF_INT128__) + __uint128_t r = *a; + r *= *b; + *a = static_cast(r); + *b = static_cast(r >> 64U); +# elif defined(_MSC_VER) && defined(_M_X64) + *a = _umul128(*a, *b, b); +# else + uint64_t ha = *a >> 32U; + uint64_t hb = *b >> 32U; + uint64_t la = static_cast(*a); + uint64_t lb = static_cast(*b); + uint64_t hi{}; + uint64_t lo{}; + uint64_t rh = ha * hb; + uint64_t rm0 = ha * lb; + uint64_t rm1 = hb * la; + uint64_t rl = la * lb; + uint64_t t = rl + (rm0 << 32U); + auto c = static_cast(t < rl); + lo = t + (rm1 << 32U); + c += static_cast(lo < t); + hi = rh + (rm0 >> 32U) + (rm1 >> 32U) + c; + *a = lo; + *b = hi; +# endif +} + +// multiply and xor mix function, aka MUM +[[nodiscard]] static inline auto mix(uint64_t a, uint64_t b) -> uint64_t { + mum(&a, &b); + return a ^ b; +} + +// read functions. WARNING: we don't care about endianness, so results are different on big endian! +[[nodiscard]] static inline auto r8(const uint8_t* p) -> uint64_t { + uint64_t v{}; + std::memcpy(&v, p, 8U); + return v; +} + +[[nodiscard]] static inline auto r4(const uint8_t* p) -> uint64_t { + uint32_t v{}; + std::memcpy(&v, p, 4); + return v; +} + +// reads 1, 2, or 3 bytes +[[nodiscard]] static inline auto r3(const uint8_t* p, size_t k) -> uint64_t { + return (static_cast(p[0]) << 16U) | (static_cast(p[k >> 1U]) << 8U) | p[k - 1]; +} + +[[maybe_unused]] [[nodiscard]] static inline auto hash(void const* key, size_t len) -> uint64_t { + static constexpr auto secret = std::array{UINT64_C(0xa0761d6478bd642f), + UINT64_C(0xe7037ed1a0b428db), + UINT64_C(0x8ebc6af09c88c6e3), + UINT64_C(0x589965cc75374cc3)}; + + auto const* p = static_cast(key); + uint64_t seed = secret[0]; + uint64_t a{}; + uint64_t b{}; + if (ANKERL_UNORDERED_DENSE_LIKELY(len <= 16)) { + if (ANKERL_UNORDERED_DENSE_LIKELY(len >= 4)) { + a = (r4(p) << 32U) | r4(p + ((len >> 3U) << 2U)); + b = (r4(p + len - 4) << 32U) | r4(p + len - 4 - ((len >> 3U) << 2U)); + } else if (ANKERL_UNORDERED_DENSE_LIKELY(len > 0)) { + a = r3(p, len); + b = 0; + } else { + a = 0; + b = 0; + } + } else { + size_t i = len; + if (ANKERL_UNORDERED_DENSE_UNLIKELY(i > 48)) { + uint64_t see1 = seed; + uint64_t see2 = seed; + do { + seed = mix(r8(p) ^ secret[1], r8(p + 8) ^ seed); + see1 = mix(r8(p + 16) ^ secret[2], r8(p + 24) ^ see1); + see2 = mix(r8(p + 32) ^ secret[3], r8(p + 40) ^ see2); + p += 48; + i -= 48; + } while (ANKERL_UNORDERED_DENSE_LIKELY(i > 48)); + seed ^= see1 ^ see2; + } + while (ANKERL_UNORDERED_DENSE_UNLIKELY(i > 16)) { + seed = mix(r8(p) ^ secret[1], r8(p + 8) ^ seed); + i -= 16; + p += 16; + } + a = r8(p + i - 16); + b = r8(p + i - 8); + } + + return mix(secret[1] ^ len, mix(a ^ secret[1], b ^ seed)); +} + +[[nodiscard]] static inline auto hash(uint64_t x) -> uint64_t { + return detail::wyhash::mix(x, UINT64_C(0x9E3779B97F4A7C15)); +} + +} // namespace detail::wyhash + +template +struct hash { + auto operator()(T const& obj) const noexcept(noexcept(std::declval>().operator()(std::declval()))) + -> uint64_t { + return std::hash{}(obj); + } +}; + +template +struct hash> { + using is_avalanching = void; + auto operator()(std::basic_string const& str) const noexcept -> uint64_t { + return detail::wyhash::hash(str.data(), sizeof(CharT) * str.size()); + } +}; + +template +struct hash> { + using is_avalanching = void; + auto operator()(std::basic_string_view const& sv) const noexcept -> uint64_t { + return detail::wyhash::hash(sv.data(), sizeof(CharT) * sv.size()); + } +}; + +template +struct hash { + using is_avalanching = void; + auto operator()(T* ptr) const noexcept -> uint64_t { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + return detail::wyhash::hash(reinterpret_cast(ptr)); + } +}; + +template +struct hash> { + using is_avalanching = void; + auto operator()(std::unique_ptr const& ptr) const noexcept -> uint64_t { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + return detail::wyhash::hash(reinterpret_cast(ptr.get())); + } +}; + +template +struct hash> { + using is_avalanching = void; + auto operator()(std::shared_ptr const& ptr) const noexcept -> uint64_t { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + return detail::wyhash::hash(reinterpret_cast(ptr.get())); + } +}; + +template +struct hash::value>::type> { + using is_avalanching = void; + auto operator()(Enum e) const noexcept -> uint64_t { + using underlying = typename std::underlying_type_t; + return detail::wyhash::hash(static_cast(e)); + } +}; + +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +# define ANKERL_UNORDERED_DENSE_HASH_STATICCAST(T) \ + template <> \ + struct hash { \ + using is_avalanching = void; \ + auto operator()(T const& obj) const noexcept -> uint64_t { \ + return detail::wyhash::hash(static_cast(obj)); \ + } \ + } + +# if defined(__GNUC__) && !defined(__clang__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wuseless-cast" +# endif +// see https://en.cppreference.com/w/cpp/utility/hash +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(bool); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(char); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(signed char); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(unsigned char); +# if ANKERL_UNORDERED_DENSE_CPP_VERSION >= 202002L +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(char8_t); +# endif +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(char16_t); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(char32_t); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(wchar_t); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(short); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(unsigned short); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(int); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(unsigned int); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(long); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(long long); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(unsigned long); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(unsigned long long); + +# if defined(__GNUC__) && !defined(__clang__) +# pragma GCC diagnostic pop +# endif + +// bucket_type ////////////////////////////////////////////////////////// + +namespace bucket_type { + +struct standard { + static constexpr uint32_t dist_inc = 1U << 8U; // skip 1 byte fingerprint + static constexpr uint32_t fingerprint_mask = dist_inc - 1; // mask for 1 byte of fingerprint + + uint32_t m_dist_and_fingerprint; // upper 3 byte: distance to original bucket. lower byte: fingerprint from hash + uint32_t m_value_idx; // index into the m_values vector. +}; + +ANKERL_UNORDERED_DENSE_PACK(struct big { + static constexpr uint32_t dist_inc = 1U << 8U; // skip 1 byte fingerprint + static constexpr uint32_t fingerprint_mask = dist_inc - 1; // mask for 1 byte of fingerprint + + uint32_t m_dist_and_fingerprint; // upper 3 byte: distance to original bucket. lower byte: fingerprint from hash + size_t m_value_idx; // index into the m_values vector. +}); + +} // namespace bucket_type + +namespace detail { + +struct nonesuch {}; + +template class Op, class... Args> +struct detector { + using value_t = std::false_type; + using type = Default; +}; + +template class Op, class... Args> +struct detector>, Op, Args...> { + using value_t = std::true_type; + using type = Op; +}; + +template