from copy import deepcopy
from dataclasses import dataclass
from enum import Enum , auto
from typing import Dict , List , Optional , Set , Tuple , Union
from common . types import *
from common . utils import ( base62alp , base62uuid , enlist ,
get_innermost , get_legal_name ,
Backend_Type , backend_strings )
from engine . storage import ColRef , Context , TableInfo
class ast_node :
header = [ ]
types = dict ( )
first_order = False
def __init__ ( self , parent : Optional [ " ast_node " ] , node , context : Optional [ Context ] = None ) :
self . context = parent . context if context is None else context
self . parent = parent
self . sql = ' '
self . ccode = ' '
if hasattr ( parent , ' datasource ' ) :
self . datasource = parent . datasource
else :
self . datasource = None
self . init ( node )
self . produce ( node )
self . spawn ( node )
self . consume ( node )
def emit ( self , code ) :
self . context . emit ( code )
def add ( self , code , sp = ' ' ) :
self . sql + = code + sp
def addc ( self , code ) :
self . ccode + = code + ' \n '
name = ' null '
def init ( self , _ ) :
if self . parent is None :
self . context . sql_begin ( )
self . add ( self . __class__ . name . upper ( ) )
def produce ( self , _ ) :
pass
def spawn ( self , _ ) :
pass
def consume ( self , _ ) :
if self . parent is None :
self . emit ( self . sql + ' ; \n ' )
self . context . sql_end ( )
from engine . expr import expr , fastscan
class SubqType ( Enum ) :
WITH = auto ( )
FROM = auto ( )
PROJECTION = auto ( )
FILTER = auto ( )
GROUPBY = auto ( )
ORDERBY = auto ( )
NONE = auto ( )
class projection ( ast_node ) :
name = ' projection '
first_order = ' select '
def __init__ ( self ,
parent : Optional [ " ast_node " ] ,
node ,
context : Optional [ Context ] = None ,
force_use_spgb : bool = False ,
subq_type : SubqType = SubqType . NONE
) :
self . force_use_spgb = (
force_use_spgb |
( context . system_state . cfg . backend_type ==
Backend_Type . BACKEND_AQuery . value )
)
self . subq_type = subq_type
super ( ) . __init__ ( parent , node , context )
def init ( self , _ ) :
# skip default init
pass
def produce ( self , node ) :
self . add ( ' SELECT ' )
self . has_preproc = not self . context . use_cached_tables
self . has_postproc = ' into ' in node
if ' select ' in node :
p = node [ ' select ' ]
self . distinct = False
elif ' select_distinct ' in node :
p = node [ ' select_distinct ' ]
self . distinct = True
else :
raise NotImplementedError ( ' AST node is not a projection node ' )
if ' with ' in node :
with_table = node [ ' with ' ] [ ' name ' ]
with_table_name = tuple ( with_table . keys ( ) ) [ 0 ]
with_table_cols = tuple ( with_table . values ( ) ) [ 0 ]
self . with_clause = projection ( self , node [ ' with ' ] [ ' value ' ] , subq_type = SubqType . WITH )
self . with_clause . out_table . add_alias ( with_table_name )
for new_name , col in zip ( with_table_cols , self . with_clause . out_table . columns ) :
col . rename ( new_name )
self . with_clause . out_table . contextname_cpp
# in monetdb, in cxt
else :
self . with_clause = None
self . limit = None
if ' limit ' in node :
self . limit = node [ ' limit ' ]
self . projections = p if type ( p ) is list else [ p ]
if self . parent is None :
self . context . sql_begin ( )
self . postproc_fname = ' dll_ ' + base62uuid ( 6 )
self . context . postproc_begin ( self . postproc_fname )
def spawn ( self , node ) :
self . datasource = join ( self , [ ] , self . context ) # datasource is Join instead of TableInfo
self . assumptions = [ ]
if ' from ' in node :
from_clause = node [ ' from ' ] [ ' table_source ' ]
self . datasource = join ( self , from_clause )
if ' assumptions ' in node [ ' from ' ] :
self . assumptions = enlist ( node [ ' from ' ] [ ' assumptions ' ] )
if self . datasource is not None :
self . has_preproc = (
self . has_preproc and
not self . datasource . no_join and
self . datasource . tables [ 0 ] . cached
)
self . datasource_changed = True
self . prev_datasource = self . context . datasource
self . context . datasource = self . datasource
else :
self . has_preproc = False
if ' where ' in node :
self . where = filter ( self , node [ ' where ' ] )
else :
self . where = None
if type ( self . datasource ) is join :
self . datasource . process_join_conditions ( )
self . context . special_gb = self . force_use_spgb
if ' groupby ' in node : # if groupby clause contains special stuff
self . context . special_gb | = groupby . check_special ( self , node [ ' groupby ' ] )
def consume ( self , node ) :
# deal with projections
out_table_varname = ' out_ ' + base62uuid ( 6 )
if ' into ' in node :
out_table_name = node [ ' into ' ]
else :
out_table_name = out_table_varname
self . out_table : TableInfo = TableInfo ( out_table_name , [ ] , self . context )
self . out_table . contextname_cpp = out_table_varname
cols = [ ]
self . col_ext : Set [ ColRef ] = set ( )
col_exprs : List [ Tuple [ str , Types ] ] = [ ]
proj_map : Dict [ int , List [ Union [ Types , int , str , expr ] ] ] = dict ( )
self . var_table = dict ( )
# self.sp_refs = set()
i = 0
for proj in self . projections :
compound = False
self . datasource . rec = set ( )
name = ' '
this_type = AnyT
if type ( proj ) is dict or proj == ' * ' :
if ' value ' in proj :
e = proj [ ' value ' ]
elif proj == ' * ' :
e = ' * '
else :
print ( ' unknown projection ' , proj )
proj_expr = expr ( self , e )
sql_expr = expr ( self , e , c_code = False )
this_type = proj_expr . type
name = proj_expr . sql
compound = [ proj_expr . is_compound > 1 ] # compound column
proj_expr . cols_mentioned = self . datasource . rec
alias = ' '
if ' name ' in proj : # renaming column by AS keyword
alias = proj [ ' name ' ]
if not proj_expr . is_special :
if str ( proj_expr . node ) . strip ( ) . endswith ( ' * ' ) :
_datasource = self . datasource
if ' . ' in proj_expr . node :
tbl = proj_expr . node . split ( ' . ' ) [ 0 ]
if tbl in self . datasource . tables_dir :
_datasource = self . datasource . tables_dir [ tbl ]
_datasource = _datasource . all_cols ( ordered = True , stripped = True )
name = [ ( c . get_name ( )
if self . datasource . single_table
else c . get_full_name ( )
) for c in _datasource ]
this_type = [ c . type for c in _datasource ]
compound = [ c . compound for c in _datasource ]
proj_expr = [ expr ( self , c . name ) for c in _datasource ]
for pe in proj_expr :
if pe . is_ColExpr :
pe . cols_mentioned = { pe . raw_col }
else :
pe . cols_mentioned = set ( )
else :
y = lambda x : x
count = lambda : ' count(*) '
name = enlist ( sql_expr . eval ( False , y , count = count ) )
this_type = enlist ( this_type )
proj_expr = enlist ( proj_expr )
for t , n , pexpr , cp in zip ( this_type , name , proj_expr , compound ) :
if cp :
# Force postporcessing for compound columns
t = VectorT ( t )
self . has_postproc = True
offset = len ( col_exprs )
if n not in self . var_table :
self . var_table [ n ] = offset
if pexpr . is_ColExpr and type ( pexpr . raw_col ) is ColRef :
for _alias in ( pexpr . raw_col . table . alias ) :
self . var_table [ f ' { _alias } . ' + n ] = offset
proj_map [ i ] = [ t , offset , pexpr ]
col_expr = n + ' AS ' + alias if alias else n
if alias :
self . var_table [ alias ] = offset
col_exprs . append ( ( col_expr , t ) )
i + = 1
else :
self . context . headers . add ( ' " ./server/aggregations.h " ' )
self . has_postproc = True
if self . datasource . rec is not None :
self . col_ext = self . col_ext . union ( self . datasource . rec )
proj_map [ i ] = [ this_type , proj_expr . sql , proj_expr ]
i + = 1
name = enlist ( name )
disp_name = [ get_legal_name ( alias if alias else n ) for n in name ]
this_type = enlist ( this_type )
elif type ( proj ) is str :
col = self . datasource . get_cols ( proj )
this_type = col . type
disp_name = proj
print ( ' Unknown behavior: ' , proj , ' is str ' )
# name = col.name
self . datasource . rec = None
# TODO: Type deduction in Python
for t , n , c in zip ( this_type , disp_name , compound ) :
cols . append ( ColRef ( t , None , self . out_table , n , len ( cols ) , compound = c ) )
self . out_table . add_cols ( cols , new = False )
self . proj_map = proj_map
if ' groupby ' in node :
self . group_node = groupby ( self , node [ ' groupby ' ] )
if self . group_node . terminate :
self . context . abandon_query ( )
projection ( self . parent , node , self . context , True , subq_type = self . subq_type )
return
if self . group_node . use_sp_gb :
self . has_postproc = True
else :
self . group_node = None
if not self . has_postproc and self . distinct :
self . add ( ' DISTINCT ' )
self . col_ext = [ c for c in self . col_ext if c . name not in self . var_table ] # remove duplicates in self.var_table
col_ext_names = [ c . name for c in self . col_ext ]
self . add ( ' , ' . join ( [ c [ 0 ] for c in col_exprs ] + col_ext_names ) )
_base_offset = len ( col_exprs )
for i , col in enumerate ( self . col_ext ) :
if col . name not in self . var_table :
offset = i + _base_offset
self . var_table [ col . name ] = offset
for n in ( col . table . alias ) :
self . var_table [ f ' { n } . ' + col . name ] = offset
# monetdb doesn't support select into table
# if 'into' in node:
# self.into_stub = f'{{INTOSTUB{base62uuid(20)}}}'
# self.add(self.into_stub, '')
def finalize ( astnode : ast_node ) :
if ( astnode is not None ) :
self . add ( astnode . sql )
finalize ( self . datasource )
finalize ( self . where )
if self . group_node and not self . group_node . use_sp_gb :
self . add ( self . group_node . sql )
if self . col_ext or self . group_node and self . group_node . use_sp_gb :
self . has_postproc = True
if self . group_node and self . group_node . use_sp_gb :
self . group_node . dedicated_glist
. . .
o = self . assumptions
if ' orderby ' in node :
o . extend ( enlist ( node [ ' orderby ' ] ) )
if o :
self . add ( orderby ( self , o ) . sql )
if ' outfile ' in node :
self . outfile = outfile ( self , node [ ' outfile ' ] , sql = self . sql )
if not self . has_postproc :
self . sql = self . outfile . sql
else :
self . outfile = None
# reset special_gb in case of subquery.
self . context . special_gb = False
if self . parent is None :
self . emit ( self . sql + ' ; \n ' )
else :
# TODO: subquery, name create tmp-table from subquery w/ alias as name
pass
# cpp module codegen
# extract typed-columns from result-set
vid2cname = [ 0 ] * len ( self . var_table )
self . pyname2cname = dict ( )
typenames = [ c [ 1 ] for c in col_exprs ] + [ c . type for c in self . col_ext ]
length_name = ' len_ ' + base62uuid ( 6 )
self . context . emitc ( f ' auto { length_name } = server->cnt; ' )
self . input_table_name = ' placeholder '
self . input_table_type = ' auto '
if not self . has_preproc :
obj_input_table : TableInfo = self . datasource . tables [ 0 ]
self . input_table_name = (
' tbl_ ' +
obj_input_table . table_name +
' _ ' + base62uuid ( 4 )
)
col_types = [ c . type . cname for c in obj_input_table . columns ]
self . input_table_type = f ' TableInfo< { " , " . join ( col_types ) } > '
self . context . emitc ( f ' { self . input_table_type } * { self . input_table_name } '
f ' = static_cast< { self . input_table_type } *>( '
f ' cxt->tables[ " { obj_input_table . table_name } " ]); ' )
# if not self.has_preproc:
# col_idxs = {col.name : i for i, col in enumerate(self.datasource.all_cols())}
for v , idx in self . var_table . items ( ) :
vname = get_legal_name ( v ) + ' _ ' + base62uuid ( 3 )
self . pyname2cname [ v ] = vname
if not self . has_preproc :
# TODO: verify/ensure that idx is the same as cid in the table.
self . context . emitc ( f ' decltype(auto) { vname } = '
f ' { self . input_table_name } ->get_col< { self . datasource . parse_col_names ( v ) . id } >(); ' )
else :
self . context . emitc ( f ' auto { vname } = ColRef< { typenames [ idx ] . cname } >( '
f ' { length_name } , '
f ' server->getCol( { idx } , { typenames [ idx ] . ctype_name } ) '
f ' ); ' )
vid2cname [ idx ] = vname
# Create table into context
out_typenames = [ None ] * len ( proj_map )
def get_proj_name ( proj_name ) :
if ' * ' in proj_name :
lst_names = self . datasource . get_cols ( proj_name )
return ' , ' . join ( [ self . pyname2cname [ n . name ] for n in lst_names ] )
else :
return self . pyname2cname [ proj_name ]
gb_tovec = [ False ] * len ( proj_map )
for i , ( key , val ) in enumerate ( proj_map . items ( ) ) :
if type ( val [ 1 ] ) is str :
x = True
y = get_proj_name
count = lambda : ' 0 '
if vid2cname :
count = lambda : f ' { vid2cname [ 0 ] } .size '
val [ 1 ] = val [ 2 ] . eval ( x , y , count = count )
if callable ( val [ 1 ] ) :
val [ 1 ] = val [ 1 ] ( False )
if val [ 0 ] == LazyT :
decltypestring = val [ 2 ] . eval ( y = y , gettype = True , c_code = True ) ( True )
decltypestring = f ' value_type<decays<decltype( { decltypestring } )>> '
out_typenames [ key ] = decltypestring
else :
out_typenames [ key ] = val [ 0 ] . cname
elemental_ret_udf = (
type ( val [ 2 ] . udf_called ) is udf and # should bulkret also be colref?
val [ 2 ] . udf_called . return_pattern == udf . ReturnPattern . elemental_return
)
folding_vector_groups = (
self . group_node and
(
self . group_node . use_sp_gb and
val [ 2 ] . cols_mentioned . intersection (
self . datasource . all_cols ( ) . difference (
self . datasource . get_joint_cols ( self . group_node . refs )
)
)
) and
val [ 2 ] . is_compound # compound val not in key
)
if ( elemental_ret_udf or folding_vector_groups ) :
out_typenames [ key ] = f ' vector_type< { out_typenames [ key ] } > '
self . out_table . columns [ key ] . compound = True
if self . group_node is not None and self . group_node . use_sp_gb :
gb_tovec [ i ] = True
outtable_col_nameslist = ' , ' . join ( [ f ' " { c . name } " ' for c in self . out_table . columns ] )
self . outtable_col_names = ' names_ ' + base62uuid ( 4 )
self . context . emitc ( f ' const char* { self . outtable_col_names } [] = {{ { outtable_col_nameslist } }} ; ' )
# out_typenames = [v[0].cname for v in proj_map.values()]
self . context . emitc ( f ' auto { self . out_table . contextname_cpp } = new TableInfo< { " , " . join ( out_typenames ) } >( " { self . out_table . table_name } " , { self . outtable_col_names } ); ' )
# TODO: Inject custom group by code here and flag them in proj_map
# Type of UDFs? Complex UDFs, ones with static vars?
if self . group_node is not None and self . group_node . use_sp_gb :
gb_vartable : Dict [ str , Union [ str , int ] ] = deepcopy ( self . pyname2cname )
gb_cexprs : List [ str ] = [ ]
gb_colnames : List [ str ] = [ ]
gb_types : List [ Types ] = [ ]
for key , val in proj_map . items ( ) :
col_name = ' col_ ' + base62uuid ( 6 )
self . context . emitc ( f ' decltype(auto) { col_name } = { self . out_table . contextname_cpp } ->get_col< { key } >(); ' )
gb_cexprs . append ( ( col_name , val [ 2 ] ) )
gb_colnames . append ( col_name )
gb_types . append ( val [ 0 ] )
self . group_node . finalize ( gb_cexprs , gb_vartable , gb_colnames , gb_types , gb_tovec )
else :
for i , ( key , val ) in enumerate ( proj_map . items ( ) ) :
if type ( val [ 1 ] ) is int :
self . context . emitc (
f ' { self . out_table . contextname_cpp } ->get_col< { key } >().initfrom( { vid2cname [ val [ 1 ] ] } , " { cols [ i ] . name } " ); '
)
else :
# for funcs evaluate f_i(x, ...)
self . context . emitc ( f ' { self . out_table . contextname_cpp } ->get_col< { key } >().initfrom( { val [ 1 ] } , " { cols [ i ] . name } " ); ' )
# print out col_is
if ' into ' not in node and self . subq_type == SubqType . NONE :
if self . limit is None :
self . context . emitc ( f ' print(* { self . out_table . contextname_cpp } ); ' )
else :
self . context . emitc ( f ' { self . out_table . contextname_cpp } ->printall( " " , " \\ n " , nullptr, nullptr, { self . limit } ); ' )
if self . outfile and self . has_postproc :
self . outfile . finalize ( )
if not self . distinct :
self . finalize ( node )
def deal_with_into ( self , node ) :
if ' into ' in node :
self . context . emitc ( select_into ( self , node [ ' into ' ] ) . ccode )
def finalize ( self , node ) :
self . deal_with_into ( node )
self . context . emitc ( f ' puts( " done. " ); ' )
## self.context.emitc('printf("done: %lld\\n", (chrono::high_resolution_clock::now() - timer).count());timer = chrono::high_resolution_clock::now();')
if not self . has_preproc :
self . context . sql = ' '
if self . parent is None :
self . context . sql_end ( )
if self . has_postproc :
self . context . has_dll = True
self . context . postproc_end ( self . postproc_fname )
else :
self . context . ccode = ' '
if self . limit != 0 and not self . outfile :
self . context . direct_output ( self . limit )
class select_distinct ( projection ) :
first_order = ' select_distinct '
def consume ( self , node ) :
super ( ) . consume ( node )
if self . has_postproc :
self . context . emitc (
f ' { self . out_table . contextname_cpp } ->distinct(); '
)
self . finalize ( node )
class select_into ( ast_node ) :
def init ( self , _ ) :
if isinstance ( self . parent , projection ) :
# if self.parent.has_postproc:
# # has postproc put back to monetdb
self . produce = self . produce_cpp
# else:
# self.produce = self.produce_sql
else :
raise ValueError ( ' parent must be projection ' )
def produce_cpp ( self , node ) :
if not hasattr ( self . parent , ' out_table ' ) :
raise Exception ( ' No out_table found. ' )
else :
self . context . headers . add ( ' " ./server/table_ext_monetdb.hpp " ' )
self . ccode = f ' { self . parent . out_table . contextname_cpp } ->monetdb_append_table(cxt->curr_server, \" { node . lower ( ) } \" ); '
def produce_sql ( self , node ) :
self . context . sql = self . context . sql . replace (
self . parent . into_stub , f ' INTO { node } ' , 1 )
class orderby ( ast_node ) :
name = ' order by '
def produce ( self , node ) :
if node is None :
self . sql = ' '
return
node = enlist ( node )
o_list = [ ]
for o in node :
o_str = expr ( self , o [ ' value ' ] ) . sql
if ' sort ' in o and f ' { o [ " sort " ] } ' . lower ( ) == ' desc ' :
o_str + = ' ' + ' DESC '
o_list . append ( o_str )
self . add ( ' , ' . join ( o_list ) )
class scan ( ast_node ) :
class Position ( Enum ) :
init = auto ( )
front = auto ( )
body = auto ( )
back = auto ( )
fin = auto ( )
# TODO: use this for positions for scanner
class LoopStyle ( Enum ) :
forloop = auto ( )
foreach = auto ( )
class LoopType ( Enum ) :
continous = auto ( )
descrete = auto ( )
name = ' scan '
def __init__ ( self , parent : " ast_node " ,
node ,
loop_style = LoopStyle . forloop ,
context : Context = None ,
const = False ,
it_name = None ,
it_type = LoopType . continous
) :
self . it_var = it_name
self . it_type = it_type
self . const = " const " if const else " "
self . loop_style = loop_style
super ( ) . __init__ ( parent , node , context )
def init ( self , _ ) :
self . it_var = self . context . get_scan_var ( ) if not self . it_var else self . it_var
self . datasource = self . context . datasource
self . initializers = ' '
self . start = ' '
self . front = ' '
self . body = [ ]
self . end = ' '
self . parent . context . scans . append ( self )
def produce ( self , node ) :
self . start + = self . context . omp_simd
if self . loop_style == scan . LoopStyle . foreach :
self . colref = node
self . start + = f ' for ( { self . const } auto& { self . it_var } : { node } ) {{ \n '
else :
self . start + = f " for (uint32_t { self . it_var } = 0; { self . it_var } < { node } ; ++ { self . it_var } ) {{ \n "
def add ( self , stmt , position = Position . body ) :
if position == scan . Position . body :
self . body . append ( stmt )
elif position == scan . Position . init :
self . initializers + = stmt + ' \n '
elif position == scan . Position . front :
self . front + = stmt + ' \n '
elif position == scan . Position . fin :
self . end + = stmt + ' \n '
elif position == scan . Position . back :
self . end = stmt + ' \n ' + self . end
else :
raise ValueError ( f ' Unknown position { position } ' )
def finalize ( self ) :
scan_assembly = ' '
if self . it_type == scan . LoopType . descrete :
scan_assembly = (
self . initializers +
' \n ' . join ( [ (
self . start +
self . front +
b +
' \n } \n '
) for b in self . body ] )
+
self . end
)
else :
scan_assembly = (
self . initializers +
self . start +
self . front +
' \n ' . join ( self . body ) +
' \n } \n ' +
self . end
)
self . context . remove_scan ( self , scan_assembly )
class groupby_c ( ast_node ) :
name = ' _groupby '
def init ( self , node : List [ Tuple [ expr , Set [ ColRef ] ] ] ) :
self . proj : projection = self . parent
self . glist : List [ Tuple [ expr , Set [ ColRef ] ] ] = node
self . vecs : str = ' vecs_ ' + base62uuid ( 3 )
return super ( ) . init ( node )
def produce ( self , node : List [ Tuple [ expr , Set [ ColRef ] ] ] ) :
self . context . headers . add ( ' " ./server/hasher.h " ' )
# self.context.headers.add('unordered_map')
self . group = ' g ' + base62uuid ( 7 )
self . group_size = ' sz_ ' + self . group
self . group_type = ' record_type ' + base62uuid ( 7 )
self . datasource = self . proj . datasource
self . scanner = None
self . datasource . rec = set ( )
g_contents = ' '
g_contents_list = [ ]
first_col = ' '
scanner_itname = self . context . get_scan_var ( )
for g in self . glist :
e = expr ( self , g [ 0 ] . node , c_code = True )
g_str = e . eval ( c_code = True , y = lambda c : self . proj . pyname2cname [ c ] )
# if v is compound expr, create tmp cols
if not e . is_ColExpr :
self . context . headers . add ( ' " ./server/aggregations.h " ' )
tmpcol = ' t ' + base62uuid ( 7 )
self . context . emitc ( f ' auto { tmpcol } = { g_str } ; ' )
e = tmpcol
else :
e = g_str
g_contents_list . append ( e )
first_col = g_contents_list [ 0 ]
self . total_sz = ' len_ ' + base62uuid ( 4 )
self . context . emitc ( f ' uint32_t { self . total_sz } = { first_col } .size; ' )
g_contents_decltype = [ f ' decays<decltype( { c } )>::value_t ' for c in g_contents_list ]
g_contents = ' , ' . join (
# [f'{c}[{scanner_itname}]' for c in g_contents_list]
g_contents_list
)
## self.context.emitc('printf("init_time: %lld\\n", (chrono::high_resolution_clock::now() - timer).count()); timer = chrono::high_resolution_clock::now();')
self . context . emitc ( f ' typedef record< { " , " . join ( g_contents_decltype ) } > { self . group_type } ; ' )
# self.context.emitc(f'AQHashTable<{self.group_type}, '
# f'transTypes<{self.group_type}, hasher>> {self.group} {{{self.total_sz}}};')
self . n_grps = len ( self . glist )
self . context . emitc ( f ' auto { self . group } = '
f ' HashTableFactory< { self . group_type } , transTypes< { self . group_type } , hasher>>:: '
f ' get< { " , " . join ( [ f " decays<decltype( { c } )> " for c in g_contents_list ] ) } >( { g_contents } ); ' )
# self.scanner = scan(self, self.total_sz, it_name=scanner_itname)
# self.scanner.add(f'{self.group}.hashtable_push(forward_as_tuple({g_contents}), {self.scanner.it_var});')
# self.context.emitc(f'{self.group}.hashtable_push_all<{", ".join([f"decays<decltype({c})>" for c in g_contents_list])}>({g_contents}, {self.total_sz});')
def consume ( self , _ ) :
# self.scanner.finalize()
## self.context.emitc('printf("ht_construct: %lld\\n", (chrono::high_resolution_clock::now() - timer).count()); timer = chrono::high_resolution_clock::now();')
self . context . emitc ( f ' auto { self . group_size } = { self . group } .size; ' )
self . context . emitc ( f ' auto { self . vecs } = { self . group } .values; ' ) #{self.group}.ht_postproc({self.total_sz});')
## self.context.emitc('printf("ht_postproc: %lld\\n", (chrono::high_resolution_clock::now() - timer).count()); timer = chrono::high_resolution_clock::now();')
# def deal_with_assumptions(self, assumption:assumption, out:TableInfo):
# gscanner = scan(self, self.group)
# val_var = 'val_'+base62uuid(7)
# gscanner.add(f'auto &{val_var} = {gscanner.it_var}.second;')
# gscanner.add(f'{self.datasource.cxt_name}->order_by<{assumption.result()}>(&{val_var});')
# gscanner.finalize()
def finalize ( self , cexprs : List [ Tuple [ str , expr ] ] , var_table : Dict [ str , Union [ str , int ] ] ,
col_names : List [ str ] , col_types : List [ Types ] , col_tovec : List [ bool ] ) :
tovec_columns = set ( )
for i , c in enumerate ( col_names ) :
if col_tovec [ i ] : # and type(col_types[i]) is VectorT:
self . context . emitc ( f ' { c } .resize( { self . group_size } ); ' )
typename : Types = col_types [ i ] # .inner_type
self . context . emitc ( f ' auto buf_ { c } = static_cast< { typename . cname } *>(calloc( { self . total_sz } , sizeof( { typename . cname } ))); ' )
tovec_columns . add ( c )
else :
self . context . emitc ( f ' { c } .resize( { self . group_size } ); ' )
# self.arr_len = 'arrlen_' + base62uuid(3)
self . arr_values = { self . group . keys } #'arrvals_' + base62uuid(3)
# self.context.emitc(f'auto {self.arr_len} = {self.group_size};')
# self.context.emitc(f'auto {self.arr_values} = {self.group}.values();')
if len ( tovec_columns ) : # do this in seperate loops.
preproc_scanner = scan ( self , self . group_size )
preproc_scanner_it = preproc_scanner . it_var
for c in tovec_columns :
preproc_scanner . add ( f ' { c } [ { preproc_scanner_it } ].init_from '
f ' ( { self . vecs } [ { preproc_scanner_it } ].size, '
f ' { " buf_ " + c } + { self . group } .offsets '
f ' [ { preproc_scanner_it } ]); '
)
preproc_scanner . finalize ( )
self . context . emitc ( ' GC::scratch_space = GC::gc_handle ? &(GC::gc_handle->scratch) : nullptr; ' )
# gscanner = scan(self, self.group, loop_style = scan.LoopStyle.foreach)
gscanner = scan ( self , self . group_size )
key_var = ' key_ ' + base62uuid ( 7 )
val_var = ' val_ ' + base62uuid ( 7 )
# gscanner.add(f'auto &{key_var} = {gscanner.it_var}.first;', position = 'front')
# gscanner.add(f'auto &{val_var} = {gscanner.it_var}.second;', position = 'front')
gscanner . add ( f ' auto & { key_var } = { self . arr_values } [ { gscanner . it_var } ]; ' , position = scan . Position . front )
gscanner . add ( f ' auto & { val_var } = { self . vecs } [ { gscanner . it_var } ]; ' , position = scan . Position . front )
len_var = None
len_var = None
def define_len_var ( ) :
nonlocal len_var
if len_var is None :
len_var = ' len_ ' + base62uuid ( 7 )
gscanner . add ( f ' auto & { len_var } = { val_var } .size; ' , position = scan . Position . front )
def get_key_idx ( varname : str ) :
ex = expr ( self , varname )
joint_cols = set ( )
if ex . is_ColExpr and ex . raw_col :
joint_cols = self . datasource . get_joint_cols ( [ ex . raw_col ] )
for i , g in enumerate ( self . glist ) :
if ( varname == g [ 0 ] . eval ( ) ) or ( g [ 0 ] . is_ColExpr and g [ 0 ] . raw_col and
g [ 0 ] . raw_col in joint_cols ) :
return i
return var_table [ varname ]
def get_var_names ( varname : str ) :
var = get_key_idx ( varname )
if type ( var ) is str :
return f ' { var } [ { val_var } ] '
else :
return f ' get< { var } >( { key_var } ) '
def get_var_names_ex ( varex : expr ) :
sql_code = varex . eval ( c_code = False )
if ( sql_code in var_table ) :
return get_var_names ( sql_code )
else :
return varex . eval ( c_code = True , y = get_var_names ,
materialize_builtin = materialize_builtin ,
count = lambda : f ' { val_var } .size ' )
for i , ce in enumerate ( cexprs ) :
ex = ce [ 1 ]
materialize_builtin = { }
if type ( ex . udf_called ) is udf :
if ' _builtin_len ' in ex . udf_called . builtin_used :
define_len_var ( )
materialize_builtin [ ' _builtin_len ' ] = len_var
if ' _builtin_ret ' in ex . udf_called . builtin_used :
define_len_var ( )
gscanner . add ( f ' { ce [ 0 ] } [ { gscanner . it_var } ] = { len_var } ; \n ' )
materialize_builtin [ ' _builtin_ret ' ] = f ' { ce [ 0 ] } .back() '
gscanner . add ( f ' { ex . eval ( c_code = True , y = get_var_names , materialize_builtin = materialize_builtin ) } ; \n ' )
continue
if col_tovec [ i ] :
if ex . remake_binary ( f ' { ce [ 0 ] } [ { gscanner . it_var } ] ' ) :
gscanner . add ( f ' { get_var_names_ex ( ex ) } ; \n ' )
else :
gscanner . add ( f ' { ce [ 0 ] } [ { gscanner . it_var } ] = { get_var_names_ex ( ex ) } ; \n ' )
else :
gscanner . add ( f ' { ce [ 0 ] } [ { gscanner . it_var } ] = ( { get_var_names_ex ( ex ) } ); \n ' )
gscanner . add ( f ' GC::scratch_space->release(); ' )
## self.context.emitc('printf("ht_initfrom: %lld\\n", (chrono::high_resolution_clock::now() - timer).count());timer = chrono::high_resolution_clock::now();')
gscanner . finalize ( )
self . context . emitc ( f ' GC::scratch_space = nullptr; ' )
## self.context.emitc('printf("agg: %lld\\n", (chrono::high_resolution_clock::now() - timer).count());timer = chrono::high_resolution_clock::now();')
self . datasource . groupinfo = None
class groupby ( ast_node ) :
name = ' group by '
@staticmethod
def check_special ( parent : projection , node ) :
node = enlist ( node )
for g in node :
if ( ' value ' in g and
expr ( parent , g [ ' value ' ] ) . is_special
) :
return True
return False
def init ( self , _ ) :
self . terminate = False
super ( ) . init ( _ )
def produce ( self , node ) :
if not isinstance ( self . parent , projection ) :
raise ValueError ( ' groupby can only be used in projection ' )
node = enlist ( node )
o_list = [ ]
self . refs = set ( )
self . gb_cols = set ( )
# dedicated_glist -> cols populated for special group by
self . dedicated_glist : List [ Tuple [ expr , Set [ ColRef ] ] ] = [ ]
self . use_sp_gb = self . parent . force_use_spgb
for g in node :
self . datasource . rec = set ( )
g_expr = expr ( self , g [ ' value ' ] )
self . use_sp_gb = self . use_sp_gb or g_expr . is_special
refs : Set [ ColRef ] = self . datasource . rec
self . datasource . rec = None
if self . parent . col_ext :
this_sp_ref = refs . difference ( self . parent . col_ext )
self . use_sp_gb = self . use_sp_gb or len ( this_sp_ref ) > 0
self . refs . update ( refs )
self . dedicated_glist . append ( ( g_expr , refs ) )
g_str = g_expr . eval ( c_code = False )
if ' sort ' in g and f ' { g [ " sort " ] } ' . lower ( ) == ' desc ' :
g_str = g_str + ' ' + ' DESC '
o_list . append ( g_str )
if g_expr . is_ColExpr :
self . gb_cols . add ( g_expr . raw_col )
else :
self . gb_cols . add ( g_expr . sql )
for projs in self . parent . proj_map . values ( ) :
if self . use_sp_gb :
break
if ( projs [ 2 ] . is_compound and
not ( ( projs [ 2 ] . is_ColExpr and projs [ 2 ] . raw_col in self . gb_cols ) or
projs [ 2 ] . sql in self . gb_cols )
) and ( not self . parent . force_use_spgb ) :
self . use_sp_gb = True
break
if self . use_sp_gb and not self . parent . force_use_spgb :
self . terminate = True
return
if not self . use_sp_gb :
self . dedicated_gb = None
self . add ( ' , ' . join ( o_list ) )
else :
for l in self . dedicated_glist :
# l_exist = l[1].difference(self.parent.col_ext)
# for l in l_exist:
# self.parent.var_table.
self . parent . col_ext . update ( l [ 1 ] )
def finalize ( self , cexprs : List [ Tuple [ str , expr ] ] , var_table : Dict [ str , Union [ str , int ] ] ,
col_names : List [ str ] , col_types : List [ Types ] , col_tovec : List [ bool ] ) :
if self . use_sp_gb :
self . dedicated_gb = groupby_c ( self . parent , self . dedicated_glist )
self . dedicated_gb . finalize ( cexprs , var_table , col_names , col_types , col_tovec )
# TODO: add support for CallExpr for tables.
class join ( ast_node ) :
name = ' join '
def get_joint_cols ( self , cols : List [ ColRef ] ) :
joint_cols = set ( cols )
for col in cols :
joint_cols | = self . joint_cols . get ( col , set ( ) )
return joint_cols
def strip_joint_cols ( self , cols : Set [ ColRef ] ) :
stripped = type ( cols ) ( cols )
for c in cols :
if c in stripped :
jc = self . get_joint_cols ( [ c ] )
for j in jc :
if j != c and j in stripped :
stripped . remove ( j )
return stripped
@property
def no_join ( self ) :
return len ( self . tables ) == 1
def init ( self , _ ) :
self . joins : List [ Tuple [ str , bool ] ] = [ ]
self . tables : List [ TableInfo ] = [ ]
self . tables_dir = dict ( )
self . rec = None
self . top_level = self . parent and isinstance ( self . parent , projection )
self . have_sep = False
self . joint_cols : Dict [ ColRef , Set ] = dict ( ) # columns that are joined with this column
self . children : Set ( join ) = { }
self . join_conditions = [ ]
# self.tmp_name = 'join_' + base62uuid(4)
# self.datasource = TableInfo(self.tmp_name, [], self.context)
def append ( self , tbls , __alias = ' ' ) :
alias = lambda t : t + ' ' + __alias if len ( __alias ) else t
# sub joins
if type ( tbls ) is join :
self . joins . append ( ( alias ( tbls . __str__ ( ) ) , tbls . have_sep ) )
self . tables + = tbls . tables
self . tables_dir = { * * self . tables_dir , * * tbls . tables_dir }
self . join_conditions + = tbls . join_conditions
# simple table
elif type ( tbls ) is TableInfo :
self . joins . append ( ( alias ( tbls . table_name ) , False ) )
self . tables . append ( tbls )
self . tables_dir [ tbls . table_name ] = tbls
for a in tbls . alias :
self . tables_dir [ a ] = tbls
# subquery
elif type ( tbls ) is projection :
self . joins . append ( ( alias ( tbls . finalize ( ) ) , False ) )
def produce ( self , node ) :
if type ( node ) is list :
for d in node :
self . append ( join ( self , d ) )
elif type ( node ) is dict :
alias = ' '
if ' value ' in node :
table_name = node [ ' value ' ]
tbl = None
if ' name ' in node :
alias = node [ ' name ' ]
if type ( table_name ) is dict :
if ' select ' in table_name :
# TODO: subquery, create and register TableInfo in projection
tbl = projection ( self , table_name ) . finalize ( )
else :
tbl = self . context . tables_byname [ table_name ]
if ' name ' in node :
tbl . add_alias ( node [ ' name ' ] )
self . append ( tbl , alias )
else :
keys : List [ str ] = list ( node . keys ( ) )
if keys [ 0 ] . lower ( ) . endswith ( ' join ' ) :
self . have_sep = True
j = join ( self , node [ keys [ 0 ] ] )
self . tables + = j . tables
self . tables_dir = { * * self . tables_dir , * * j . tables_dir }
self . join_conditions + = j . join_conditions
_tbl_union = _tmp_join_union ( self . context , self . parent , self )
tablename = f ' { keys [ 0 ] } { j } '
if len ( keys ) > 1 :
jcond = node [ keys [ 1 ] ]
sqoute = ' \' '
if type ( jcond ) is list :
_ex = [ expr ( _tbl_union , j ) for j in jcond ]
else :
_ex = expr ( _tbl_union , jcond )
if keys [ 1 ] . lower ( ) == ' on ' : # postpone join condition evaluation after consume
self . join_conditions + = _ex . join_conditions
tablename + = f " ON { _ex . eval ( ) . replace ( sqoute , ' ' ) } "
elif keys [ 1 ] . lower ( ) == ' using ' :
_ex = enlist ( _ex )
lst_jconds = [ ]
for _e in _ex :
if _e . is_ColExpr :
cl = _e . raw_col
cr = j . get_cols ( _e . raw_col . name )
self . join_conditions . append ( ( cl , cr ) )
lst_jconds + = [ f ' { cl . get_full_name ( ) } = { cr . get_full_name ( ) } ' ]
tablename + = f ' ON { " and " . join ( lst_jconds ) } '
if keys [ 0 ] . lower ( ) . startswith ( ' natural ' ) :
ltbls : List [ TableInfo ] = [ ]
if isinstance ( self . parent , join ) :
ltbls = self . parent . tables
elif isinstance ( self . parent , TableInfo ) :
ltbls = [ self . parent ]
for tl in ltbls :
for cl in tl . columns :
cr = j . get_cols ( cl . name )
if cr :
self . join_conditions . append ( ( cl , cr ) )
self . joins . append ( ( tablename , self . have_sep ) )
elif type ( node ) is str :
if node in self . context . tables_byname :
self . append ( self . context . tables_byname [ node ] )
else :
print ( f ' Error: table { node } not found. ' )
def get_cols ( self , colExpr : str ) - > Optional [ ColRef ] :
if ' * ' in colExpr :
if colExpr == ' * ' :
return self . all_cols ( ordered = True , stripped = True )
elif colExpr . endswith ( ' .* ' ) :
tbl = colExpr . split ( ' . ' )
if len ( tbl ) > 2 :
raise KeyError ( f ' Invalid expression: { colExpr } ' )
if tbl [ 0 ] in self . tables_dir :
tbl : TableInfo = self . tables_dir [ tbl [ 0 ] ]
return tbl . all_cols ( ordered = True )
else :
raise KeyError ( f ' Invalid table name: { colExpr } ' )
for t in self . tables :
if colExpr in t . columns_byname :
col = t . columns_byname [ colExpr ]
if type ( self . rec ) is set :
self . rec . add ( col )
return col
return None
def parse_col_names ( self , colExpr : str ) - > ColRef :
parsedColExpr = colExpr . split ( ' . ' )
if len ( parsedColExpr ) < = 1 :
return self . get_cols ( colExpr )
else :
datasource = self . tables_dir [ parsedColExpr [ 0 ] ]
if datasource is None :
raise ValueError ( f ' Table name/alias not defined { parsedColExpr [ 0 ] } ' )
else :
datasource . rec = self . rec
ret = datasource . parse_col_names ( parsedColExpr [ 1 ] )
datasource . rec = None
return ret
@property
def single_table ( self ) :
return len ( self . tables ) == 1
# @property
def all_cols ( self , ordered = False , stripped = True ) :
from ordered_set import OrderedSet
ret = OrderedSet ( ) if ordered else set ( )
for table in self . tables :
rec = table . rec
table . rec = self . rec
ret . update ( table . all_cols ( ordered = ordered ) )
table . rec = rec
if stripped :
return self . strip_joint_cols ( ret )
return ret
# TODO: join condition awareness
def process_join_conditions ( self ) :
# This is done after both from
# and where clause are processed
for j in self . join_conditions :
l = j [ 0 ]
r = j [ 1 ]
for k in ( 0 , 1 ) :
if j [ k ] not in self . joint_cols :
self . joint_cols [ j [ k ] ] = set ( [ l , r ] )
jr = self . joint_cols [ r ]
jl = self . joint_cols [ l ]
if jl != jr :
jl | = jr
for c , jc in self . joint_cols . items ( ) :
if jc == jr :
self . joint_cols [ c ] = jl
# print(self.join_conditions)
# print(self.joint_cols)
def consume ( self , node ) :
self . sql = ' '
for j in self . joins :
if not self . sql or j [ 1 ] :
self . sql + = j [ 0 ] # using JOIN keyword
else :
self . sql + = ' , ' + j [ 0 ] # using comma
if node and self . sql and self . top_level :
self . sql = ' FROM ' + self . sql
return super ( ) . consume ( node )
def __str__ ( self ) :
return self . sql
def __repr__ ( self ) :
return self . __str__ ( )
class _tmp_join_union ( join ) :
name = ' __tmp_join_union '
def __init__ ( self , context , l : join , r : join ) :
self . tables = l . tables + r . tables
self . tables_dir = { * * l . tables_dir , * * r . tables_dir }
self . datasource = self
self . context = context
self . parent = self
self . rec = None
class filter ( ast_node ) :
name = ' where '
def produce ( self , node ) :
filter_expr = expr ( self , node )
self . add ( filter_expr . sql )
if self . datasource is not None :
self . datasource . join_conditions + = filter_expr . join_conditions
class union_all ( ast_node ) :
name = ' union_all '
first_order = name
sql_name = ' UNION ALL '
def produce ( self , node ) :
queries = node [ self . name ]
self . generated_queries : List [ Optional [ projection ] ] = [ None ] * len ( queries )
is_standard = True
for i , q in enumerate ( queries ) :
if ' select ' in q :
self . generated_queries [ i ] = projection ( self , q )
is_standard & = not self . generated_queries [ i ] . has_postproc
if is_standard :
self . sql = f ' { self . sql_name } ' . join ( [ q . sql for q in self . generated_queries ] )
else :
raise NotImplementedError ( f " { self . sql_name } only support standard sql for now " )
def consume ( self , node ) :
if ' into ' in node :
outtable = TableInfo ( node [ ' into ' ] , [ ] , self . context )
lst_cols = [ None ] * len ( self . generated_queries [ 0 ] . out_table . columns )
for i , c in enumerate ( self . generated_queries [ 0 ] . out_table . columns ) :
lst_cols [ i ] = ColRef ( c . type , None , outtable , c . name , i , c . compound )
outtable . add_cols ( lst_cols , new = False )
col_names = [ c . name for c in outtable . columns ]
col_names = ' ( ' + ' , ' . join ( col_names ) + ' ) '
self . sql = f ' CREATE TABLE { node [ " into " ] } { col_names } AS { self . sql } '
super ( ) . consume ( node )
if ' into ' not in node :
self . context . direct_output ( )
class except_clause ( union_all ) :
name = ' except '
first_order = name
sql_name = ' EXCEPT '
class create_table ( ast_node ) :
name = ' create_table '
first_order = name
allowed_subq = {
' select_distinct ' : select_distinct ,
' select ' : projection ,
' union_all ' : union_all ,
' except ' : except_clause
}
def init ( self , node ) :
node = node [ self . name ]
if ' query ' in node :
if ' name ' not in node :
raise ValueError ( " Table name not specified " )
projection_node = node [ ' query ' ]
projection_node [ ' into ' ] = node [ ' name ' ]
proj_cls = projection
for k in create_table . allowed_subq . keys ( ) :
if k in projection_node :
proj_cls = create_table . allowed_subq [ k ]
break
proj_cls ( None , projection_node , self . context )
self . produce = lambda * _ : None
self . spawn = lambda * _ : None
self . consume = lambda * _ : None
return
if self . parent is None :
self . context . sql_begin ( )
self . sql = ' CREATE TABLE '
def produce ( self , node ) :
ct = node [ self . name ]
tbl = self . context . add_table ( ct [ ' name ' ] , ct [ ' columns ' ] )
self . sql = f ' CREATE TABLE IF NOT EXISTS { tbl . table_name } ( '
columns = [ ]
for c in tbl . columns :
columns . append ( f ' { c . name } { c . type . sqlname } ' )
self . sql + = ' , ' . join ( columns )
self . sql + = ' ) '
if self . context . use_columnstore :
self . sql + = ' engine=ColumnStore '
class create_trigger ( ast_node ) :
name = ' create_trigger '
first_order = name
class Type ( Enum ) :
Interval = auto ( )
Callback = auto ( )
def init ( self , _ ) :
# overload init to prevent automatic sql generation
pass
def consume ( self , _ ) :
# overload consume to prevent automatic sqlend action
pass
def produce ( self , node ) :
from common . utils import send_to_server , get_storedproc
node = node [ ' create_trigger ' ]
self . trigger_name = node [ ' name ' ]
self . action_name = node [ ' action ' ]
self . action = get_storedproc ( self . action_name )
if self . trigger_name in self . context . triggers :
raise ValueError ( f ' trigger { self . trigger_name } exists ' )
elif not self . action :
raise ValueError ( f ' Stored Procedure { self . action_name } do not exist ' )
if ' interval ' in node : # executed periodically from server
self . type = self . Type . Interval
self . interval = node [ ' interval ' ]
self . context . queries . append ( f ' TI { self . trigger_name } \0 { self . action_name } \0 { self . interval } ' )
else : # executed from sql backend
self . type = self . Type . Callback
self . query_name = node [ ' query ' ]
self . table_name = node [ ' table ' ]
self . procedure = get_storedproc ( self . query_name )
if self . procedure and self . table_name in self . context . tables_byname :
self . table = self . context . tables_byname [ self . table_name ]
self . table . triggers . add ( self )
self . context . queries . append (
f ' TC { self . trigger_name } \0 { self . table_name } \0 '
f ' { self . query_name } \0 { self . action_name } '
)
else :
return
self . context . triggers [ self . trigger_name ] = self
# manually execute trigger
def register ( self ) :
if self . type != self . Type . Callback :
self . context . triggers . pop ( self . trigger_name )
raise ValueError ( f ' Trigger { self . trigger_name } is not a callback based trigger ' )
self . context . triggers_active . add ( self )
def execute ( self ) :
from common . utils import send_to_server
self . context . queries . append ( f ' TA { self . query_name } \0 { self . action_name } ' )
def remove ( self ) :
from common . utils import send_to_server
self . context . queries . append ( f ' TR { self . trigger_name } ' )
class drop_trigger ( ast_node ) :
name = ' drop_trigger '
first_order = name
def produce ( self , node ) :
. . .
class drop ( ast_node ) :
name = ' drop '
first_order = name
def produce ( self , node ) :
node = node [ ' drop ' ]
tbl_name = node [ ' table ' ]
if tbl_name in self . context . tables_byname :
tbl_obj : TableInfo = self . context . tables_byname [ tbl_name ]
for a in tbl_obj . alias :
self . context . tables_byname . pop ( a , None )
# TODO: delete in postproc engine
self . context . tables_byname . pop ( tbl_name , None )
self . context . tables . discard ( tbl_obj )
self . sql + = ' TABLE IF EXISTS ' + tbl_name
return
elif ' if_exists ' not in node or not node [ ' if_exists ' ] :
print ( f ' Error: table { tbl_name } not found. ' )
self . sql = ' '
class insert ( ast_node ) :
name = ' insert '
first_order = name
def init ( self , node ) :
if ' query ' in node :
values = node [ ' query ' ]
complex_query_kw = [ ' from ' , ' where ' , ' groupby ' , ' having ' , ' orderby ' , ' limit ' ]
if any ( [ kw in values for kw in complex_query_kw ] ) :
values [ ' into ' ] = node [ ' insert ' ]
proj_cls = (
select_distinct
if ' select_distinct ' in values
else projection
)
proj_cls ( None , values , self . context )
self . produce = lambda * _ : None
self . spawn = lambda * _ : None
self . consume = lambda * _ : None
else :
super ( ) . init ( node )
def produce ( self , node ) :
keys = [ ]
if ' query ' in node :
if ' select ' in node [ ' query ' ] :
values = enlist ( node [ ' query ' ] [ ' select ' ] )
if ' columns ' in node :
keys = node [ ' columns ' ]
values = [ v [ ' value ' ] for v in values ]
elif ' union_all ' in node [ ' query ' ] :
values = [ [ v [ ' select ' ] [ ' value ' ] ] for v in node [ ' query ' ] [ ' union_all ' ] ]
if ' columns ' in node :
keys = node [ ' columns ' ]
else :
values = enlist ( node [ ' values ' ] )
_vals = [ ]
for v in values :
if isinstance ( v , dict ) :
keys = v . keys ( )
v = list ( v . values ( ) )
v = [ f " ' { vv } ' " if type ( vv ) is str else vv for vv in v ]
_vals . append ( v )
values = _vals
keys = f ' ( { " , " . join ( keys ) } ) ' if keys else ' '
tbl = node [ ' insert ' ]
if tbl not in self . context . tables_byname :
print ( ' Warning: {tbl} not registered in aquery compiler. ' )
tbl_obj = self . context . tables_byname [ tbl ]
for t in tbl_obj . triggers :
t . register ( )
self . sql = f ' INSERT INTO { tbl } { keys } VALUES '
# if len(values) != table.n_cols:
# raise ValueError("Column Mismatch")
values = [ values ] if isinstance ( values , list ) and not isinstance ( values [ 0 ] , list ) else values
list_values = [ ]
for l in values :
inner_list_values = [ ]
for s in enlist ( l ) :
if type ( s ) is dict and ' value ' in s :
s = s [ ' value ' ]
inner_list_values . append ( f " { get_innermost ( s ) } " )
list_values . append ( f " ( { ' , ' . join ( inner_list_values ) } ) " )
self . sql + = ' , ' . join ( list_values )
class delete_from ( ast_node ) :
name = ' delete '
first_order = name
def init ( self , node ) :
super ( ) . init ( node )
def produce ( self , node ) :
tbl = node [ ' delete ' ]
self . sql = f ' DELETE FROM { tbl } '
if ' where ' in node :
self . sql + = filter ( self , node [ ' where ' ] ) . sql
class load ( ast_node ) :
name = " load "
first_order = name
def init ( self , node ) :
from common . utils import Backend_Type
self . module = False
if node [ ' load ' ] [ ' file_type ' ] == ' module ' :
self . produce = self . produce_module
self . module = True
elif ' complex ' in node [ ' load ' ] :
self . produce = self . produce_cpp
self . consume = lambda * _ : None
elif self . context . system_state . cfg . backend_type == Backend_Type . BACKEND_MonetDB . value :
self . produce = self . produce_monetdb
elif self . context . system_state . cfg . backend_type == Backend_Type . BACKEND_DuckDB . value :
self . produce = self . produce_duckdb
else :
self . produce = self . produce_aq
if self . parent is None :
self . context . sql_begin ( )
def produce_module ( self , node ) :
# create command for exec engine -> done
# create c++ stub
# create dummy udf obj for parsing
# def decode_type(ty : str) -> str:
# ret = ''
# back = ''
# while(ty.startswith('vec')):
# ret += 'ColRef<'
# back += '>'
# ty = ty[3:]
# ret += ty
# return ret + back
node = node [ ' load ' ]
file = node [ ' file ' ] [ ' literal ' ]
self . context . queries . append ( f ' M { file } ' )
self . module_name = file
self . functions = { }
if ' funcs ' in node :
for f in enlist ( node [ ' funcs ' ] ) :
fname = f [ ' fname ' ]
self . context . queries . append ( f ' F { fname } ' )
ret_type = VoidT
if ' ret_type ' in f :
ret_type = Types . decode ( f [ ' ret_type ' ] , vector_type = ' vector_type ' )
nargs = 0
arglist = ' '
if ' vars ' in f :
arglist = [ ]
for v in enlist ( f [ ' vars ' ] ) :
arglist . append ( f ' { Types . decode ( v [ " type " ] ) . cname } { v [ " arg " ] } ' )
nargs = len ( arglist )
arglist = ' , ' . join ( arglist )
# create c++ stub
cpp_stub = f ' { " vectortype_cstorage " if isinstance ( ret_type , VectorT ) else ret_type . cname } (* { fname } )( { arglist } ) = nullptr; '
self . context . module_stubs + = cpp_stub + ' \n '
self . context . module_map [ fname ] = cpp_stub
#registration for parser
self . functions [ fname ] = user_module_function ( fname , nargs , ret_type , self . context )
self . context . module_init_loc = len ( self . context . queries )
def produce_aq ( self , node ) :
node = node [ ' load ' ]
s1 = ' LOAD DATA INFILE '
s2 = ' INTO TABLE '
s3 = ' FIELDS TERMINATED BY '
self . sql = f ' { s1 } \" { node [ " file " ] [ " literal " ] } \" { s2 } { node [ " table " ] } '
if ' term ' in node :
self . sql + = f ' { s3 } \" { node [ " term " ] [ " literal " ] } \" '
def produce_monetdb ( self , node ) :
node = node [ ' load ' ]
s1 = f ' COPY OFFSET 2 INTO { node [ " table " ] } FROM '
s2 = ' ON SERVER '
s3 = ' USING DELIMITERS '
import os
p = os . path . abspath ( node [ ' file ' ] [ ' literal ' ] ) . replace ( ' \\ ' , ' / ' )
self . sql = f ' { s1 } \' { p } \' { s2 } '
if ' term ' in node :
self . sql + = f ' { s3 } \' { node [ " term " ] [ " literal " ] } \' '
def produce_duckdb ( self , node ) :
node = node [ ' load ' ]
s1 = f ' COPY { node [ " table " ] } FROM '
import os
p = os . path . abspath ( node [ ' file ' ] [ ' literal ' ] ) . replace ( ' \\ ' , ' / ' )
s2 = f " DELIMITER ' { node [ ' term ' ] [ ' literal ' ] } ' , " if ' term ' in node else ' '
self . sql = f ' { s1 } \' { p } \' ( { s2 } HEADER ) '
def produce_cpp ( self , node ) :
self . context . has_dll = True
self . context . headers . add ( ' " csv.h " ' )
node = node [ ' load ' ]
self . postproc_fname = ' ld_ ' + base62uuid ( 5 )
self . context . postproc_begin ( self . postproc_fname )
table : TableInfo = self . context . tables_byname [ node [ ' table ' ] ]
# self.sql = F"SELECT {', '.join([c.name for c in table.columns])} FROM {table.table_name};"
# self.emit(self.sql+';\n')
# self.context.sql_end()
# length_name = 'len_' + base62uuid(6)
# self.context.emitc(f'auto {length_name} = server->cnt;')
out_typenames = [ t . type . cname for t in table . columns ]
outtable_col_nameslist = ' , ' . join ( [ f ' " { c . name } " ' for c in table . columns ] )
self . outtable_col_names = ' names_ ' + base62uuid ( 4 )
self . context . emitc ( f ' const char* { self . outtable_col_names } [] = {{ { outtable_col_nameslist } }} ; ' )
self . out_table = ' tbl_ ' + base62uuid ( 4 )
self . context . emitc ( f ' auto { self . out_table } = new TableInfo< { " , " . join ( out_typenames ) } >( " { table . table_name } " , { self . outtable_col_names } ); ' )
for i , c in enumerate ( table . columns ) :
c . cxt_name = ' c_ ' + base62uuid ( 6 )
self . context . emitc ( f ' decltype(auto) { c . cxt_name } = { self . out_table } ->get_col< { i } >(); ' )
self . context . emitc ( f ' { c . cxt_name } .init( " { table . columns [ i ] . name } " ); ' )
#self.context.emitc(f'{c.cxt_name}.initfrom({length_name}, server->getCol({i}), "{table.columns[i].name}");')
csv_reader_name = ' csv_reader_ ' + base62uuid ( 6 )
col_types = [ c . type . cname for c in table . columns ]
col_tmp_names = [ ' tmp_ ' + base62uuid ( 8 ) for _ in range ( len ( table . columns ) ) ]
#col_names = ','.join([f'"{c.name}"' for c in table.columns])
term_field = ' , ' if ' term ' not in node else node [ ' term ' ] [ ' literal ' ]
term_ele = ' ; ' if ' ele ' not in node else node [ ' ele ' ] [ ' literal ' ]
self . context . emitc ( f ' AQCSVReader< { len ( col_types ) } , \' { term_field . strip ( ) [ 0 ] } \' , \' { term_ele . strip ( ) [ 0 ] } \' > { csv_reader_name } ( " { node [ " file " ] [ " literal " ] } " ); ' )
# self.context.emitc(f'{csv_reader_name}.read_header(io::ignore_extra_column, {col_names});')
self . context . emitc ( f ' { csv_reader_name } .next_line(); ' )
for t , n in zip ( col_types , col_tmp_names ) :
self . context . emitc ( f ' { t } { n } ; ' )
self . context . emitc ( f ' while( { csv_reader_name } .read_row( { " , " . join ( col_tmp_names ) } )) {{ \n ' )
for i , c in enumerate ( table . columns ) :
# self.context.emitc(f'print({col_tmp_names[i]});')
self . context . emitc ( f ' { c . cxt_name } .emplace_back( { col_tmp_names [ i ] } ); ' )
self . context . emitc ( ' } ' )
# self.context.emitc(f'print(*{self.out_table});')
self . context . emitc ( f ' { self . out_table } ->monetdb_append_table(cxt->curr_server, " { table . table_name } " ); ' )
self . context . postproc_end ( self . postproc_fname )
class outfile ( ast_node ) :
name = " _outfile "
def __init__ ( self , parent , node , context = None , * , sql = None ) :
self . node = node
self . sql = sql if sql else ' '
super ( ) . __init__ ( parent , node , context )
def init ( self , _ ) :
assert ( isinstance ( self . parent , projection ) )
if not self . parent . has_postproc :
if self . context . dialect == ' MonetDB ' :
self . produce = self . produce_monetdb
else :
self . produce = self . produce_aq
return super ( ) . init ( _ )
def produce_aq ( self , node ) :
filename = node [ ' loc ' ] [ ' literal ' ] if ' loc ' in node else node [ ' literal ' ]
self . sql + = f ' INTO OUTFILE " { filename } " '
if ' term ' in node :
self . sql + = f ' FIELDS TERMINATED BY \" { node [ " term " ] [ " literal " ] } \" '
def produce_monetdb ( self , node ) :
filename = node [ ' loc ' ] [ ' literal ' ] if ' loc ' in node else node [ ' literal ' ]
import os
p = os . path . abspath ( ' . ' ) . replace ( ' \\ ' , ' / ' ) + ' / ' + filename
print ( ' Warning: file {p} exists and will be overwritten ' )
if os . path . exists ( p ) :
try :
os . remove ( p )
except OSError :
print ( f ' Error: file { p } exists and cannot be removed ' )
self . sql = f ' COPY { self . parent . sql } INTO \' { p } \' '
d = ' , '
e = ' \\ n '
if ' term ' in node :
d = node [ ' term ' ] [ ' literal ' ]
self . sql + = f ' delimiters \' { d } \' , \' { e } \' '
def finalize ( self ) :
filename = self . node [ ' loc ' ] [ ' literal ' ] if ' loc ' in self . node else self . node [ ' literal ' ]
sep = ' , ' if ' term ' not in self . node else self . node [ ' term ' ] [ ' literal ' ]
file_pointer = ' fp_ ' + base62uuid ( 6 )
self . addc ( f ' FILE* { file_pointer } = fopen( " { filename } " , " wb " ); ' )
self . addc ( f ' { self . parent . out_table . contextname_cpp } ->printall( " { sep } " , " \\ n " , nullptr, { file_pointer } ); ' )
# if self.context.use_gc:
# self.addc(f'GC::gc_handle->reg({file_pointer}, 65536, fclose_gc);')
# else:
self . addc ( f ' fclose( { file_pointer } ); ' )
self . context . ccode + = self . ccode
class udf ( ast_node ) :
name = ' udf '
first_order = name
@staticmethod
def try_init_udf ( context : Context ) :
if context . udf is None :
context . udf = ' /*UDF Start*/ \n '
context . headers . add ( ' \" ./udf.hpp \" ' )
@dataclass
class builtin_var :
enabled : bool = False
_type : Types = AnyT
all = ( ' _builtin_len ' , ' _builtin_ret ' )
def decltypecall ( self , c_code = False , * args ) :
from common . types import fn_behavior
class dummy :
def __init__ ( self , name ) :
self . cname = name + ' _gettype '
self . sqlname = self . cname
return fn_behavior ( dummy ( self . cname ) , c_code , * args )
def __call__ ( self , c_code = False , * args ) :
from common . types import fn_behavior
builtin_args = [ f ' {{ { n } () }} ' for n , v in self . builtin . items ( ) if v . enabled ]
return fn_behavior ( self , c_code , * args , * builtin_args )
def return_type ( self , * _ : Types ) :
return LazyT
def init ( self , _ ) :
self . builtin : Dict [ str , udf . builtin_var ] = {
' _builtin_len ' : udf . builtin_var ( False , UIntT ) ,
' _builtin_ret ' : udf . builtin_var ( False , Types (
255 , name = ' generic_ref ' , cname = ' auto& '
) )
}
self . ccode = ' '
self . var_table = { }
self . args = [ ]
udf . try_init_udf ( self . context )
self . vecs = set ( )
self . code_list = [ ]
self . builtin_used = None
def add ( self , * code ) :
ccode = ' '
for c in code :
if type ( c ) is str :
ccode + = c
else :
self . code_list . append ( ccode )
self . code_list . append ( c )
ccode = ' '
if ccode :
self . code_list . append ( ccode )
def produce ( self , node ) :
from common . utils import check_legal_name , get_legal_name
node = node [ self . name ]
# register udf
self . agg = ' Agg ' in node
self . cname = get_legal_name ( node [ ' fname ' ] )
self . sqlname = self . cname
self . context . udf_map [ self . cname ] = self
if self . agg :
self . context . udf_agg_map [ self . cname ] = self
self . add ( f ' auto { self . cname } = []( ' )
def get_block ( self , ind , node ) :
if ' stmt ' in node :
old_ind = ind
ind + = ' \t '
next_stmt = enlist ( node [ ' stmt ' ] )
if len ( next_stmt ) > 1 :
self . add ( f ' {{ \n ' )
self . get_stmt ( ind , next_stmt )
self . add ( f ' { old_ind } }} \n ' )
else :
self . get_stmt ( ind , next_stmt )
def get_cname ( self , x : str ) :
return self . var_table [ x ]
def get_assignment ( self , ind , node , * , types = ' auto ' , sep = ' ; \n ' ) :
var_ex = expr ( self , node [ ' var ' ] , c_code = True , supress_undefined = True )
ex = expr ( self , node [ ' expr ' ] , c_code = True )
var = var_ex . eval ( y = self . get_cname )
if var in self . var_table or hasattr ( var_ex , ' builtin_var ' ) :
op = ' = '
if ' op ' in node and node [ ' op ' ] != ' := ' :
op = node [ ' op ' ]
e = ex . eval ( y = self . get_cname )
def assign_behavior ( decltypestr = False ) :
nonlocal ind , var , op , e , sep
v = var ( decltypestr ) if callable ( var ) else var
_e = e ( decltypestr ) if callable ( e ) else e
if v == ' _builtin_ret ' :
return f ' { ind } return { _e } { sep } '
elif ' _builtin_ret ' not in _e :
return f ' { ind } { v } { op } { _e } { sep } '
else :
return ' '
self . add ( assign_behavior )
else :
cvar = get_legal_name ( var )
self . var_table [ var ] = cvar
self . add ( f ' { ind } { types } { cvar } = ' , ex . eval ( y = self . get_cname ) , sep )
def get_stmt ( self , ind , node ) :
node = enlist ( node )
for n in node :
if ' if ' in n :
_ifnode = n [ ' if ' ]
self . add ( f ' { ind } if( ' , expr ( self , _ifnode [ " cond " ] ) . eval ( y = self . get_cname ) , ' ) ' )
if ' stmt ' in _ifnode :
self . get_block ( ind , _ifnode )
else :
self . add ( ' \n ' )
self . get_stmt ( ind + ' \t ' , _ifnode )
if ' elif ' in _ifnode :
for e in n [ ' elif ' ] :
self . add ( f ' { ind } else if( ' , expr ( self , e [ " cond " ] ) . eval ( y = self . get_cname ) , ' ) ' )
self . get_block ( ind , e )
if ' else ' in _ifnode :
self . add ( f ' { ind } else ' )
self . get_block ( ind , _ifnode [ ' else ' ] )
elif ' for ' in n :
_fornode = n [ ' for ' ]
defs = _fornode [ ' defs ' ]
self . add ( f ' { ind } for( { " auto " if len ( enlist ( defs [ " op " ] ) ) != 0 else " ; " } ' )
def get_inline_assignments ( node , end = ' ; ' ) :
var = enlist ( node [ ' var ' ] )
op = enlist ( node [ ' op ' ] )
expr = enlist ( node [ ' expr ' ] )
len_node = len ( enlist ( op ) )
for i , ( v , o , e ) in enumerate ( zip ( var , op , expr ) ) :
self . get_assignment ( ' ' , { ' var ' : v , ' op ' : o , ' expr ' : e } , types = ' ' , sep = ' , ' if i != len_node - 1 else end )
get_inline_assignments ( defs )
self . add ( expr ( self , _fornode [ " cond " ] ) . eval ( y = self . get_cname ) , ' ; ' )
get_inline_assignments ( _fornode [ ' tail ' ] , ' ) ' )
if ' stmt ' in _fornode :
self . get_block ( ind , _fornode )
else :
self . add ( ' \n ' )
self . get_stmt ( ind + ' \t ' , _fornode )
elif ' assignment ' in n :
assign = n [ ' assignment ' ]
self . get_assignment ( ind , assign )
def consume ( self , node ) :
from common . utils import check_legal_name , get_legal_name
node = node [ self . name ]
if ' params ' in node :
for args in node [ ' params ' ] :
cname = get_legal_name ( args )
self . var_table [ args ] = cname
self . args . append ( cname )
front = [ * self . code_list , ' , ' . join ( [ f ' const auto& { a } ' for a in self . args ] ) ]
self . code_list = [ ]
self . with_storage = False
self . with_statics = False
self . static_decl : Optional [ List [ str ] ] = None
ind = ' \t '
if ' static_decl ' in node :
self . add ( ' ) { \n ' )
curr = node [ ' static_decl ' ]
self . with_statics = True
if ' var ' in curr and ' expr ' in curr :
if len ( curr [ ' var ' ] ) != len ( curr [ ' expr ' ] ) :
print ( " Error: every static variable must be initialized. " )
self . static_decl = [ ]
for v , e in zip ( curr [ ' var ' ] , curr [ ' expr ' ] ) :
cname = get_legal_name ( v )
self . var_table [ v ] = cname
self . static_decl . append ( f ' { cname } = ' , expr ( self , e , c_code = True ) . eval ( self . get_cname ) )
self . add ( f ' { ind } static auto { " ; static auto " . join ( self . static_decl ) } ; \n ' )
self . add ( f ' { ind } auto reset = [=]() {{ { " ; " . join ( self . static_decl ) } ; }} ; \n ' )
self . add ( f ' { ind } auto call = []( { " , " . join ( [ f " decltype( { a } ) { a } " for a in self . args ] ) } ' )
ind = ' \t \t '
front = [ * front , * self . code_list ]
self . code_list = [ ]
if ' stmt ' in node :
self . get_stmt ( ind , node [ ' stmt ' ] )
# first scan to determine vec types
# if self.agg:
# for assign in node['assignment']:
# var = fastscan(assign['var'])
# ex = fastscan(assign['expr'])
# self.vecs.union(var.vec_vars)
# self.vecs.union(var.requested_lens)
# self.vecs.union(ex.vec_vars)
# self.vecs.union(ex.requested_lens)
# if len(self.vecs) != 0:
# self.idx_var = 'idx_' + base62uuid(5)
# self.ccode += f'{ind}auto {self.idx_var} = 0;\n'
ret = node [ ' ret ' ]
def return_call ( decltypestr = False ) :
if ( decltypestr ) :
return ' '
ret = ' '
for r in self . return_call :
if callable ( r ) :
ret + = r ( False )
else :
ret + = r
return ret
self . return_call = ( f ' { ind } return ' , expr ( self , ret , c_code = True ) . eval ( self . get_cname ) , ' ; \n ' )
self . add ( return_call )
if self . with_statics :
self . add ( ' \t }; \n ' )
self . add ( ' \t return std::make_pair(reset, call); \n ' )
self . add ( ' }; \n ' )
#print(self.ccode)
self . builtin_args = [ ( name , var . _type . cname ) for name , var in self . builtin . items ( ) if var . enabled ]
# self.context.udf += front + builtin_argstr + self.ccode + '\n'
self . finalize ( front )
def finalize ( self , front ) :
builtin_argstr = ' , ' if len ( self . builtin_args ) and len ( self . args ) else ' '
builtin_argstr + = ' , ' . join ( [ f ' { t } { n } ' for ( n , t ) in self . builtin_args ] )
self . builtin_used = [ b for b , v in self . builtin . items ( ) if v . enabled ]
ccode = ' '
def process_recursion ( l , decltypestr = False ) :
nonlocal ccode
for c in l :
if type ( c ) is str :
ccode + = c
elif callable ( c ) :
ccode + = c ( decltypestr ) # a callback function
else :
raise ValueError ( f ' Illegal operation in udf code generation: { c } ' )
process_recursion ( front )
ccode + = builtin_argstr + ' ) { \n '
process_recursion ( self . code_list )
self . ccode + = ccode + ' \n '
ccode = ' '
if self . return_pattern == udf . ReturnPattern . elemental_return :
ccode + = f ' auto { self . cname } _gettype = []( '
process_recursion ( front [ 1 : ] , True )
ccode + = ' ) { \n \t uint32_t _builtin_len = 0; \n '
process_recursion ( self . code_list , True )
self . ccode + = ccode + ' \n '
class ReturnPattern ( Enum ) :
bulk_return = auto ( )
elemental_return = auto ( )
@property
def return_pattern ( self ) :
if ' _builtin_ret ' in self . builtin_used :
return udf . ReturnPattern . elemental_return
else :
return udf . ReturnPattern . bulk_return
class passthru_sql ( ast_node ) :
name = ' sql '
first_order = name
import re
# escapestr = r'''(?:((?:[^;"']|"[^"]*"|'[^']*')+)|(?:--[^\r\n]*[\r|\n])+)'''
# escape_comment = fr'''(?:{escapestr}|{escapestr}*-{escapestr}*)'''
seprator = re . compile ( r ''' ((?:[^; " ' ]| " [^ " ]* " | ' [^ ' ]* ' )+) ''' )
def __init__ ( self , _ , node , context : Context ) :
sqls = passthru_sql . seprator . split ( node [ ' sql ' ] )
try :
if callable ( context . parser ) :
parsed = context . parser ( node [ ' sql ' ] )
except BaseException :
parsed = None
for sql in sqls :
sq = sql . strip ( ' \t \n \r ; ' )
if sq :
context . queries . append ( ' Q ' + sql . strip ( ' \r \n \t ; ' ) + ' ; ' )
lq = sq . lower ( )
if lq . startswith ( ' select ' ) :
context . direct_output ( )
class user_module_function ( OperatorBase ) :
def __init__ ( self , name , nargs , ret_type , context : Context ) :
super ( ) . __init__ ( name , nargs , lambda * _ : ret_type , call = fn_behavior )
user_module_func [ name ] = self
# builtin_operators[name] = self
udf . try_init_udf ( context )
class cache ( ast_node ) :
name = ' cache '
first_order = name
def init ( self , node ) :
self . context . has_payload = False
source = node [ ' cache ' ] [ ' source ' ]
# lazy = node['cache']['lazy']
lazy = 0
try :
tbl : TableInfo = self . context . tables_byname [ source ]
except KeyError :
raise ValueError ( f ' Cannot find table { source } . ' )
from common . utils import encode_integral
tbl . cached = True
schema_string = encode_integral ( len ( tbl . columns ) )
for t in tbl . columns :
schema_string + = t . name + ' \0 ' + \
encode_integral ( aquery_types [ t . type . ctype_name ] )
# TODO: 1. deal with hsituation where the table has aliases
# TODO: 2. when table is already cached. use write-back/lazy/eager to deal with changes
from common . utils import send_to_server
send_to_server ( f ' C { source } \0 { " l " if lazy else " e " } \0 { schema_string } \0 ' )
tbl . cached = True
def include ( objs ) :
import inspect
for _ , cls in inspect . getmembers ( objs ) :
if inspect . isclass ( cls ) and issubclass ( cls , ast_node ) and type ( cls . first_order ) is str :
ast_node . types [ cls . first_order ] = cls
import sys
include ( sys . modules [ __name__ ] )