regression: nested aggregation support

dev
Bill 2 years ago
parent 3a56c19e2e
commit 44ccc0b835

@ -30,6 +30,8 @@ info:
$(info $(Threading)) $(info $(Threading))
$(info "test") $(info "test")
$(info $(CXX)) $(info $(CXX))
libaquery.a:
$(CXX) -c server/server.cpp server/io.cpp server/table.cpp $(OS_SUPPORT) $(Threading) $(OPTFLAGS) $(CXXFLAGS) -o server.bin
server.bin: server.bin:
$(CXX) server/server.cpp server/io.cpp server/table.cpp $(OS_SUPPORT) $(Threading) $(OPTFLAGS) $(CXXFLAGS) -o server.bin $(CXX) server/server.cpp server/io.cpp server/table.cpp $(OS_SUPPORT) $(Threading) $(OPTFLAGS) $(CXXFLAGS) -o server.bin
server.so: server.so:

@ -4,6 +4,8 @@
AQuery++ Database is a cross-platform, In-Memory Column-Store Database that incorporates compiled query execution. AQuery++ Database is a cross-platform, In-Memory Column-Store Database that incorporates compiled query execution.
## Architecture ## Architecture
![Architecture](./docs/arch-hybrid.svg)
### AQuery Compiler ### AQuery Compiler
- The query is first processed by the AQuery Compiler which is composed of a frontend that parses the query into AST and a backend that generates target code that delivers the query. - The query is first processed by the AQuery Compiler which is composed of a frontend that parses the query into AST and a backend that generates target code that delivers the query.
- Front end of AQuery++ Compiler is built on top of [mo-sql-parsing](https://github.com/klahnakoski/mo-sql-parsing) with modifications to handle AQuery dialect and extension. - Front end of AQuery++ Compiler is built on top of [mo-sql-parsing](https://github.com/klahnakoski/mo-sql-parsing) with modifications to handle AQuery dialect and extension.

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 291 KiB

@ -278,6 +278,7 @@ def is_null_call_behavior(op:OperatorBase, c_code : bool, x : str):
spnull = OperatorBase('missing', 1, logical, cname = "", sqlname = "", call = is_null_call_behavior) spnull = OperatorBase('missing', 1, logical, cname = "", sqlname = "", call = is_null_call_behavior)
# cstdlib # cstdlib
# If in aggregation functions, using monetdb builtins. If in nested agg, inside udfs, using cstdlib.
fnsqrt = OperatorBase('sqrt', 1, lambda *_ : DoubleT, cname = 'sqrt', sqlname = 'SQRT', call = fn_behavior) fnsqrt = OperatorBase('sqrt', 1, lambda *_ : DoubleT, cname = 'sqrt', sqlname = 'SQRT', call = fn_behavior)
fnlog = OperatorBase('log', 2, lambda *_ : DoubleT, cname = 'log', sqlname = 'LOG', call = fn_behavior) fnlog = OperatorBase('log', 2, lambda *_ : DoubleT, cname = 'log', sqlname = 'LOG', call = fn_behavior)
fnsin = OperatorBase('sin', 1, lambda *_ : DoubleT, cname = 'sin', sqlname = 'SIN', call = fn_behavior) fnsin = OperatorBase('sin', 1, lambda *_ : DoubleT, cname = 'sin', sqlname = 'SIN', call = fn_behavior)

@ -348,7 +348,7 @@ def prompt(running = lambda:True, next = input, state = None):
else: else:
print(prompt_help) print(prompt_help)
continue continue
elif q == 'xexec': # generate build and run (MonetDB Engine) elif q.startswith('xexec'): # generate build and run (MonetDB Engine)
state.cfg.backend_type = Backend_Type.BACKEND_MonetDB.value state.cfg.backend_type = Backend_Type.BACKEND_MonetDB.value
cxt = xengine.exec(state.stmts, cxt, keep) cxt = xengine.exec(state.stmts, cxt, keep)
if state.server_mode == RunType.Threaded: if state.server_mode == RunType.Threaded:
@ -362,17 +362,22 @@ def prompt(running = lambda:True, next = input, state = None):
state.send(sz, payload) state.send(sz, payload)
except TypeError as e: except TypeError as e:
print(e) print(e)
if cxt.udf is not None: this_udf = cxt.finalize_udf()
with open('udf.hpp', 'wb') as outfile:
outfile.write(cxt.udf.encode('utf-8'))
if this_udf:
with open('udf.hpp', 'wb') as outfile:
outfile.write(this_udf.encode('utf-8'))
qs = re.split(r'[ \t]', q)
build_this = not(len(qs) > 1 and qs[1].startswith('n'))
if cxt.has_dll: if cxt.has_dll:
with open('out.cpp', 'wb') as outfile: with open('out.cpp', 'wb') as outfile:
outfile.write((cxt.finalize()).encode('utf-8')) outfile.write((cxt.finalize()).encode('utf-8'))
if build_this:
subprocess.call(['make', 'snippet'], stdout = nullstream) subprocess.call(['make', 'snippet'], stdout = nullstream)
state.cfg.has_dll = 1 state.cfg.has_dll = 1
else: else:
state.cfg.has_dll = 0 state.cfg.has_dll = 0
if build_this:
state.set_ready() state.set_ready()
continue continue
@ -465,13 +470,14 @@ def prompt(running = lambda:True, next = input, state = None):
except SystemExit: except SystemExit:
print("\nBye.") print("\nBye.")
raise raise
except: except BaseException as e:
import code, traceback import code, traceback
sh = code.InteractiveConsole({**globals(), **locals()}) sh = code.InteractiveConsole({**globals(), **locals()})
sh.interact(banner = traceback.format_exc(), exitmsg = 'debugging session ended.') sh.interact(banner = traceback.format_exc(), exitmsg = 'debugging session ended.')
save('', cxt) save('', cxt)
rm(state) rm(state)
raise raise e
rm(state) rm(state)
## FUNCTIONS END ## FUNCTIONS END

@ -70,14 +70,14 @@ class projection(ast_node):
self.context.postproc_begin(self.postproc_fname) self.context.postproc_begin(self.postproc_fname)
def spawn(self, node): def spawn(self, node):
self.datasource = None # datasource is Join instead of TableInfo self.datasource = join(self, [], self.context) # datasource is Join instead of TableInfo
self.assumptions = []
if 'from' in node: if 'from' in node:
from_clause = node['from'] from_clause = node['from']
self.datasource = join(self, from_clause) self.datasource = join(self, from_clause)
if 'assumptions' in from_clause: if 'assumptions' in from_clause:
self.assumptions = enlist(from_clause['assumptions']) self.assumptions = enlist(from_clause['assumptions'])
else:
self.assumptions = []
if self.datasource is not None: if self.datasource is not None:
self.datasource_changed = True self.datasource_changed = True
self.prev_datasource = self.context.datasource self.prev_datasource = self.context.datasource
@ -157,7 +157,6 @@ class projection(ast_node):
def finialize(astnode:ast_node): def finialize(astnode:ast_node):
if(astnode is not None): if(astnode is not None):
self.add(astnode.sql) self.add(astnode.sql)
self.add('FROM')
finialize(self.datasource) finialize(self.datasource)
finialize(self.where) finialize(self.where)
if self.group_node and not self.group_node.use_sp_gb: if self.group_node and not self.group_node.use_sp_gb:
@ -469,6 +468,7 @@ class join(ast_node):
self.tables : List[TableInfo] = [] self.tables : List[TableInfo] = []
self.tables_dir = dict() self.tables_dir = dict()
self.rec = None self.rec = None
self.top_level = self.parent and type(self.parent) is projection
# self.tmp_name = 'join_' + base62uuid(4) # self.tmp_name = 'join_' + base62uuid(4)
# self.datasource = TableInfo(self.tmp_name, [], self.context) # self.datasource = TableInfo(self.tmp_name, [], self.context)
def append(self, tbls, __alias = ''): def append(self, tbls, __alias = ''):
@ -547,9 +547,12 @@ class join(ast_node):
@property @property
def all_cols(self): def all_cols(self):
return set([c for t in self.tables for c in t.columns]) return set([c for t in self.tables for c in t.columns])
def consume(self, _): def consume(self, node):
self.sql = ', '.join(self.joins) self.sql = ', '.join(self.joins)
return super().consume(_) if node and self.sql and self.top_level:
self.sql = ' FROM ' + self.sql
return super().consume(node)
def __str__(self): def __str__(self):
return ', '.join(self.joins) return ', '.join(self.joins)
def __repr__(self): def __repr__(self):
@ -644,18 +647,18 @@ class load(ast_node):
ret_type = Types.decode(f['ret_type']) ret_type = Types.decode(f['ret_type'])
nargs = 0 nargs = 0
arglist = '' arglist = ''
if 'var' in f: if 'vars' in f:
arglist = [] arglist = []
for v in enlist(f['var']): for v in enlist(f['vars']):
arglist.append(f'{Types.decode(v["type"]).cname} {v["arg"]}') arglist.append(f'{Types.decode(v["type"]).cname} {v["arg"]}')
nargs = len(arglist) nargs = len(arglist)
arglist = ', '.join(arglist) arglist = ', '.join(arglist)
# create c++ stub # create c++ stub
cpp_stub = f'{ret_type.cname} (*{fname})({arglist});' cpp_stub = f'{ret_type.cname} (*{fname})({arglist}) = nullptr;'
self.context.module_stubs += cpp_stub + '\n' self.context.module_stubs += cpp_stub + '\n'
self.context.module_map[fname] = cpp_stub self.context.module_map[fname] = cpp_stub
#registration for parser #registration for parser
self.functions[fname] = user_module_function(fname, nargs, ret_type) self.functions[fname] = user_module_function(fname, nargs, ret_type, self.context)
def produce_aq(self, node): def produce_aq(self, node):
node = node['load'] node = node['load']
@ -723,6 +726,12 @@ class outfile(ast_node):
class udf(ast_node): class udf(ast_node):
name = 'udf' name = 'udf'
first_order = name first_order = name
@staticmethod
def try_init_udf(context : Context):
if context.udf is None:
context.udf = '/*UDF Start*/\n'
context.headers.add('\"./udf.hpp\"')
@dataclass @dataclass
class builtin_var: class builtin_var:
enabled : bool = False enabled : bool = False
@ -754,13 +763,7 @@ class udf(ast_node):
} }
self.var_table = {} self.var_table = {}
self.args = [] self.args = []
if self.context.udf is None: udf.try_init_udf(self.context)
self.context.udf = (
Context.udf_head
+ self.context.module_stubs
+ self.context.get_init_func()
)
self.context.headers.add('\"./udf.hpp\"')
self.vecs = set() self.vecs = set()
self.code_list = [] self.code_list = []
self.builtin_used = None self.builtin_used = None
@ -983,10 +986,11 @@ class udf(ast_node):
return udf.ReturnPattern.bulk_return return udf.ReturnPattern.bulk_return
class user_module_function(OperatorBase): class user_module_function(OperatorBase):
def __init__(self, name, nargs, ret_type): def __init__(self, name, nargs, ret_type, context : Context):
super().__init__(name, nargs, lambda: ret_type, call=fn_behavior) super().__init__(name, nargs, lambda *_: ret_type, call=fn_behavior)
user_module_func[name] = self user_module_func[name] = self
builtin_operators[name] = self # builtin_operators[name] = self
udf.try_init_udf(context)
def include(objs): def include(objs):
import inspect import inspect

