You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
128 lines
5.7 KiB
128 lines
5.7 KiB
# code-gen for data decl languages
|
|
|
|
from common.orderby import orderby
|
|
from common.ast import ColRef, TableInfo, ast_node, Context, include
|
|
from common.scan import scan
|
|
from common.utils import base62uuid
|
|
|
|
class create_table(ast_node):
|
|
name = 'create_table'
|
|
def __init__(self, parent: "ast_node", node, context: Context = None, cexprs = None, lineage = False):
|
|
self.cexprs = cexprs
|
|
self.lineage = lineage
|
|
super().__init__(parent, node, context)
|
|
def produce(self, node):
|
|
if type(node) is not TableInfo:
|
|
ct = node[self.name]
|
|
tbl = self.context.add_table(ct['name'], ct['columns'])
|
|
else:
|
|
tbl = node
|
|
|
|
col_type_str = ','.join([c.type for c in tbl.columns])
|
|
# create tables in c
|
|
self.emit(f"auto {tbl.table_name} = new TableInfo<{col_type_str}>(\"{tbl.table_name}\", {tbl.n_cols});")
|
|
self.emit("cxt->tables.insert({\"" + tbl.table_name + f"\", {tbl.table_name}"+"});")
|
|
self.context.tables_in_context[tbl] = tbl.table_name
|
|
tbl.cxt_name = tbl.table_name
|
|
tbl.refer_all()
|
|
# create an empty new table
|
|
if self.cexprs is None:
|
|
for c in tbl.columns:
|
|
self.emit(f'{c.cxt_name}.init("{c.name}");')
|
|
# create an output table
|
|
else:
|
|
# 1 to 1 lineage.
|
|
if len(self.context.scans) == 0:
|
|
if self.lineage:
|
|
order = 'order_' + base62uuid(6)
|
|
self.emit(f'auto {order} = {self.parent.datasource.cxt_name}->order_by<{orderby(self.parent, self.parent.assumptions).result()}>();')
|
|
self.lineage = '*' + order
|
|
else:
|
|
self.lineage = None
|
|
for i, c in enumerate(tbl.columns):
|
|
self.emit(f'{c.cxt_name}.init("{c.name}");')
|
|
self.emit(f"{c.cxt_name} = {self.cexprs[i](self.lineage)};")
|
|
self.lineage = None
|
|
self.parent.assumptions = None
|
|
else:
|
|
scanner:scan = self.context.scans[-1]
|
|
if self.lineage:
|
|
lineage_var = 'lineage_' + base62uuid(6)
|
|
counter_var = 'counter_' + base62uuid(6)
|
|
scanner.add(f'auto {lineage_var} = {self.datasource.cxt_name}->bind({tbl.cxt_name});', "init")
|
|
scanner.add(f'auto {counter_var} = 0;', "init")
|
|
scanner.add(f"{lineage_var}.emplace_back({counter_var}++);", "front")
|
|
self.lineage = f"{lineage_var}.rid"
|
|
for i, c in enumerate(tbl.columns):
|
|
scanner.add(f'{c.cxt_name}.init("{c.name}");', "init")
|
|
scanner.add(f"{c.cxt_name} = {self.cexprs[i](scanner.it_var)};")
|
|
|
|
class insert(ast_node):
|
|
name = 'insert'
|
|
def produce(self, node):
|
|
ct = node[self.name]
|
|
table:TableInfo = self.context.tables_byname[ct]
|
|
|
|
values = node['query']['select']
|
|
if len(values) != table.n_cols:
|
|
raise ValueError("Column Mismatch")
|
|
table.refer_all()
|
|
for i, s in enumerate(values):
|
|
if 'value' in s:
|
|
cname = table.columns[i].cxt_name
|
|
self.emit(f"{cname}.emplace_back({s['value']});")
|
|
else:
|
|
# subquery, dispatch to select astnode
|
|
pass
|
|
|
|
class c(ast_node):
|
|
name='c'
|
|
def produce(self, node):
|
|
self.emit(node[self.name])
|
|
|
|
class load(ast_node):
|
|
name="load"
|
|
def produce(self, node):
|
|
self.context.headers.add('"csv.h"')
|
|
node = node[self.name]
|
|
table:TableInfo = self.context.tables_byname[node['table']]
|
|
table.refer_all()
|
|
csv_reader_name = 'csv_reader_' + base62uuid(6)
|
|
col_types = [c.type for c in table.columns]
|
|
col_tmp_names = ['tmp_'+base62uuid(8) for _ in range(len(table.columns))]
|
|
# col_type_str = ",".join(col_types)
|
|
col_names = ','.join([f'"{c.name}"' for c in table.columns])
|
|
|
|
self.emit(f'io::CSVReader<{len(col_types)}> {csv_reader_name}("{node["file"]["literal"]}");')
|
|
self.emit(f'{csv_reader_name}.read_header(io::ignore_extra_column, {col_names});')
|
|
for t, n in zip(col_types, col_tmp_names):
|
|
self.emit(f'{t} {n};')
|
|
self.emit(f'while({csv_reader_name}.read_row({",".join(col_tmp_names)})) {{ \n')
|
|
for i, c in enumerate(table.columns):
|
|
self.emit(f'{c.cxt_name}.emplace_back({col_tmp_names[i]});')
|
|
self.emit('}')
|
|
|
|
|
|
class outfile(ast_node):
|
|
name="_outfile"
|
|
def produce(self, node):
|
|
out_table:TableInfo = self.parent.out_table
|
|
filename = node['loc']['literal'] if 'loc' in node else node['literal']
|
|
sep = ',' if 'term' not in node else node['term']['literal']
|
|
file_pointer = 'fp_' + base62uuid(6)
|
|
self.emit(f'FILE* {file_pointer} = fopen("{filename}", "wb");')
|
|
self.emit(f'{out_table.cxt_name}->printall("{sep}", "\\n", nullptr, {file_pointer});')
|
|
self.emit(f'fclose({file_pointer});')
|
|
# self.context.headers.add('fstream')
|
|
# cout_backup_buffer = 'stdout_' + base62uuid(4)
|
|
# ofstream = 'ofstream_' + base62uuid(6)
|
|
# self.emit(f'auto {cout_backup_buffer} = cout.rdbuf();')
|
|
# self.emit(f'auto {ofstream} = ofstream("{filename}");')
|
|
# self.emit(f'cout.rdbuf({ofstream}.rdbuf());')
|
|
# TODO: ADD STMTS.
|
|
# self.emit(f'cout.rdbuf({cout_backup_buffer});')
|
|
# self.emit(f'{ofstream}.close();')
|
|
|
|
|
|
import sys
|
|
include(sys.modules[__name__]) |