fixed bugs wrt sp groupbys, insert multiple values

dev
Bill 2 years ago
parent 6ad9ba6ea3
commit 31d823ac89

@ -1125,7 +1125,7 @@ namespace io{
} }
template<unsigned column_count, template<unsigned column_count,
char sep2 = -2, char sep2 = ';',
class trim_policy = trim_chars<' ', '\t'>, class trim_policy = trim_chars<' ', '\t'>,
class quote_policy = no_quote_escape<','>, class quote_policy = no_quote_escape<','>,
class overflow_policy = throw_on_overflow, class overflow_policy = throw_on_overflow,

@ -148,3 +148,13 @@ def clamp(val, minval, maxval):
def escape_qoutes(string : str): def escape_qoutes(string : str):
return re.sub(r'^\'', r'\'',re.sub(r'([^\\])\'', r'\1\'', string)) return re.sub(r'^\'', r'\'',re.sub(r'([^\\])\'', r'\1\'', string))
def get_innermost(sl):
if sl and type(sl) is dict:
if 'literal' in sl and type(sl['literal']) is str:
return f"'{get_innermost(sl['literal'])}'"
return get_innermost(next(iter(sl.values()), None))
elif sl and type(sl) is list:
return get_innermost(sl[0])
else:
return sl

@ -576,7 +576,7 @@ def prompt(running = lambda:True, next = lambda:input('> '), state = None):
state.stats.print(clear = False) state.stats.print(clear = False)
continue continue
trimed = ws.sub(' ', og_q).split(' ') trimed = ws.sub(' ', og_q).split(' ')
if trimed[0].lower().startswith('f'): if len(trimed) > 1 and trimed[0].lower().startswith('fi') or trimed[0].lower() == 'f':
fn = 'stock.a' if len(trimed) <= 1 or len(trimed[1]) == 0 \ fn = 'stock.a' if len(trimed) <= 1 or len(trimed[1]) == 0 \
else trimed[1] else trimed[1]
try: try:

