master
billsun 1 year ago
parent 52afa95e94
commit 91a1cc80cd

@ -47,6 +47,7 @@ def init_config():
os_platform = 'bsd' os_platform = 'bsd'
elif sys.platform == 'cygwin' or sys.platform == 'msys': elif sys.platform == 'cygwin' or sys.platform == 'msys':
os_platform = 'cygwin' os_platform = 'cygwin'
# deal with msys dependencies: # deal with msys dependencies:
if os_platform == 'win': if os_platform == 'win':
add_dll_dir(os.path.abspath('./msc-plugin')) add_dll_dir(os.path.abspath('./msc-plugin'))
@ -73,8 +74,9 @@ def init_config():
if build_driver == 'Auto': if build_driver == 'Auto':
build_driver = 'Makefile' build_driver = 'Makefile'
if os_platform == 'linux': if os_platform == 'linux':
os.environ['PATH'] = os.environ['PATH'] + os.pathsep + '/usr/lib' os.environ['PATH'] += os.pathsep + '/usr/lib'
if os_platform == 'cygwin': if os_platform == 'cygwin':
add_dll_dir('./lib') add_dll_dir('./lib')
os.environ['LD_LIBRARY_PATH'] += os.pathsep + os.getcwd()+ os.sep + 'deps'
__config_initialized__ = True __config_initialized__ = True

