You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
488 lines
22 KiB
488 lines
22 KiB
from typing import Optional, Set
|
|
|
|
from common.types import *
|
|
from engine.ast import ast_node
|
|
from engine.storage import ColRef, Context
|
|
from common.utils import Backend_Type
|
|
|
|
# TODO: Decouple expr and upgrade architecture
|
|
# C_CODE : get ccode/sql code?
|
|
# projections : C/SQL/decltype string
|
|
# orderby/joins/where : SQL only
|
|
# assumption/groupby : C/sql
|
|
# is_udfexpr: C only
|
|
|
|
class expr(ast_node):
|
|
name='expr'
|
|
valid_joincond = {
|
|
0 : ('and', 'eq', 'not'),
|
|
1 : ('or', 'neq', 'not'),
|
|
2 : ('', '', '')
|
|
}
|
|
@property
|
|
def udf_decltypecall(self):
|
|
return self._udf_decltypecall if self._udf_decltypecall else self.sql
|
|
|
|
@udf_decltypecall.setter
|
|
def udf_decltypecall(self, val):
|
|
self._udf_decltypecall = val
|
|
|
|
@property
|
|
def need_decltypestr(self):
|
|
return self._udf_decltypecall is not None
|
|
|
|
def __init__(self, parent, node, *, c_code = None, supress_undefined = False):
|
|
from engine.ast import projection, udf
|
|
|
|
# gen2 expr have multi-passes
|
|
# first pass parse json into expr tree
|
|
# generate target code in later passes upon need
|
|
self.children = []
|
|
self.opname = ''
|
|
self.curr_code = ''
|
|
self.counts = {}
|
|
self.type = None
|
|
self.raw_col = None
|
|
self.udf : Optional[udf] = None
|
|
self.inside_agg = False
|
|
self.is_special = False
|
|
self.is_ColExpr = False
|
|
self.is_recursive_call_inudf = False
|
|
self.codlets : list = []
|
|
self.codebuf : Optional[str] = None
|
|
self._udf_decltypecall = None
|
|
self.node = node
|
|
self.supress_undefined = supress_undefined
|
|
if(type(parent) is expr):
|
|
self.next_valid = parent.next_valid
|
|
self.inside_agg = parent.inside_agg
|
|
self.is_udfexpr = parent.is_udfexpr
|
|
self.is_agg_func = parent.is_agg_func
|
|
self.root : expr = parent.root
|
|
self.c_code = parent.c_code
|
|
self.builtin_vars = parent.builtin_vars
|
|
else:
|
|
self.join_conditions = []
|
|
self.next_valid = 0
|
|
self.is_agg_func = False
|
|
self.is_udfexpr = type(parent) is udf
|
|
self.root : expr = self
|
|
self.c_code = self.is_udfexpr or type(parent) is projection
|
|
if self.is_udfexpr:
|
|
self.udf : udf = parent
|
|
self.builtin_vars = self.udf.builtin.keys()
|
|
else:
|
|
self.builtin_vars = []
|
|
if type(c_code) is bool:
|
|
self.c_code = c_code
|
|
|
|
self.udf_called = None
|
|
self.cols_mentioned : Optional[set[ColRef]] = None
|
|
ast_node.__init__(self, parent, node, None)
|
|
|
|
def init(self, _):
|
|
from engine.ast import _tmp_join_union, projection
|
|
parent = self.parent
|
|
self.is_compound = parent.is_compound if type(parent) is expr else False
|
|
if type(parent) in [projection, expr, _tmp_join_union]:
|
|
self.datasource = parent.datasource
|
|
else:
|
|
self.datasource = self.context.datasource
|
|
self.udf_map = parent.context.udf_map
|
|
self.func_maps = {**builtin_func, **self.udf_map, **user_module_func}
|
|
self.operators = {**builtin_operators, **self.udf_map, **user_module_func}
|
|
self.ext_aggfuncs = ['sum', 'avg', 'count', 'min', 'max',
|
|
'last', 'first', 'prev', 'next', 'var',
|
|
'stddev']
|
|
|
|
def produce(self, node):
|
|
from common.utils import enlist
|
|
from engine.ast import udf, projection
|
|
|
|
if type(node) is dict:
|
|
if 'literal' in node:
|
|
node = node['literal']
|
|
else:
|
|
if len(node) > 1:
|
|
print(f'Parser Error: {node} has more than 1 dict entry.')
|
|
|
|
is_joincond = False
|
|
for key, val in node.items():
|
|
key = key.lower()
|
|
if key not in self.valid_joincond[self.next_valid]:
|
|
self.next_valid = 2
|
|
else:
|
|
if key == self.valid_joincond[self.next_valid][2]:
|
|
self.next_valid = not self.next_valid
|
|
elif key == self.valid_joincond[self.next_valid][1]:
|
|
self.next_valid = 2
|
|
is_joincond = True
|
|
if key in self.operators:
|
|
if key in builtin_func:
|
|
if self.is_agg_func:
|
|
self.root.is_special = True # Nested Aggregation
|
|
else:
|
|
self.is_agg_func = True
|
|
|
|
op = self.operators[key]
|
|
count_distinct = False
|
|
if key == 'count' and type(val) is dict and 'distinct' in val:
|
|
count_distinct = True
|
|
val = val['distinct']
|
|
|
|
val = enlist(val)
|
|
exp_vals = []
|
|
for v in val:
|
|
if (
|
|
type(v) is str and
|
|
'*' in v and
|
|
key != 'count'
|
|
):
|
|
cols = self.datasource.get_cols(v)
|
|
if cols:
|
|
for c in cols:
|
|
exp_vals.append(expr(self, c.name, c_code=self.c_code))
|
|
else:
|
|
exp_vals.append(expr(self, v, c_code=self.c_code))
|
|
|
|
self.children = exp_vals
|
|
self.opname = key
|
|
|
|
str_vals = [e.sql for e in exp_vals]
|
|
type_vals = [e.type for e in exp_vals]
|
|
is_compound = max([e.is_compound for e in exp_vals])
|
|
if key in self.ext_aggfuncs:
|
|
self.is_compound = max(0, is_compound - 1)
|
|
else:
|
|
self.is_compound = is_compound
|
|
try:
|
|
self.type = op.return_type(*type_vals)
|
|
except AttributeError as e:
|
|
if type(self.root.parent) is not udf:
|
|
# TODO: do something when this is not an error
|
|
print(f'alert: {e}')
|
|
pass
|
|
self.type = AnyT
|
|
|
|
if count_distinct: # inject distinct col later
|
|
self.sql = f'{{{op(self.c_code, *str_vals, True)}}}'
|
|
else:
|
|
self.sql = op(self.c_code, *str_vals)
|
|
|
|
special_func = [*self.context.udf_map.keys(), *self.context.module_map.keys(),
|
|
"maxs", "mins", "avgs", "sums", "deltas", "last", "first",
|
|
"stddevs", "vars", "ratios", "pack", "truncate", "subvec"]
|
|
|
|
if (
|
|
self.context.special_gb
|
|
or
|
|
(
|
|
type(self.root.parent) is projection
|
|
and
|
|
self.root.parent.force_use_spgb
|
|
)
|
|
or
|
|
self.context.system_state.cfg.backend_type == Backend_Type.BACKEND_AQuery.value
|
|
):
|
|
special_func = [*special_func, *self.ext_aggfuncs]
|
|
|
|
if key in special_func and not self.is_special:
|
|
self.is_special = True
|
|
if key in self.context.udf_map:
|
|
self.root.udf_called = self.context.udf_map[key]
|
|
if self.is_udfexpr and key == self.root.udf.name:
|
|
self.root.is_recursive_call_inudf = True
|
|
elif key in user_module_func.keys():
|
|
udf.try_init_udf(self.context)
|
|
# TODO: make udf_called a set!
|
|
p = self.parent
|
|
while type(p) is expr and not p.udf_called:
|
|
p.udf_called = self.udf_called
|
|
p = p.parent
|
|
p = self.parent
|
|
while type(p) is expr and not p.is_special:
|
|
p.is_special = True
|
|
p = p.parent
|
|
|
|
need_decltypestr = any([e.need_decltypestr for e in exp_vals])
|
|
if need_decltypestr or (self.udf_called and type(op) is udf):
|
|
decltypestr_vals = [e.udf_decltypecall for e in exp_vals]
|
|
self.udf_decltypecall = op(self.c_code, *decltypestr_vals)
|
|
|
|
if self.udf_called and type(op) is udf:
|
|
self.udf_decltypecall = op.decltypecall(self.c_code, *decltypestr_vals)
|
|
|
|
elif self.is_udfexpr:
|
|
var_table = self.root.udf.var_table
|
|
vec = key.split('.')
|
|
_vars = [*var_table, *self.builtin_vars]
|
|
def get_vname (node):
|
|
if node in self.builtin_vars:
|
|
self.root.udf.builtin[node].enabled = True
|
|
self.builtin_var = node
|
|
return node
|
|
else:
|
|
return var_table[node]
|
|
if vec[0] not in _vars:
|
|
# print(f'Use of undefined variable {vec[0]}')
|
|
# TODO: do something when this is not an error
|
|
pass
|
|
else:
|
|
vname = get_vname(vec[0])
|
|
val = enlist(val)
|
|
if(len(val) > 2):
|
|
print('Warning: more than 2 indexes found for subvec operator.')
|
|
ex = [expr(self, v, c_code = self.c_code) for v in val]
|
|
idxs = ', '.join([e.sql for e in ex])
|
|
self.sql = f'{vname}.subvec({idxs})'
|
|
if any([e.need_decltypestr for e in ex]):
|
|
self.udf_decltypecall = f'{vname}.subvec({[", ".join([e.udf_decltypecall for e in ex])]})'
|
|
if key == 'get' and len(val) > 1:
|
|
ex_vname = expr(self, val[0], c_code=self.c_code)
|
|
self.sql = f'{ex_vname.sql}[{expr(self, val[1], c_code=self.c_code).sql}]'
|
|
if hasattr(ex_vname, 'builtin_var'):
|
|
if not hasattr(self, 'builtin_var'):
|
|
self.builtin_var = []
|
|
self.builtin_var = [*self.builtin_var, *ex_vname.builtin_var]
|
|
self.udf_decltypecall = ex_vname.sql
|
|
else:
|
|
print(f'Undefined expr: {key}{val}')
|
|
|
|
if (is_joincond and len(self.children) == 2
|
|
and all([c.is_ColExpr for c in self.children])) :
|
|
self.root.join_conditions.append(
|
|
(self.children[0].raw_col, self.children[1].raw_col)
|
|
)
|
|
|
|
if type(node) is str:
|
|
if self.is_udfexpr:
|
|
curr_udf : udf = self.root.udf
|
|
var_table = curr_udf.var_table
|
|
split = node.split('.')
|
|
if split[0] in var_table:
|
|
varname = var_table[split[0]]
|
|
if curr_udf.agg and varname in curr_udf.vecs:
|
|
if len(split) > 1:
|
|
if split[1] == 'vec':
|
|
self.sql += varname
|
|
elif split[1] == 'len':
|
|
self.sql += f'{varname}.size'
|
|
else:
|
|
print(f'no member {split[1]} in object {varname}')
|
|
else:
|
|
self.sql += f'{varname}[{curr_udf.idx_var}]'
|
|
else:
|
|
self.sql += varname
|
|
elif self.supress_undefined or split[0] in self.builtin_vars:
|
|
self.sql += node
|
|
if split[0] in self.builtin_vars:
|
|
curr_udf.builtin[split[0]].enabled = True
|
|
self.builtin_var = split[0]
|
|
else:
|
|
print(f'Undefined varname: {split[0]}')
|
|
|
|
|
|
# get the column from the datasource in SQL context
|
|
else:
|
|
if self.datasource is not None:
|
|
if (node == '*' and
|
|
not (type(self.parent) is expr
|
|
and 'count' in self.parent.node)):
|
|
self.datasource.all_cols(ordered = True)
|
|
else:
|
|
self.raw_col = self.datasource.parse_col_names(node)
|
|
self.raw_col = self.raw_col if type(self.raw_col) is ColRef else None
|
|
if self.raw_col is not None:
|
|
self.is_ColExpr = True
|
|
table_name = ''
|
|
if '.' in node:
|
|
table_name = self.raw_col.table.table_name
|
|
if self.raw_col.table.alias:
|
|
alias = iter(self.raw_col.table.alias)
|
|
try:
|
|
a = next(alias)
|
|
while(not a or a == table_name):
|
|
a = next(alias)
|
|
if (a and a != table_name):
|
|
table_name = a
|
|
except StopIteration:
|
|
pass
|
|
if table_name:
|
|
table_name = table_name + '.'
|
|
self.sql = table_name + self.raw_col.name
|
|
self.type = self.raw_col.type
|
|
self.is_compound = True
|
|
self.is_compound += self.raw_col.compound
|
|
self.opname = self.raw_col
|
|
else:
|
|
self.sql = '\'' + node + '\'' if node != '*' else '*'
|
|
self.type = StrT
|
|
self.opname = self.sql
|
|
if self.c_code and self.datasource is not None:
|
|
if (type(self.parent) is expr and
|
|
'distinct' in self.parent.node and
|
|
not self.is_special):
|
|
# this node is executed by monetdb
|
|
# gb condition, not special
|
|
self.sql = f'distinct({self.sql})'
|
|
self.sql = f'{{y(\"{self.sql}\")}}'
|
|
elif type(node) is bool:
|
|
self.type = BoolT
|
|
self.opname = node
|
|
if self.c_code:
|
|
self.sql = '1' if node else '0'
|
|
else:
|
|
self.sql = 'TRUE' if node else 'FALSE'
|
|
elif type(node) is not dict:
|
|
self.sql = f'{node}'
|
|
self.opname = node
|
|
if type(node) is int:
|
|
if (node >= 2**63 - 1 or node <= -2**63):
|
|
self.type = HgeT
|
|
elif (node >= 2**31 - 1 or node <= -2**31):
|
|
self.type = LongT
|
|
elif node >= 2**15 - 1 or node <= -2**15:
|
|
self.type = IntT
|
|
elif node >= 2**7 - 1 or node <= -2**7:
|
|
self.type = ShortT
|
|
else:
|
|
self.type = ByteT
|
|
elif type(node) is float:
|
|
self.type = DoubleT
|
|
self.sql = f'{{"CAST({node} AS DOUBLE)" if not c_code else "{node}f"}}'
|
|
|
|
def finalize(self, override = False):
|
|
from engine.ast import udf
|
|
if self.codebuf is None or override:
|
|
self.codebuf = ''
|
|
for c in self.codlets:
|
|
if type(c) is str:
|
|
self.codebuf += c
|
|
elif type(c) is udf:
|
|
self.codebuf += c()
|
|
elif type(c) is expr:
|
|
self.codebuf += c.finalize(override=override)
|
|
return self.codebuf
|
|
|
|
def codegen(self, delegate):
|
|
self.curr_code = ''
|
|
for c in self.children:
|
|
self.curr_code += c.codegen(delegate)
|
|
return self.curr_code
|
|
|
|
def remake_binary(self, ret_expr):
|
|
if self.root:
|
|
self.oldsql = self.sql
|
|
if (self.opname in builtin_binary_ops):
|
|
patched_opname = 'aqop_' + self.opname
|
|
self.sql = (f'{patched_opname}({self.children[0].sql}, '
|
|
f'{self.children[1].sql}, {ret_expr})')
|
|
return True
|
|
elif self.opname in builtin_vecfunc:
|
|
self.sql = self.sql[:self.sql.rindex(')')]
|
|
self.sql += ', ' + ret_expr + ')'
|
|
return True
|
|
return False
|
|
def __str__(self):
|
|
return self.sql
|
|
def __repr__(self):
|
|
return self.__str__()
|
|
|
|
# builtins is readonly, so it's okay to set default value as an object
|
|
# eval is only called at root expr.
|
|
def eval(self, c_code = None, y = lambda t: t,
|
|
materialize_builtin = False, _decltypestr = False,
|
|
count = lambda : 'count', var_inject = None,
|
|
*,
|
|
gettype = False):
|
|
assert(self.is_root)
|
|
def call(decltypestr = False) -> str:
|
|
nonlocal c_code, y, materialize_builtin, count, var_inject
|
|
if var_inject:
|
|
for k, v in var_inject.items():
|
|
locals()[k] = v
|
|
if self.udf_called is not None:
|
|
loc = locals()
|
|
builtin_vars = self.udf_called.builtin_used
|
|
for b in self.udf_called.builtin_var.all:
|
|
exec(f'loc["{b}"] = lambda: "{{{b}()}}"')
|
|
if builtin_vars:
|
|
if type(materialize_builtin) is dict:
|
|
for b in builtin_vars:
|
|
exec(f'loc["{b}"] = lambda: "{materialize_builtin[b]}"')
|
|
elif self.is_recursive_call_inudf:
|
|
for b in builtin_vars:
|
|
exec(f'loc["{b}"] = lambda : "{b}"')
|
|
|
|
x = self.c_code if c_code is None else c_code
|
|
from common.utils import escape_qoutes
|
|
if decltypestr:
|
|
return eval('f\'' + escape_qoutes(self.udf_decltypecall) + '\'')
|
|
self.sql.replace("'", "\\'")
|
|
return eval('f\'' + escape_qoutes(self.sql) + '\'')
|
|
if self.is_recursive_call_inudf or (self.need_decltypestr and self.is_udfexpr) or gettype:
|
|
return call
|
|
else:
|
|
return call(_decltypestr)
|
|
|
|
@property
|
|
def is_root(self):
|
|
return self.root == self
|
|
|
|
|
|
# For UDFs: first check if agg variable is used as vector
|
|
# if not, then check if its length is used
|
|
class fastscan(expr):
|
|
name = 'fastscan'
|
|
|
|
def init(self, _):
|
|
self.vec_vars = set()
|
|
self.requested_lens = set()
|
|
super().init(self, _)
|
|
|
|
def process(self, key : str):
|
|
segs = key.split('.')
|
|
var_table = self.root.udf.var_table
|
|
if segs[0] in var_table and len(segs) > 1:
|
|
if segs[1] == 'vec':
|
|
self.vec_vars.add(segs[0])
|
|
elif segs[1] == 'len':
|
|
self.requested_lens.add(segs[0])
|
|
|
|
def produce(self, node):
|
|
from common.utils import enlist
|
|
if type(node) is dict:
|
|
for key, val in node.items():
|
|
if key in self.operators:
|
|
val = enlist(val)
|
|
elif self.is_udfexpr:
|
|
self.process(key)
|
|
[fastscan(self, v, c_code = self.c_code) for v in val]
|
|
|
|
elif type(node) is str:
|
|
self.process(node)
|
|
|
|
|
|
class getrefs(expr):
|
|
name = 'getrefs'
|
|
|
|
def init(self, _):
|
|
self.datasource.rec = set()
|
|
self.rec = None
|
|
|
|
def produce(self, node):
|
|
from common.utils import enlist
|
|
if type(node) is dict:
|
|
for key, val in node.items():
|
|
if key in self.operators:
|
|
val = enlist(val)
|
|
[getrefs(self, v, c_code = self.c_code) for v in val]
|
|
|
|
elif type(node) is str:
|
|
self.datasource.parse_col_names(node)
|
|
|
|
def consume(self, _):
|
|
if self.root == self:
|
|
self.rec = self.datasource.rec
|
|
self.datasource.rec = None
|