fixed bugs wrt sp groupbys, insert multiple values

dev
Bill 2 years ago
parent 6ad9ba6ea3
commit 31d823ac89

@ -1125,7 +1125,7 @@ namespace io{
}
template<unsigned column_count,
char sep2 = -2,
char sep2 = ';',
class trim_policy = trim_chars<' ', '\t'>,
class quote_policy = no_quote_escape<','>,
class overflow_policy = throw_on_overflow,

@ -148,3 +148,13 @@ def clamp(val, minval, maxval):
def escape_qoutes(string : str):
return re.sub(r'^\'', r'\'',re.sub(r'([^\\])\'', r'\1\'', string))
def get_innermost(sl):
if sl and type(sl) is dict:
if 'literal' in sl and type(sl['literal']) is str:
return f"'{get_innermost(sl['literal'])}'"
return get_innermost(next(iter(sl.values()), None))
elif sl and type(sl) is list:
return get_innermost(sl[0])
else:
return sl

@ -576,7 +576,7 @@ def prompt(running = lambda:True, next = lambda:input('> '), state = None):
state.stats.print(clear = False)
continue
trimed = ws.sub(' ', og_q).split(' ')
if trimed[0].lower().startswith('f'):
if len(trimed) > 1 and trimed[0].lower().startswith('fi') or trimed[0].lower() == 'f':
fn = 'stock.a' if len(trimed) <= 1 or len(trimed[1]) == 0 \
else trimed[1]
try:

@ -1,10 +1,12 @@
from binascii import Error
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum, auto
from typing import Dict, List, Optional, Set, Tuple, Union
from engine.types import *
from engine.utils import base62alp, base62uuid, enlist, get_legal_name
from engine.utils import (base62alp, base62uuid, enlist, get_innermost,
get_legal_name)
from reconstruct.storage import ColRef, Context, TableInfo
@ -58,6 +60,15 @@ class projection(ast_node):
name = 'projection'
first_order = 'select'
def __init__(self,
parent : Optional["ast_node"],
node,
context : Optional[Context] = None,
force_use_spgb : bool = False
):
self.force_use_spgb = force_use_spgb
super().__init__(parent, node, context)
def init(self, _):
# skip default init
pass
@ -104,7 +115,7 @@ class projection(ast_node):
if type(self.datasource) is join:
self.datasource.process_join_conditions()
if 'groupby' in node:
if 'groupby' in node: # if groupby clause contains special stuff
self.context.special_gb = groupby.check_special(self, node['groupby'])
def consume(self, node):
@ -163,6 +174,11 @@ class projection(ast_node):
this_type = [c.type for c in _datasource]
compound = [c.compound for c in _datasource]
proj_expr = [expr(self, c.name) for c in _datasource]
for pe in proj_expr:
if pe.is_ColExpr:
pe.cols_mentioned = {pe.raw_col}
else:
pe.cols_mentioned = set()
else:
y = lambda x:x
count = lambda : 'count(*)'
@ -208,8 +224,14 @@ class projection(ast_node):
self.out_table.add_cols(cols, new = False)
self.proj_map = proj_map
if 'groupby' in node:
self.group_node = groupby(self, node['groupby'])
if self.group_node.terminate:
self.context.abandon_query()
projection(self.parent, node, self.context, True)
return
if self.group_node.use_sp_gb:
self.has_postproc = True
else:
@ -588,6 +610,10 @@ class groupby(ast_node):
return True
return False
def init(self, _):
self.terminate = False
super().init(_)
def produce(self, node):
if not isinstance(self.parent, projection):
raise ValueError('groupby can only be used in projection')
@ -595,6 +621,7 @@ class groupby(ast_node):
node = enlist(node)
o_list = []
self.refs = set()
self.gb_cols = set()
self.dedicated_glist : List[Tuple[expr, Set[ColRef]]] = []
self.use_sp_gb = False
for g in node:
@ -612,7 +639,23 @@ class groupby(ast_node):
if 'sort' in g and f'{g["sort"]}'.lower() == 'desc':
g_str = g_str + ' ' + 'DESC'
o_list.append(g_str)
if g_expr.is_ColExpr:
self.gb_cols.add(g_expr.raw_col)
else:
self.gb_cols.add(g_expr.sql)
for projs in self.parent.proj_map.values():
if self.use_sp_gb:
break
if (projs[2].is_compound and
not ((projs[2].is_ColExpr and projs[2].raw_col in self.gb_cols) or
projs[2].sql in self.gb_cols)
):
if self.parent.force_use_spgb:
self.use_sp_gb = True
else:
self.terminate = True
return
if not self.use_sp_gb:
self.dedicated_gb = None
self.add(', '.join(o_list))
@ -917,6 +960,7 @@ class insert(ast_node):
name = 'insert'
first_order = name
def init(self, node):
if 'query' in node:
values = node['query']
complex_query_kw = ['from', 'where', 'groupby', 'having', 'orderby', 'limit']
if any([kw in values for kw in complex_query_kw]):
@ -932,20 +976,44 @@ class insert(ast_node):
super().init(node)
def produce(self, node):
values = node['query']['select']
keys = []
if 'query' in node:
if 'select' in node['query']:
values = enlist(node['query']['select'])
if 'columns' in node:
keys = node['columns']
values = [v['value'] for v in values]
elif 'union_all' in node['query']:
values = [[v['select']['value']] for v in node['query']['union_all']]
if 'columns' in node:
keys = node['columns']
else:
values = enlist(node['values'])
_vals = []
for v in values:
if isinstance(v, dict):
keys = v.keys()
v = list(v.values())
_vals.append(v)
values = _vals
keys = f'({", ".join(keys)})' if keys else ''
tbl = node['insert']
self.sql = f'INSERT INTO {tbl} VALUES('
self.sql = f'INSERT INTO {tbl}{keys} VALUES'
# if len(values) != table.n_cols:
# raise ValueError("Column Mismatch")
values = [values] if isinstance(values, list) and not isinstance(values[0], list) else values
list_values = []
for i, s in enumerate(enlist(values)):
if 'value' in s:
list_values.append(f"{s['value']}")
else:
# subquery, dispatch to select astnode
pass
self.sql += ', '.join(list_values) + ')'
for l in values:
inner_list_values = []
for s in enlist(l):
if type(s) is dict and 'value' in s:
s = s['value']
inner_list_values.append(f"{get_innermost(s)}")
list_values.append(f"({', '.join(inner_list_values)})")
self.sql += ', '.join(list_values)
class load(ast_node):