@ -42,10 +42,12 @@ class expr(ast_node):
if(type(parent) is expr): if(type(parent) is expr):
self.inside_agg = parent.inside_agg self.inside_agg = parent.inside_agg
self.is_udfexpr = parent.is_udfexpr self.is_udfexpr = parent.is_udfexpr
self.is_agg_func = parent.is_agg_func
self.root : expr = parent.root self.root : expr = parent.root
self.c_code = parent.c_code self.c_code = parent.c_code
self.builtin_vars = parent.builtin_vars self.builtin_vars = parent.builtin_vars
else: else:
self.is_agg_func = False
self.is_udfexpr = type(parent) is udf self.is_udfexpr = type(parent) is udf
self.root : expr = self self.root : expr = self
self.c_code = self.is_udfexpr or type(parent) is projection self.c_code = self.is_udfexpr or type(parent) is projection
@ -71,8 +73,8 @@ class expr(ast_node):
else: else:
self.datasource = self.context.datasource self.datasource = self.context.datasource
self.udf_map = parent.context.udf_map self.udf_map = parent.context.udf_map
self.func_maps = {**builtin_func, **self.udf_map} self.func_maps = {**builtin_func, **self.udf_map, **user_module_func}
self.operators = {**builtin_operators, **self.udf_map} self.operators = {**builtin_operators, **self.udf_map, **user_module_func}
def produce(self, node): def produce(self, node):
from engine.utils import enlist from engine.utils import enlist
@ -81,6 +83,12 @@ class expr(ast_node):
if type(node) is dict: if type(node) is dict:
for key, val in node.items(): for key, val in node.items():
if key in self.operators: if key in self.operators:
if key in builtin_func:
if self.is_agg_func:
self.root.is_special = True # Nested Aggregation
else:
self.is_agg_func = True
op = self.operators[key] op = self.operators[key]
val = enlist(val) val = enlist(val)
@ -95,13 +103,15 @@ class expr(ast_node):
self.type = AnyT self.type = AnyT
self.sql = op(self.c_code, *str_vals) self.sql = op(self.c_code, *str_vals)
special_func = [*self.context.udf_map.keys(), "maxs", "mins", "avgs", "sums"] special_func = [*self.context.udf_map.keys(), *self.context.module_map.keys(), "maxs", "mins", "avgs", "sums"]
if key in special_func and not self.is_special: if key in special_func and not self.is_special:
self.is_special = True self.is_special = True
if key in self.context.udf_map: if key in self.context.udf_map:
self.root.udf_called = self.context.udf_map[key] self.root.udf_called = self.context.udf_map[key]
if self.is_udfexpr and key == self.root.udf.name: if self.is_udfexpr and key == self.root.udf.name:
self.root.is_recursive_call_inudf = True self.root.is_recursive_call_inudf = True
elif key in user_module_func.keys():
udf.try_init_udf(self.context)
# TODO: make udf_called a set! # TODO: make udf_called a set!
p = self.parent p = self.parent
while type(p) is expr and not p.udf_called: while type(p) is expr and not p.udf_called:
@ -201,7 +211,7 @@ class expr(ast_node):
if self.c_code and self.datasource is not None: if self.c_code and self.datasource is not None:
self.sql = f'{{y(\"{self.sql}\")}}' self.sql = f'{{y(\"{self.sql}\")}}'
elif type(node) is bool: elif type(node) is bool:
self.type = ByteT self.type = BoolT
if self.c_code: if self.c_code:
self.sql = '1' if node else '0' self.sql = '1' if node else '0'
else: else:

