Very basic code gen

dev
Bill Sun 3 years ago
parent 49a3fc0a78
commit ee2dc88f06

1
.gitignore vendored

@ -14,3 +14,4 @@ vendor/
.DS_Store
.eggs
.vscode
out.k

@ -5,7 +5,10 @@ Frontend built on top of [mo-sql-parsing](https://github.com/klahnakoski/mo-sql-
## Roadmap
- [x] SQL Parser -> AQuery Parser
- [ ] Data acquisition/output from/to csv file (By Jan. 21)
- -> AQuery-K9 Compiler
- Simple query (By Jan. 21)
- [ ] Nested queries (Jan. 28)
- [ ] -> Optimizing Compiler
# Descriptions from mo-sql-parsing:

@ -1,5 +1,5 @@
from engine.ast import Context, ast_node
import engine.ddl
import engine.ddl, engine.projection
def initialize():
return Context()
@ -8,6 +8,5 @@ def generate(ast, cxt):
for k in ast.keys():
if k in ast_node.types.keys():
root = ast_node.types[k](None, ast, cxt)
__all__ = ["generate"]
__all__ = ["initialize", "generate"]

@ -1,36 +1,53 @@
from typing import List
import uuid
class TableInfo:
def __init__(self, table_name, cols, cxt:'Context'):
# statics
self.table_name = table_name
self.columns = dict() # column_name, type
self.columns_byname = dict() # column_name, type
self.columns = []
for c in cols:
self.columns[c['name']] = ((list(c['type'].keys()))[0], c)
k9name = self.table_name + c['name']
if k9name in cxt.k9cols_byname: # duplicate names?
root = cxt.k9cols_byname[k9name]
k9name = k9name + root[1]
root[1] += 1
cxt.k9cols[c] = k9name
cxt.k9cols_byname[k9name] = (c, 1)
k9name = k9name + root[3]
root[3] += 1
# column: (k9name, type, original col_object, dup_count)
col_object = (k9name, (list(c['type'].keys()))[0], c, 1)
cxt.k9cols_byname[k9name] = col_object
self.columns_byname[c['name']] = col_object
self.columns.append(col_object)
# runtime
self.n_cols = 0 # number of cols
self.n_rows = 0 # number of cols
self.order = [] # assumptions
cxt.tables_byname[self.table_name] = self # construct reverse map
def get_k9colname(self, cxt:'Context', col_name):
return cxt.k9cols[self.columns[col_name][1]] # well, this is gnarly.. will change later
@property
def n_cols(self):
return len(self.columns)
def get_k9colname(self, col_name):
return self.columns_byname[col_name][0]
def parse_tablenames(self, str):
# TODO: deal with alias
return self.get_k9colname(str)
class Context:
def __init__(self):
self.tables:List[TableInfo] = []
self.tables_byname = dict()
self.k9cols = dict()
self.k9cols_byname = dict()
self.udf_map = dict()
self.k9code = ''
def add_table(self, table_name, cols):
@ -38,9 +55,14 @@ class Context:
self.tables.append(tbl)
return tbl
def gen_tmptable(self):
from engine.utils import base62uuid
return f'tmp{base62uuid()}'
def emit(self, codelet):
self.k9code += codelet + '\n'
def emit_nonewline(self, codelet):
self.k9code += codelet
def __str__(self):
return self.k9code
@ -48,18 +70,35 @@ class ast_node:
types = dict()
def __init__(self, parent:"ast_node", node, context:Context = None):
self.context = parent.context if context is None else context
self.init(node)
self.produce(node)
self.enumerate(node)
self.spawn(node)
self.consume(node)
def emit(self, code):
self.context.emit(code)
def emit_no_ln(self, code):
self.context.emit_nonewline(code)
name = 'null'
# each ast node has 3 stages.
# `produce' generates info for child nodes
# `spawn' populates child nodes
# `consume' consumes info from child nodes and finalizes codegen
# For simple operators, there may not be need for some of these stages
def init(self, _):
pass
def produce(self, _):
pass
def enumerate(self, _):
def spawn(self, _):
pass
def consume(self, _):
pass
# include classes in module as first order operators
def include(objs):
import inspect
for _, cls in inspect.getmembers(objs):
if inspect.isclass(cls) and issubclass(cls, ast_node):
ast_node.types[cls.name] = cls

@ -1,4 +1,7 @@
from engine.ast import TableInfo, ast_node
# code-gen for data decl languages
from engine.ast import TableInfo, ast_node, include
class create_table(ast_node):
name = 'create_table'
def produce(self, node):
@ -6,17 +9,24 @@ class create_table(ast_node):
tbl = self.context.add_table(ct['name'], ct['columns'])
# create tables in k9
for c in ct['columns']:
self.emit(f"{tbl.get_k9colname((list(c['name'].keys())))[0]}:()")
self.emit(f"{tbl.get_k9colname(c['name'])}:()")
class insert_into(ast_node):
class insert(ast_node):
name = 'insert'
def produce(self, node):
ct = node[self.name]
table:TableInfo = self.context.tables_byname[ct]
import sys, inspect
for name, cls in inspect.getmembers(sys.modules[__name__]):
if inspect.isclass(cls) and issubclass(cls, ast_node):
ast_node.types[name] = cls
values = node['query']['select']
if len(values) != table.n_cols:
raise ValueError("Column Mismatch")
for i, s in enumerate(values):
if 'value' in s:
k9name = table.columns[i][0]
self.emit(f"{k9name}:{k9name},{s['value']}")
else:
# subquery, dispatch to select astnode
pass
import sys
include(sys.modules[__name__])

@ -0,0 +1,46 @@
from engine.ast import ast_node
class expr(ast_node):
name='expr'
builtin_func_maps = {
'max': 'max',
'min': 'min',
}
binary_ops = {'sub':'-', 'plus':'+'}
unary_ops = []
def __init__(self, parent, node):
from engine.projection import projection
if type(parent) in [projection, expr]:
self.datasource = parent.datasource
else:
self.datasource = None
self.udf_map = parent.context.udf_map
self.k9expr = ''
self.func_maps = {**self.udf_map, **self.builtin_func_maps}
ast_node.__init__(self, parent, node, None)
def produce(self, node):
if type(node) is dict:
for key, val in node.items():
if key in self.func_maps:
self.k9expr += f"{self.func_maps[key]}("
# if type(val) in [dict, str]:
self.k9expr += expr(self, val).k9expr
self.k9expr+=')'
elif key in self.binary_ops:
l = expr(self, val[0]).k9expr
r = expr(self, val[1]).k9expr
self.k9expr += f'({l}{self.binary_ops[key]}{r})'
print(f'binary{key}')
elif key in self.unary_ops:
print(f'unary{key}')
else:
print(key)
elif type(node) is str:
self.k9expr = self.datasource.parse_tablenames(node)
def __str__(self):
return self.k9expr

@ -0,0 +1,6 @@
from engine.ast import ast_node
class join(ast_node):
name='join'

@ -0,0 +1,72 @@
from engine.ast import TableInfo, ast_node, Context, include
from engine.join import join
from engine.expr import expr
from engine.utils import base62uuid
class projection(ast_node):
name='select'
def __init__(self, parent:ast_node, node, context:Context = None, outname = None, disp = True):
self.disp = disp
self.outname = outname
ast_node.__init__(self, parent, node, context)
def init(self, _):
if self.outname is None:
self.outname = self.context.gen_tmptable()
def produce(self, node):
p = node['select']
self.projections = p if type(projection) == list else [p]
print(node)
def spawn(self, node):
self.datasource = None
if 'from' in node:
from_clause = node['from']
if type(from_clause) is list:
# from joins
join(self, from_clause)
elif type(from_clause) is dict:
if 'value' in from_clause:
value = from_clause['value']
if type(value) is dict:
if 'select' in value:
# from subquery
projection(self, from_clause, disp = False)
else:
# TODO: from func over table
print(f"from func over table{node}")
elif type(value) is str:
self.datasource = self.context.tables_byname[value]
if 'assumptions' in from_clause:
ord = from_clause['assumptions']['ord'] == 'asc'
ord = '^' if ord else '|^'
# TODO: generate view of table by order
elif type(from_clause) is str:
self.datasource = self.context.tables_byname[from_clause]
if self.datasource is None:
raise ValueError('spawn error: from clause')
def consume(self, node):
disp_varname = 'disptmp' + base62uuid()
self.emit_no_ln(f'{disp_varname}:(')
for proj in self.projections:
if type(proj) is dict:
if 'value' in proj:
e = proj['value']
if type(e) is str:
self.emit_no_ln(f"{self.datasource.parse_tablenames(proj['value'])};")
elif type(e) is dict:
self.emit_no_ln(f"{expr(self, e).k9expr};")
self.emit(')')
if self.disp:
self.emit(disp_varname)
import sys
include(sys.modules[__name__])

@ -0,0 +1,12 @@
import uuid
def base62uuid(crop=8):
alp = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
id = uuid.uuid4().int
ret = ''
while id:
ret = alp[id % 62] + ret
id //= 62
return ret[:crop] if len(ret) else '0'

@ -16,8 +16,14 @@ print(res)
while test_parser:
try:
q = input()
if q == 'break':
break
if q == 'exec':
cxt = engine.initialize()
for s in stmts['stmts']:
engine.generate(s, cxt)
print(cxt.k9code)
with open('out.k', 'wb') as outfile:
outfile.write(cxt.k9code)
continue
trimed = ws.sub(' ', q.lower()).split(' ')
if trimed[0] == 'file':
fn = 'q.sql' if len(trimed) <= 1 or len(trimed[1]) == 0 \
@ -32,6 +38,3 @@ while test_parser:
except Exception as e:
print(type(e), e)
cxt = engine.initialize()
for s in stmts['stmts']:
engine.generate(s, cxt)

@ -17,6 +17,9 @@ INSERT INTO stocks VALUES(14,5)
INSERT INTO stocks VALUES(15,2)
INSERT INTO stocks VALUES(16,5)
SELECT max(price-mins(price))
SELECT max(price-min(timestamp)) FROM stocks
/*SELECT max(price-mins(price))
FROM stocks
ASSUMING ASC timestamp
ASSUMING ASC timestamp*/

Loading…
Cancel
Save