master
Bill 2 years ago
parent c5bf4c46e4
commit acc610280e

@ -33,6 +33,7 @@ def common_parser():
return parser(ansi_string | aquery_doublequote_string, combined_ident) return parser(ansi_string | aquery_doublequote_string, combined_ident)
def parser(literal_string, ident): def parser(literal_string, ident):
with Whitespace() as engine: with Whitespace() as engine:
engine.add_ignore(Literal("--") + restOfLine) engine.add_ignore(Literal("--") + restOfLine)
@ -569,7 +570,25 @@ def parser(literal_string, ident):
+ index_type + index_type
+ index_column_names + index_column_names
+ index_options + index_options
)("create index") )("create_index")
create_trigger = (
keyword("create trigger")
+ var_name("name")
+ ((
ON
+ var_name("table")
+ keyword("action")
+ var_name("action")
+ WHEN
+ var_name("query") )
| (
keyword("action")
+ var_name("action")
+ INTERVAL
+ int_num("interval")
))
)("create_trigger")
cache_options = Optional(( cache_options = Optional((
keyword("options").suppress() keyword("options").suppress()
@ -693,7 +712,7 @@ def parser(literal_string, ident):
sql_stmts = delimited_list( ( sql_stmts = delimited_list( (
query query
| (insert | update | delete | load) | (insert | update | delete | load)
| (create_table | create_view | create_cache | create_index) | (create_table | create_view | create_cache | create_index | create_trigger)
| (drop_table | drop_view | drop_index) | (drop_table | drop_view | drop_index)
)("stmts"), ";") )("stmts"), ";")
@ -707,6 +726,5 @@ def parser(literal_string, ident):
|other_stmt |other_stmt
| keyword(";").suppress() # empty stmt | keyword(";").suppress() # empty stmt
) )
return stmts.finalize() return stmts.finalize()

@ -8,7 +8,7 @@ nums = '0123456789'
base62alp = nums + lower_alp + upper_alp base62alp = nums + lower_alp + upper_alp
reserved_monet = ['month'] reserved_monet = ['month']
session_context = None
class CaseInsensitiveDict(MutableMapping): class CaseInsensitiveDict(MutableMapping):
def __init__(self, data=None, **kwargs): def __init__(self, data=None, **kwargs):
@ -158,3 +158,35 @@ def get_innermost(sl):
return get_innermost(sl[0]) return get_innermost(sl[0])
else: else:
return sl return sl
def send_to_server(payload : str):
from prompt import PromptState
cxt : PromptState = session_context
if cxt is None:
raise RuntimeError("Error! no session specified.")
else:
from ctypes import c_char_p
cxt.payload = (c_char_p*1)(c_char_p(bytes(payload, 'utf-8')))
cxt.cfg.has_dll = 0
cxt.send(1, cxt.payload)
cxt.set_ready()
def get_storedproc(name : str):
from prompt import PromptState, StoredProcedure
cxt : PromptState = session_context
if cxt is None:
raise RuntimeError("Error! no session specified.")
else:
ret : StoredProcedure = cxt.get_storedproc(bytes(name, 'utf-8'))
if (
ret.name.value and
ret.name.value.decode('utf-8') != name
):
print(f'Procedure {name} mismatch in server {ret.name.value}')
return None
else:
return ret
def execute_procedure(proc):
pass

@ -119,6 +119,15 @@ class Backend_Type(enum.Enum):
BACKEND_MonetDB = 1 BACKEND_MonetDB = 1
BACKEND_MariaDB = 2 BACKEND_MariaDB = 2
class StoredProcedure(ctypes.Structure):
_fields_ = [
('cnt', ctypes.c_uint32),
('postproc_modules', ctypes.c_uint32),
('queries', ctypes.POINTER(ctypes.c_char_p)),
('name', ctypes.c_char_p),
('__rt_loaded_modules', ctypes.POINTER(ctypes.c_void_p)),
]
@dataclass @dataclass
class QueryStats: class QueryStats:
last_time : int = time.time() last_time : int = time.time()
@ -229,6 +238,7 @@ class PromptState():
server_bin = 'server.bin' if server_mode == RunType.IPC else 'server.so' server_bin = 'server.bin' if server_mode == RunType.IPC else 'server.so'
wait_engine = lambda: None wait_engine = lambda: None
wake_engine = lambda: None wake_engine = lambda: None
get_storedproc = lambda : StoredProcedure()
set_ready = lambda: None set_ready = lambda: None
get_ready = lambda: None get_ready = lambda: None
server_status = lambda: False server_status = lambda: False
@ -322,6 +332,8 @@ def init_threaded(state : PromptState):
state.send = server_so['receive_args'] state.send = server_so['receive_args']
state.wait_engine = server_so['wait_engine'] state.wait_engine = server_so['wait_engine']
state.wake_engine = server_so['wake_engine'] state.wake_engine = server_so['wake_engine']
state.get_storedproc = server_so['get_procedure']
state.get_storedproc.restype = StoredProcedure
aquery_config.have_hge = server_so['have_hge']() aquery_config.have_hge = server_so['have_hge']()
if aquery_config.have_hge != 0: if aquery_config.have_hge != 0:
from engine.types import get_int128_support from engine.types import get_int128_support
@ -330,9 +342,11 @@ def init_threaded(state : PromptState):
state.th.start() state.th.start()
def init_prompt() -> PromptState: def init_prompt() -> PromptState:
from engine.utils import session_context
aquery_config.init_config() aquery_config.init_config()
state = PromptState() state = PromptState()
session_context = state
# if aquery_config.rebuild_backend: # if aquery_config.rebuild_backend:
# try: # try:
# os.remove(state.server_bin) # os.remove(state.server_bin)
@ -454,7 +468,7 @@ def prompt(running = lambda:True, next = lambda:input('> '), state : Optional[Pr
continue continue
elif q.startswith('xexec') or q.startswith('exec'): # generate build and run (MonetDB Engine) elif q.startswith('xexec') or q.startswith('exec'): # generate build and run (MonetDB Engine)
state.cfg.backend_type = Backend_Type.BACKEND_MonetDB.value state.cfg.backend_type = Backend_Type.BACKEND_MonetDB.value
cxt = xengine.exec(state.stmts, cxt, keep) cxt = xengine.exec(state.stmts, cxt, keep, parser.parse)
this_udf = cxt.finalize_udf() this_udf = cxt.finalize_udf()
if this_udf: if this_udf:
@ -613,11 +627,7 @@ def prompt(running = lambda:True, next = lambda:input('> '), state : Optional[Pr
elif q.startswith('procedure'): elif q.startswith('procedure'):
qs = re.split(r'[ \t\r\n]', q) qs = re.split(r'[ \t\r\n]', q)
procedure_help = '''Usage: procedure <procedure_name> [record|stop|run|remove|save|load]''' procedure_help = '''Usage: procedure <procedure_name> [record|stop|run|remove|save|load]'''
def send_to_server(payload : str): from engine.utils import send_to_server
state.payload = (ctypes.c_char_p*1)(ctypes.c_char_p(bytes(payload, 'utf-8')))
state.cfg.has_dll = 0
state.send(1, state.payload)
state.set_ready()
if len(qs) > 2: if len(qs) > 2:
if qs[2].lower() =='record': if qs[2].lower() =='record':
if state.current_procedure is not None and state.current_procedure != qs[1]: if state.current_procedure is not None and state.current_procedure != qs[1]:

@ -18,10 +18,11 @@ def generate(ast, cxt):
if k in ast_node.types.keys(): if k in ast_node.types.keys():
ast_node.types[k](None, ast, cxt) ast_node.types[k](None, ast, cxt)
def exec(stmts, cxt = None, keep = False): def exec(stmts, cxt = None, keep = False, parser = None):
if 'stmts' not in stmts: if 'stmts' not in stmts:
return return
cxt = initialize(cxt, keep) cxt = initialize(cxt, keep)
cxt.parser = parser
stmts_stmts = stmts['stmts'] stmts_stmts = stmts['stmts']
if type(stmts_stmts) is list: if type(stmts_stmts) is list:
for s in stmts_stmts: for s in stmts_stmts:

@ -1082,6 +1082,61 @@ class create_table(ast_node):
if self.context.use_columnstore: if self.context.use_columnstore:
self.sql += ' engine=ColumnStore' self.sql += ' engine=ColumnStore'
class create_trigger(ast_node):
name = 'create_trigger'
first_order = name
class Type (Enum):
Interval = auto()
Callback = auto()
def produce(self, node):
from engine.utils import send_to_server, get_storedproc
node = node['create_trigger']
self.trigger_name = node['name']
self.action_name = node['action']
self.action = get_storedproc(self.action_name)
if self.trigger_name in self.context.triggers:
raise ValueError(f'trigger {self.trigger_name} exists')
elif self.action:
raise ValueError(f'Stored Procedure {self.action_name} do not exist')
if 'interval' in node: # executed periodically from server
self.type = self.Type.Interval
self.interval = node['interval']
send_to_server(f'TI{self.trigger_name}{self.action_name}{self.interval}')
else: # executed from sql backend
self.type = self.Type.Callback
self.query_name = node['query']
self.table_name = node['table']
self.procedure = get_storedproc(self.query_name)
if self.procedure and self.table_name in self.context.tables_byname:
self.table = self.context.tables_byname[self.table_name]
self.table.triggers.add(self)
else:
return
self.context.triggers[self.trigger_name] = self
# manually execute trigger
def register(self):
if self.type != self.Type.Callback:
self.context.triggers.pop(self.trigger_name)
raise ValueError(f'Trigger {self.trigger_name} is not a callback based trigger')
self.context.triggers_active.add(self)
def execute(self):
from engine.utils import send_to_server
send_to_server(f'TC{self.query_name}{self.action_name}')
def remove(self):
from engine.utils import send_to_server
send_to_server(f'TR{self.trigger_name}')
class drop_trigger(ast_node):
name = 'create_trigger'
first_order = name
def produce(self, node):
...
class drop(ast_node): class drop(ast_node):
name = 'drop' name = 'drop'
first_order = name first_order = name
@ -1111,9 +1166,11 @@ class insert(ast_node):
complex_query_kw = ['from', 'where', 'groupby', 'having', 'orderby', 'limit'] complex_query_kw = ['from', 'where', 'groupby', 'having', 'orderby', 'limit']
if any([kw in values for kw in complex_query_kw]): if any([kw in values for kw in complex_query_kw]):
values['into'] = node['insert'] values['into'] = node['insert']
proj_cls = (select_distinct proj_cls = (
if 'select_distinct' in values select_distinct
else projection) if 'select_distinct' in values
else projection
)
proj_cls(None, values, self.context) proj_cls(None, values, self.context)
self.produce = lambda*_:None self.produce = lambda*_:None
self.spawn = lambda*_:None self.spawn = lambda*_:None
@ -1147,6 +1204,11 @@ class insert(ast_node):
keys = f'({", ".join(keys)})' if keys else '' keys = f'({", ".join(keys)})' if keys else ''
tbl = node['insert'] tbl = node['insert']
if tbl not in self.context.tables_byname:
print('Warning: {tbl} not registered in aquery compiler.')
tbl_obj = self.context.tables_byname[tbl]
for t in tbl_obj.triggers:
t.register()
self.sql = f'INSERT INTO {tbl}{keys} VALUES' self.sql = f'INSERT INTO {tbl}{keys} VALUES'
# if len(values) != table.n_cols: # if len(values) != table.n_cols:
# raise ValueError("Column Mismatch") # raise ValueError("Column Mismatch")
@ -1624,6 +1686,11 @@ class passthru_sql(ast_node):
seprator = re.compile(r'''((?:[^;"']|"[^"]*"|'[^']*')+)''') seprator = re.compile(r'''((?:[^;"']|"[^"]*"|'[^']*')+)''')
def __init__(self, _, node, context:Context): def __init__(self, _, node, context:Context):
sqls = passthru_sql.seprator.split(node['sql']) sqls = passthru_sql.seprator.split(node['sql'])
try:
if callable(context.parser):
parsed = context.parser(node['sql'])
except BaseException:
parsed = None
for sql in sqls: for sql in sqls:
sq = sql.strip(' \t\n\r;') sq = sql.strip(' \t\n\r;')
if sq: if sq:

@ -64,12 +64,14 @@ class ColRef:
class TableInfo: class TableInfo:
def __init__(self, table_name, cols, cxt:'Context'): def __init__(self, table_name, cols, cxt:'Context'):
from reconstruct.ast import create_trigger
# statics # statics
self.table_name : str = table_name self.table_name : str = table_name
self.contextname_cpp : str = '' self.contextname_cpp : str = ''
self.alias : Set[str] = set([table_name]) self.alias : Set[str] = set([table_name])
self.columns_byname : CaseInsensitiveDict[str, ColRef] = CaseInsensitiveDict() # column_name, type self.columns_byname : CaseInsensitiveDict[str, ColRef] = CaseInsensitiveDict() # column_name, type
self.columns : List[ColRef] = [] self.columns : List[ColRef] = []
self.triggers : Set[create_trigger] = set()
self.cxt = cxt self.cxt = cxt
# keep track of temp vars # keep track of temp vars
self.rec = None self.rec = None
@ -156,9 +158,11 @@ class Context:
self.module_init_loc = 0 self.module_init_loc = 0
self.special_gb = False self.special_gb = False
self.has_dll = False self.has_dll = False
self.triggers_active.clear()
def __init__(self): def __init__(self):
self.tables_byname = dict() from .ast import create_trigger
self.tables_byname : Dict[str, TableInfo] = dict()
self.col_byname = dict() self.col_byname = dict()
self.tables : Set[TableInfo] = set() self.tables : Set[TableInfo] = set()
self.cols = [] self.cols = []
@ -174,6 +178,9 @@ class Context:
self.have_hge = False self.have_hge = False
self.Error = lambda *args: print(*args) self.Error = lambda *args: print(*args)
self.Info = lambda *_: None self.Info = lambda *_: None
self.triggers : Dict[str, create_trigger] = dict()
self.triggers_active = set()
self.stored_proceudres = dict()
# self.new() called everytime new query batch is started # self.new() called everytime new query batch is started
def get_scan_var(self): def get_scan_var(self):
@ -257,6 +264,13 @@ class Context:
self.queries.append( self.queries.append(
'O' + limit + sep + end) 'O' + limit + sep + end)
def remove_trigger(self, name : str):
from reconstruct.ast import create_trigger
val = self.triggers.pop(name, None)
if val.type == create_trigger.Type.Callback:
val.table.triggers.remove(val)
val.remove()
def abandon_postproc(self): def abandon_postproc(self):
self.ccode = '' self.ccode = ''
self.finalize_query() self.finalize_query()

@ -86,7 +86,8 @@ extern "C" int __DLLEXPORT__ binary_info() {
#endif #endif
} }
__AQEXPORT__(bool) have_hge(){ __AQEXPORT__(bool)
have_hge() {
#if defined(__MONETDB_CONN_H__) #if defined(__MONETDB_CONN_H__)
return Server::havehge(); return Server::havehge();
#else #else
@ -94,6 +95,19 @@ __AQEXPORT__(bool) have_hge(){
#endif #endif
} }
__AQEXPORT__(StoredProcedure)
get_procedure(Context* cxt, const char* name) {
auto res = cxt->stored_proc.find(name);
if (res == cxt->stored_proc.end())
return { .cnt = 0,
.postproc_modules = 0,
.queries = nullptr,
.name = nullptr,
.__rt_loaded_modules = nullptr
};
return res->second;
}
using prt_fn_t = char* (*)(void*, char*); using prt_fn_t = char* (*)(void*, char*);
// This function contains heap allocations, free after use // This function contains heap allocations, free after use

Loading…
Cancel
Save