fixed join using, join on

dev
Bill 2 years ago
parent b666d6d9b2
commit 136b3b6c6c

@ -152,6 +152,7 @@ See ./tests/ for more examples.
- [x] Bug fixes: type deduction misaligned in Hybrid Engine
- [ ] Investigation: Using postproc only for q1 in Hybrid Engine (make is_special always on)
- [ ] C++ Meta-Programming: Eliminate template recursions as much as possible.
- [ ] Functionality: Basic helper functions in aquery
- [x] Functionality: Basic helper functions in aquery
- [ ] Bug: Join-Aware Column management
- [ ] Bug: Order By after Group By
- [ ] Functionality: Having clause

@ -2,7 +2,7 @@
## GLOBAL CONFIGURATION FLAGS
version_string = '0.4.7a'
version_string = '0.4.8a'
add_path_to_ldpath = True
rebuild_backend = False
run_backend = True

@ -145,7 +145,10 @@ class projection(ast_node):
if not proj_expr.is_special:
if proj_expr.node == '*':
name = [c.get_full_name() for c in self.datasource.rec]
name = [(c.get_name()
if self.datasource.single_table
else c.get_full_name()
) for c in self.datasource.rec]
this_type = [c.type for c in self.datasource.rec]
compound = [c.compound for c in self.datasource.rec]
proj_expr = [expr(self, c.name) for c in self.datasource.rec]
@ -288,7 +291,9 @@ class projection(ast_node):
self.group_node and
(self.group_node.use_sp_gb and
val[2].cols_mentioned.intersection(
self.datasource.all_cols().difference(self.group_node.refs))
self.datasource.all_cols().difference(
self.datasource.get_joint_cols(self.group_node.refs)
))
) and val[2].is_compound # compound val not in key
# or
# val[2].is_compound > 1
@ -366,7 +371,7 @@ class select_into(ast_node):
raise Exception('No out_table found.')
else:
self.context.headers.add('"./server/table_ext_monetdb.hpp"')
self.ccode = f'{self.parent.out_table.contextname_cpp}->monetdb_append_table(cxt->alt_server, \"{node}\");'
self.ccode = f'{self.parent.out_table.contextname_cpp}->monetdb_append_table(cxt->alt_server, \"{node.lower()}\");'
def produce_sql(self, node):
self.sql = f' INTO {node}'
@ -443,6 +448,7 @@ class groupby_c(ast_node):
self.proj : projection = self.parent
self.glist : List[Tuple[expr, Set[ColRef]]] = node
return super().init(node)
def produce(self, node : List[Tuple[expr, Set[ColRef]]]):
self.context.headers.add('"./server/hasher.h"')
self.context.headers.add('unordered_map')
@ -505,9 +511,14 @@ class groupby_c(ast_node):
gscanner.add(f'auto &{len_var} = {val_var}.size;', position = 'front')
def get_key_idx (varname : str):
ex = expr(self, varname)
joint_cols = set()
if ex.is_ColExpr and ex.raw_col:
joint_cols = self.datasource.get_joint_cols([ex.raw_col])
for i, g in enumerate(self.glist):
if varname == g[0].eval():
return i
if (varname == g[0].eval()) or (g[0].is_ColExpr and g[0].raw_col and
g[0].raw_col in joint_cols):
return i
return var_table[varname]
def get_var_names (varname : str):
@ -598,14 +609,15 @@ class groupby(ast_node):
if self.use_sp_gb:
self.dedicated_gb = groupby_c(self.parent, self.dedicated_glist)
self.dedicated_gb.finalize(cexprs, var_table)
class join(ast_node):
name = 'join'
def get_joint_cols(self, cols : List[ColRef]):
joint_cols = set()
joint_cols = set(cols)
for col in cols:
joint_cols |= self.joint_cols[col]
joint_cols |= self.joint_cols.get(col, set())
return joint_cols
def init(self, _):
@ -661,21 +673,36 @@ class join(ast_node):
tbl.add_alias(node['name'])
self.append(tbl, alias)
else:
keys = list(node.keys())
keys : List[str] = list(node.keys())
if keys[0].lower().endswith('join'):
self.have_sep = True
j = join(self, node[keys[0]])
self.tables += j.tables
self.tables_dir = {**self.tables_dir, **j.tables_dir}
self.join_conditions += j.join_conditions
_tbl_union = _tmp_join_union(self.context, self.parent, self)
tablename = f' {keys[0]} {j}'
if len(keys) > 1 :
_ex = expr(self, node[keys[1]])
if keys[1].lower() == 'on':
jcond = node[keys[1]]
sqoute = '\''
if type(jcond) is list:
_ex = [expr(_tbl_union, j) for j in jcond]
else:
_ex = expr(_tbl_union, jcond)
if keys[1].lower() == 'on': # postpone join condition evaluation after consume
self.join_conditions += _ex.join_conditions
tablename += f' ON {_ex}'
tablename += f" ON {_ex.eval().replace(sqoute, '')}"
elif keys[1].lower() == 'using':
if _ex.is_ColExpr:
self.join_conditions.append( (_ex.raw_col, j.get_cols(_ex.raw_col.name)) )
tablename += f' USING {_ex}'
_ex = enlist(_ex)
lst_jconds = []
for _e in _ex:
if _e.is_ColExpr:
cl = _e.raw_col
cr = j.get_cols(_e.raw_col.name)
self.join_conditions.append( (cl, cr) )
lst_jconds += [f'{cl.get_full_name()} = {cr.get_full_name()}']
tablename += f' ON {" and ".join(lst_jconds)}'
if keys[0].lower().startswith('natural'):
ltbls : List[TableInfo] = []
if isinstance(self.parent, join):
@ -688,8 +715,7 @@ class join(ast_node):
if cr:
self.join_conditions.append( (cl, cr) )
self.joins.append((tablename, self.have_sep))
self.tables += j.tables
self.tables_dir = {**self.tables_dir, **j.tables_dir}
elif type(node) is str:
if node in self.context.tables_byname:
@ -719,7 +745,11 @@ class join(ast_node):
ret = datasource.parse_col_names(parsedColExpr[1])
datasource.rec = None
return ret
@property
def single_table(self):
return len(self.tables) == 1
# @property
def all_cols(self):
ret = set()
@ -734,7 +764,22 @@ class join(ast_node):
def process_join_conditions(self):
# This is done after both from
# and where clause are processed
print(self.join_conditions)
for j in self.join_conditions:
l = j[0]
r = j[1]
for k in (0, 1):
if j[k] not in self.joint_cols:
self.joint_cols[j[k]] = set([l, r])
jr = self.joint_cols[r]
jl = self.joint_cols[l]
if jl != jr:
jl |= jr
for c, jc in self.joint_cols.items():
if jc == jr:
self.joint_cols[c] = jl
# print(self.join_conditions)
# print(self.joint_cols)
def consume(self, node):
self.sql = ''
@ -753,7 +798,17 @@ class join(ast_node):
def __repr__(self):
return self.__str__()
class _tmp_join_union(join):
name = '__tmp_join_union'
def __init__(self, context, l: join, r: join):
self.tables = l.tables + r.tables
self.tables_dir = {**l.tables_dir, **r.tables_dir}
self.datasource = self
self.context = context
self.parent = self
self.rec = None
class filter(ast_node):
name = 'where'
def produce(self, node):
@ -771,7 +826,10 @@ class create_table(ast_node):
raise ValueError("Table name not specified")
projection_node = node['query']
projection_node['into'] = node['name']
projection(None, projection_node, self.context)
proj_cls = (select_distinct
if 'select_distinct' in projection_node
else projection)
proj_cls(None, projection_node, self.context)
self.produce = lambda *_: None
self.spawn = lambda *_: None
self.consume = lambda *_: None
@ -819,7 +877,10 @@ class insert(ast_node):
complex_query_kw = ['from', 'where', 'groupby', 'having', 'orderby', 'limit']
if any([kw in values for kw in complex_query_kw]):
values['into'] = node['insert']
projection(None, values, self.context)
proj_cls = (select_distinct
if 'select_distinct' in values
else projection)
proj_cls(None, values, self.context)
self.produce = lambda*_:None
self.spawn = lambda*_:None
self.consume = lambda*_:None

