From 522e9e267b50da4535889c34d64eb7861b9b3919 Mon Sep 17 00:00:00 2001 From: Bill Date: Thu, 22 Sep 2022 23:50:08 +0800 Subject: [PATCH] WIP: new decoupled expr implementation --- reconstruct/ast.py | 23 +++++++++++--- reconstruct/new_expr.py | 66 +++++++++++++++++++++++++++++++++++++++++ reconstruct/storage.py | 4 +-- server/types.h | 3 ++ 4 files changed, 90 insertions(+), 6 deletions(-) create mode 100644 reconstruct/new_expr.py diff --git a/reconstruct/ast.py b/reconstruct/ast.py index 3a43c81..8af334e 100644 --- a/reconstruct/ast.py +++ b/reconstruct/ast.py @@ -589,13 +589,22 @@ class groupby(ast_node): class join(ast_node): name = 'join' + + def get_joint_cols(self, cols : List[ColRef]): + joint_cols = set() + for col in cols: + joint_cols |= self.joint_cols[col] + return joint_cols + def init(self, _): - self.joins:list = [] + self.joins : List[join] = [] self.tables : List[TableInfo] = [] self.tables_dir = dict() self.rec = None self.top_level = self.parent and isinstance(self.parent, projection) self.have_sep = False + self.joint_cols : Dict[ColRef, Set]= dict() # columns that are joined with this column + # self.tmp_name = 'join_' + base62uuid(4) # self.datasource = TableInfo(self.tmp_name, [], self.context) def append(self, tbls, __alias = ''): @@ -690,13 +699,19 @@ class join(ast_node): self.sql = '' for j in self.joins: if not self.sql or j[1]: - self.sql += j[0] + self.sql += j[0] # using JOIN keyword else: - self.sql += ', ' + j[0] + self.sql += ', ' + j[0] # using comma + for col, jc in j.joint_cols: + if col in self.joint_cols: + self.joint_cols[col] |= jc + else: + self.joint_cols[col] = set(jc) + if node and self.sql and self.top_level: self.sql = ' FROM ' + self.sql return super().consume(node) - + def __str__(self): return self.sql def __repr__(self): diff --git a/reconstruct/new_expr.py b/reconstruct/new_expr.py new file mode 100644 index 0000000..a5bc562 --- /dev/null +++ b/reconstruct/new_expr.py @@ -0,0 +1,66 @@ +import abc +from reconstruct.ast import ast_node +from typing import Optional +from reconstruct.storage import Context, ColRef +from engine.utils import enlist +from engine.types import builtin_func, user_module_func, builtin_operators + + +class expr_base(ast_node, metaclass = abc.ABCMeta): + def __init__(self, parent: Optional["ast_node"], node, context: Optional[Context] = None): + self.node = node + super().__init__(parent, node, context) + + def init(self, node): + self.is_literal = False + self.udf_map = self.context.udf_map + self.func_maps = {**builtin_func, **self.udf_map, **user_module_func} + self.operators = {**builtin_operators, **self.udf_map, **user_module_func} + self.narrow_funcs = ['sum', 'avg', 'count', 'min', 'max', 'last'] + + def get_literal(self, node): + self.is_literal = True + + def process_child_nodes(self): + pass + def produce(self, node): + from reconstruct.ast import udf + if node and type(node) is dict: + if 'litral' in node: + self.get_literal(node['literal']) + else: + if len(node) > 1: + raise ValueError(f'Parse Error: more than 1 entry in {node}.') + key, val = next(iter(node.items())) + if key in self.operators: + self.child_exprs = [__class__(self, v) for v in val] + self.process_child_nodes() + def consume(self, _): + pass + + +class c_expr(expr_base): + pass + +class sql_expr(expr_base): + pass + +class udf_expr(c_expr): + pass + +class proj_expr(c_expr, sql_expr): + pass + +class orderby_expr(c_expr, sql_expr): + pass + +class groupby_expr(orderby_expr): + pass + +class from_expr(sql_expr): + pass + +class where_expr(sql_expr): + pass + + diff --git a/reconstruct/storage.py b/reconstruct/storage.py index d7e568d..c43131c 100644 --- a/reconstruct/storage.py +++ b/reconstruct/storage.py @@ -47,7 +47,7 @@ class TableInfo: self.table_name : str = table_name self.contextname_cpp : str = '' self.alias : Set[str] = set([table_name]) - self.columns_byname : Dict[str, ColRef] = CaseInsensitiveDict() # column_name, type + self.columns_byname : CaseInsensitiveDict[str, ColRef] = CaseInsensitiveDict() # column_name, type self.columns : List[ColRef] = [] self.cxt = cxt # keep track of temp vars @@ -117,7 +117,7 @@ class Context: self.queries = [] self.module_init_loc = 0 self.special_gb = False - + def __init__(self): self.tables_byname = dict() self.col_byname = dict() diff --git a/server/types.h b/server/types.h index 43d2d1a..d4a5656 100644 --- a/server/types.h +++ b/server/types.h @@ -189,6 +189,9 @@ struct astring_view { other_str++; } return !(*this_str || *other_str); + } + bool operator >(const astring_view&r) const{ + } operator const char* () const { return reinterpret_cast(str);