diff --git a/.gitignore b/.gitignore index 508685f..739eae5 100644 --- a/.gitignore +++ b/.gitignore @@ -81,5 +81,5 @@ saves out*.cpp udf*.hpp *.ipynb - - +saved_procedures/** +procedures/** diff --git a/aquery_config.py b/aquery_config.py index caa4faa..094bc47 100644 --- a/aquery_config.py +++ b/aquery_config.py @@ -2,7 +2,7 @@ ## GLOBAL CONFIGURATION FLAGS -version_string = '0.5.3a' +version_string = '0.5.4a' add_path_to_ldpath = True rebuild_backend = False run_backend = True diff --git a/prompt.py b/prompt.py index 33240eb..3afb22c 100644 --- a/prompt.py +++ b/prompt.py @@ -246,6 +246,7 @@ class PromptState(): stats : Optional[QueryStats] = None currstats : Optional[QueryStats] = None buildmgr : Optional[build_manager]= None + current_procedure : Optional[str] = None ## CLASSES END ## FUNCTIONS BEGIN @@ -392,7 +393,7 @@ def save(q:str, cxt: xengine.Context): savefile('udf', 'udf', '.hpp') savefile('sql', 'sql') -def prompt(running = lambda:True, next = lambda:input('> '), state = None): +def prompt(running = lambda:True, next = lambda:input('> '), state : Optional[PromptState] = None): if state is None: state = init_prompt() q = '' @@ -609,6 +610,34 @@ def prompt(running = lambda:True, next = lambda:input('> '), state = None): state.stats.need_print = True state.stats.print(clear = False) continue + elif q.startswith('procedure'): + qs = re.split(r'[ \t\r\n]', q) + procedure_help = '''Usage: procedure [record|stop|run|remove|save|load]''' + send_to_server = lambda str: state.send(1, ctypes.c_char_p(bytes(str, 'utf-8'))) + if len(qs) > 2: + if qs[2].lower() =='record': + if state.current_procedure != qs[1]: + print(f'Cannot record 2 procedures at the same time. Stop recording {state.current_procedure} first.') + elif not state.current_procedure: + state.current_procedure = qs[1] + send_to_server(f'R\0{qs[1]}', 'utf-8') + elif qs[2].lower() == 'stop': + send_to_server(f'RT\0{qs[1]}') + else: + if state.current_procedure: + print(f'Procedure manipulation commands are disallowed during procedure recording.') + continue + if qs[2].lower() == 'run': + send_to_server(f'RE\0{qs[1]}') + elif qs[2].lower() == 'remove': + send_to_server(f'RD\0{qs[1]}') + elif qs[2].lower() == 'save': + send_to_server(f'RS\0{qs[1]}') + elif qs[2].lower() == 'load': + send_to_server(f'RL\0{qs[1]}') + else: + print(procedure_help) + continue trimed = ws.sub(' ', og_q).split(' ') if len(trimed) > 1 and trimed[0].lower().startswith('fi') or trimed[0].lower() == 'f': fn = 'stock.a' if len(trimed) <= 1 or len(trimed[1]) == 0 \ diff --git a/reconstruct/ast.py b/reconstruct/ast.py index e376487..07efd77 100644 --- a/reconstruct/ast.py +++ b/reconstruct/ast.py @@ -246,7 +246,7 @@ class projection(ast_node): self.datasource.rec = None # TODO: Type deduction in Python for t, n, c in zip(this_type, disp_name, compound): - cols.append(ColRef(t, self.out_table, None, n, len(cols), compound=c)) + cols.append(ColRef(t, None, self.out_table, n, len(cols), compound=c)) self.out_table.add_cols(cols, new = False) @@ -409,13 +409,14 @@ class projection(ast_node): if self.outfile and self.has_postproc: self.outfile.finalize() - if 'into' in node: - self.context.emitc(select_into(self, node['into']).ccode) - self.has_postproc = True if not self.distinct: - self.finalize() - - def finalize(self): + self.finalize(node) + + def deal_with_into(self, node): + if 'into' in node: + self.context.emitc(select_into(self, node['into']).ccode) + def finalize(self, node): + self.deal_with_into(node) self.context.emitc(f'puts("done.");') if self.parent is None: @@ -436,7 +437,7 @@ class select_distinct(projection): self.context.emitc( f'{self.out_table.contextname_cpp}->distinct();' ) - self.finalize() + self.finalize(node) class select_into(ast_node): def init(self, _): @@ -955,19 +956,31 @@ class union_all(ast_node): sql_name = 'UNION ALL' def produce(self, node): queries = node[self.name] - generated_queries : List[Optional[projection]] = [None] * len(queries) + self.generated_queries : List[Optional[projection]] = [None] * len(queries) is_standard = True for i, q in enumerate(queries): if 'select' in q: - generated_queries[i] = projection(self, q) - is_standard &= not generated_queries[i].has_postproc + self.generated_queries[i] = projection(self, q) + is_standard &= not self.generated_queries[i].has_postproc if is_standard: - self.sql = f' {self.sql_name} '.join([q.sql for q in generated_queries]) + self.sql = f' {self.sql_name} '.join([q.sql for q in self.generated_queries]) else: raise NotImplementedError(f"{self.sql_name} only support standard sql for now") def consume(self, node): + if 'into' in node: + outtable = TableInfo(node['into'], [], self.context) + lst_cols = [None] * len(self.generated_queries[0].out_table.columns) + for i, c in enumerate(self.generated_queries[0].out_table.columns): + lst_cols[i] = ColRef(c.type, None, outtable, c.name, i, c.compound) + outtable.add_cols(lst_cols, new = False) + + col_names = [c.name for c in outtable.columns] + col_names = '(' + ', '.join(col_names) + ')' + self.sql = f'CREATE TABLE {node["into"]} {col_names} AS {self.sql}' super().consume(node) - self.context.direct_output() + if 'into' not in node: + self.context.direct_output() + class except_clause(union_all): name = 'except' diff --git a/reconstruct/storage.py b/reconstruct/storage.py index 840e9ac..4636bd5 100644 --- a/reconstruct/storage.py +++ b/reconstruct/storage.py @@ -91,7 +91,7 @@ class TableInfo: _ty_val = list(_ty.keys())[0] _ty_args = _ty[_ty_val] _ty = _ty_val - if new: + if new or type(c) is not ColRef: col_object = ColRef(_ty, c, self, c['name'], len(self.columns), _ty_args = _ty_args) else: col_object = c diff --git a/server/libaquery.h b/server/libaquery.h index cc0b5a9..6bfc98c 100644 --- a/server/libaquery.h +++ b/server/libaquery.h @@ -9,6 +9,7 @@ #include #include +#include class aq_timer { private: std::chrono::high_resolution_clock::time_point now; @@ -66,6 +67,12 @@ struct Session{ void* memory_map; }; +struct StoredProcedure{ + uint32_t cnt, postproc_modules; + char **queries; + const char* name; + void **__rt_loaded_modules; +}; struct Context{ typedef int (*printf_type) (const char *format, ...); @@ -79,7 +86,7 @@ struct Context{ Log_level log_level = LOG_INFO; Session current; - + const char* aquery_root_path; #ifdef THREADING void* thread_pool; #endif @@ -104,6 +111,8 @@ struct Context{ void* get_module_function(const char*); std::unordered_map tables; std::unordered_map cols; + std::unordered_map loaded_modules; + std::unordered_map stored_proc; }; diff --git a/server/server.cpp b/server/server.cpp index 913ee4c..dd55597 100644 --- a/server/server.cpp +++ b/server/server.cpp @@ -7,6 +7,7 @@ #include "libaquery.h" #include "monetdb_conn.h" +#pragma region misc #ifdef THREADING #include "threading.h" #endif @@ -120,7 +121,6 @@ A_Semaphore prompt{ true }, engine{ false }; typedef int (*code_snippet)(void*); typedef void (*module_init_fn)(Context*); - int n_recv = 0; char** n_recvd = nullptr; @@ -265,11 +265,12 @@ void initialize_module(const char* module_name, void* module_handle, Context* cx printf("Warning: module %s have no session support.\n", module_name); } } - +#pragma endregion int dll_main(int argc, char** argv, Context* cxt){ aq_timer timer; Config *cfg = reinterpret_cast(argv[0]); std::unordered_map user_module_map; + std::string procedure_name = ""; if (cxt->module_function_maps == nullptr) cxt->module_function_maps = new std::unordered_map(); auto module_fn_map = @@ -279,7 +280,6 @@ int dll_main(int argc, char** argv, Context* cxt){ void** buffers = (void**) malloc (sizeof(void*) * cfg->n_buffers); for (int i = 0; i < cfg->n_buffers; i++) buffers[i] = static_cast(argv[i + 1]); - cxt->buffers = buffers; cxt->cfg = cfg; cxt->n_buffers = cfg->n_buffers; @@ -291,6 +291,7 @@ int dll_main(int argc, char** argv, Context* cxt){ puts(*(const char**)(alt_server->getCol(0))); cxt->alt_server = alt_server; } + bool rec = false; while(cfg->running){ ENGINE_ACQUIRE(); if (cfg->new_query) { @@ -334,7 +335,7 @@ int dll_main(int argc, char** argv, Context* cxt){ case 'M': // Load Module { auto mname = n_recvd[i] + 1; - user_module_handle = dlopen(mname, RTLD_LAZY); + user_module_handle = dlopen(mname, RTLD_NOW); //getlasterror if (!user_module_handle) @@ -360,7 +361,6 @@ int dll_main(int argc, char** argv, Context* cxt){ if(!server->haserror()){ uint32_t limit; memcpy(&limit, n_recvd[i] + 1, sizeof(uint32_t)); - // printf("Limit: %x\n", limit); if (limit == 0) continue; timer.reset(); @@ -379,6 +379,48 @@ int dll_main(int argc, char** argv, Context* cxt){ user_module_map.erase(it); } break; + case 'R': //recorded procedure + { + auto proc_name = n_recvd[i] + 1; + proc_name = *proc_name?proc_name : proc_name + 1; + const auto& load_modules = [](StoredProcedure &p){ + if (!p.__rt_loaded_modules){ + p.__rt_loaded_modules = static_cast( + malloc(sizeof(void*) * p.postproc_modules)); + for(uint32_t j = 0; j < p.postproc_modules; ++j){ + p.__rt_loaded_modules[j] = dlopen(p.name, RTLD_NOW); + } + } + }; + switch(n_recvd[i][1]){ + case '\0': + procedure_name = proc_name; + break; + case 'T': + procedure_name = ""; + break; + case 'E': // execute procedure + { + auto _proc = cxt->stored_proc.find(procedure_name.c_str()); + if (_proc == cxt->stored_proc.end()) + printf("Procedure %s not found.\n", procedure_name.c_str()); + else{ + StoredProcedure &p = _proc->second; + n_recv = p.cnt; + n_recvd = p.queries; + load_modules(p); + } + } + break; + case 'D': // delete procedure + break; + case 'S': //save procedure + break; + case 'L': //load procedure + break; + } + } + break; } } if(handle) { @@ -400,7 +442,7 @@ int dll_main(int argc, char** argv, Context* cxt){ // puts(cfg->has_dll ? "true" : "false"); if (cfg->backend_type == BACKEND_AQuery) { - handle = dlopen("./dll.so", RTLD_LAZY); + handle = dlopen("./dll.so", RTLD_NOW); code_snippet c = reinterpret_cast(dlsym(handle, "dllmain")); c(cxt); } @@ -444,6 +486,7 @@ extern "C" int __DLLEXPORT__ main(int argc, char** argv) { #endif // puts("running"); Context* cxt = new Context(); + cxt->aquery_root_path = std::filesystem::current_path().c_str(); // cxt->log("%d %s\n", argc, argv[1]); #ifdef THREADING @@ -472,7 +515,7 @@ extern "C" int __DLLEXPORT__ main(int argc, char** argv) { if(ready){ cxt->log("running: %s\n", running? "true":"false"); cxt->log("ready: %s\n", ready? "true":"false"); - void* handle = dlopen("./dll.so", RTLD_LAZY); + void* handle = dlopen("./dll.so", RTLD_NOW); cxt->log("handle: %p\n", handle); if (handle) { cxt->log("inner\n"); diff --git a/server/table.h b/server/table.h index c726017..a32705c 100644 --- a/server/table.h +++ b/server/table.h @@ -77,7 +77,6 @@ public: } template