@ -20,7 +20,7 @@ __AQEXPORT__(int) action(Context* cxt) {
if (fit_inc == nullptr) if (fit_inc == nullptr)
fit_inc = (decltype(fit_inc))(cxt->get_module_function("fit_inc")); fit_inc = (decltype(fit_inc))(cxt->get_module_function("fit_inc"));
auto server = static_cast<DataSource*>(cxt->alt_server); auto server = reinterpret_cast<DataSource*>(cxt->alt_server);
auto len = uint32_t(monetdbe_get_size(*((void**)server->server), "source")); auto len = uint32_t(monetdbe_get_size(*((void**)server->server), "source"));
auto x_1bN = ColRef<vector_type<double>>(len, monetdbe_get_col(*((void**)(server->server)), "source", 0)); auto x_1bN = ColRef<vector_type<double>>(len, monetdbe_get_col(*((void**)(server->server)), "source", 0));
auto y_6uX = ColRef<int64_t>(len, monetdbe_get_col(*((void**)(server->server)), "source", 1)); auto y_6uX = ColRef<int64_t>(len, monetdbe_get_col(*((void**)(server->server)), "source", 1));

@ -0,0 +1,21 @@
import urllib.request
import zipfile
from aquery_config import os_platform
from os import remove
version = '0.8.1'
duckdb_os = 'windows' if os_platform == 'windows' else 'osx' if os_platform == 'darwin' else 'linux'
duckdb_plat = 'i386'
if duckdb_os == 'darwin':
duckdb_plat = 'universal'
else:
duckdb_plat = 'amd64'
duckdb_pkg = f'libduckdb-{duckdb_os}-{duckdb_plat}.zip'
# urllib.request.urlretrieve(f"https://github.com/duckdb/duckdb/releases/latest/download/{duckdb_pkg}", duckdb_pkg)
urllib.request.urlretrieve(f"https://github.com/duckdb/duckdb/releases/download/v{version}/{duckdb_pkg}", duckdb_pkg)
with zipfile.ZipFile(duckdb_pkg, 'r') as duck:
duck.extractall('deps')
remove(duckdb_pkg)

@ -629,6 +629,7 @@ class groupby_c(ast_node):
self.context.headers.add('"./server/hasher.h"') self.context.headers.add('"./server/hasher.h"')
# self.context.headers.add('unordered_map') # self.context.headers.add('unordered_map')
self.group = 'g' + base62uuid(7) self.group = 'g' + base62uuid(7)
self.group_size = 'sz_' + self.group
self.group_type = 'record_type' + base62uuid(7) self.group_type = 'record_type' + base62uuid(7)
self.datasource = self.proj.datasource self.datasource = self.proj.datasource
self.scanner = None self.scanner = None
@ -660,18 +661,25 @@ class groupby_c(ast_node):
) )
## self.context.emitc('printf("init_time: %lld\\n", (chrono::high_resolution_clock::now() - timer).count()); timer = chrono::high_resolution_clock::now();') ## self.context.emitc('printf("init_time: %lld\\n", (chrono::high_resolution_clock::now() - timer).count()); timer = chrono::high_resolution_clock::now();')
self.context.emitc(f'typedef record<{",".join(g_contents_decltype)}> {self.group_type};') self.context.emitc(f'typedef record<{",".join(g_contents_decltype)}> {self.group_type};')
self.context.emitc(f'AQHashTable<{self.group_type}, ' # self.context.emitc(f'AQHashTable<{self.group_type}, '
f'transTypes<{self.group_type}, hasher>> {self.group} {{{self.total_sz}}};') # f'transTypes<{self.group_type}, hasher>> {self.group} {{{self.total_sz}}};')
self.n_grps = len(self.glist) self.n_grps = len(self.glist)
self.context.emitc(f'auto {self.group} = '
f'HashTableFactory<{self.group_type}, transTypes<{self.group_type}, hasher>>::'
f'get<{", ".join([f"decays<decltype({c})>" for c in g_contents_list])}>({g_contents});')
# self.scanner = scan(self, self.total_sz, it_name=scanner_itname) # self.scanner = scan(self, self.total_sz, it_name=scanner_itname)
# self.scanner.add(f'{self.group}.hashtable_push(forward_as_tuple({g_contents}), {self.scanner.it_var});') # self.scanner.add(f'{self.group}.hashtable_push(forward_as_tuple({g_contents}), {self.scanner.it_var});')
self.context.emitc(f'{self.group}.hashtable_push_all<{", ".join([f"decays<decltype({c})>" for c in g_contents_list])}>({g_contents}, {self.total_sz});') # self.context.emitc(f'{self.group}.hashtable_push_all<{", ".join([f"decays<decltype({c})>" for c in g_contents_list])}>({g_contents}, {self.total_sz});')
def consume(self, _): def consume(self, _):
# self.scanner.finalize() # self.scanner.finalize()
## self.context.emitc('printf("ht_construct: %lld\\n", (chrono::high_resolution_clock::now() - timer).count()); timer = chrono::high_resolution_clock::now();') ## self.context.emitc('printf("ht_construct: %lld\\n", (chrono::high_resolution_clock::now() - timer).count()); timer = chrono::high_resolution_clock::now();')
self.context.emitc(f'auto {self.vecs} = {self.group}.ht_postproc({self.total_sz});') self.context.emitc(f'auto {self.group_size} = {self.group}.size;')
self.context.emitc(f'auto {self.vecs} = {self.group}.values;')#{self.group}.ht_postproc({self.total_sz});')
## self.context.emitc('printf("ht_postproc: %lld\\n", (chrono::high_resolution_clock::now() - timer).count()); timer = chrono::high_resolution_clock::now();') ## self.context.emitc('printf("ht_postproc: %lld\\n", (chrono::high_resolution_clock::now() - timer).count()); timer = chrono::high_resolution_clock::now();')
# def deal_with_assumptions(self, assumption:assumption, out:TableInfo): # def deal_with_assumptions(self, assumption:assumption, out:TableInfo):
# gscanner = scan(self, self.group) # gscanner = scan(self, self.group)
@ -685,33 +693,33 @@ class groupby_c(ast_node):
tovec_columns = set() tovec_columns = set()
for i, c in enumerate(col_names): for i, c in enumerate(col_names):
if col_tovec[i]: # and type(col_types[i]) is VectorT: if col_tovec[i]: # and type(col_types[i]) is VectorT:
self.context.emitc(f'{c}.resize({self.group}.size());') self.context.emitc(f'{c}.resize({self.group_size});')
typename : Types = col_types[i] # .inner_type typename : Types = col_types[i] # .inner_type
self.context.emitc(f'auto buf_{c} = static_cast<{typename.cname} *>(calloc({self.total_sz}, sizeof({typename.cname})));') self.context.emitc(f'auto buf_{c} = static_cast<{typename.cname} *>(calloc({self.total_sz}, sizeof({typename.cname})));')
tovec_columns.add(c) tovec_columns.add(c)
else: else:
self.context.emitc(f'{c}.resize({self.group}.size());') self.context.emitc(f'{c}.resize({self.group_size});')
self.arr_len = 'arrlen_' + base62uuid(3) # self.arr_len = 'arrlen_' + base62uuid(3)
self.arr_values = 'arrvals_' + base62uuid(3) self.arr_values = {self.group.keys}#'arrvals_' + base62uuid(3)
self.context.emitc(f'auto {self.arr_len} = {self.group}.size();') # self.context.emitc(f'auto {self.arr_len} = {self.group_size};')
self.context.emitc(f'auto {self.arr_values} = {self.group}.values();') # self.context.emitc(f'auto {self.arr_values} = {self.group}.values();')
if len(tovec_columns): if len(tovec_columns): # do this in seperate loops.
preproc_scanner = scan(self, self.arr_len) preproc_scanner = scan(self, self.group_size)
preproc_scanner_it = preproc_scanner.it_var preproc_scanner_it = preproc_scanner.it_var
for c in tovec_columns: for c in tovec_columns:
preproc_scanner.add(f'{c}[{preproc_scanner_it}].init_from' preproc_scanner.add(f'{c}[{preproc_scanner_it}].init_from'
f'({self.vecs}[{preproc_scanner_it}].size,' f'({self.vecs}[{preproc_scanner_it}].size,'
f' {"buf_" + c} + {self.group}.ht_base' f' {"buf_" + c} + {self.group}.offsets'
f'[{preproc_scanner_it}]);' f'[{preproc_scanner_it}]);'
) )
preproc_scanner.finalize() preproc_scanner.finalize()
self.context.emitc('GC::scratch_space = GC::gc_handle ? &(GC::gc_handle->scratch) : nullptr;') self.context.emitc('GC::scratch_space = GC::gc_handle ? &(GC::gc_handle->scratch) : nullptr;')
# gscanner = scan(self, self.group, loop_style = scan.LoopStyle.foreach) # gscanner = scan(self, self.group, loop_style = scan.LoopStyle.foreach)
gscanner = scan(self, self.arr_len) gscanner = scan(self, self.group_size)
key_var = 'key_'+base62uuid(7) key_var = 'key_'+base62uuid(7)
val_var = 'val_'+base62uuid(7) val_var = 'val_'+base62uuid(7)

