From c39ec80e1d7ace06b3f5ab86e713542fe5814e15 Mon Sep 17 00:00:00 2001 From: Bill Date: Fri, 23 Sep 2022 02:46:30 +0800 Subject: [PATCH] more on join awareness and decoupled expr --- Makefile | 2 +- reconstruct/ast.py | 15 ++++++----- reconstruct/new_expr.py | 60 ++++++++++++++++++++++++++++++++++++----- server/server.cpp | 13 --------- 4 files changed, 63 insertions(+), 27 deletions(-) diff --git a/Makefile b/Makefile index ee95e4c..4a16eb8 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ MonetDB_LIB = MonetDB_INC = Threading = CXXFLAGS = --std=c++1z -OPTFLAGS = -O3 -fno-semantic-interposition +OPTFLAGS = -O3 -fno-semantic-interposition -DNDEBUG LINKFLAGS = -flto SHAREDFLAGS = -shared FPIC = -fPIC diff --git a/reconstruct/ast.py b/reconstruct/ast.py index 8af334e..d6293c4 100644 --- a/reconstruct/ast.py +++ b/reconstruct/ast.py @@ -604,7 +604,8 @@ class join(ast_node): self.top_level = self.parent and isinstance(self.parent, projection) self.have_sep = False self.joint_cols : Dict[ColRef, Set]= dict() # columns that are joined with this column - + self.children : Set(join) = {} + self.join_conditions = [] # self.tmp_name = 'join_' + base62uuid(4) # self.datasource = TableInfo(self.tmp_name, [], self.context) def append(self, tbls, __alias = ''): @@ -628,7 +629,7 @@ class join(ast_node): if type(node) is list: for d in node: self.append(join(self, d)) - + elif type(node) is dict: alias = '' if 'value' in node: @@ -695,6 +696,10 @@ class join(ast_node): table.rec = rec return ret + # TODO: join condition awareness + def process_join_conditions(self): + pass + def consume(self, node): self.sql = '' for j in self.joins: @@ -702,11 +707,7 @@ class join(ast_node): self.sql += j[0] # using JOIN keyword else: self.sql += ', ' + j[0] # using comma - for col, jc in j.joint_cols: - if col in self.joint_cols: - self.joint_cols[col] |= jc - else: - self.joint_cols[col] = set(jc) + self.process_join_conditions() if node and self.sql and self.top_level: self.sql = ' FROM ' + self.sql diff --git a/reconstruct/new_expr.py b/reconstruct/new_expr.py index a5bc562..d12ef56 100644 --- a/reconstruct/new_expr.py +++ b/reconstruct/new_expr.py @@ -18,11 +18,34 @@ class expr_base(ast_node, metaclass = abc.ABCMeta): self.operators = {**builtin_operators, **self.udf_map, **user_module_func} self.narrow_funcs = ['sum', 'avg', 'count', 'min', 'max', 'last'] - def get_literal(self, node): - self.is_literal = True + def get_variable(self): + pass + def str_literal(self, node): + pass + def int_literal(self, node): + pass + def bool_literal(self, node): + pass + def get_literal(self, node): + if not self.get_variable(node): + self.is_literal = True + if type(node) is str: + self.str_literal(node) + elif type(node) is int: + self.int_literal(node) + elif type(node) is float: + self.float_literal(node) + elif type(node) is bool: + self.bool_literal(node) + def process_child_nodes(self): + if not hasattr(self, 'child_exprs'): + raise ValueError(f'Internal Error: process_child_nodes called without child_exprs.') + + def process_non_operator(self, key, value): pass + def produce(self, node): from reconstruct.ast import udf if node and type(node) is dict: @@ -35,24 +58,49 @@ class expr_base(ast_node, metaclass = abc.ABCMeta): if key in self.operators: self.child_exprs = [__class__(self, v) for v in val] self.process_child_nodes() + else: + self.process_non_operator(key, val) + else: + self.get_literal(node['literal']) + def consume(self, _): pass class c_expr(expr_base): - pass + def init(self, node): + super().init(node) + self.ccode = '' + + def emit(self, snippet : str): + self.ccode += snippet + + def eval(self): + return self.ccode class sql_expr(expr_base): - pass + def init(self): + super().init() + self.sql = '' + + def emit(self, snippet): + self.sql += snippet + + def eval(self): + return self.sql class udf_expr(c_expr): pass class proj_expr(c_expr, sql_expr): - pass + def init(self, node): + super(c_expr).init() + super(sql_expr).init() class orderby_expr(c_expr, sql_expr): - pass + def init(self, node): + super(c_expr).init() + super(sql_expr).init() class groupby_expr(orderby_expr): pass diff --git a/server/server.cpp b/server/server.cpp index dd63e3e..c92d25a 100644 --- a/server/server.cpp +++ b/server/server.cpp @@ -318,19 +318,6 @@ int test_main() if (cxt->alt_server == 0) cxt->alt_server = new Server(cxt); Server* server = reinterpret_cast(cxt->alt_server); - - - //TableInfo table("sibal"); - //int col0[] = { 1,2,3,4,5 }; - //float col1[] = { 5.f, 4.f, 3.f, 2.f, 1.f }; - //table.get_col<0>().initfrom(5, col0, "a"); - //table.get_col<1>().initfrom(5, col1, "b"); - //table.monetdb_append_table(server); - // - //server->exec("select * from sibal;"); - //auto aa = server->getCol(0); - //auto bb = server->getCol(1); - //printf("sibal: %p %p\n", aa, bb); const char* qs[]= { "CREATE TABLE test1(a INT, b INT, c INT, d INT);",