@ -136,20 +136,24 @@ class Context:
'#include \"./server/libaquery.h\"\n' '#include \"./server/libaquery.h\"\n'
'#include \"./server/aggregations.h\"\n\n' '#include \"./server/aggregations.h\"\n\n'
) )
def get_init_func(self): def get_init_func(self):
if self.module_map: if not self.module_map:
return '' return ''
ret = 'void init(Context* cxt){\n' ret = '__AQEXPORT__(void) __builtin_init_user_module(Context* cxt){\n'
for fname in self.module_map.keys(): for fname in self.module_map.keys():
ret += f'{fname} = (decltype({fname}))(cxt->get_module_function("{fname}"));\n' ret += f'{fname} = (decltype({fname}))(cxt->get_module_function("{fname}"));\n'
self.queries.insert(0, f'P__builtin_init_user_module')
return ret + '}\n' return ret + '}\n'
def sql_begin(self): def sql_begin(self):
self.sql = '' self.sql = ''
def sql_end(self): def sql_end(self):
if self.sql.strip():
self.queries.append('Q' + self.sql) self.queries.append('Q' + self.sql)
self.sql = '' self.sql = ''
def postproc_begin(self, proc_name: str): def postproc_begin(self, proc_name: str):
self.ccode = self.function_deco + proc_name + self.function_head self.ccode = self.function_deco + proc_name + self.function_head
@ -158,6 +162,16 @@ class Context:
self.ccode = '' self.ccode = ''
self.queries.append('P' + proc_name) self.queries.append('P' + proc_name)
def finalize_udf(self):
if self.udf is not None:
return (Context.udf_head
+ self.module_stubs
+ self.get_init_func()
+ self.udf
)
else:
return None
def finalize(self): def finalize(self):
if not self.finalized: if not self.finalized:
headers = '' headers = ''