@ -78,10 +78,10 @@ class expr(ast_node):
ast_node.__init__(self, parent, node, None)
def init(self, _):
from reconstruct.ast import projection
from reconstruct.ast import projection, _tmp_join_union
parent = self.parent
self.is_compound = parent.is_compound if type(parent) is expr else False
if type(parent) in [projection, expr]:
if type(parent) in [projection, expr, _tmp_join_union]:
self.datasource = parent.datasource
else:
self.datasource = self.context.datasource

@ -21,6 +21,16 @@ class ColRef:
self.__arr__ = (_ty, cobj, table, name, id)
def get_name(self):
it_alias = iter(self.alias)
alias = next(it_alias, self.name)
try:
while alias == self.name:
alias = next(it_alias)
except StopIteration:
alias = self.name
return alias
def get_full_name(self):
table_name = self.table.table_name
it_alias = iter(self.table.alias)
@ -30,7 +40,7 @@ class ColRef:
alias = next(it_alias)
except StopIteration:
alias = table_name
return f'{alias}.{self.name}'
return f'{alias}.{self.get_name()}'
def __getitem__(self, key):
if type(key) is str:
@ -103,6 +113,10 @@ class TableInfo:
self.rec.update(self.columns)
return set(self.columns)
@property
def single_table(self):
return True
class Context:
def new(self):
self.headers = set(['\"./server/libaquery.h\"',

@ -41,6 +41,6 @@ from TradedStocks
SELECT ID, avgs(10, ClosePrice)
FROM td NATURAL JOIN
HistoricQuotes
HistoricQuotes hq
ASSUMING ASC TradeDate
GROUP BY ID
GROUP BY hq.ID
Loading…
Cancel
Save