@ -94,7 +94,7 @@ class expr(ast_node):
def produce(self, node):
from engine.utils import enlist
from reconstruct.ast import udf
from reconstruct.ast import udf, projection
if type(node) is dict:
if 'literal' in node:
@ -169,7 +169,16 @@ class expr(ast_node):
special_func = [*self.context.udf_map.keys(), *self.context.module_map.keys(),
"maxs", "mins", "avgs", "sums", "deltas", "last", "first",
"ratios", "pack", "truncate"]
if self.context.special_gb:
if (
self.context.special_gb
or
(
type(self.root.parent) is projection
and
self.root.parent.force_use_spgb
)
):
special_func = [*special_func, *self.ext_aggfuncs]
if key in special_func and not self.is_special:

@ -226,6 +226,11 @@ class Context:
self.queries.append('P' + proc_name)
self.finalize_query()
def abandon_query(self):
self.sql = ''
self.ccode = ''
self.finalize_query()
def finalize_udf(self):
if self.udf is not None:
return (Context.udf_head

@ -0,0 +1,20 @@
CREATE TABLE t(indiv INT, grp STRING, val INT)
INSERT INTO t VALUES(1, 'A', 1)
INSERT INTO t VALUES(1, 'A', 2)
INSERT INTO t VALUES(1, 'A', 3)
INSERT INTO t VALUES(1, 'A', 4)
INSERT INTO t VALUES(2, 'A', 2)
INSERT INTO t VALUES(2, 'A', 2)
INSERT INTO t VALUES(2, 'A', 4)
INSERT INTO t VALUES(2, 'A', 8)
INSERT INTO t VALUES(3, 'B', 10)
INSERT INTO t VALUES(3, 'B', 20)
INSERT INTO t VALUES(3, 'B', 30)
INSERT INTO t VALUES(3, 'B', 40)
INSERT INTO t VALUES(4, 'B', 20)
INSERT INTO t VALUES(4, 'B', 20)
INSERT INTO t VALUES(4, 'B', 40)
INSERT INTO t VALUES(4, 'B', 80)
SELECT * FROM t

@ -0,0 +1,10 @@
FUNCTION myCov(x, y) {
center_x := x - avg(x);
center_y := y - avg(y);
num := sum(center_x * center_y);
denom := sqrt(sum(center_x * center_x)) * sqrt(sum(center_y * center_y));
num / denom
}
select myCov(1,2);
Loading…
Cancel
Save