diff --git a/README.md b/README.md index f629eef..79c3deb 100644 --- a/README.md +++ b/README.md @@ -152,6 +152,7 @@ See ./tests/ for more examples. - [x] Bug fixes: type deduction misaligned in Hybrid Engine - [ ] Investigation: Using postproc only for q1 in Hybrid Engine (make is_special always on) - [ ] C++ Meta-Programming: Eliminate template recursions as much as possible. -- [ ] Functionality: Basic helper functions in aquery +- [x] Functionality: Basic helper functions in aquery - [ ] Bug: Join-Aware Column management - [ ] Bug: Order By after Group By +- [ ] Functionality: Having clause \ No newline at end of file diff --git a/aquery_config.py b/aquery_config.py index 9f55062..4155ea8 100644 --- a/aquery_config.py +++ b/aquery_config.py @@ -2,7 +2,7 @@ ## GLOBAL CONFIGURATION FLAGS -version_string = '0.4.7a' +version_string = '0.4.8a' add_path_to_ldpath = True rebuild_backend = False run_backend = True diff --git a/reconstruct/ast.py b/reconstruct/ast.py index cfd43f2..11e4c37 100644 --- a/reconstruct/ast.py +++ b/reconstruct/ast.py @@ -145,7 +145,10 @@ class projection(ast_node): if not proj_expr.is_special: if proj_expr.node == '*': - name = [c.get_full_name() for c in self.datasource.rec] + name = [(c.get_name() + if self.datasource.single_table + else c.get_full_name() + ) for c in self.datasource.rec] this_type = [c.type for c in self.datasource.rec] compound = [c.compound for c in self.datasource.rec] proj_expr = [expr(self, c.name) for c in self.datasource.rec] @@ -288,7 +291,9 @@ class projection(ast_node): self.group_node and (self.group_node.use_sp_gb and val[2].cols_mentioned.intersection( - self.datasource.all_cols().difference(self.group_node.refs)) + self.datasource.all_cols().difference( + self.datasource.get_joint_cols(self.group_node.refs) + )) ) and val[2].is_compound # compound val not in key # or # val[2].is_compound > 1 @@ -366,7 +371,7 @@ class select_into(ast_node): raise Exception('No out_table found.') else: self.context.headers.add('"./server/table_ext_monetdb.hpp"') - self.ccode = f'{self.parent.out_table.contextname_cpp}->monetdb_append_table(cxt->alt_server, \"{node}\");' + self.ccode = f'{self.parent.out_table.contextname_cpp}->monetdb_append_table(cxt->alt_server, \"{node.lower()}\");' def produce_sql(self, node): self.sql = f' INTO {node}' @@ -443,6 +448,7 @@ class groupby_c(ast_node): self.proj : projection = self.parent self.glist : List[Tuple[expr, Set[ColRef]]] = node return super().init(node) + def produce(self, node : List[Tuple[expr, Set[ColRef]]]): self.context.headers.add('"./server/hasher.h"') self.context.headers.add('unordered_map') @@ -505,9 +511,14 @@ class groupby_c(ast_node): gscanner.add(f'auto &{len_var} = {val_var}.size;', position = 'front') def get_key_idx (varname : str): + ex = expr(self, varname) + joint_cols = set() + if ex.is_ColExpr and ex.raw_col: + joint_cols = self.datasource.get_joint_cols([ex.raw_col]) for i, g in enumerate(self.glist): - if varname == g[0].eval(): - return i + if (varname == g[0].eval()) or (g[0].is_ColExpr and g[0].raw_col and + g[0].raw_col in joint_cols): + return i return var_table[varname] def get_var_names (varname : str): @@ -598,14 +609,15 @@ class groupby(ast_node): if self.use_sp_gb: self.dedicated_gb = groupby_c(self.parent, self.dedicated_glist) self.dedicated_gb.finalize(cexprs, var_table) - + + class join(ast_node): name = 'join' def get_joint_cols(self, cols : List[ColRef]): - joint_cols = set() + joint_cols = set(cols) for col in cols: - joint_cols |= self.joint_cols[col] + joint_cols |= self.joint_cols.get(col, set()) return joint_cols def init(self, _): @@ -661,21 +673,36 @@ class join(ast_node): tbl.add_alias(node['name']) self.append(tbl, alias) else: - keys = list(node.keys()) + keys : List[str] = list(node.keys()) if keys[0].lower().endswith('join'): self.have_sep = True j = join(self, node[keys[0]]) + self.tables += j.tables + self.tables_dir = {**self.tables_dir, **j.tables_dir} self.join_conditions += j.join_conditions + _tbl_union = _tmp_join_union(self.context, self.parent, self) tablename = f' {keys[0]} {j}' if len(keys) > 1 : - _ex = expr(self, node[keys[1]]) - if keys[1].lower() == 'on': + jcond = node[keys[1]] + sqoute = '\'' + if type(jcond) is list: + _ex = [expr(_tbl_union, j) for j in jcond] + else: + _ex = expr(_tbl_union, jcond) + if keys[1].lower() == 'on': # postpone join condition evaluation after consume self.join_conditions += _ex.join_conditions - tablename += f' ON {_ex}' + tablename += f" ON {_ex.eval().replace(sqoute, '')}" elif keys[1].lower() == 'using': - if _ex.is_ColExpr: - self.join_conditions.append( (_ex.raw_col, j.get_cols(_ex.raw_col.name)) ) - tablename += f' USING {_ex}' + _ex = enlist(_ex) + lst_jconds = [] + for _e in _ex: + if _e.is_ColExpr: + cl = _e.raw_col + cr = j.get_cols(_e.raw_col.name) + self.join_conditions.append( (cl, cr) ) + lst_jconds += [f'{cl.get_full_name()} = {cr.get_full_name()}'] + tablename += f' ON {" and ".join(lst_jconds)}' + if keys[0].lower().startswith('natural'): ltbls : List[TableInfo] = [] if isinstance(self.parent, join): @@ -688,8 +715,7 @@ class join(ast_node): if cr: self.join_conditions.append( (cl, cr) ) self.joins.append((tablename, self.have_sep)) - self.tables += j.tables - self.tables_dir = {**self.tables_dir, **j.tables_dir} + elif type(node) is str: if node in self.context.tables_byname: @@ -719,7 +745,11 @@ class join(ast_node): ret = datasource.parse_col_names(parsedColExpr[1]) datasource.rec = None return ret - + + @property + def single_table(self): + return len(self.tables) == 1 + # @property def all_cols(self): ret = set() @@ -734,7 +764,22 @@ class join(ast_node): def process_join_conditions(self): # This is done after both from # and where clause are processed - print(self.join_conditions) + for j in self.join_conditions: + l = j[0] + r = j[1] + for k in (0, 1): + if j[k] not in self.joint_cols: + self.joint_cols[j[k]] = set([l, r]) + jr = self.joint_cols[r] + jl = self.joint_cols[l] + if jl != jr: + jl |= jr + for c, jc in self.joint_cols.items(): + if jc == jr: + self.joint_cols[c] = jl + + # print(self.join_conditions) + # print(self.joint_cols) def consume(self, node): self.sql = '' @@ -753,7 +798,17 @@ class join(ast_node): def __repr__(self): return self.__str__() +class _tmp_join_union(join): + name = '__tmp_join_union' + def __init__(self, context, l: join, r: join): + self.tables = l.tables + r.tables + self.tables_dir = {**l.tables_dir, **r.tables_dir} + self.datasource = self + self.context = context + self.parent = self + self.rec = None + class filter(ast_node): name = 'where' def produce(self, node): @@ -771,7 +826,10 @@ class create_table(ast_node): raise ValueError("Table name not specified") projection_node = node['query'] projection_node['into'] = node['name'] - projection(None, projection_node, self.context) + proj_cls = (select_distinct + if 'select_distinct' in projection_node + else projection) + proj_cls(None, projection_node, self.context) self.produce = lambda *_: None self.spawn = lambda *_: None self.consume = lambda *_: None @@ -819,7 +877,10 @@ class insert(ast_node): complex_query_kw = ['from', 'where', 'groupby', 'having', 'orderby', 'limit'] if any([kw in values for kw in complex_query_kw]): values['into'] = node['insert'] - projection(None, values, self.context) + proj_cls = (select_distinct + if 'select_distinct' in values + else projection) + proj_cls(None, values, self.context) self.produce = lambda*_:None self.spawn = lambda*_:None self.consume = lambda*_:None diff --git a/reconstruct/expr.py b/reconstruct/expr.py index 2440430..ea2480c 100644 --- a/reconstruct/expr.py +++ b/reconstruct/expr.py @@ -78,10 +78,10 @@ class expr(ast_node): ast_node.__init__(self, parent, node, None) def init(self, _): - from reconstruct.ast import projection + from reconstruct.ast import projection, _tmp_join_union parent = self.parent self.is_compound = parent.is_compound if type(parent) is expr else False - if type(parent) in [projection, expr]: + if type(parent) in [projection, expr, _tmp_join_union]: self.datasource = parent.datasource else: self.datasource = self.context.datasource diff --git a/reconstruct/storage.py b/reconstruct/storage.py index 0ddafe1..ec5277f 100644 --- a/reconstruct/storage.py +++ b/reconstruct/storage.py @@ -21,6 +21,16 @@ class ColRef: self.__arr__ = (_ty, cobj, table, name, id) + def get_name(self): + it_alias = iter(self.alias) + alias = next(it_alias, self.name) + try: + while alias == self.name: + alias = next(it_alias) + except StopIteration: + alias = self.name + return alias + def get_full_name(self): table_name = self.table.table_name it_alias = iter(self.table.alias) @@ -30,7 +40,7 @@ class ColRef: alias = next(it_alias) except StopIteration: alias = table_name - return f'{alias}.{self.name}' + return f'{alias}.{self.get_name()}' def __getitem__(self, key): if type(key) is str: @@ -103,6 +113,10 @@ class TableInfo: self.rec.update(self.columns) return set(self.columns) + @property + def single_table(self): + return True + class Context: def new(self): self.headers = set(['\"./server/libaquery.h\"', diff --git a/tests/best_profit.a b/tests/best_profit.a index 4d242ec..907b1f5 100644 --- a/tests/best_profit.a +++ b/tests/best_profit.a @@ -41,6 +41,6 @@ from TradedStocks SELECT ID, avgs(10, ClosePrice) FROM td NATURAL JOIN - HistoricQuotes + HistoricQuotes hq ASSUMING ASC TradeDate -GROUP BY ID \ No newline at end of file +GROUP BY hq.ID \ No newline at end of file