@ -1,16 +1,77 @@
#include "../server/libaquery.h" #ifndef _AQUERY_H
#define _AQUERY_H
typedef void (*dealloctor_t) (void*); enum Log_level {
LOG_INFO,
LOG_ERROR,
LOG_SILENT
};
extern void* Aalloc(size_t sz); enum Backend_Type {
extern void Afree(void * mem); BACKEND_AQuery,
extern size_t register_memory(void* ptr, dealloctor_t deallocator); BACKEND_MonetDB,
BACKEND_MariaDB
};
struct Config{
int running, new_query, server_mode,
backend_type, has_dll, n_buffers;
int buffer_sizes[];
};
struct Session{ struct Session{
struct Statistic{ struct Statistic{
size_t total_active; unsigned long long total_active;
size_t cnt_object; unsigned long long cnt_object;
size_t total_alloc; unsigned long long total_alloc;
}; };
void* memory_map; void* memory_map;
}; };
struct Context{
typedef int (*printf_type) (const char *format, ...);
void* module_function_maps = 0;
Config* cfg;
int n_buffers, *sz_bufs;
void **buffers;
void* alt_server;
Log_level log_level = LOG_INFO;
Session current;
#ifdef THREADING
void* thread_pool;
#endif
printf_type print = printf;
template <class ...Types>
void log(Types... args) {
if (log_level == LOG_INFO)
print(args...);
}
template <class ...Types>
void err(Types... args) {
if (log_level <= LOG_ERROR)
print(args...);
}
void init_session();
void end_session();
void* get_module_function(const char*);
char remainder[];
};
#ifdef _WIN32
#define __DLLEXPORT__ __declspec(dllexport) __stdcall
#else
#define __DLLEXPORT__
#endif
#define __AQEXPORT__(_Ty) extern "C" _Ty __DLLEXPORT__
typedef void (*dealloctor_t) (void*);
extern void* Aalloc(size_t sz);
extern void Afree(void * mem);
extern size_t register_memory(void* ptr, dealloctor_t deallocator);
#endif

