master
Bill 2 years ago
parent c5bf4c46e4
commit acc610280e

@ -33,6 +33,7 @@ def common_parser():
return parser(ansi_string | aquery_doublequote_string, combined_ident)
def parser(literal_string, ident):
with Whitespace() as engine:
engine.add_ignore(Literal("--") + restOfLine)
@ -569,8 +570,26 @@ def parser(literal_string, ident):
+ index_type
+ index_column_names
+ index_options
)("create index")
)("create_index")
create_trigger = (
keyword("create trigger")
+ var_name("name")
+ ((
ON
+ var_name("table")
+ keyword("action")
+ var_name("action")
+ WHEN
+ var_name("query") )
| (
keyword("action")
+ var_name("action")
+ INTERVAL
+ int_num("interval")
))
)("create_trigger")
cache_options = Optional((
keyword("options").suppress()
+ LB
@ -693,7 +712,7 @@ def parser(literal_string, ident):
sql_stmts = delimited_list( (
query
| (insert | update | delete | load)
| (create_table | create_view | create_cache | create_index)
| (create_table | create_view | create_cache | create_index | create_trigger)
| (drop_table | drop_view | drop_index)
)("stmts"), ";")
@ -707,6 +726,5 @@ def parser(literal_string, ident):
|other_stmt
| keyword(";").suppress() # empty stmt
)
return stmts.finalize()

@ -8,7 +8,7 @@ nums = '0123456789'
base62alp = nums + lower_alp + upper_alp
reserved_monet = ['month']
session_context = None
class CaseInsensitiveDict(MutableMapping):
def __init__(self, data=None, **kwargs):
@ -158,3 +158,35 @@ def get_innermost(sl):
return get_innermost(sl[0])
else:
return sl
def send_to_server(payload : str):
from prompt import PromptState
cxt : PromptState = session_context
if cxt is None:
raise RuntimeError("Error! no session specified.")
else:
from ctypes import c_char_p
cxt.payload = (c_char_p*1)(c_char_p(bytes(payload, 'utf-8')))
cxt.cfg.has_dll = 0
cxt.send(1, cxt.payload)
cxt.set_ready()
def get_storedproc(name : str):
from prompt import PromptState, StoredProcedure
cxt : PromptState = session_context
if cxt is None:
raise RuntimeError("Error! no session specified.")
else:
ret : StoredProcedure = cxt.get_storedproc(bytes(name, 'utf-8'))
if (
ret.name.value and
ret.name.value.decode('utf-8') != name
):
print(f'Procedure {name} mismatch in server {ret.name.value}')
return None
else:
return ret
def execute_procedure(proc):
pass

@ -119,6 +119,15 @@ class Backend_Type(enum.Enum):
BACKEND_MonetDB = 1
BACKEND_MariaDB = 2
class StoredProcedure(ctypes.Structure):
_fields_ = [
('cnt', ctypes.c_uint32),
('postproc_modules', ctypes.c_uint32),
('queries', ctypes.POINTER(ctypes.c_char_p)),
('name', ctypes.c_char_p),
('__rt_loaded_modules', ctypes.POINTER(ctypes.c_void_p)),
]
@dataclass
class QueryStats:
last_time : int = time.time()
@ -229,6 +238,7 @@ class PromptState():
server_bin = 'server.bin' if server_mode == RunType.IPC else 'server.so'
wait_engine = lambda: None
wake_engine = lambda: None
get_storedproc = lambda : StoredProcedure()
set_ready = lambda: None
get_ready = lambda: None
server_status = lambda: False
@ -322,6 +332,8 @@ def init_threaded(state : PromptState):
state.send = server_so['receive_args']
state.wait_engine = server_so['wait_engine']
state.wake_engine = server_so['wake_engine']
state.get_storedproc = server_so['get_procedure']
state.get_storedproc.restype = StoredProcedure
aquery_config.have_hge = server_so['have_hge']()
if aquery_config.have_hge != 0:
from engine.types import get_int128_support
@ -330,9 +342,11 @@ def init_threaded(state : PromptState):
state.th.start()
def init_prompt() -> PromptState:
from engine.utils import session_context
aquery_config.init_config()
state = PromptState()
session_context = state
# if aquery_config.rebuild_backend:
# try:
# os.remove(state.server_bin)
@ -454,7 +468,7 @@ def prompt(running = lambda:True, next = lambda:input('> '), state : Optional[Pr
continue
elif q.startswith('xexec') or q.startswith('exec'): # generate build and run (MonetDB Engine)
state.cfg.backend_type = Backend_Type.BACKEND_MonetDB.value
cxt = xengine.exec(state.stmts, cxt, keep)
cxt = xengine.exec(state.stmts, cxt, keep, parser.parse)
this_udf = cxt.finalize_udf()
if this_udf:
@ -613,11 +627,7 @@ def prompt(running = lambda:True, next = lambda:input('> '), state : Optional[Pr
elif q.startswith('procedure'):
qs = re.split(r'[ \t\r\n]', q)
procedure_help = '''Usage: procedure <procedure_name> [record|stop|run|remove|save|load]'''
def send_to_server(payload : str):
state.payload = (ctypes.c_char_p*1)(ctypes.c_char_p(bytes(payload, 'utf-8')))
state.cfg.has_dll = 0
state.send(1, state.payload)
state.set_ready()
from engine.utils import send_to_server
if len(qs) > 2:
if qs[2].lower() =='record':
if state.current_procedure is not None and state.current_procedure != qs[1]:

@ -18,10 +18,11 @@ def generate(ast, cxt):
if k in ast_node.types.keys():
ast_node.types[k](None, ast, cxt)
def exec(stmts, cxt = None, keep = False):
def exec(stmts, cxt = None, keep = False, parser = None):
if 'stmts' not in stmts:
return
cxt = initialize(cxt, keep)
cxt.parser = parser
stmts_stmts = stmts['stmts']
if type(stmts_stmts) is list:
for s in stmts_stmts:

@ -1081,7 +1081,62 @@ class create_table(ast_node):
self.sql += ')'
if self.context.use_columnstore:
self.sql += ' engine=ColumnStore'
class create_trigger(ast_node):
name = 'create_trigger'
first_order = name
class Type (Enum):
Interval = auto()
Callback = auto()
def produce(self, node):
from engine.utils import send_to_server, get_storedproc
node = node['create_trigger']
self.trigger_name = node['name']
self.action_name = node['action']
self.action = get_storedproc(self.action_name)
if self.trigger_name in self.context.triggers:
raise ValueError(f'trigger {self.trigger_name} exists')
elif self.action:
raise ValueError(f'Stored Procedure {self.action_name} do not exist')
if 'interval' in node: # executed periodically from server
self.type = self.Type.Interval
self.interval = node['interval']
send_to_server(f'TI{self.trigger_name}{self.action_name}{self.interval}')
else: # executed from sql backend
self.type = self.Type.Callback
self.query_name = node['query']
self.table_name = node['table']
self.procedure = get_storedproc(self.query_name)
if self.procedure and self.table_name in self.context.tables_byname:
self.table = self.context.tables_byname[self.table_name]
self.table.triggers.add(self)
else:
return
self.context.triggers[self.trigger_name] = self
# manually execute trigger
def register(self):
if self.type != self.Type.Callback:
self.context.triggers.pop(self.trigger_name)
raise ValueError(f'Trigger {self.trigger_name} is not a callback based trigger')
self.context.triggers_active.add(self)
def execute(self):
from engine.utils import send_to_server
send_to_server(f'TC{self.query_name}{self.action_name}')
def remove(self):
from engine.utils import send_to_server
send_to_server(f'TR{self.trigger_name}')
class drop_trigger(ast_node):
name = 'create_trigger'
first_order = name
def produce(self, node):
...
class drop(ast_node):
name = 'drop'
first_order = name
@ -1111,9 +1166,11 @@ class insert(ast_node):
complex_query_kw = ['from', 'where', 'groupby', 'having', 'orderby', 'limit']
if any([kw in values for kw in complex_query_kw]):
values['into'] = node['insert']
proj_cls = (select_distinct
if 'select_distinct' in values
else projection)
proj_cls = (
select_distinct
if 'select_distinct' in values
else projection
)
proj_cls(None, values, self.context)
self.produce = lambda*_:None
self.spawn = lambda*_:None
@ -1147,6 +1204,11 @@ class insert(ast_node):
keys = f'({", ".join(keys)})' if keys else ''
tbl = node['insert']
if tbl not in self.context.tables_byname:
print('Warning: {tbl} not registered in aquery compiler.')
tbl_obj = self.context.tables_byname[tbl]
for t in tbl_obj.triggers:
t.register()
self.sql = f'INSERT INTO {tbl}{keys} VALUES'
# if len(values) != table.n_cols:
# raise ValueError("Column Mismatch")
@ -1161,7 +1223,7 @@ class insert(ast_node):
list_values.append(f"({', '.join(inner_list_values)})")
self.sql += ', '.join(list_values)
class delete_from(ast_node):
name = 'delete'
@ -1624,6 +1686,11 @@ class passthru_sql(ast_node):
seprator = re.compile(r'''((?:[^;"']|"[^"]*"|'[^']*')+)''')
def __init__(self, _, node, context:Context):
sqls = passthru_sql.seprator.split(node['sql'])
try:
if callable(context.parser):
parsed = context.parser(node['sql'])
except BaseException:
parsed = None
for sql in sqls:
sq = sql.strip(' \t\n\r;')
if sq:

@ -64,12 +64,14 @@ class ColRef:
class TableInfo:
def __init__(self, table_name, cols, cxt:'Context'):
from reconstruct.ast import create_trigger
# statics
self.table_name : str = table_name
self.contextname_cpp : str = ''
self.alias : Set[str] = set([table_name])
self.columns_byname : CaseInsensitiveDict[str, ColRef] = CaseInsensitiveDict() # column_name, type
self.columns : List[ColRef] = []
self.triggers : Set[create_trigger] = set()
self.cxt = cxt
# keep track of temp vars
self.rec = None
@ -83,7 +85,7 @@ class TableInfo:
def add_cols(self, cols, new = True):
for c in enlist(cols):
self.add_col(c, new)
def add_col(self, c, new = True):
_ty = c['type']
_ty_args = None
@ -156,9 +158,11 @@ class Context:
self.module_init_loc = 0
self.special_gb = False
self.has_dll = False
self.triggers_active.clear()
def __init__(self):
self.tables_byname = dict()
from .ast import create_trigger
self.tables_byname : Dict[str, TableInfo] = dict()
self.col_byname = dict()
self.tables : Set[TableInfo] = set()
self.cols = []
@ -174,6 +178,9 @@ class Context:
self.have_hge = False
self.Error = lambda *args: print(*args)
self.Info = lambda *_: None
self.triggers : Dict[str, create_trigger] = dict()
self.triggers_active = set()
self.stored_proceudres = dict()
# self.new() called everytime new query batch is started
def get_scan_var(self):
@ -256,7 +263,14 @@ class Context:
limit = limit.to_bytes(4, 'little').decode('latin-1')
self.queries.append(
'O' + limit + sep + end)
def remove_trigger(self, name : str):
from reconstruct.ast import create_trigger
val = self.triggers.pop(name, None)
if val.type == create_trigger.Type.Callback:
val.table.triggers.remove(val)
val.remove()
def abandon_postproc(self):
self.ccode = ''
self.finalize_query()

@ -86,7 +86,8 @@ extern "C" int __DLLEXPORT__ binary_info() {
#endif
}
__AQEXPORT__(bool) have_hge(){
__AQEXPORT__(bool)
have_hge() {
#if defined(__MONETDB_CONN_H__)
return Server::havehge();
#else
@ -94,6 +95,19 @@ __AQEXPORT__(bool) have_hge(){
#endif
}
__AQEXPORT__(StoredProcedure)
get_procedure(Context* cxt, const char* name) {
auto res = cxt->stored_proc.find(name);
if (res == cxt->stored_proc.end())
return { .cnt = 0,
.postproc_modules = 0,
.queries = nullptr,
.name = nullptr,
.__rt_loaded_modules = nullptr
};
return res->second;
}
using prt_fn_t = char* (*)(void*, char*);
// This function contains heap allocations, free after use

Loading…
Cancel
Save