fixed join using, join on

dev
Bill 2 years ago
parent b666d6d9b2
commit 136b3b6c6c

@ -152,6 +152,7 @@ See ./tests/ for more examples.
- [x] Bug fixes: type deduction misaligned in Hybrid Engine - [x] Bug fixes: type deduction misaligned in Hybrid Engine
- [ ] Investigation: Using postproc only for q1 in Hybrid Engine (make is_special always on) - [ ] Investigation: Using postproc only for q1 in Hybrid Engine (make is_special always on)
- [ ] C++ Meta-Programming: Eliminate template recursions as much as possible. - [ ] C++ Meta-Programming: Eliminate template recursions as much as possible.
- [ ] Functionality: Basic helper functions in aquery - [x] Functionality: Basic helper functions in aquery
- [ ] Bug: Join-Aware Column management - [ ] Bug: Join-Aware Column management
- [ ] Bug: Order By after Group By - [ ] Bug: Order By after Group By
- [ ] Functionality: Having clause

@ -2,7 +2,7 @@
## GLOBAL CONFIGURATION FLAGS ## GLOBAL CONFIGURATION FLAGS
version_string = '0.4.7a' version_string = '0.4.8a'
add_path_to_ldpath = True add_path_to_ldpath = True
rebuild_backend = False rebuild_backend = False
run_backend = True run_backend = True

