From 34a9fe105c7f7a7e3c8d044eb4a2f82ad45f5e0e Mon Sep 17 00:00:00 2001 From: Bill Date: Wed, 21 Sep 2022 17:23:50 +0800 Subject: [PATCH] bug fix on select into --- engine/types.py | 17 +++-- engine/utils.py | 48 ++++++++++++- reconstruct/ast.py | 135 +++++++++++++++++++++++------------ reconstruct/expr.py | 26 ++++--- reconstruct/storage.py | 22 +++++- sdk/Makefile | 2 +- sdk/aquery.h | 4 +- server/aggregations.h | 4 +- server/table.h | 81 +++++++++++++++++---- server/table_ext_monetdb.hpp | 25 +++++-- server/types.h | 2 +- tests/best_profit.a | 52 ++++++++------ 12 files changed, 306 insertions(+), 112 deletions(-) diff --git a/engine/types.py b/engine/types.py index de80c7d..3083795 100644 --- a/engine/types.py +++ b/engine/types.py @@ -232,9 +232,9 @@ def ext (fx): # operator call behavior -def binary_op_behavior(op:OperatorBase, c_code, x, y): +def binary_op_behavior(op:OperatorBase, c_code, *xs): name = op.cname if c_code else op.sqlname - return f'({x} {name} {y})' + return f'({f" {name} ".join(xs)})' def unary_op_behavior(op:OperatorBase, c_code, x): name = op.cname if c_code else op.sqlname @@ -248,10 +248,16 @@ def count_behavior(op:OperatorBase, c_code, x, distinct = False): if not c_code: return f'{op.sqlname}({"distinct " if distinct else ""}{x})' elif distinct: - return '({x}).distinct_size()' + return f'({x}).distinct_size()' else: return '{count()}' - + +def distinct_behavior(op:OperatorBase, c_code, x): + if not c_code: + return f'{op.sqlname}({x})' + else: + return f'({x}).distinct()' + def windowed_fn_behavor(op: OperatorBase, c_code, *x): if not c_code: return f'{op.sqlname}({", ".join([f"{xx}" for xx in x])})' @@ -277,6 +283,7 @@ oplte = OperatorBase('lte', 2, logical, cname = '<=', sqlname = '<=', call = bin opneq = OperatorBase('neq', 2, logical, cname = '!=', sqlname = '!=', call = binary_op_behavior) opeq = OperatorBase('eq', 2, logical, cname = '==', sqlname = '=', call = binary_op_behavior) opnot = OperatorBase('not', 1, logical, cname = '!', sqlname = 'NOT', call = unary_op_behavior) +opdistinct = OperatorBase('distinct', 1, as_is, cname = '.distinct()', sqlname = 'distinct', call = distinct_behavior) # functional fnmax = OperatorBase('max', 1, as_is, cname = 'max', sqlname = 'MAX', call = fn_behavior) fnmin = OperatorBase('min', 1, as_is, cname = 'min', sqlname = 'MIN', call = fn_behavior) @@ -315,7 +322,7 @@ builtin_binary_arith = _op_make_dict(opadd, opdiv, opmul, opsub, opmod) builtin_binary_logical = _op_make_dict(opand, opor, opxor, opgt, oplt, opge, oplte, opneq, opeq) builtin_unary_logical = _op_make_dict(opnot) builtin_unary_arith = _op_make_dict(opneg) -builtin_unary_special = _op_make_dict(spnull) +builtin_unary_special = _op_make_dict(spnull, opdistinct) builtin_cstdlib = _op_make_dict(fnsqrt, fnlog, fnsin, fncos, fntan, fnpow) builtin_func = _op_make_dict(fnmax, fnmin, fnsum, fnavg, fnmaxs, fnmins, fndeltas, fnlast, fnsums, fnavgs, fncnt) user_module_func = {} diff --git a/engine/utils.py b/engine/utils.py index ac1ce24..1a8b403 100644 --- a/engine/utils.py +++ b/engine/utils.py @@ -1,3 +1,5 @@ +from collections import OrderedDict +from collections.abc import MutableMapping, Mapping import uuid lower_alp = 'abcdefghijklmnopqrstuvwxyz' @@ -7,6 +9,50 @@ base62alp = nums + lower_alp + upper_alp reserved_monet = ['month'] + +class CaseInsensitiveDict(MutableMapping): + def __init__(self, data=None, **kwargs): + self._store = OrderedDict() + if data is None: + data = {} + self.update(data, **kwargs) + + def __setitem__(self, key, value): + # Use the lowercased key for lookups, but store the actual + # key alongside the value. + self._store[key.lower()] = (key, value) + + def __getitem__(self, key): + return self._store[key.lower()][1] + + def __delitem__(self, key): + del self._store[key.lower()] + + def __iter__(self): + return (casedkey for casedkey, mappedvalue in self._store.values()) + + def __len__(self): + return len(self._store) + + def lower_items(self): + """Like iteritems(), but with all lowercase keys.""" + return ((lowerkey, keyval[1]) for (lowerkey, keyval) in self._store.items()) + + def __eq__(self, other): + if isinstance(other, Mapping): + other = CaseInsensitiveDict(other) + else: + return NotImplemented + # Compare insensitively + return dict(self.lower_items()) == dict(other.lower_items()) + + # Copy is required + def copy(self): + return CaseInsensitiveDict(self._store.values()) + + def __repr__(self): + return str(dict(self.items())) + def base62uuid(crop=8): _id = uuid.uuid4().int ret = '' @@ -60,7 +106,7 @@ def defval(val, default): return default if val is None else val # escape must be readonly -from typing import Set +from typing import Mapping, Set def remove_last(pattern : str, string : str, escape : Set[str] = set()) -> str: idx = string.rfind(pattern) if idx == -1: diff --git a/reconstruct/ast.py b/reconstruct/ast.py index a66da39..3a43c81 100644 --- a/reconstruct/ast.py +++ b/reconstruct/ast.py @@ -61,9 +61,16 @@ class projection(ast_node): pass def produce(self, node): - p = node['select'] - self.projections = p if type(p) is list else [p] self.add('SELECT') + self.has_postproc = False + if 'select' in node: + p = node['select'] + self.distinct = False + elif 'select_distinct' in node: + p = node['select_distinct'] + self.distinct = True + + self.projections = p if type(p) is list else [p] if self.parent is None: self.context.sql_begin() self.postproc_fname = 'dll_' + base62uuid(6) @@ -75,8 +82,8 @@ class projection(ast_node): if 'from' in node: from_clause = node['from']['table_source'] self.datasource = join(self, from_clause) - if 'assumptions' in from_clause: - self.assumptions = enlist(from_clause['assumptions']) + if 'assumptions' in node['from']: + self.assumptions = enlist(node['from']['assumptions']) if self.datasource is not None: self.datasource_changed = True @@ -109,61 +116,82 @@ class projection(ast_node): proj_map : Dict[int, List[Union[Types, int, str, expr]]]= dict() self.var_table = dict() # self.sp_refs = set() - for i, proj in enumerate(self.projections): + i = 0 + for proj in self.projections: compound = False self.datasource.rec = set() name = '' this_type = AnyT - if type(proj) is dict: + if type(proj) is dict or proj == '*': if 'value' in proj: e = proj['value'] - proj_expr = expr(self, e) - this_type = proj_expr.type - name = proj_expr.sql - compound = True # compound column - proj_expr.cols_mentioned = self.datasource.rec - alias = '' - if 'name' in proj: # renaming column by AS keyword - alias = proj['name'] - - if not proj_expr.is_special: + elif proj == '*': + e = '*' + else: + print('unknown projection', proj) + proj_expr = expr(self, e) + sql_expr = expr(self, e, c_code=False) + this_type = proj_expr.type + name = proj_expr.sql + compound = True # compound column + proj_expr.cols_mentioned = self.datasource.rec + alias = '' + if 'name' in proj: # renaming column by AS keyword + alias = proj['name'] + + if not proj_expr.is_special: + if proj_expr.node == '*': + name = [c.get_full_name() for c in self.datasource.rec] + else: y = lambda x:x - name = eval('f\'' + name + '\'') + count = lambda : 'count(*)' + name = enlist(sql_expr.eval(False, y, count=count)) + for n in name: offset = len(col_exprs) - if name not in self.var_table: - self.var_table[name] = offset + if n not in self.var_table: + self.var_table[n] = offset if proj_expr.is_ColExpr and type(proj_expr.raw_col) is ColRef: - for n in (proj_expr.raw_col.table.alias): - self.var_table[f'{n}.'+name] = offset + for _alias in (proj_expr.raw_col.table.alias): + self.var_table[f'{_alias}.'+n] = offset proj_map[i] = [this_type, offset, proj_expr] - col_expr = name + ' AS ' + alias if alias else name + col_expr = n + ' AS ' + alias if alias else n if alias: self.var_table[alias] = offset col_exprs.append((col_expr, proj_expr.type)) - else: - self.context.headers.add('"./server/aggregations.h"') - if self.datasource.rec is not None: - self.col_ext = self.col_ext.union(self.datasource.rec) - proj_map[i] = [this_type, proj_expr.sql, proj_expr] - - disp_name = get_legal_name(alias if alias else name) - + i += 1 + else: + self.context.headers.add('"./server/aggregations.h"') + self.has_postproc = True + if self.datasource.rec is not None: + self.col_ext = self.col_ext.union(self.datasource.rec) + proj_map[i] = [this_type, proj_expr.sql, proj_expr] + i += 1 + name = enlist(name) + disp_name = [get_legal_name(alias if alias else n) for n in name] + elif type(proj) is str: col = self.datasource.get_col(proj) this_type = col.type + disp_name = proj + print('Unknown behavior:', proj, 'is str') # name = col.name self.datasource.rec = None # TODO: Type deduction in Python - cols.append(ColRef(this_type, self.out_table, None, disp_name, i, compound=compound)) + for n in disp_name: + cols.append(ColRef(this_type, self.out_table, None, n, len(cols), compound=compound)) self.out_table.add_cols(cols, new = False) if 'groupby' in node: self.group_node = groupby(self, node['groupby']) + if self.group_node.use_sp_gb: + self.has_postproc = True else: self.group_node = None - + + if not self.has_postproc and self.distinct: + self.add('DISTINCT') self.col_ext = [c for c in self.col_ext if c.name not in self.var_table] # remove duplicates in self.var_table col_ext_names = [c.name for c in self.col_ext] self.add(', '.join([c[0] for c in col_exprs] + col_ext_names)) @@ -249,7 +277,7 @@ class projection(ast_node): self.group_node and (self.group_node.use_sp_gb and val[2].cols_mentioned.intersection( - self.datasource.all_cols.difference(self.group_node.refs)) + self.datasource.all_cols().difference(self.group_node.refs)) ) and val[2].is_compound # compound val not in key # or # (not self.group_node and val[2].is_compound) @@ -282,25 +310,37 @@ class projection(ast_node): # for funcs evaluate f_i(x, ...) self.context.emitc(f'{self.out_table.contextname_cpp}->get_col<{key}>() = {val[1]};') # print out col_is - self.context.emitc(f'print(*{self.out_table.contextname_cpp});') + if 'into' not in node: + self.context.emitc(f'print(*{self.out_table.contextname_cpp});') if self.outfile: self.outfile.finalize() if 'into' in node: self.context.emitc(select_into(self, node['into']).ccode) - + if not self.distinct: + self.finalize() + + def finalize(self): self.context.emitc(f'puts("done.");') if self.parent is None: self.context.sql_end() self.context.postproc_end(self.postproc_fname) - - +class select_distinct(projection): + first_order = 'select_distinct' + def consume(self, node): + super().consume(node) + if self.has_postproc: + self.context.emitc( + f'{self.out_table.table_name}->distinct();' + ) + self.finalize() + class select_into(ast_node): def init(self, node): - if type(self.parent) is projection: + if isinstance(self.parent, projection): if self.context.has_dll: # has postproc put back to monetdb self.produce = self.produce_cpp @@ -308,8 +348,8 @@ class select_into(ast_node): self.produce = self.produce_sql else: raise ValueError('parent must be projection') + def produce_cpp(self, node): - assert(type(self.parent) is projection) if not hasattr(self.parent, 'out_table'): raise Exception('No out_table found.') else: @@ -508,7 +548,7 @@ class groupby(ast_node): return False def produce(self, node): - if type(self.parent) is not projection: + if not isinstance(self.parent, projection): raise ValueError('groupby can only be used in projection') node = enlist(node) @@ -554,7 +594,7 @@ class join(ast_node): self.tables : List[TableInfo] = [] self.tables_dir = dict() self.rec = None - self.top_level = self.parent and type(self.parent) is projection + self.top_level = self.parent and isinstance(self.parent, projection) self.have_sep = False # self.tmp_name = 'join_' + base62uuid(4) # self.datasource = TableInfo(self.tmp_name, [], self.context) @@ -636,9 +676,16 @@ class join(ast_node): datasource.rec = None return ret - @property +# @property def all_cols(self): - return set([c for t in self.tables for c in t.columns]) + ret = set() + for table in self.tables: + rec = table.rec + table.rec = self.rec + ret.update(table.all_cols()) + table.rec = rec + return ret + def consume(self, node): self.sql = '' for j in self.joins: @@ -787,7 +834,7 @@ class outfile(ast_node): self.sql = sql if sql else '' def init(self, _): - assert(type(self.parent) is projection) + assert(isinstance(self.parent, projection)) if not self.parent.use_postproc: if self.context.dialect == 'MonetDB': self.produce = self.produce_monetdb diff --git a/reconstruct/expr.py b/reconstruct/expr.py index ce5ea4f..504ab8e 100644 --- a/reconstruct/expr.py +++ b/reconstruct/expr.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Set from reconstruct.ast import ast_node from reconstruct.storage import ColRef, Context from engine.types import * @@ -199,11 +199,8 @@ class expr(ast_node): self.udf_decltypecall = ex_vname.sql else: print(f'Undefined expr: {key}{val}') - if 'distinct' in val and key != count: - if self.c_code: - self.sql = 'distinct ' + self.sql - elif self.is_compound: - self.sql = '(' + self.sql + ').distinct()' + + if type(node) is str: if self.is_udfexpr: curr_udf : udf = self.root.udf @@ -235,8 +232,13 @@ class expr(ast_node): # get the column from the datasource in SQL context else: if self.datasource is not None: - self.raw_col = self.datasource.parse_col_names(node) - self.raw_col = self.raw_col if type(self.raw_col) is ColRef else None + if (node == '*' and + not (type(self.parent) is expr + and 'count' in self.parent.node)): + self.datasource.all_cols() + else: + self.raw_col = self.datasource.parse_col_names(node) + self.raw_col = self.raw_col if type(self.raw_col) is ColRef else None if self.raw_col is not None: self.is_ColExpr = True table_name = '' @@ -259,10 +261,16 @@ class expr(ast_node): self.is_compound = True self.opname = self.raw_col else: - self.sql = '\'' + node + '\'' + self.sql = '\'' + node + '\'' if node != '*' else '*' self.type = StrT self.opname = self.sql if self.c_code and self.datasource is not None: + if (type(self.parent) is expr and + 'distinct' in self.parent.node and + not self.is_special): + # this node is executed by monetdb + # gb condition, not special + self.sql = f'distinct({self.sql})' self.sql = f'{{y(\"{self.sql}\")}}' elif type(node) is bool: self.type = BoolT diff --git a/reconstruct/storage.py b/reconstruct/storage.py index 87940ab..d7e568d 100644 --- a/reconstruct/storage.py +++ b/reconstruct/storage.py @@ -1,5 +1,5 @@ from engine.types import * -from engine.utils import base62uuid, enlist +from engine.utils import CaseInsensitiveDict, base62uuid, enlist from typing import List, Dict, Set class ColRef: @@ -20,6 +20,18 @@ class ColRef: # e.g. order by, group by, filter by expressions self.__arr__ = (_ty, cobj, table, name, id) + + def get_full_name(self): + table_name = self.table.table_name + it_alias = iter(self.table.alias) + alias = next(it_alias, table_name) + try: + while alias == table_name: + alias = next(it_alias) + except StopIteration: + alias = table_name + return f'{alias}.{self.name}' + def __getitem__(self, key): if type(key) is str: return getattr(self, key) @@ -35,7 +47,7 @@ class TableInfo: self.table_name : str = table_name self.contextname_cpp : str = '' self.alias : Set[str] = set([table_name]) - self.columns_byname : Dict[str, ColRef] = dict() # column_name, type + self.columns_byname : Dict[str, ColRef] = CaseInsensitiveDict() # column_name, type self.columns : List[ColRef] = [] self.cxt = cxt # keep track of temp vars @@ -85,7 +97,11 @@ class TableInfo: raise ValueError(f'Table name/alias not defined{parsedColExpr[0]}') else: return datasource.parse_col_names(parsedColExpr[1]) - + + def all_cols(self): + if type(self.rec) is set: + self.rec.update(self.columns) + return set(self.columns) class Context: def new(self): diff --git a/sdk/Makefile b/sdk/Makefile index 668d667..166386b 100644 --- a/sdk/Makefile +++ b/sdk/Makefile @@ -1,4 +1,4 @@ example: - g++-9 -shared -fPIC example.cpp aquery_mem.cpp -fno-semantic-interposition -Ofast -march=native -flto --std=c++1z -o ../test.so + $(CXX) -shared -fPIC example.cpp aquery_mem.cpp -fno-semantic-interposition -Ofast -march=native -flto --std=c++1z -o ../test.so all: example diff --git a/sdk/aquery.h b/sdk/aquery.h index 87c860b..3ef5bb6 100644 --- a/sdk/aquery.h +++ b/sdk/aquery.h @@ -75,7 +75,7 @@ extern void register_memory(void* ptr, deallocator_t deallocator); __AQEXPORT__(void) init_session(Context* cxt); #define __AQ_NO_SESSION__ __AQEXPORT__(void) init_session(Context*) {} -void* memcpy(void*, void*, unsigned long long); +void* memcpy(void*, const void*, unsigned long long); struct ColRef_storage { void* container; unsigned int capacity, size; @@ -86,4 +86,4 @@ struct ColRef_storage { memcpy(this, &vt, sizeof(ColRef_storage)); } }; -#endif \ No newline at end of file +#endif diff --git a/server/aggregations.h b/server/aggregations.h index 2add603..514e4a6 100644 --- a/server/aggregations.h +++ b/server/aggregations.h @@ -117,9 +117,9 @@ decayed_t> sums(const VT& arr) { return ret; } template class VT> -decayed_t> avgs(const VT& arr) { +decayed_t>> avgs(const VT& arr) { const uint32_t& len = arr.size; - typedef types::GetFPType FPType; + typedef types::GetFPType> FPType; decayed_t ret(len); uint32_t i = 0; types::GetLongType s; diff --git a/server/table.h b/server/table.h index eab5f9d..1e07968 100644 --- a/server/table.h +++ b/server/table.h @@ -26,6 +26,21 @@ namespace types { struct Coercion; } #endif +template