diff --git a/reconstruct/ast.py b/reconstruct/ast.py index ae8e1e8..e376487 100644 --- a/reconstruct/ast.py +++ b/reconstruct/ast.py @@ -1,4 +1,3 @@ -from binascii import Error from copy import deepcopy from dataclasses import dataclass from enum import Enum, auto @@ -90,6 +89,9 @@ class projection(ast_node): elif 'select_distinct' in node: p = node['select_distinct'] self.distinct = True + else: + raise NotImplementedError('AST node is not a projection node') + if 'with' in node: with_table = node['with']['name'] with_table_name = tuple(with_table.keys())[0] @@ -946,10 +948,41 @@ class filter(ast_node): self.add(filter_expr.sql) if self.datasource is not None: self.datasource.join_conditions += filter_expr.join_conditions - + +class union_all(ast_node): + name = 'union_all' + first_order = name + sql_name = 'UNION ALL' + def produce(self, node): + queries = node[self.name] + generated_queries : List[Optional[projection]] = [None] * len(queries) + is_standard = True + for i, q in enumerate(queries): + if 'select' in q: + generated_queries[i] = projection(self, q) + is_standard &= not generated_queries[i].has_postproc + if is_standard: + self.sql = f' {self.sql_name} '.join([q.sql for q in generated_queries]) + else: + raise NotImplementedError(f"{self.sql_name} only support standard sql for now") + def consume(self, node): + super().consume(node) + self.context.direct_output() + +class except_clause(union_all): + name = 'except' + first_order = name + sql_name = 'EXCEPT' + class create_table(ast_node): name = 'create_table' first_order = name + allowed_subq = { + 'select_distinct': select_distinct, + 'select': projection, + 'union_all': union_all, + 'except': except_clause + } def init(self, node): node = node[self.name] if 'query' in node: @@ -957,9 +990,11 @@ class create_table(ast_node): raise ValueError("Table name not specified") projection_node = node['query'] projection_node['into'] = node['name'] - proj_cls = (select_distinct - if 'select_distinct' in projection_node - else projection) + proj_cls = projection + for k in create_table.allowed_subq.keys(): + if k in projection_node: + proj_cls = create_table.allowed_subq[k] + break proj_cls(None, projection_node, self.context) self.produce = lambda *_: None self.spawn = lambda *_: None @@ -1073,31 +1108,6 @@ class delete_from(ast_node): self.sql = f'DELETE FROM {tbl} ' if 'where' in node: self.sql += filter(self, node['where']).sql - -class union_all(ast_node): - name = 'union_all' - first_order = name - sql_name = 'UNION ALL' - def produce(self, node): - queries = node[self.name] - generated_queries : List[Optional[projection]] = [None] * len(queries) - is_standard = True - for i, q in enumerate(queries): - if 'select' in q: - generated_queries[i] = projection(self, q) - is_standard &= not generated_queries[i].has_postproc - if is_standard: - self.sql = f' {self.sql_name} '.join([q.sql for q in generated_queries]) - else: - raise NotImplementedError(f"{self.sql_name} only support standard sql for now") - def consume(self, node): - super().consume(node) - self.context.direct_output() - -class except_clause(union_all): - name = 'except' - first_order = name - sql_name = 'EXCEPT' class load(ast_node): name="load" diff --git a/server/server.cpp b/server/server.cpp index 3b47f57..913ee4c 100644 --- a/server/server.cpp +++ b/server/server.cpp @@ -360,7 +360,7 @@ int dll_main(int argc, char** argv, Context* cxt){ if(!server->haserror()){ uint32_t limit; memcpy(&limit, n_recvd[i] + 1, sizeof(uint32_t)); - printf("Limit: %x\n", limit); + // printf("Limit: %x\n", limit); if (limit == 0) continue; timer.reset(); diff --git a/server/table.h b/server/table.h index af26ae7..1646088 100644 --- a/server/table.h +++ b/server/table.h @@ -77,6 +77,7 @@ public: } template