@ -145,7 +145,10 @@ class projection(ast_node):
if not proj_expr.is_special: if not proj_expr.is_special:
if proj_expr.node == '*': if proj_expr.node == '*':
name = [c.get_full_name() for c in self.datasource.rec] name = [(c.get_name()
if self.datasource.single_table
else c.get_full_name()
) for c in self.datasource.rec]
this_type = [c.type for c in self.datasource.rec] this_type = [c.type for c in self.datasource.rec]
compound = [c.compound for c in self.datasource.rec] compound = [c.compound for c in self.datasource.rec]
proj_expr = [expr(self, c.name) for c in self.datasource.rec] proj_expr = [expr(self, c.name) for c in self.datasource.rec]
@ -288,7 +291,9 @@ class projection(ast_node):
self.group_node and self.group_node and
(self.group_node.use_sp_gb and (self.group_node.use_sp_gb and
val[2].cols_mentioned.intersection( val[2].cols_mentioned.intersection(
self.datasource.all_cols().difference(self.group_node.refs)) self.datasource.all_cols().difference(
self.datasource.get_joint_cols(self.group_node.refs)
))
) and val[2].is_compound # compound val not in key ) and val[2].is_compound # compound val not in key
# or # or
# val[2].is_compound > 1 # val[2].is_compound > 1
@ -366,7 +371,7 @@ class select_into(ast_node):
raise Exception('No out_table found.') raise Exception('No out_table found.')
else: else:
self.context.headers.add('"./server/table_ext_monetdb.hpp"') self.context.headers.add('"./server/table_ext_monetdb.hpp"')
self.ccode = f'{self.parent.out_table.contextname_cpp}->monetdb_append_table(cxt->alt_server, \"{node}\");' self.ccode = f'{self.parent.out_table.contextname_cpp}->monetdb_append_table(cxt->alt_server, \"{node.lower()}\");'
def produce_sql(self, node): def produce_sql(self, node):
self.sql = f' INTO {node}' self.sql = f' INTO {node}'
@ -443,6 +448,7 @@ class groupby_c(ast_node):
self.proj : projection = self.parent self.proj : projection = self.parent
self.glist : List[Tuple[expr, Set[ColRef]]] = node self.glist : List[Tuple[expr, Set[ColRef]]] = node
return super().init(node) return super().init(node)
def produce(self, node : List[Tuple[expr, Set[ColRef]]]): def produce(self, node : List[Tuple[expr, Set[ColRef]]]):
self.context.headers.add('"./server/hasher.h"') self.context.headers.add('"./server/hasher.h"')
self.context.headers.add('unordered_map') self.context.headers.add('unordered_map')
@ -505,9 +511,14 @@ class groupby_c(ast_node):
gscanner.add(f'auto &{len_var} = {val_var}.size;', position = 'front') gscanner.add(f'auto &{len_var} = {val_var}.size;', position = 'front')
def get_key_idx (varname : str): def get_key_idx (varname : str):
ex = expr(self, varname)
joint_cols = set()
if ex.is_ColExpr and ex.raw_col:
joint_cols = self.datasource.get_joint_cols([ex.raw_col])
for i, g in enumerate(self.glist): for i, g in enumerate(self.glist):
if varname == g[0].eval(): if (varname == g[0].eval()) or (g[0].is_ColExpr and g[0].raw_col and
return i g[0].raw_col in joint_cols):
return i
return var_table[varname] return var_table[varname]
def get_var_names (varname : str): def get_var_names (varname : str):
@ -599,13 +610,14 @@ class groupby(ast_node):
self.dedicated_gb = groupby_c(self.parent, self.dedicated_glist) self.dedicated_gb = groupby_c(self.parent, self.dedicated_glist)
self.dedicated_gb.finalize(cexprs, var_table) self.dedicated_gb.finalize(cexprs, var_table)
class join(ast_node): class join(ast_node):
name = 'join' name = 'join'
def get_joint_cols(self, cols : List[ColRef]): def get_joint_cols(self, cols : List[ColRef]):
joint_cols = set() joint_cols = set(cols)
for col in cols: for col in cols:
joint_cols |= self.joint_cols[col] joint_cols |= self.joint_cols.get(col, set())
return joint_cols return joint_cols
def init(self, _): def init(self, _):
@ -661,21 +673,36 @@ class join(ast_node):
tbl.add_alias(node['name']) tbl.add_alias(node['name'])
self.append(tbl, alias) self.append(tbl, alias)
else: else:
keys = list(node.keys()) keys : List[str] = list(node.keys())
if keys[0].lower().endswith('join'): if keys[0].lower().endswith('join'):
self.have_sep = True self.have_sep = True
j = join(self, node[keys[0]]) j = join(self, node[keys[0]])
self.tables += j.tables
self.tables_dir = {**self.tables_dir, **j.tables_dir}
self.join_conditions += j.join_conditions self.join_conditions += j.join_conditions
_tbl_union = _tmp_join_union(self.context, self.parent, self)
tablename = f' {keys[0]} {j}' tablename = f' {keys[0]} {j}'
if len(keys) > 1 : if len(keys) > 1 :
_ex = expr(self, node[keys[1]]) jcond = node[keys[1]]
if keys[1].lower() == 'on': sqoute = '\''
if type(jcond) is list:
_ex = [expr(_tbl_union, j) for j in jcond]
else:
_ex = expr(_tbl_union, jcond)
if keys[1].lower() == 'on': # postpone join condition evaluation after consume
self.join_conditions += _ex.join_conditions self.join_conditions += _ex.join_conditions
tablename += f' ON {_ex}' tablename += f" ON {_ex.eval().replace(sqoute, '')}"
elif keys[1].lower() == 'using': elif keys[1].lower() == 'using':
if _ex.is_ColExpr: _ex = enlist(_ex)
self.join_conditions.append( (_ex.raw_col, j.get_cols(_ex.raw_col.name)) ) lst_jconds = []
tablename += f' USING {_ex}' for _e in _ex:
if _e.is_ColExpr:
cl = _e.raw_col
cr = j.get_cols(_e.raw_col.name)
self.join_conditions.append( (cl, cr) )
lst_jconds += [f'{cl.get_full_name()} = {cr.get_full_name()}']
tablename += f' ON {" and ".join(lst_jconds)}'
if keys[0].lower().startswith('natural'): if keys[0].lower().startswith('natural'):
ltbls : List[TableInfo] = [] ltbls : List[TableInfo] = []
if isinstance(self.parent, join): if isinstance(self.parent, join):
@ -688,8 +715,7 @@ class join(ast_node):
if cr: if cr:
self.join_conditions.append( (cl, cr) ) self.join_conditions.append( (cl, cr) )
self.joins.append((tablename, self.have_sep)) self.joins.append((tablename, self.have_sep))
self.tables += j.tables
self.tables_dir = {**self.tables_dir, **j.tables_dir}
elif type(node) is str: elif type(node) is str:
if node in self.context.tables_byname: if node in self.context.tables_byname:
@ -720,6 +746,10 @@ class join(ast_node):
datasource.rec = None datasource.rec = None
return ret return ret
@property
def single_table(self):
return len(self.tables) == 1
# @property # @property
def all_cols(self): def all_cols(self):
ret = set() ret = set()
@ -734,7 +764,22 @@ class join(ast_node):
def process_join_conditions(self): def process_join_conditions(self):
# This is done after both from # This is done after both from
# and where clause are processed # and where clause are processed
print(self.join_conditions) for j in self.join_conditions:
l = j[0]
r = j[1]
for k in (0, 1):
if j[k] not in self.joint_cols:
self.joint_cols[j[k]] = set([l, r])
jr = self.joint_cols[r]
jl = self.joint_cols[l]
if jl != jr:
jl |= jr
for c, jc in self.joint_cols.items():
if jc == jr:
self.joint_cols[c] = jl
# print(self.join_conditions)
# print(self.joint_cols)
def consume(self, node): def consume(self, node):
self.sql = '' self.sql = ''
@ -753,6 +798,16 @@ class join(ast_node):
def __repr__(self): def __repr__(self):
return self.__str__() return self.__str__()
class _tmp_join_union(join):
name = '__tmp_join_union'
def __init__(self, context, l: join, r: join):
self.tables = l.tables + r.tables
self.tables_dir = {**l.tables_dir, **r.tables_dir}
self.datasource = self
self.context = context
self.parent = self
self.rec = None
class filter(ast_node): class filter(ast_node):
name = 'where' name = 'where'
@ -771,7 +826,10 @@ class create_table(ast_node):
raise ValueError("Table name not specified") raise ValueError("Table name not specified")
projection_node = node['query'] projection_node = node['query']
projection_node['into'] = node['name'] projection_node['into'] = node['name']
projection(None, projection_node, self.context) proj_cls = (select_distinct
if 'select_distinct' in projection_node
else projection)
proj_cls(None, projection_node, self.context)
self.produce = lambda *_: None self.produce = lambda *_: None
self.spawn = lambda *_: None self.spawn = lambda *_: None
self.consume = lambda *_: None self.consume = lambda *_: None
@ -819,7 +877,10 @@ class insert(ast_node):
complex_query_kw = ['from', 'where', 'groupby', 'having', 'orderby', 'limit'] complex_query_kw = ['from', 'where', 'groupby', 'having', 'orderby', 'limit']
if any([kw in values for kw in complex_query_kw]): if any([kw in values for kw in complex_query_kw]):
values['into'] = node['insert'] values['into'] = node['insert']
projection(None, values, self.context) proj_cls = (select_distinct
if 'select_distinct' in values
else projection)
proj_cls(None, values, self.context)
self.produce = lambda*_:None self.produce = lambda*_:None
self.spawn = lambda*_:None self.spawn = lambda*_:None
self.consume = lambda*_:None self.consume = lambda*_:None

@ -78,10 +78,10 @@ class expr(ast_node):
ast_node.__init__(self, parent, node, None) ast_node.__init__(self, parent, node, None)
def init(self, _): def init(self, _):
from reconstruct.ast import projection from reconstruct.ast import projection, _tmp_join_union
parent = self.parent parent = self.parent
self.is_compound = parent.is_compound if type(parent) is expr else False self.is_compound = parent.is_compound if type(parent) is expr else False
if type(parent) in [projection, expr]: if type(parent) in [projection, expr, _tmp_join_union]:
self.datasource = parent.datasource self.datasource = parent.datasource
else: else:
self.datasource = self.context.datasource self.datasource = self.context.datasource

@ -21,6 +21,16 @@ class ColRef:
self.__arr__ = (_ty, cobj, table, name, id) self.__arr__ = (_ty, cobj, table, name, id)
def get_name(self):
it_alias = iter(self.alias)
alias = next(it_alias, self.name)
try:
while alias == self.name:
alias = next(it_alias)
except StopIteration:
alias = self.name
return alias
def get_full_name(self): def get_full_name(self):
table_name = self.table.table_name table_name = self.table.table_name
it_alias = iter(self.table.alias) it_alias = iter(self.table.alias)
@ -30,7 +40,7 @@ class ColRef:
alias = next(it_alias) alias = next(it_alias)
except StopIteration: except StopIteration:
alias = table_name alias = table_name
return f'{alias}.{self.name}' return f'{alias}.{self.get_name()}'
def __getitem__(self, key): def __getitem__(self, key):
if type(key) is str: if type(key) is str:
@ -103,6 +113,10 @@ class TableInfo:
self.rec.update(self.columns) self.rec.update(self.columns)
return set(self.columns) return set(self.columns)
@property
def single_table(self):
return True
class Context: class Context:
def new(self): def new(self):
self.headers = set(['\"./server/libaquery.h\"', self.headers = set(['\"./server/libaquery.h\"',

@ -41,6 +41,6 @@ from TradedStocks
SELECT ID, avgs(10, ClosePrice) SELECT ID, avgs(10, ClosePrice)
FROM td NATURAL JOIN FROM td NATURAL JOIN
HistoricQuotes HistoricQuotes hq
ASSUMING ASC TradeDate ASSUMING ASC TradeDate
GROUP BY ID GROUP BY hq.ID
Loading…
Cancel
Save