WIP: new decoupled expr implementation

dev
Bill 2 years ago
parent ce2dd054e3
commit 522e9e267b

@ -589,13 +589,22 @@ class groupby(ast_node):
class join(ast_node): class join(ast_node):
name = 'join' name = 'join'
def get_joint_cols(self, cols : List[ColRef]):
joint_cols = set()
for col in cols:
joint_cols |= self.joint_cols[col]
return joint_cols
def init(self, _): def init(self, _):
self.joins:list = [] self.joins : List[join] = []
self.tables : List[TableInfo] = [] self.tables : List[TableInfo] = []
self.tables_dir = dict() self.tables_dir = dict()
self.rec = None self.rec = None
self.top_level = self.parent and isinstance(self.parent, projection) self.top_level = self.parent and isinstance(self.parent, projection)
self.have_sep = False self.have_sep = False
self.joint_cols : Dict[ColRef, Set]= dict() # columns that are joined with this column
# self.tmp_name = 'join_' + base62uuid(4) # self.tmp_name = 'join_' + base62uuid(4)
# self.datasource = TableInfo(self.tmp_name, [], self.context) # self.datasource = TableInfo(self.tmp_name, [], self.context)
def append(self, tbls, __alias = ''): def append(self, tbls, __alias = ''):
@ -690,13 +699,19 @@ class join(ast_node):
self.sql = '' self.sql = ''
for j in self.joins: for j in self.joins:
if not self.sql or j[1]: if not self.sql or j[1]:
self.sql += j[0] self.sql += j[0] # using JOIN keyword
else: else:
self.sql += ', ' + j[0] self.sql += ', ' + j[0] # using comma
for col, jc in j.joint_cols:
if col in self.joint_cols:
self.joint_cols[col] |= jc
else:
self.joint_cols[col] = set(jc)
if node and self.sql and self.top_level: if node and self.sql and self.top_level:
self.sql = ' FROM ' + self.sql self.sql = ' FROM ' + self.sql
return super().consume(node) return super().consume(node)
def __str__(self): def __str__(self):
return self.sql return self.sql
def __repr__(self): def __repr__(self):

@ -0,0 +1,66 @@
import abc
from reconstruct.ast import ast_node
from typing import Optional
from reconstruct.storage import Context, ColRef
from engine.utils import enlist
from engine.types import builtin_func, user_module_func, builtin_operators
class expr_base(ast_node, metaclass = abc.ABCMeta):
def __init__(self, parent: Optional["ast_node"], node, context: Optional[Context] = None):
self.node = node
super().__init__(parent, node, context)
def init(self, node):
self.is_literal = False
self.udf_map = self.context.udf_map
self.func_maps = {**builtin_func, **self.udf_map, **user_module_func}
self.operators = {**builtin_operators, **self.udf_map, **user_module_func}
self.narrow_funcs = ['sum', 'avg', 'count', 'min', 'max', 'last']
def get_literal(self, node):
self.is_literal = True
def process_child_nodes(self):
pass
def produce(self, node):
from reconstruct.ast import udf
if node and type(node) is dict:
if 'litral' in node:
self.get_literal(node['literal'])
else:
if len(node) > 1:
raise ValueError(f'Parse Error: more than 1 entry in {node}.')
key, val = next(iter(node.items()))
if key in self.operators:
self.child_exprs = [__class__(self, v) for v in val]
self.process_child_nodes()
def consume(self, _):
pass
class c_expr(expr_base):
pass
class sql_expr(expr_base):
pass
class udf_expr(c_expr):
pass
class proj_expr(c_expr, sql_expr):
pass
class orderby_expr(c_expr, sql_expr):
pass
class groupby_expr(orderby_expr):
pass
class from_expr(sql_expr):
pass
class where_expr(sql_expr):
pass

@ -47,7 +47,7 @@ class TableInfo:
self.table_name : str = table_name self.table_name : str = table_name
self.contextname_cpp : str = '' self.contextname_cpp : str = ''
self.alias : Set[str] = set([table_name]) self.alias : Set[str] = set([table_name])
self.columns_byname : Dict[str, ColRef] = CaseInsensitiveDict() # column_name, type self.columns_byname : CaseInsensitiveDict[str, ColRef] = CaseInsensitiveDict() # column_name, type
self.columns : List[ColRef] = [] self.columns : List[ColRef] = []
self.cxt = cxt self.cxt = cxt
# keep track of temp vars # keep track of temp vars
@ -117,7 +117,7 @@ class Context:
self.queries = [] self.queries = []
self.module_init_loc = 0 self.module_init_loc = 0
self.special_gb = False self.special_gb = False
def __init__(self): def __init__(self):
self.tables_byname = dict() self.tables_byname = dict()
self.col_byname = dict() self.col_byname = dict()

@ -189,6 +189,9 @@ struct astring_view {
other_str++; other_str++;
} }
return !(*this_str || *other_str); return !(*this_str || *other_str);
}
bool operator >(const astring_view&r) const{
} }
operator const char* () const { operator const char* () const {
return reinterpret_cast<const char*>(str); return reinterpret_cast<const char*>(str);

Loading…
Cancel
Save