From acc610280ebc2bba7261556f390f93cfa31b2c47 Mon Sep 17 00:00:00 2001 From: Bill Date: Wed, 15 Feb 2023 22:17:14 +0800 Subject: [PATCH] triggers --- aquery_parser/parser.py | 26 +++++++++++--- engine/utils.py | 34 +++++++++++++++++- prompt.py | 22 ++++++++---- reconstruct/__init__.py | 3 +- reconstruct/ast.py | 77 ++++++++++++++++++++++++++++++++++++++--- reconstruct/storage.py | 22 +++++++++--- server/server.cpp | 16 ++++++++- 7 files changed, 178 insertions(+), 22 deletions(-) diff --git a/aquery_parser/parser.py b/aquery_parser/parser.py index 819a18a..711344c 100644 --- a/aquery_parser/parser.py +++ b/aquery_parser/parser.py @@ -33,6 +33,7 @@ def common_parser(): return parser(ansi_string | aquery_doublequote_string, combined_ident) + def parser(literal_string, ident): with Whitespace() as engine: engine.add_ignore(Literal("--") + restOfLine) @@ -569,8 +570,26 @@ def parser(literal_string, ident): + index_type + index_column_names + index_options - )("create index") + )("create_index") + create_trigger = ( + keyword("create trigger") + + var_name("name") + + (( + ON + + var_name("table") + + keyword("action") + + var_name("action") + + WHEN + + var_name("query") ) + | ( + keyword("action") + + var_name("action") + + INTERVAL + + int_num("interval") + )) + )("create_trigger") + cache_options = Optional(( keyword("options").suppress() + LB @@ -693,7 +712,7 @@ def parser(literal_string, ident): sql_stmts = delimited_list( ( query | (insert | update | delete | load) - | (create_table | create_view | create_cache | create_index) + | (create_table | create_view | create_cache | create_index | create_trigger) | (drop_table | drop_view | drop_index) )("stmts"), ";") @@ -707,6 +726,5 @@ def parser(literal_string, ident): |other_stmt | keyword(";").suppress() # empty stmt ) - return stmts.finalize() - + \ No newline at end of file diff --git a/engine/utils.py b/engine/utils.py index 59b1309..00597f9 100644 --- a/engine/utils.py +++ b/engine/utils.py @@ -8,7 +8,7 @@ nums = '0123456789' base62alp = nums + lower_alp + upper_alp reserved_monet = ['month'] - +session_context = None class CaseInsensitiveDict(MutableMapping): def __init__(self, data=None, **kwargs): @@ -158,3 +158,35 @@ def get_innermost(sl): return get_innermost(sl[0]) else: return sl + + +def send_to_server(payload : str): + from prompt import PromptState + cxt : PromptState = session_context + if cxt is None: + raise RuntimeError("Error! no session specified.") + else: + from ctypes import c_char_p + cxt.payload = (c_char_p*1)(c_char_p(bytes(payload, 'utf-8'))) + cxt.cfg.has_dll = 0 + cxt.send(1, cxt.payload) + cxt.set_ready() + +def get_storedproc(name : str): + from prompt import PromptState, StoredProcedure + cxt : PromptState = session_context + if cxt is None: + raise RuntimeError("Error! no session specified.") + else: + ret : StoredProcedure = cxt.get_storedproc(bytes(name, 'utf-8')) + if ( + ret.name.value and + ret.name.value.decode('utf-8') != name + ): + print(f'Procedure {name} mismatch in server {ret.name.value}') + return None + else: + return ret + +def execute_procedure(proc): + pass diff --git a/prompt.py b/prompt.py index 4e72ceb..e7fefa9 100644 --- a/prompt.py +++ b/prompt.py @@ -119,6 +119,15 @@ class Backend_Type(enum.Enum): BACKEND_MonetDB = 1 BACKEND_MariaDB = 2 +class StoredProcedure(ctypes.Structure): + _fields_ = [ + ('cnt', ctypes.c_uint32), + ('postproc_modules', ctypes.c_uint32), + ('queries', ctypes.POINTER(ctypes.c_char_p)), + ('name', ctypes.c_char_p), + ('__rt_loaded_modules', ctypes.POINTER(ctypes.c_void_p)), + ] + @dataclass class QueryStats: last_time : int = time.time() @@ -229,6 +238,7 @@ class PromptState(): server_bin = 'server.bin' if server_mode == RunType.IPC else 'server.so' wait_engine = lambda: None wake_engine = lambda: None + get_storedproc = lambda : StoredProcedure() set_ready = lambda: None get_ready = lambda: None server_status = lambda: False @@ -322,6 +332,8 @@ def init_threaded(state : PromptState): state.send = server_so['receive_args'] state.wait_engine = server_so['wait_engine'] state.wake_engine = server_so['wake_engine'] + state.get_storedproc = server_so['get_procedure'] + state.get_storedproc.restype = StoredProcedure aquery_config.have_hge = server_so['have_hge']() if aquery_config.have_hge != 0: from engine.types import get_int128_support @@ -330,9 +342,11 @@ def init_threaded(state : PromptState): state.th.start() def init_prompt() -> PromptState: + from engine.utils import session_context aquery_config.init_config() state = PromptState() + session_context = state # if aquery_config.rebuild_backend: # try: # os.remove(state.server_bin) @@ -454,7 +468,7 @@ def prompt(running = lambda:True, next = lambda:input('> '), state : Optional[Pr continue elif q.startswith('xexec') or q.startswith('exec'): # generate build and run (MonetDB Engine) state.cfg.backend_type = Backend_Type.BACKEND_MonetDB.value - cxt = xengine.exec(state.stmts, cxt, keep) + cxt = xengine.exec(state.stmts, cxt, keep, parser.parse) this_udf = cxt.finalize_udf() if this_udf: @@ -613,11 +627,7 @@ def prompt(running = lambda:True, next = lambda:input('> '), state : Optional[Pr elif q.startswith('procedure'): qs = re.split(r'[ \t\r\n]', q) procedure_help = '''Usage: procedure [record|stop|run|remove|save|load]''' - def send_to_server(payload : str): - state.payload = (ctypes.c_char_p*1)(ctypes.c_char_p(bytes(payload, 'utf-8'))) - state.cfg.has_dll = 0 - state.send(1, state.payload) - state.set_ready() + from engine.utils import send_to_server if len(qs) > 2: if qs[2].lower() =='record': if state.current_procedure is not None and state.current_procedure != qs[1]: diff --git a/reconstruct/__init__.py b/reconstruct/__init__.py index 0bbba8c..4247989 100644 --- a/reconstruct/__init__.py +++ b/reconstruct/__init__.py @@ -18,10 +18,11 @@ def generate(ast, cxt): if k in ast_node.types.keys(): ast_node.types[k](None, ast, cxt) -def exec(stmts, cxt = None, keep = False): +def exec(stmts, cxt = None, keep = False, parser = None): if 'stmts' not in stmts: return cxt = initialize(cxt, keep) + cxt.parser = parser stmts_stmts = stmts['stmts'] if type(stmts_stmts) is list: for s in stmts_stmts: diff --git a/reconstruct/ast.py b/reconstruct/ast.py index e1430f3..aa0c609 100644 --- a/reconstruct/ast.py +++ b/reconstruct/ast.py @@ -1081,7 +1081,62 @@ class create_table(ast_node): self.sql += ')' if self.context.use_columnstore: self.sql += ' engine=ColumnStore' - + +class create_trigger(ast_node): + name = 'create_trigger' + first_order = name + class Type (Enum): + Interval = auto() + Callback = auto() + + def produce(self, node): + from engine.utils import send_to_server, get_storedproc + node = node['create_trigger'] + self.trigger_name = node['name'] + self.action_name = node['action'] + self.action = get_storedproc(self.action_name) + if self.trigger_name in self.context.triggers: + raise ValueError(f'trigger {self.trigger_name} exists') + elif self.action: + raise ValueError(f'Stored Procedure {self.action_name} do not exist') + + if 'interval' in node: # executed periodically from server + self.type = self.Type.Interval + self.interval = node['interval'] + send_to_server(f'TI{self.trigger_name}{self.action_name}{self.interval}') + else: # executed from sql backend + self.type = self.Type.Callback + self.query_name = node['query'] + self.table_name = node['table'] + self.procedure = get_storedproc(self.query_name) + if self.procedure and self.table_name in self.context.tables_byname: + self.table = self.context.tables_byname[self.table_name] + self.table.triggers.add(self) + else: + return + self.context.triggers[self.trigger_name] = self + + # manually execute trigger + def register(self): + if self.type != self.Type.Callback: + self.context.triggers.pop(self.trigger_name) + raise ValueError(f'Trigger {self.trigger_name} is not a callback based trigger') + self.context.triggers_active.add(self) + + def execute(self): + from engine.utils import send_to_server + send_to_server(f'TC{self.query_name}{self.action_name}') + + def remove(self): + from engine.utils import send_to_server + send_to_server(f'TR{self.trigger_name}') + +class drop_trigger(ast_node): + name = 'create_trigger' + first_order = name + def produce(self, node): + ... + class drop(ast_node): name = 'drop' first_order = name @@ -1111,9 +1166,11 @@ class insert(ast_node): complex_query_kw = ['from', 'where', 'groupby', 'having', 'orderby', 'limit'] if any([kw in values for kw in complex_query_kw]): values['into'] = node['insert'] - proj_cls = (select_distinct - if 'select_distinct' in values - else projection) + proj_cls = ( + select_distinct + if 'select_distinct' in values + else projection + ) proj_cls(None, values, self.context) self.produce = lambda*_:None self.spawn = lambda*_:None @@ -1147,6 +1204,11 @@ class insert(ast_node): keys = f'({", ".join(keys)})' if keys else '' tbl = node['insert'] + if tbl not in self.context.tables_byname: + print('Warning: {tbl} not registered in aquery compiler.') + tbl_obj = self.context.tables_byname[tbl] + for t in tbl_obj.triggers: + t.register() self.sql = f'INSERT INTO {tbl}{keys} VALUES' # if len(values) != table.n_cols: # raise ValueError("Column Mismatch") @@ -1161,7 +1223,7 @@ class insert(ast_node): list_values.append(f"({', '.join(inner_list_values)})") self.sql += ', '.join(list_values) - + class delete_from(ast_node): name = 'delete' @@ -1624,6 +1686,11 @@ class passthru_sql(ast_node): seprator = re.compile(r'''((?:[^;"']|"[^"]*"|'[^']*')+)''') def __init__(self, _, node, context:Context): sqls = passthru_sql.seprator.split(node['sql']) + try: + if callable(context.parser): + parsed = context.parser(node['sql']) + except BaseException: + parsed = None for sql in sqls: sq = sql.strip(' \t\n\r;') if sq: diff --git a/reconstruct/storage.py b/reconstruct/storage.py index d507342..974586a 100644 --- a/reconstruct/storage.py +++ b/reconstruct/storage.py @@ -64,12 +64,14 @@ class ColRef: class TableInfo: def __init__(self, table_name, cols, cxt:'Context'): + from reconstruct.ast import create_trigger # statics self.table_name : str = table_name self.contextname_cpp : str = '' self.alias : Set[str] = set([table_name]) self.columns_byname : CaseInsensitiveDict[str, ColRef] = CaseInsensitiveDict() # column_name, type self.columns : List[ColRef] = [] + self.triggers : Set[create_trigger] = set() self.cxt = cxt # keep track of temp vars self.rec = None @@ -83,7 +85,7 @@ class TableInfo: def add_cols(self, cols, new = True): for c in enlist(cols): self.add_col(c, new) - + def add_col(self, c, new = True): _ty = c['type'] _ty_args = None @@ -156,9 +158,11 @@ class Context: self.module_init_loc = 0 self.special_gb = False self.has_dll = False - + self.triggers_active.clear() + def __init__(self): - self.tables_byname = dict() + from .ast import create_trigger + self.tables_byname : Dict[str, TableInfo] = dict() self.col_byname = dict() self.tables : Set[TableInfo] = set() self.cols = [] @@ -174,6 +178,9 @@ class Context: self.have_hge = False self.Error = lambda *args: print(*args) self.Info = lambda *_: None + self.triggers : Dict[str, create_trigger] = dict() + self.triggers_active = set() + self.stored_proceudres = dict() # self.new() called everytime new query batch is started def get_scan_var(self): @@ -256,7 +263,14 @@ class Context: limit = limit.to_bytes(4, 'little').decode('latin-1') self.queries.append( 'O' + limit + sep + end) - + + def remove_trigger(self, name : str): + from reconstruct.ast import create_trigger + val = self.triggers.pop(name, None) + if val.type == create_trigger.Type.Callback: + val.table.triggers.remove(val) + val.remove() + def abandon_postproc(self): self.ccode = '' self.finalize_query() diff --git a/server/server.cpp b/server/server.cpp index 4438930..0966bcd 100644 --- a/server/server.cpp +++ b/server/server.cpp @@ -86,7 +86,8 @@ extern "C" int __DLLEXPORT__ binary_info() { #endif } -__AQEXPORT__(bool) have_hge(){ +__AQEXPORT__(bool) +have_hge() { #if defined(__MONETDB_CONN_H__) return Server::havehge(); #else @@ -94,6 +95,19 @@ __AQEXPORT__(bool) have_hge(){ #endif } +__AQEXPORT__(StoredProcedure) +get_procedure(Context* cxt, const char* name) { + auto res = cxt->stored_proc.find(name); + if (res == cxt->stored_proc.end()) + return { .cnt = 0, + .postproc_modules = 0, + .queries = nullptr, + .name = nullptr, + .__rt_loaded_modules = nullptr + }; + return res->second; +} + using prt_fn_t = char* (*)(void*, char*); // This function contains heap allocations, free after use