You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
330 lines
11 KiB
330 lines
11 KiB
from typing import Dict, List, Optional, Set
|
|
|
|
from common.types import *
|
|
from common.utils import CaseInsensitiveDict, base62uuid, enlist
|
|
|
|
|
|
class ColRef:
|
|
def __init__(self, _ty, cobj, table:'TableInfo', name, id, compound = False, _ty_args = None):
|
|
self.type : Types = AnyT
|
|
if type(_ty) is str:
|
|
self.type = Types.decode(_ty)
|
|
if _ty_args:
|
|
self.type = self.type(enlist(_ty_args))
|
|
elif type(_ty) is Types:
|
|
self.type = _ty
|
|
self.cobj = cobj
|
|
self.table = table
|
|
self.name = name
|
|
self.alias = set()
|
|
self.id = id # position in table
|
|
self.compound = compound # compound field (list as a field)
|
|
self.cxt_name = ''
|
|
# e.g. order by, group by, filter by expressions
|
|
|
|
self.__arr__ = (_ty, cobj, table, name, id)
|
|
|
|
def get_name(self):
|
|
it_alias = iter(self.alias)
|
|
alias = next(it_alias, self.name)
|
|
try:
|
|
while alias == self.name:
|
|
alias = next(it_alias)
|
|
except StopIteration:
|
|
alias = self.name
|
|
return alias
|
|
|
|
def get_full_name(self):
|
|
table_name = self.table.table_name
|
|
it_alias = iter(self.table.alias)
|
|
alias = next(it_alias, table_name)
|
|
try:
|
|
while alias == table_name:
|
|
alias = next(it_alias)
|
|
except StopIteration:
|
|
alias = table_name
|
|
return f'{alias}.{self.get_name()}'
|
|
|
|
def rename(self, name):
|
|
self.alias.discard(self.name)
|
|
self.table.columns_byname.pop(self.name, None)
|
|
self.name = name
|
|
self.table.columns_byname[name] = self
|
|
|
|
return self
|
|
|
|
def __getitem__(self, key):
|
|
if type(key) is str:
|
|
return getattr(self, key)
|
|
else:
|
|
return self.__arr__[key]
|
|
|
|
def __setitem__(self, key, value):
|
|
self.__arr__[key] = value
|
|
|
|
class TableInfo:
|
|
def __init__(self, table_name, cols, cxt:'Context'):
|
|
from engine.ast import create_trigger
|
|
# statics
|
|
self.table_name : str = table_name
|
|
self.contextname_cpp : str = ''
|
|
self.alias : Set[str] = set([table_name])
|
|
self.columns_byname : CaseInsensitiveDict[str, ColRef] = CaseInsensitiveDict() # column_name, type
|
|
self.columns : List[ColRef] = []
|
|
self.triggers : Set[create_trigger] = set()
|
|
self.cxt = cxt
|
|
self.cached = False
|
|
# keep track of temp vars
|
|
self.rec = None
|
|
self.add_cols(cols)
|
|
# runtime
|
|
self.order = [] # assumptions
|
|
|
|
cxt.tables_byname[self.table_name] = self # construct reverse map
|
|
cxt.tables.add(self)
|
|
|
|
def add_cols(self, cols, new = True):
|
|
for c in enlist(cols):
|
|
self.add_col(c, new)
|
|
|
|
def add_col(self, c, new = True):
|
|
_ty = c['type']
|
|
_ty_args = None
|
|
if type(_ty) is dict:
|
|
_ty_val = list(_ty.keys())[0]
|
|
_ty_args = _ty[_ty_val]
|
|
_ty = _ty_val
|
|
if new or type(c) is not ColRef:
|
|
col_object = ColRef(_ty, c, self, c['name'], len(self.columns), _ty_args = _ty_args)
|
|
else:
|
|
col_object = c
|
|
c.table = self
|
|
self.columns_byname[c['name']] = col_object
|
|
self.columns.append(col_object)
|
|
|
|
def add_alias(self, alias):
|
|
if alias in self.cxt.tables_byname.keys():
|
|
print("Error: table alias already exists")
|
|
return
|
|
self.cxt.tables_byname[alias] = self
|
|
self.alias.add(alias)
|
|
|
|
def rename(self, name):
|
|
if name in self.cxt.tables_byname.keys():
|
|
print(f"Error: table name {name} already exists")
|
|
return
|
|
|
|
self.cxt.tables_byname.pop(self.table_name, None)
|
|
self.alias.discard(self.table_name)
|
|
self.table_name = name
|
|
self.cxt.tables_byname[name] = self
|
|
self.alias.add(name)
|
|
|
|
def parse_col_names(self, colExpr) -> ColRef:
|
|
parsedColExpr = colExpr.split('.')
|
|
if len(parsedColExpr) <= 1:
|
|
col = self.columns_byname[colExpr]
|
|
if type(self.rec) is set:
|
|
self.rec.add(col)
|
|
return col
|
|
else:
|
|
datasource = self.cxt.tables_byname[parsedColExpr[0]]
|
|
if datasource is None:
|
|
raise ValueError(f'Table name/alias not defined{parsedColExpr[0]}')
|
|
else:
|
|
return datasource.parse_col_names(parsedColExpr[1])
|
|
|
|
def all_cols(self, ordered = False):
|
|
from ordered_set import OrderedSet
|
|
_ret_set_t = OrderedSet if ordered else set
|
|
if type(self.rec) is set:
|
|
self.rec.update(self.columns)
|
|
return _ret_set_t(self.columns)
|
|
|
|
@property
|
|
def single_table(self):
|
|
return True
|
|
|
|
class Context:
|
|
def new(self):
|
|
self.headers = set(['\"./server/monetdb_conn.h\"'])
|
|
self.ccode = ''
|
|
|
|
self.sql = ''
|
|
self.finalized = False
|
|
self.udf = None
|
|
self.module_stubs = ''
|
|
self.scans = []
|
|
self.procs = []
|
|
self.queries = []
|
|
self.module_init_loc = 0
|
|
self.special_gb = False
|
|
self.has_dll = False
|
|
self.triggers_active.clear()
|
|
|
|
def __init__(self, state = None):
|
|
from prompt import PromptState
|
|
from .ast import create_trigger
|
|
from aquery_config import compile_use_gc
|
|
self.tables_byname : Dict[str, TableInfo] = dict()
|
|
self.col_byname = dict()
|
|
self.tables : Set[TableInfo] = set()
|
|
self.cols = []
|
|
self.datasource = None
|
|
self.module_map = {}
|
|
self.udf_map = dict()
|
|
self.udf_agg_map = dict()
|
|
self.use_columnstore = False
|
|
self.print = print
|
|
self.dialect = 'MonetDB'
|
|
self.is_msvc = False
|
|
self.have_hge = False
|
|
self.Error = lambda *args: print(*args)
|
|
self.Info = lambda *_: None
|
|
self.triggers : Dict[str, create_trigger] = dict()
|
|
self.triggers_active = set()
|
|
self.stored_proceudres = dict()
|
|
self.force_compiled = False
|
|
self.use_gc = compile_use_gc
|
|
self.system_state: Optional[PromptState] = state
|
|
self.use_cached_tables = True
|
|
# self.new() called everytime new query batch is started
|
|
|
|
def get_scan_var(self):
|
|
it_var = 'i' + base62uuid(2)
|
|
scan_vars = set(s.it_var for s in self.scans)
|
|
while(it_var in scan_vars):
|
|
it_var = 'i' + base62uuid(6)
|
|
return it_var
|
|
|
|
def emit(self, sql:str):
|
|
self.sql += sql + ' '
|
|
def emitc(self, c:str):
|
|
self.ccode += c + '\n'
|
|
def add_table(self, table_name, cols):
|
|
tbl = TableInfo(table_name, cols, self)
|
|
self.tables.add(tbl)
|
|
return tbl
|
|
def remove_scan(self, scan, str_scan):
|
|
self.emitc(str_scan)
|
|
self.scans.remove(scan)
|
|
|
|
function_deco = '__AQEXPORT__(int) '
|
|
function_head = ('(Context* cxt) {\n' +
|
|
'\tusing namespace std;\n' +
|
|
'\tusing namespace types;\n' +
|
|
'\tauto server = static_cast<DataSource*>(cxt->curr_server);\n'
|
|
'\tauto timer = chrono::high_resolution_clock::now();\n'
|
|
|
|
)
|
|
|
|
udf_head = ('#pragma once\n'
|
|
'#include \"./server/libaquery.h\"\n'
|
|
'#include \"./server/aggregations.h\"\n\n'
|
|
)
|
|
|
|
def get_init_func(self):
|
|
if not self.module_map:
|
|
return ''
|
|
ret = '__AQEXPORT__(void) __builtin_init_user_module(Context* cxt){\n'
|
|
for fname in self.module_map.keys():
|
|
ret += f'{fname} = (decltype({fname}))(cxt->get_module_function("{fname}"));\n'
|
|
self.queries.insert(self.module_init_loc, 'P__builtin_init_user_module')
|
|
return ret + '}\n'
|
|
|
|
def finalize_query(self):
|
|
# clear aliases
|
|
for t in self.tables:
|
|
for a in t.alias:
|
|
if a != t.table_name:
|
|
self.tables_byname.pop(a, None)
|
|
t.alias.clear()
|
|
t.alias.add(t.table_name)
|
|
|
|
def sql_begin(self):
|
|
self.sql = ''
|
|
|
|
def sql_end(self):
|
|
# eliminate empty queries
|
|
s = self.sql.strip()
|
|
while(s and s[-1] == ';'):
|
|
s = s[:-1].strip()
|
|
if s and s.lower() != 'select':
|
|
self.queries.append('Q' + self.sql)
|
|
self.sql = ''
|
|
|
|
def postproc_begin(self, proc_name: str):
|
|
self.ccode = self.function_deco + proc_name + self.function_head
|
|
|
|
def postproc_end(self, proc_name: str):
|
|
self.procs.append(self.ccode + 'return 0;\n}')
|
|
self.ccode = ''
|
|
self.queries.append('P' + proc_name)
|
|
self.finalize_query()
|
|
|
|
def abandon_query(self):
|
|
self.sql = ''
|
|
self.ccode = ''
|
|
self.finalize_query()
|
|
|
|
def direct_output(self, limit = -1, sep = ' ', end = '\n'):
|
|
from common.utils import encode_integral
|
|
if type(limit) is not int or limit > 2**32 - 1 or limit < 0:
|
|
limit = 2**32 - 1
|
|
limit = encode_integral(limit)
|
|
self.queries.append(
|
|
'O' + limit + sep + end)
|
|
|
|
def remove_trigger(self, name : str):
|
|
from engine.ast import create_trigger
|
|
val = self.triggers.pop(name, None)
|
|
if val.type == create_trigger.Type.Callback:
|
|
val.table.triggers.remove(val)
|
|
val.remove()
|
|
|
|
def post_exec_triggers(self):
|
|
for t in self.triggers_active:
|
|
t.execute()
|
|
self.triggers_active.clear()
|
|
|
|
def abandon_postproc(self):
|
|
self.ccode = ''
|
|
self.finalize_query()
|
|
|
|
def finalize_udf(self):
|
|
if self.udf:
|
|
self.udf += '\n'.join([
|
|
u.ccode for u in self.udf_map.values()
|
|
])
|
|
self.module_stubs = '\n'.join(
|
|
[m for m in self.module_map.values()
|
|
])
|
|
return (Context.udf_head
|
|
+ self.module_stubs
|
|
+ self.get_init_func()
|
|
+ self.udf
|
|
)
|
|
else:
|
|
return None
|
|
|
|
def finalize(self):
|
|
from aquery_config import build_driver, os_platform
|
|
if not self.finalized:
|
|
headers = ''
|
|
# if build_driver == 'MSBuild':
|
|
# headers ='#include \"./server/pch.hpp\"\n'
|
|
with open('header.cxx', 'r') as header:
|
|
headers += header.read()
|
|
for h in self.headers:
|
|
if h[0] != '"':
|
|
headers += '#include <' + h + '>\n'
|
|
else:
|
|
headers += '#include ' + h + '\n'
|
|
if os_platform == 'win':
|
|
headers += '#undef max\n'
|
|
headers += '#undef min\n'
|
|
|
|
self.ccode += headers + '\n'.join(self.procs)
|
|
self.headers = set()
|
|
return self.ccode
|