|
|
|
@ -1,10 +1,12 @@
|
|
|
|
|
from binascii import Error
|
|
|
|
|
from copy import deepcopy
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from enum import Enum, auto
|
|
|
|
|
from typing import Dict, List, Optional, Set, Tuple, Union
|
|
|
|
|
|
|
|
|
|
from engine.types import *
|
|
|
|
|
from engine.utils import base62alp, base62uuid, enlist, get_legal_name
|
|
|
|
|
from engine.utils import (base62alp, base62uuid, enlist, get_innermost,
|
|
|
|
|
get_legal_name)
|
|
|
|
|
from reconstruct.storage import ColRef, Context, TableInfo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -58,6 +60,15 @@ class projection(ast_node):
|
|
|
|
|
name = 'projection'
|
|
|
|
|
first_order = 'select'
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
parent : Optional["ast_node"],
|
|
|
|
|
node,
|
|
|
|
|
context : Optional[Context] = None,
|
|
|
|
|
force_use_spgb : bool = False
|
|
|
|
|
):
|
|
|
|
|
self.force_use_spgb = force_use_spgb
|
|
|
|
|
super().__init__(parent, node, context)
|
|
|
|
|
|
|
|
|
|
def init(self, _):
|
|
|
|
|
# skip default init
|
|
|
|
|
pass
|
|
|
|
@ -104,7 +115,7 @@ class projection(ast_node):
|
|
|
|
|
if type(self.datasource) is join:
|
|
|
|
|
self.datasource.process_join_conditions()
|
|
|
|
|
|
|
|
|
|
if 'groupby' in node:
|
|
|
|
|
if 'groupby' in node: # if groupby clause contains special stuff
|
|
|
|
|
self.context.special_gb = groupby.check_special(self, node['groupby'])
|
|
|
|
|
|
|
|
|
|
def consume(self, node):
|
|
|
|
@ -163,6 +174,11 @@ class projection(ast_node):
|
|
|
|
|
this_type = [c.type for c in _datasource]
|
|
|
|
|
compound = [c.compound for c in _datasource]
|
|
|
|
|
proj_expr = [expr(self, c.name) for c in _datasource]
|
|
|
|
|
for pe in proj_expr:
|
|
|
|
|
if pe.is_ColExpr:
|
|
|
|
|
pe.cols_mentioned = {pe.raw_col}
|
|
|
|
|
else:
|
|
|
|
|
pe.cols_mentioned = set()
|
|
|
|
|
else:
|
|
|
|
|
y = lambda x:x
|
|
|
|
|
count = lambda : 'count(*)'
|
|
|
|
@ -208,8 +224,14 @@ class projection(ast_node):
|
|
|
|
|
|
|
|
|
|
self.out_table.add_cols(cols, new = False)
|
|
|
|
|
|
|
|
|
|
self.proj_map = proj_map
|
|
|
|
|
|
|
|
|
|
if 'groupby' in node:
|
|
|
|
|
self.group_node = groupby(self, node['groupby'])
|
|
|
|
|
if self.group_node.terminate:
|
|
|
|
|
self.context.abandon_query()
|
|
|
|
|
projection(self.parent, node, self.context, True)
|
|
|
|
|
return
|
|
|
|
|
if self.group_node.use_sp_gb:
|
|
|
|
|
self.has_postproc = True
|
|
|
|
|
else:
|
|
|
|
@ -588,6 +610,10 @@ class groupby(ast_node):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def init(self, _):
|
|
|
|
|
self.terminate = False
|
|
|
|
|
super().init(_)
|
|
|
|
|
|
|
|
|
|
def produce(self, node):
|
|
|
|
|
if not isinstance(self.parent, projection):
|
|
|
|
|
raise ValueError('groupby can only be used in projection')
|
|
|
|
@ -595,6 +621,7 @@ class groupby(ast_node):
|
|
|
|
|
node = enlist(node)
|
|
|
|
|
o_list = []
|
|
|
|
|
self.refs = set()
|
|
|
|
|
self.gb_cols = set()
|
|
|
|
|
self.dedicated_glist : List[Tuple[expr, Set[ColRef]]] = []
|
|
|
|
|
self.use_sp_gb = False
|
|
|
|
|
for g in node:
|
|
|
|
@ -612,7 +639,23 @@ class groupby(ast_node):
|
|
|
|
|
if 'sort' in g and f'{g["sort"]}'.lower() == 'desc':
|
|
|
|
|
g_str = g_str + ' ' + 'DESC'
|
|
|
|
|
o_list.append(g_str)
|
|
|
|
|
if g_expr.is_ColExpr:
|
|
|
|
|
self.gb_cols.add(g_expr.raw_col)
|
|
|
|
|
else:
|
|
|
|
|
self.gb_cols.add(g_expr.sql)
|
|
|
|
|
|
|
|
|
|
for projs in self.parent.proj_map.values():
|
|
|
|
|
if self.use_sp_gb:
|
|
|
|
|
break
|
|
|
|
|
if (projs[2].is_compound and
|
|
|
|
|
not ((projs[2].is_ColExpr and projs[2].raw_col in self.gb_cols) or
|
|
|
|
|
projs[2].sql in self.gb_cols)
|
|
|
|
|
):
|
|
|
|
|
if self.parent.force_use_spgb:
|
|
|
|
|
self.use_sp_gb = True
|
|
|
|
|
else:
|
|
|
|
|
self.terminate = True
|
|
|
|
|
return
|
|
|
|
|
if not self.use_sp_gb:
|
|
|
|
|
self.dedicated_gb = None
|
|
|
|
|
self.add(', '.join(o_list))
|
|
|
|
@ -917,6 +960,7 @@ class insert(ast_node):
|
|
|
|
|
name = 'insert'
|
|
|
|
|
first_order = name
|
|
|
|
|
def init(self, node):
|
|
|
|
|
if 'query' in node:
|
|
|
|
|
values = node['query']
|
|
|
|
|
complex_query_kw = ['from', 'where', 'groupby', 'having', 'orderby', 'limit']
|
|
|
|
|
if any([kw in values for kw in complex_query_kw]):
|
|
|
|
@ -932,20 +976,44 @@ class insert(ast_node):
|
|
|
|
|
super().init(node)
|
|
|
|
|
|
|
|
|
|
def produce(self, node):
|
|
|
|
|
values = node['query']['select']
|
|
|
|
|
keys = []
|
|
|
|
|
if 'query' in node:
|
|
|
|
|
if 'select' in node['query']:
|
|
|
|
|
values = enlist(node['query']['select'])
|
|
|
|
|
if 'columns' in node:
|
|
|
|
|
keys = node['columns']
|
|
|
|
|
values = [v['value'] for v in values]
|
|
|
|
|
|
|
|
|
|
elif 'union_all' in node['query']:
|
|
|
|
|
values = [[v['select']['value']] for v in node['query']['union_all']]
|
|
|
|
|
if 'columns' in node:
|
|
|
|
|
keys = node['columns']
|
|
|
|
|
else:
|
|
|
|
|
values = enlist(node['values'])
|
|
|
|
|
_vals = []
|
|
|
|
|
for v in values:
|
|
|
|
|
if isinstance(v, dict):
|
|
|
|
|
keys = v.keys()
|
|
|
|
|
v = list(v.values())
|
|
|
|
|
_vals.append(v)
|
|
|
|
|
values = _vals
|
|
|
|
|
|
|
|
|
|
keys = f'({", ".join(keys)})' if keys else ''
|
|
|
|
|
tbl = node['insert']
|
|
|
|
|
self.sql = f'INSERT INTO {tbl} VALUES('
|
|
|
|
|
self.sql = f'INSERT INTO {tbl}{keys} VALUES'
|
|
|
|
|
# if len(values) != table.n_cols:
|
|
|
|
|
# raise ValueError("Column Mismatch")
|
|
|
|
|
|
|
|
|
|
values = [values] if isinstance(values, list) and not isinstance(values[0], list) else values
|
|
|
|
|
list_values = []
|
|
|
|
|
for i, s in enumerate(enlist(values)):
|
|
|
|
|
if 'value' in s:
|
|
|
|
|
list_values.append(f"{s['value']}")
|
|
|
|
|
else:
|
|
|
|
|
# subquery, dispatch to select astnode
|
|
|
|
|
pass
|
|
|
|
|
self.sql += ', '.join(list_values) + ')'
|
|
|
|
|
for l in values:
|
|
|
|
|
inner_list_values = []
|
|
|
|
|
for s in enlist(l):
|
|
|
|
|
if type(s) is dict and 'value' in s:
|
|
|
|
|
s = s['value']
|
|
|
|
|
inner_list_values.append(f"{get_innermost(s)}")
|
|
|
|
|
list_values.append(f"({', '.join(inner_list_values)})")
|
|
|
|
|
|
|
|
|
|
self.sql += ', '.join(list_values)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class load(ast_node):
|
|
|
|
|