@ -1,10 +1,12 @@
from binascii import Error
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from typing import Dict, List, Optional, Set, Tuple, Union from typing import Dict, List, Optional, Set, Tuple, Union
from engine.types import * from engine.types import *
from engine.utils import base62alp, base62uuid, enlist, get_legal_name from engine.utils import (base62alp, base62uuid, enlist, get_innermost,
get_legal_name)
from reconstruct.storage import ColRef, Context, TableInfo from reconstruct.storage import ColRef, Context, TableInfo
@ -58,6 +60,15 @@ class projection(ast_node):
name = 'projection' name = 'projection'
first_order = 'select' first_order = 'select'
def __init__(self,
parent : Optional["ast_node"],
node,
context : Optional[Context] = None,
force_use_spgb : bool = False
):
self.force_use_spgb = force_use_spgb
super().__init__(parent, node, context)
def init(self, _): def init(self, _):
# skip default init # skip default init
pass pass
@ -104,7 +115,7 @@ class projection(ast_node):
if type(self.datasource) is join: if type(self.datasource) is join:
self.datasource.process_join_conditions() self.datasource.process_join_conditions()
if 'groupby' in node: if 'groupby' in node: # if groupby clause contains special stuff
self.context.special_gb = groupby.check_special(self, node['groupby']) self.context.special_gb = groupby.check_special(self, node['groupby'])
def consume(self, node): def consume(self, node):
@ -163,6 +174,11 @@ class projection(ast_node):
this_type = [c.type for c in _datasource] this_type = [c.type for c in _datasource]
compound = [c.compound for c in _datasource] compound = [c.compound for c in _datasource]
proj_expr = [expr(self, c.name) for c in _datasource] proj_expr = [expr(self, c.name) for c in _datasource]
for pe in proj_expr:
if pe.is_ColExpr:
pe.cols_mentioned = {pe.raw_col}
else:
pe.cols_mentioned = set()
else: else:
y = lambda x:x y = lambda x:x
count = lambda : 'count(*)' count = lambda : 'count(*)'
@ -208,8 +224,14 @@ class projection(ast_node):
self.out_table.add_cols(cols, new = False) self.out_table.add_cols(cols, new = False)
self.proj_map = proj_map
if 'groupby' in node: if 'groupby' in node:
self.group_node = groupby(self, node['groupby']) self.group_node = groupby(self, node['groupby'])
if self.group_node.terminate:
self.context.abandon_query()
projection(self.parent, node, self.context, True)
return
if self.group_node.use_sp_gb: if self.group_node.use_sp_gb:
self.has_postproc = True self.has_postproc = True
else: else:
@ -588,6 +610,10 @@ class groupby(ast_node):
return True return True
return False return False
def init(self, _):
self.terminate = False
super().init(_)
def produce(self, node): def produce(self, node):
if not isinstance(self.parent, projection): if not isinstance(self.parent, projection):
raise ValueError('groupby can only be used in projection') raise ValueError('groupby can only be used in projection')
@ -595,6 +621,7 @@ class groupby(ast_node):
node = enlist(node) node = enlist(node)
o_list = [] o_list = []
self.refs = set() self.refs = set()
self.gb_cols = set()
self.dedicated_glist : List[Tuple[expr, Set[ColRef]]] = [] self.dedicated_glist : List[Tuple[expr, Set[ColRef]]] = []
self.use_sp_gb = False self.use_sp_gb = False
for g in node: for g in node:
@ -612,7 +639,23 @@ class groupby(ast_node):
if 'sort' in g and f'{g["sort"]}'.lower() == 'desc': if 'sort' in g and f'{g["sort"]}'.lower() == 'desc':
g_str = g_str + ' ' + 'DESC' g_str = g_str + ' ' + 'DESC'
o_list.append(g_str) o_list.append(g_str)
if g_expr.is_ColExpr:
self.gb_cols.add(g_expr.raw_col)
else:
self.gb_cols.add(g_expr.sql)
for projs in self.parent.proj_map.values():
if self.use_sp_gb:
break
if (projs[2].is_compound and
not ((projs[2].is_ColExpr and projs[2].raw_col in self.gb_cols) or
projs[2].sql in self.gb_cols)
):
if self.parent.force_use_spgb:
self.use_sp_gb = True
else:
self.terminate = True
return
if not self.use_sp_gb: if not self.use_sp_gb:
self.dedicated_gb = None self.dedicated_gb = None
self.add(', '.join(o_list)) self.add(', '.join(o_list))
@ -917,35 +960,60 @@ class insert(ast_node):
name = 'insert' name = 'insert'
first_order = name first_order = name
def init(self, node): def init(self, node):
values = node['query'] if 'query' in node:
complex_query_kw = ['from', 'where', 'groupby', 'having', 'orderby', 'limit'] values = node['query']
if any([kw in values for kw in complex_query_kw]): complex_query_kw = ['from', 'where', 'groupby', 'having', 'orderby', 'limit']
values['into'] = node['insert'] if any([kw in values for kw in complex_query_kw]):
proj_cls = (select_distinct values['into'] = node['insert']
if 'select_distinct' in values proj_cls = (select_distinct
else projection) if 'select_distinct' in values
proj_cls(None, values, self.context) else projection)
self.produce = lambda*_:None proj_cls(None, values, self.context)
self.spawn = lambda*_:None self.produce = lambda*_:None
self.consume = lambda*_:None self.spawn = lambda*_:None
self.consume = lambda*_:None
else: else:
super().init(node) super().init(node)
def produce(self, node): def produce(self, node):
values = node['query']['select'] keys = []
if 'query' in node:
if 'select' in node['query']:
values = enlist(node['query']['select'])
if 'columns' in node:
keys = node['columns']
values = [v['value'] for v in values]
elif 'union_all' in node['query']:
values = [[v['select']['value']] for v in node['query']['union_all']]
if 'columns' in node:
keys = node['columns']
else:
values = enlist(node['values'])
_vals = []
for v in values:
if isinstance(v, dict):
keys = v.keys()
v = list(v.values())
_vals.append(v)
values = _vals
keys = f'({", ".join(keys)})' if keys else ''
tbl = node['insert'] tbl = node['insert']
self.sql = f'INSERT INTO {tbl} VALUES(' self.sql = f'INSERT INTO {tbl}{keys} VALUES'
# if len(values) != table.n_cols: # if len(values) != table.n_cols:
# raise ValueError("Column Mismatch") # raise ValueError("Column Mismatch")
values = [values] if isinstance(values, list) and not isinstance(values[0], list) else values
list_values = [] list_values = []
for i, s in enumerate(enlist(values)): for l in values:
if 'value' in s: inner_list_values = []
list_values.append(f"{s['value']}") for s in enlist(l):
else: if type(s) is dict and 'value' in s:
# subquery, dispatch to select astnode s = s['value']
pass inner_list_values.append(f"{get_innermost(s)}")
self.sql += ', '.join(list_values) + ')' list_values.append(f"({', '.join(inner_list_values)})")
self.sql += ', '.join(list_values)
class load(ast_node): class load(ast_node):

