From 91a1cc80cdbb3c835a198ced6b4604a160e7bb20 Mon Sep 17 00:00:00 2001 From: billsun Date: Fri, 20 Oct 2023 20:56:29 +0000 Subject: [PATCH] update --- aquery_config.py | 4 ++- demo/action.cpp | 2 +- duckdb_install.py | 21 ++++++++++++ engine/ast.py | 36 ++++++++++++-------- server/hasher.h | 86 ++++++++++++++++++++++++++++++++++++++++++----- 5 files changed, 125 insertions(+), 24 deletions(-) create mode 100644 duckdb_install.py diff --git a/aquery_config.py b/aquery_config.py index 6b7a00f..5e2d30a 100644 --- a/aquery_config.py +++ b/aquery_config.py @@ -47,6 +47,7 @@ def init_config(): os_platform = 'bsd' elif sys.platform == 'cygwin' or sys.platform == 'msys': os_platform = 'cygwin' + # deal with msys dependencies: if os_platform == 'win': add_dll_dir(os.path.abspath('./msc-plugin')) @@ -73,8 +74,9 @@ def init_config(): if build_driver == 'Auto': build_driver = 'Makefile' if os_platform == 'linux': - os.environ['PATH'] = os.environ['PATH'] + os.pathsep + '/usr/lib' + os.environ['PATH'] += os.pathsep + '/usr/lib' if os_platform == 'cygwin': add_dll_dir('./lib') + os.environ['LD_LIBRARY_PATH'] += os.pathsep + os.getcwd()+ os.sep + 'deps' __config_initialized__ = True diff --git a/demo/action.cpp b/demo/action.cpp index e4ef271..e3da581 100644 --- a/demo/action.cpp +++ b/demo/action.cpp @@ -20,7 +20,7 @@ __AQEXPORT__(int) action(Context* cxt) { if (fit_inc == nullptr) fit_inc = (decltype(fit_inc))(cxt->get_module_function("fit_inc")); - auto server = static_cast(cxt->alt_server); + auto server = reinterpret_cast(cxt->alt_server); auto len = uint32_t(monetdbe_get_size(*((void**)server->server), "source")); auto x_1bN = ColRef>(len, monetdbe_get_col(*((void**)(server->server)), "source", 0)); auto y_6uX = ColRef(len, monetdbe_get_col(*((void**)(server->server)), "source", 1)); diff --git a/duckdb_install.py b/duckdb_install.py new file mode 100644 index 0000000..c69fcea --- /dev/null +++ b/duckdb_install.py @@ -0,0 +1,21 @@ +import urllib.request +import zipfile +from aquery_config import os_platform +from os import remove + +version = '0.8.1' +duckdb_os = 'windows' if os_platform == 'windows' else 'osx' if os_platform == 'darwin' else 'linux' + +duckdb_plat = 'i386' +if duckdb_os == 'darwin': + duckdb_plat = 'universal' +else: + duckdb_plat = 'amd64' + +duckdb_pkg = f'libduckdb-{duckdb_os}-{duckdb_plat}.zip' +# urllib.request.urlretrieve(f"https://github.com/duckdb/duckdb/releases/latest/download/{duckdb_pkg}", duckdb_pkg) +urllib.request.urlretrieve(f"https://github.com/duckdb/duckdb/releases/download/v{version}/{duckdb_pkg}", duckdb_pkg) +with zipfile.ZipFile(duckdb_pkg, 'r') as duck: + duck.extractall('deps') + +remove(duckdb_pkg) diff --git a/engine/ast.py b/engine/ast.py index ab9bc4c..eae42e9 100644 --- a/engine/ast.py +++ b/engine/ast.py @@ -629,6 +629,7 @@ class groupby_c(ast_node): self.context.headers.add('"./server/hasher.h"') # self.context.headers.add('unordered_map') self.group = 'g' + base62uuid(7) + self.group_size = 'sz_' + self.group self.group_type = 'record_type' + base62uuid(7) self.datasource = self.proj.datasource self.scanner = None @@ -660,18 +661,25 @@ class groupby_c(ast_node): ) ## self.context.emitc('printf("init_time: %lld\\n", (chrono::high_resolution_clock::now() - timer).count()); timer = chrono::high_resolution_clock::now();') self.context.emitc(f'typedef record<{",".join(g_contents_decltype)}> {self.group_type};') - self.context.emitc(f'AQHashTable<{self.group_type}, ' - f'transTypes<{self.group_type}, hasher>> {self.group} {{{self.total_sz}}};') + # self.context.emitc(f'AQHashTable<{self.group_type}, ' + # f'transTypes<{self.group_type}, hasher>> {self.group} {{{self.total_sz}}};') self.n_grps = len(self.glist) + + self.context.emitc(f'auto {self.group} = ' + f'HashTableFactory<{self.group_type}, transTypes<{self.group_type}, hasher>>::' + f'get<{", ".join([f"decays" for c in g_contents_list])}>({g_contents});') + + # self.scanner = scan(self, self.total_sz, it_name=scanner_itname) # self.scanner.add(f'{self.group}.hashtable_push(forward_as_tuple({g_contents}), {self.scanner.it_var});') - self.context.emitc(f'{self.group}.hashtable_push_all<{", ".join([f"decays" for c in g_contents_list])}>({g_contents}, {self.total_sz});') + # self.context.emitc(f'{self.group}.hashtable_push_all<{", ".join([f"decays" for c in g_contents_list])}>({g_contents}, {self.total_sz});') def consume(self, _): # self.scanner.finalize() ## self.context.emitc('printf("ht_construct: %lld\\n", (chrono::high_resolution_clock::now() - timer).count()); timer = chrono::high_resolution_clock::now();') - self.context.emitc(f'auto {self.vecs} = {self.group}.ht_postproc({self.total_sz});') + self.context.emitc(f'auto {self.group_size} = {self.group}.size;') + self.context.emitc(f'auto {self.vecs} = {self.group}.values;')#{self.group}.ht_postproc({self.total_sz});') ## self.context.emitc('printf("ht_postproc: %lld\\n", (chrono::high_resolution_clock::now() - timer).count()); timer = chrono::high_resolution_clock::now();') # def deal_with_assumptions(self, assumption:assumption, out:TableInfo): # gscanner = scan(self, self.group) @@ -685,33 +693,33 @@ class groupby_c(ast_node): tovec_columns = set() for i, c in enumerate(col_names): if col_tovec[i]: # and type(col_types[i]) is VectorT: - self.context.emitc(f'{c}.resize({self.group}.size());') + self.context.emitc(f'{c}.resize({self.group_size});') typename : Types = col_types[i] # .inner_type self.context.emitc(f'auto buf_{c} = static_cast<{typename.cname} *>(calloc({self.total_sz}, sizeof({typename.cname})));') tovec_columns.add(c) else: - self.context.emitc(f'{c}.resize({self.group}.size());') + self.context.emitc(f'{c}.resize({self.group_size});') - self.arr_len = 'arrlen_' + base62uuid(3) - self.arr_values = 'arrvals_' + base62uuid(3) + # self.arr_len = 'arrlen_' + base62uuid(3) + self.arr_values = {self.group.keys}#'arrvals_' + base62uuid(3) - self.context.emitc(f'auto {self.arr_len} = {self.group}.size();') - self.context.emitc(f'auto {self.arr_values} = {self.group}.values();') + # self.context.emitc(f'auto {self.arr_len} = {self.group_size};') + # self.context.emitc(f'auto {self.arr_values} = {self.group}.values();') - if len(tovec_columns): - preproc_scanner = scan(self, self.arr_len) + if len(tovec_columns): # do this in seperate loops. + preproc_scanner = scan(self, self.group_size) preproc_scanner_it = preproc_scanner.it_var for c in tovec_columns: preproc_scanner.add(f'{c}[{preproc_scanner_it}].init_from' f'({self.vecs}[{preproc_scanner_it}].size,' - f' {"buf_" + c} + {self.group}.ht_base' + f' {"buf_" + c} + {self.group}.offsets' f'[{preproc_scanner_it}]);' ) preproc_scanner.finalize() self.context.emitc('GC::scratch_space = GC::gc_handle ? &(GC::gc_handle->scratch) : nullptr;') # gscanner = scan(self, self.group, loop_style = scan.LoopStyle.foreach) - gscanner = scan(self, self.arr_len) + gscanner = scan(self, self.group_size) key_var = 'key_'+base62uuid(7) val_var = 'val_'+base62uuid(7) diff --git a/server/hasher.h b/server/hasher.h index c60fa36..677967e 100644 --- a/server/hasher.h +++ b/server/hasher.h @@ -166,8 +166,10 @@ public: template inline void hashtable_push_all(Keys_t& ... keys, uint32_t len) { +#pragma omp simd for(uint32_t i = 0; i < len; ++i) reversemap[i] = ankerl::unordered_dense::set::hashtable_push(keys[i]...); +#pragma omp simd for(uint32_t i = 0; i < len; ++i) ++ht_base[reversemap[i]]; } @@ -182,10 +184,12 @@ public: auto vecs = static_cast*>(malloc(sizeof(vector_type) * len)); vecs[0].init_from(ht_base[0], mapbase); +#pragma omp simd for (uint32_t i = 1; i < len; ++i) { vecs[i].init_from(ht_base[i], mapbase + ht_base[i - 1]); ht_base[i] += ht_base[i - 1]; } +#pragma omp simd for (uint32_t i = 0; i < sz; ++i) { auto id = reversemap[i]; mapbase[--ht_base[id]] = i; @@ -194,11 +198,18 @@ public: } }; +template +struct HashTableComponents { + uint32_t size; + std::vector>* keys; + vector_type* values; + uint32_t* offsets; +}; template < typename ValueType = uint32_t, - int PerfectHashingThreshold = 12 -> + int PerfectHashingThreshold = 18 +> // default < 1M table size struct PerfectHashTable { using key_t = std::conditional_t>>; constexpr static uint32_t tbl_sz = 1 << PerfectHashingThreshold; template class VT> - static vector_type* + static HashTableComponents //vector_type* construct(VT&... args) { // construct a hash set // AQTmr(); int n_cols, n_rows = 0; @@ -216,7 +227,7 @@ struct PerfectHashTable { static_assert( (sizeof...(Types) < PerfectHashingThreshold) && (std::is_integral_v && ...), - "Types must be integral and less than 12 wide in total." + "Types must be integral and less than \"PerfectHashingThreshold\" wide in total." ); key_t* hash_values = static_cast( @@ -241,9 +252,10 @@ struct PerfectHashTable { }; int idx = 0; (get_hash(args, idx++), ...); - uint32_t cnt[tbl_sz]; + uint32_t *cnt_ext = static_cast( + calloc(tbl_sz, sizeof(uint32_t)) + ), *cnt = cnt_ext + 1; uint32_t n_grps = 0; - memset(cnt, 0, tbl_sz * sizeof(tbl_sz)); #pragma omp simd for (uint32_t i = 0; i < n_cols; ++i) { ++cnt[hash_values[i]]; @@ -256,6 +268,24 @@ struct PerfectHashTable { grp_ids[i] = n_grps++; } } + std::vector>* keys = new std::vector>(n_grps); // Memory leak here, cleanup after module is done. + + const char bits[] = {0, args.stats.bits ... }; + auto decode = [](auto &val, const char prev, const char curr) -> Ret { + val >>= prev; + const auto mask = (1 << curr) - 1; + return val & mask; + }; +#pragma omp simd + for (ValueType i = 0; i < n_grps; ++ i) { + int idx2 = 1; + ValueType curr_val = grp_ids[i]; + keys[i] = std::make_tuple(( + decode.template operator()( + curr_val, bits[idx2 - 1], bits[idx2++] + ), ...) + ); // require C++20 for the calls to be executed sequentially. + } uint32_t* idxs = static_cast( malloc(n_cols * sizeof(uint32_t)) ); @@ -281,7 +311,47 @@ struct PerfectHashTable { idxs_vec[i].container = idxs_ptr[i]; idxs_vec[i].size = cnt[i]; } - free(hash_values); - return idxs_vec; + GC::gc_handle->reg(hash_values); + +#pragma omp simd + for(int i = 1; i < n_grps; ++ i) + cnt[i] += cnt[i - 1]; + cnt_ext[0] = 0; + return {.size = n_grps, .keys = keys, .values = idxs_vec, .offset = cnt_ext}; } }; + +template +class ColRef; + +template < + class Key, + class Hash, + int PerfectHashingThreshold = 18 +> +class HashTableFactory { +public: + template + static HashTableComponents + get(ColRef& ... cols) { +// To use Perfect Hash Table + if constexpr ((std::is_integral_v && ...)) { + if ((cols.stats.bits + ...) <= PerfectHashingThreshold) { + return PerfectHashTable< + uint32_t, + PerfectHashingThreshold + >::construct(cols ...); + } + } + +// Fallback to regular hash table + int n_rows = 0; + ((n_rows = cols.size), ...); + + AQHashTable ht{n_rows}; + ht.template hashtable_push_all ...>(cols ..., n_rows); + auto vals = ht.ht_postproc(n_rows); + + return {.size = ht.size(), .keys = ht.values(), .values = vals, .offset = ht.ht_base}; + } +};