@ -166,8 +166,10 @@ public:
template<typename... Keys_t> template<typename... Keys_t>
inline void hashtable_push_all(Keys_t& ... keys, uint32_t len) { inline void hashtable_push_all(Keys_t& ... keys, uint32_t len) {
#pragma omp simd
for(uint32_t i = 0; i < len; ++i) for(uint32_t i = 0; i < len; ++i)
reversemap[i] = ankerl::unordered_dense::set<Key, Hash>::hashtable_push(keys[i]...); reversemap[i] = ankerl::unordered_dense::set<Key, Hash>::hashtable_push(keys[i]...);
#pragma omp simd
for(uint32_t i = 0; i < len; ++i) for(uint32_t i = 0; i < len; ++i)
++ht_base[reversemap[i]]; ++ht_base[reversemap[i]];
} }
@ -182,10 +184,12 @@ public:
auto vecs = static_cast<vector_type<uint32_t>*>(malloc(sizeof(vector_type<uint32_t>) * len)); auto vecs = static_cast<vector_type<uint32_t>*>(malloc(sizeof(vector_type<uint32_t>) * len));
vecs[0].init_from(ht_base[0], mapbase); vecs[0].init_from(ht_base[0], mapbase);
#pragma omp simd
for (uint32_t i = 1; i < len; ++i) { for (uint32_t i = 1; i < len; ++i) {
vecs[i].init_from(ht_base[i], mapbase + ht_base[i - 1]); vecs[i].init_from(ht_base[i], mapbase + ht_base[i - 1]);
ht_base[i] += ht_base[i - 1]; ht_base[i] += ht_base[i - 1];
} }
#pragma omp simd
for (uint32_t i = 0; i < sz; ++i) { for (uint32_t i = 0; i < sz; ++i) {
auto id = reversemap[i]; auto id = reversemap[i];
mapbase[--ht_base[id]] = i; mapbase[--ht_base[id]] = i;
@ -194,11 +198,18 @@ public:
} }
}; };
template <class ... Ty>
struct HashTableComponents {
uint32_t size;
std::vector<std::tuple<Ty...>>* keys;
vector_type<uint32_t>* values;
uint32_t* offsets;
};
template < template <
typename ValueType = uint32_t, typename ValueType = uint32_t,
int PerfectHashingThreshold = 12 int PerfectHashingThreshold = 18
> > // default < 1M table size
struct PerfectHashTable { struct PerfectHashTable {
using key_t = std::conditional_t<PerfectHashingThreshold <= 8, uint8_t, using key_t = std::conditional_t<PerfectHashingThreshold <= 8, uint8_t,
std::conditional_t<PerfectHashingThreshold <= 16, uint16_t, std::conditional_t<PerfectHashingThreshold <= 16, uint16_t,
@ -207,7 +218,7 @@ struct PerfectHashTable {
>>>; >>>;
constexpr static uint32_t tbl_sz = 1 << PerfectHashingThreshold; constexpr static uint32_t tbl_sz = 1 << PerfectHashingThreshold;
template <typename ... Types, template <typename> class VT> template <typename ... Types, template <typename> class VT>
static vector_type<uint32_t>* static HashTableComponents<Types ...> //vector_type<uint32_t>*
construct(VT<Types>&... args) { // construct a hash set construct(VT<Types>&... args) { // construct a hash set
// AQTmr(); // AQTmr();
int n_cols, n_rows = 0; int n_cols, n_rows = 0;
@ -216,7 +227,7 @@ struct PerfectHashTable {
static_assert( static_assert(
(sizeof...(Types) < PerfectHashingThreshold) && (sizeof...(Types) < PerfectHashingThreshold) &&
(std::is_integral_v<Types> && ...), (std::is_integral_v<Types> && ...),
"Types must be integral and less than 12 wide in total." "Types must be integral and less than \"PerfectHashingThreshold\" wide in total."
); );
key_t* key_t*
hash_values = static_cast<key_t*>( hash_values = static_cast<key_t*>(
@ -241,9 +252,10 @@ struct PerfectHashTable {
}; };
int idx = 0; int idx = 0;
(get_hash(args, idx++), ...); (get_hash(args, idx++), ...);
uint32_t cnt[tbl_sz]; uint32_t *cnt_ext = static_cast<uint32_t*>(
calloc(tbl_sz, sizeof(uint32_t))
), *cnt = cnt_ext + 1;
uint32_t n_grps = 0; uint32_t n_grps = 0;
memset(cnt, 0, tbl_sz * sizeof(tbl_sz));
#pragma omp simd #pragma omp simd
for (uint32_t i = 0; i < n_cols; ++i) { for (uint32_t i = 0; i < n_cols; ++i) {
++cnt[hash_values[i]]; ++cnt[hash_values[i]];
@ -256,6 +268,24 @@ struct PerfectHashTable {
grp_ids[i] = n_grps++; grp_ids[i] = n_grps++;
} }
} }
std::vector<std::tuple<Types ...>>* keys = new std::vector<std::tuple<Types ...>>(n_grps); // Memory leak here, cleanup after module is done.
const char bits[] = {0, args.stats.bits ... };
auto decode = []<typename Ret>(auto &val, const char prev, const char curr) -> Ret {
val >>= prev;
const auto mask = (1 << curr) - 1;
return val & mask;
};
#pragma omp simd
for (ValueType i = 0; i < n_grps; ++ i) {
int idx2 = 1;
ValueType curr_val = grp_ids[i];
keys[i] = std::make_tuple((
decode.template operator()<Types>(
curr_val, bits[idx2 - 1], bits[idx2++]
), ...)
); // require C++20 for the calls to be executed sequentially.
}
uint32_t* idxs = static_cast<uint32_t*>( uint32_t* idxs = static_cast<uint32_t*>(
malloc(n_cols * sizeof(uint32_t)) malloc(n_cols * sizeof(uint32_t))
); );
@ -281,7 +311,47 @@ struct PerfectHashTable {
idxs_vec[i].container = idxs_ptr[i]; idxs_vec[i].container = idxs_ptr[i];
idxs_vec[i].size = cnt[i]; idxs_vec[i].size = cnt[i];
} }
free(hash_values); GC::gc_handle->reg(hash_values);
return idxs_vec;
#pragma omp simd
for(int i = 1; i < n_grps; ++ i)
cnt[i] += cnt[i - 1];
cnt_ext[0] = 0;
return {.size = n_grps, .keys = keys, .values = idxs_vec, .offset = cnt_ext};
} }
}; };
template <class>
class ColRef;
template <
class Key,
class Hash,
int PerfectHashingThreshold = 18
>
class HashTableFactory {
public:
template <class ... Ty>
static HashTableComponents<Ty ...>
get(ColRef<Ty>& ... cols) {
// To use Perfect Hash Table
if constexpr ((std::is_integral_v<Ty> && ...)) {
if ((cols.stats.bits + ...) <= PerfectHashingThreshold) {
return PerfectHashTable<
uint32_t,
PerfectHashingThreshold
>::construct(cols ...);
}
}
// Fallback to regular hash table
int n_rows = 0;
((n_rows = cols.size), ...);
AQHashTable<Key, Hash> ht{n_rows};
ht.template hashtable_push_all<decays<decltype(cols)> ...>(cols ..., n_rows);
auto vals = ht.ht_postproc(n_rows);
return {.size = ht.size(), .keys = ht.values(), .values = vals, .offset = ht.ht_base};
}
};

Loading…
Cancel
Save