From 31d823ac89b9fee2e4cfd99c9573f12ad0bfde80 Mon Sep 17 00:00:00 2001 From: Bill Date: Wed, 26 Oct 2022 05:37:49 +0800 Subject: [PATCH] fixed bugs wrt sp groupbys, insert multiple values --- csv.h | 2 +- engine/utils.py | 10 ++++ prompt.py | 2 +- reconstruct/ast.py | 116 ++++++++++++++++++++++++++++++++--------- reconstruct/expr.py | 13 ++++- reconstruct/storage.py | 5 ++ tests/simple2 | 20 +++++++ tests/udf5.a | 10 ++++ 8 files changed, 150 insertions(+), 28 deletions(-) create mode 100644 tests/simple2 create mode 100644 tests/udf5.a diff --git a/csv.h b/csv.h index c0d1762..6b10915 100644 --- a/csv.h +++ b/csv.h @@ -1125,7 +1125,7 @@ namespace io{ } template, class quote_policy = no_quote_escape<','>, class overflow_policy = throw_on_overflow, diff --git a/engine/utils.py b/engine/utils.py index dc7f2bc..8e65fcd 100644 --- a/engine/utils.py +++ b/engine/utils.py @@ -148,3 +148,13 @@ def clamp(val, minval, maxval): def escape_qoutes(string : str): return re.sub(r'^\'', r'\'',re.sub(r'([^\\])\'', r'\1\'', string)) + +def get_innermost(sl): + if sl and type(sl) is dict: + if 'literal' in sl and type(sl['literal']) is str: + return f"'{get_innermost(sl['literal'])}'" + return get_innermost(next(iter(sl.values()), None)) + elif sl and type(sl) is list: + return get_innermost(sl[0]) + else: + return sl \ No newline at end of file diff --git a/prompt.py b/prompt.py index c6a00dd..9c486f1 100644 --- a/prompt.py +++ b/prompt.py @@ -576,7 +576,7 @@ def prompt(running = lambda:True, next = lambda:input('> '), state = None): state.stats.print(clear = False) continue trimed = ws.sub(' ', og_q).split(' ') - if trimed[0].lower().startswith('f'): + if len(trimed) > 1 and trimed[0].lower().startswith('fi') or trimed[0].lower() == 'f': fn = 'stock.a' if len(trimed) <= 1 or len(trimed[1]) == 0 \ else trimed[1] try: diff --git a/reconstruct/ast.py b/reconstruct/ast.py index 173399b..b8228c1 100644 --- a/reconstruct/ast.py +++ b/reconstruct/ast.py @@ -1,10 +1,12 @@ +from binascii import Error from copy import deepcopy from dataclasses import dataclass from enum import Enum, auto from typing import Dict, List, Optional, Set, Tuple, Union from engine.types import * -from engine.utils import base62alp, base62uuid, enlist, get_legal_name +from engine.utils import (base62alp, base62uuid, enlist, get_innermost, + get_legal_name) from reconstruct.storage import ColRef, Context, TableInfo @@ -58,6 +60,15 @@ class projection(ast_node): name = 'projection' first_order = 'select' + def __init__(self, + parent : Optional["ast_node"], + node, + context : Optional[Context] = None, + force_use_spgb : bool = False + ): + self.force_use_spgb = force_use_spgb + super().__init__(parent, node, context) + def init(self, _): # skip default init pass @@ -104,7 +115,7 @@ class projection(ast_node): if type(self.datasource) is join: self.datasource.process_join_conditions() - if 'groupby' in node: + if 'groupby' in node: # if groupby clause contains special stuff self.context.special_gb = groupby.check_special(self, node['groupby']) def consume(self, node): @@ -163,6 +174,11 @@ class projection(ast_node): this_type = [c.type for c in _datasource] compound = [c.compound for c in _datasource] proj_expr = [expr(self, c.name) for c in _datasource] + for pe in proj_expr: + if pe.is_ColExpr: + pe.cols_mentioned = {pe.raw_col} + else: + pe.cols_mentioned = set() else: y = lambda x:x count = lambda : 'count(*)' @@ -208,8 +224,14 @@ class projection(ast_node): self.out_table.add_cols(cols, new = False) + self.proj_map = proj_map + if 'groupby' in node: self.group_node = groupby(self, node['groupby']) + if self.group_node.terminate: + self.context.abandon_query() + projection(self.parent, node, self.context, True) + return if self.group_node.use_sp_gb: self.has_postproc = True else: @@ -588,6 +610,10 @@ class groupby(ast_node): return True return False + def init(self, _): + self.terminate = False + super().init(_) + def produce(self, node): if not isinstance(self.parent, projection): raise ValueError('groupby can only be used in projection') @@ -595,6 +621,7 @@ class groupby(ast_node): node = enlist(node) o_list = [] self.refs = set() + self.gb_cols = set() self.dedicated_glist : List[Tuple[expr, Set[ColRef]]] = [] self.use_sp_gb = False for g in node: @@ -612,7 +639,23 @@ class groupby(ast_node): if 'sort' in g and f'{g["sort"]}'.lower() == 'desc': g_str = g_str + ' ' + 'DESC' o_list.append(g_str) - + if g_expr.is_ColExpr: + self.gb_cols.add(g_expr.raw_col) + else: + self.gb_cols.add(g_expr.sql) + + for projs in self.parent.proj_map.values(): + if self.use_sp_gb: + break + if (projs[2].is_compound and + not ((projs[2].is_ColExpr and projs[2].raw_col in self.gb_cols) or + projs[2].sql in self.gb_cols) + ): + if self.parent.force_use_spgb: + self.use_sp_gb = True + else: + self.terminate = True + return if not self.use_sp_gb: self.dedicated_gb = None self.add(', '.join(o_list)) @@ -917,35 +960,60 @@ class insert(ast_node): name = 'insert' first_order = name def init(self, node): - values = node['query'] - complex_query_kw = ['from', 'where', 'groupby', 'having', 'orderby', 'limit'] - if any([kw in values for kw in complex_query_kw]): - values['into'] = node['insert'] - proj_cls = (select_distinct - if 'select_distinct' in values - else projection) - proj_cls(None, values, self.context) - self.produce = lambda*_:None - self.spawn = lambda*_:None - self.consume = lambda*_:None + if 'query' in node: + values = node['query'] + complex_query_kw = ['from', 'where', 'groupby', 'having', 'orderby', 'limit'] + if any([kw in values for kw in complex_query_kw]): + values['into'] = node['insert'] + proj_cls = (select_distinct + if 'select_distinct' in values + else projection) + proj_cls(None, values, self.context) + self.produce = lambda*_:None + self.spawn = lambda*_:None + self.consume = lambda*_:None else: super().init(node) def produce(self, node): - values = node['query']['select'] + keys = [] + if 'query' in node: + if 'select' in node['query']: + values = enlist(node['query']['select']) + if 'columns' in node: + keys = node['columns'] + values = [v['value'] for v in values] + + elif 'union_all' in node['query']: + values = [[v['select']['value']] for v in node['query']['union_all']] + if 'columns' in node: + keys = node['columns'] + else: + values = enlist(node['values']) + _vals = [] + for v in values: + if isinstance(v, dict): + keys = v.keys() + v = list(v.values()) + _vals.append(v) + values = _vals + + keys = f'({", ".join(keys)})' if keys else '' tbl = node['insert'] - self.sql = f'INSERT INTO {tbl} VALUES(' + self.sql = f'INSERT INTO {tbl}{keys} VALUES' # if len(values) != table.n_cols: # raise ValueError("Column Mismatch") - + values = [values] if isinstance(values, list) and not isinstance(values[0], list) else values list_values = [] - for i, s in enumerate(enlist(values)): - if 'value' in s: - list_values.append(f"{s['value']}") - else: - # subquery, dispatch to select astnode - pass - self.sql += ', '.join(list_values) + ')' + for l in values: + inner_list_values = [] + for s in enlist(l): + if type(s) is dict and 'value' in s: + s = s['value'] + inner_list_values.append(f"{get_innermost(s)}") + list_values.append(f"({', '.join(inner_list_values)})") + + self.sql += ', '.join(list_values) class load(ast_node): diff --git a/reconstruct/expr.py b/reconstruct/expr.py index f1e3d5a..bfd552c 100644 --- a/reconstruct/expr.py +++ b/reconstruct/expr.py @@ -94,7 +94,7 @@ class expr(ast_node): def produce(self, node): from engine.utils import enlist - from reconstruct.ast import udf + from reconstruct.ast import udf, projection if type(node) is dict: if 'literal' in node: @@ -169,7 +169,16 @@ class expr(ast_node): special_func = [*self.context.udf_map.keys(), *self.context.module_map.keys(), "maxs", "mins", "avgs", "sums", "deltas", "last", "first", "ratios", "pack", "truncate"] - if self.context.special_gb: + + if ( + self.context.special_gb + or + ( + type(self.root.parent) is projection + and + self.root.parent.force_use_spgb + ) + ): special_func = [*special_func, *self.ext_aggfuncs] if key in special_func and not self.is_special: diff --git a/reconstruct/storage.py b/reconstruct/storage.py index 2873747..47eab9a 100644 --- a/reconstruct/storage.py +++ b/reconstruct/storage.py @@ -226,6 +226,11 @@ class Context: self.queries.append('P' + proc_name) self.finalize_query() + def abandon_query(self): + self.sql = '' + self.ccode = '' + self.finalize_query() + def finalize_udf(self): if self.udf is not None: return (Context.udf_head diff --git a/tests/simple2 b/tests/simple2 new file mode 100644 index 0000000..d8f3d8c --- /dev/null +++ b/tests/simple2 @@ -0,0 +1,20 @@ +CREATE TABLE t(indiv INT, grp STRING, val INT) +INSERT INTO t VALUES(1, 'A', 1) +INSERT INTO t VALUES(1, 'A', 2) +INSERT INTO t VALUES(1, 'A', 3) +INSERT INTO t VALUES(1, 'A', 4) +INSERT INTO t VALUES(2, 'A', 2) +INSERT INTO t VALUES(2, 'A', 2) +INSERT INTO t VALUES(2, 'A', 4) +INSERT INTO t VALUES(2, 'A', 8) +INSERT INTO t VALUES(3, 'B', 10) +INSERT INTO t VALUES(3, 'B', 20) +INSERT INTO t VALUES(3, 'B', 30) +INSERT INTO t VALUES(3, 'B', 40) +INSERT INTO t VALUES(4, 'B', 20) +INSERT INTO t VALUES(4, 'B', 20) +INSERT INTO t VALUES(4, 'B', 40) +INSERT INTO t VALUES(4, 'B', 80) + + +SELECT * FROM t \ No newline at end of file diff --git a/tests/udf5.a b/tests/udf5.a new file mode 100644 index 0000000..89e1d0f --- /dev/null +++ b/tests/udf5.a @@ -0,0 +1,10 @@ +FUNCTION myCov(x, y) { + center_x := x - avg(x); + center_y := y - avg(y); + num := sum(center_x * center_y); + denom := sqrt(sum(center_x * center_x)) * sqrt(sum(center_y * center_y)); + num / denom + } + + +select myCov(1,2); \ No newline at end of file