@ -94,7 +94,7 @@ class expr(ast_node):
def produce(self, node): def produce(self, node):
from engine.utils import enlist from engine.utils import enlist
from reconstruct.ast import udf from reconstruct.ast import udf, projection
if type(node) is dict: if type(node) is dict:
if 'literal' in node: if 'literal' in node:
@ -169,7 +169,16 @@ class expr(ast_node):
special_func = [*self.context.udf_map.keys(), *self.context.module_map.keys(), special_func = [*self.context.udf_map.keys(), *self.context.module_map.keys(),
"maxs", "mins", "avgs", "sums", "deltas", "last", "first", "maxs", "mins", "avgs", "sums", "deltas", "last", "first",
"ratios", "pack", "truncate"] "ratios", "pack", "truncate"]
if self.context.special_gb:
if (
self.context.special_gb
or
(
type(self.root.parent) is projection
and
self.root.parent.force_use_spgb
)
):
special_func = [*special_func, *self.ext_aggfuncs] special_func = [*special_func, *self.ext_aggfuncs]
if key in special_func and not self.is_special: if key in special_func and not self.is_special:

@ -226,6 +226,11 @@ class Context:
self.queries.append('P' + proc_name) self.queries.append('P' + proc_name)
self.finalize_query() self.finalize_query()
def abandon_query(self):
self.sql = ''
self.ccode = ''
self.finalize_query()
def finalize_udf(self): def finalize_udf(self):
if self.udf is not None: if self.udf is not None:
return (Context.udf_head return (Context.udf_head

@ -0,0 +1,20 @@
CREATE TABLE t(indiv INT, grp STRING, val INT)
INSERT INTO t VALUES(1, 'A', 1)
INSERT INTO t VALUES(1, 'A', 2)
INSERT INTO t VALUES(1, 'A', 3)
INSERT INTO t VALUES(1, 'A', 4)
INSERT INTO t VALUES(2, 'A', 2)
INSERT INTO t VALUES(2, 'A', 2)
INSERT INTO t VALUES(2, 'A', 4)
INSERT INTO t VALUES(2, 'A', 8)
INSERT INTO t VALUES(3, 'B', 10)
INSERT INTO t VALUES(3, 'B', 20)
INSERT INTO t VALUES(3, 'B', 30)
INSERT INTO t VALUES(3, 'B', 40)
INSERT INTO t VALUES(4, 'B', 20)
INSERT INTO t VALUES(4, 'B', 20)
INSERT INTO t VALUES(4, 'B', 40)
INSERT INTO t VALUES(4, 'B', 80)
SELECT * FROM t

@ -0,0 +1,10 @@
FUNCTION myCov(x, y) {
center_x := x - avg(x);
center_y := y - avg(y);
num := sum(center_x * center_y);
denom := sqrt(sum(center_x * center_x)) * sqrt(sum(center_y * center_y));
num / denom
}
select myCov(1,2);
Loading…
Cancel
Save