@ -32,8 +32,7 @@ struct Session{
struct Context{ struct Context{
typedef int (*printf_type) (const char *format, ...); typedef int (*printf_type) (const char *format, ...);
std::unordered_map<const char*, void*> tables;
std::unordered_map<const char*, uColRef *> cols;
void* module_function_maps = 0; void* module_function_maps = 0;
Config* cfg; Config* cfg;
@ -63,6 +62,8 @@ struct Context{
void init_session(); void init_session();
void end_session(); void end_session();
void* get_module_function(const char*); void* get_module_function(const char*);
std::unordered_map<const char*, void*> tables;
std::unordered_map<const char*, uColRef *> cols;
}; };
#ifdef _WIN32 #ifdef _WIN32

@ -248,6 +248,7 @@ int test_main()
cxt->alt_server = new Server(cxt); cxt->alt_server = new Server(cxt);
Server* server = reinterpret_cast<Server*>(cxt->alt_server); Server* server = reinterpret_cast<Server*>(cxt->alt_server);
const char* qs[]= { const char* qs[]= {
"SELECT MIN(3)-MAX(2);",
"CREATE TABLE stocks(timestamp INT, price INT);", "CREATE TABLE stocks(timestamp INT, price INT);",
"INSERT INTO stocks VALUES(1, 15);;", "INSERT INTO stocks VALUES(1, 15);;",
"INSERT INTO stocks VALUES(2,19); ", "INSERT INTO stocks VALUES(2,19); ",

@ -1,5 +1,8 @@
LOAD MODULE FROM "test.so" LOAD MODULE FROM "test.so"
FUNCTIONS ( FUNCTIONS (
div(a:int, b:int) -> double, mydiv(a:int, b:int) -> double,
mulvec(a:int, b:vecfloat) -> vecfloat mulvec(a:int, b:vecfloat) -> vecfloat
); );
select mydiv(2,3);

@ -18,15 +18,15 @@ INSERT INTO stocks VALUES(15,2)
INSERT INTO stocks VALUES(16,5) INSERT INTO stocks VALUES(16,5)
/*<k> "q1" </k>*/ /*<k> "q1" </k>*/
-- SELECT max(price-min(timestamp)) FROM stocks SELECT max(price-min(timestamp)) FROM stocks
/*<k> "q2" </k>*/ /*<k> "q2" </k>*/
-- SELECT max(price-mins(price)) FROM stocks SELECT max(price-mins(price)) FROM stocks
/*<k> "q3"</k>*/ /*<k> "q3"</k>*/
SELECT price, timestamp FROM stocks where price - timestamp > 1 and not (price*timestamp<100) SELECT price, timestamp FROM stocks where price - timestamp > 1 and not (price*timestamp<100)
/*<k> "q4"</k>*/ /*<k> "q4"</k>*/
-- SELECT max(price-mins(price)) SELECT max(price-mins(price))
-- FROM stocks FROM stocks
-- ASSUMING DESC timestamp ASSUMING DESC timestamp

Loading…
Cancel
Save