make ext_engine: duckdb to work

master
Bill 2 years ago
parent 200dc71aad
commit 7c5440c4fb

@ -5,7 +5,7 @@ Defines =
CC = $(CXX) -xc CC = $(CXX) -xc
CXXFLAGS = --std=c++2a CXXFLAGS = --std=c++2a
ifeq ($(AQ_DEBUG), 1) ifeq ($(AQ_DEBUG), 1)
OPTFLAGS = -g3 #-static-libasan -fsanitize=address OPTFLAGS = -g3 #-static-libsan -fsanitize=address
LINKFLAGS = LINKFLAGS =
else else
OPTFLAGS = -Ofast -DNDEBUG -fno-stack-protector OPTFLAGS = -Ofast -DNDEBUG -fno-stack-protector
@ -17,12 +17,15 @@ _COMPILER = $(shell $(CXX) --version | grep -q clang && echo clang|| echo gcc)
COMPILER = $(strip $(_COMPILER)) COMPILER = $(strip $(_COMPILER))
LIBTOOL = ar rcs LIBTOOL = ar rcs
USELIB_FLAG = -Wl,--whole-archive,libaquery.a -Wl,-no-whole-archive USELIB_FLAG = -Wl,--whole-archive,libaquery.a -Wl,-no-whole-archive
LIBAQ_SRC = server/monetdb_conn.cpp server/libaquery.cpp LIBAQ_SRC = server/monetdb_conn.cpp server/duckdb_conn.cpp server/libaquery.cpp
LIBAQ_OBJ = monetdb_conn.o libaquery.o monetdb_ext.o LIBAQ_OBJ = monetdb_conn.o duckdb_conn.o libaquery.o monetdb_ext.o
SEMANTIC_INTERPOSITION = -fno-semantic-interposition SEMANTIC_INTERPOSITION = -fno-semantic-interposition
RANLIB = ranlib RANLIB = ranlib
_LINKER_BINARY = $(shell `$(CXX) -print-prog-name=ld` -v 2>&1 | grep -q LLVM && echo lld || echo ld) _LINKER_BINARY = $(shell `$(CXX) -print-prog-name=ld` -v 2>&1 | grep -q LLVM && echo lld || echo ld)
LINKER_BINARY = $(strip $(_LINKER_BINARY)) LINKER_BINARY = $(strip $(_LINKER_BINARY))
DuckDB_LIB = -Ldeps -lduckdb
DuckDB_INC = -Ideps
ifeq ($(LINKER_BINARY), ld) ifeq ($(LINKER_BINARY), ld)
LINKER_FLAGS = -Wl,--allow-multiple-definition LINKER_FLAGS = -Wl,--allow-multiple-definition
else else
@ -58,6 +61,7 @@ ifeq ($(OS),Windows_NT)
LIBAQ_OBJ += winhelper.o LIBAQ_OBJ += winhelper.o
MonetDB_LIB += msc-plugin/monetdbe.dll MonetDB_LIB += msc-plugin/monetdbe.dll
MonetDB_INC += -Imonetdb/msvc MonetDB_INC += -Imonetdb/msvc
LIBTOOL = gcc-ar rcs LIBTOOL = gcc-ar rcs
ifeq ($(COMPILER), clang) ifeq ($(COMPILER), clang)
FPIC = FPIC =
@ -96,8 +100,8 @@ ifeq ($(AQUERY_ITC_USE_SEMPH), 1)
Defines += -D__AQUERY_ITC_USE_SEMPH__ Defines += -D__AQUERY_ITC_USE_SEMPH__
endif endif
CXXFLAGS += $(OPTFLAGS) $(Defines) $(MonetDB_INC) CXXFLAGS += $(OPTFLAGS) $(Defines) $(MonetDB_INC) $(DuckDB_INC)
BINARYFLAGS = $(CXXFLAGS) $(LINKFLAGS) $(MonetDB_LIB) BINARYFLAGS = $(CXXFLAGS) $(LINKFLAGS) $(MonetDB_LIB) $(DuckDB_LIB)
SHAREDFLAGS += $(FPIC) $(BINARYFLAGS) SHAREDFLAGS += $(FPIC) $(BINARYFLAGS)
info: info:

@ -2,7 +2,7 @@
## GLOBAL CONFIGURATION FLAGS ## GLOBAL CONFIGURATION FLAGS
version_string = '0.7.5a' version_string = '0.7.6a'
add_path_to_ldpath = True add_path_to_ldpath = True
rebuild_backend = False rebuild_backend = False
run_backend = True run_backend = True
@ -12,6 +12,8 @@ msbuildroot = ''
os_platform = 'unknown' os_platform = 'unknown'
build_driver = 'Auto' build_driver = 'Auto'
compilation_output = True compilation_output = True
compile_use_gc = True
compile_use_threading = True
## END GLOBAL CONFIGURATION FLAGS ## END GLOBAL CONFIGURATION FLAGS

@ -75,14 +75,15 @@ class build_manager:
sourcefiles = [ sourcefiles = [
'build.py', 'Makefile', 'build.py', 'Makefile',
'server/server.cpp', 'server/libaquery.cpp', 'server/server.cpp', 'server/libaquery.cpp',
'server/monetdb_conn.cpp', 'server/threading.cpp', 'server/monetdb_conn.cpp', 'server/duckdb_conn.cpp',
'server/winhelper.cpp', 'server/monetdb_ext.c' 'server/threading.cpp', 'server/winhelper.cpp',
'server/monetdb_ext.c'
] ]
headerfiles = ['server/aggregations.h', 'server/hasher.h', 'server/io.h', headerfiles = ['server/aggregations.h', 'server/hasher.h', 'server/io.h',
'server/libaquery.h', 'server/monetdb_conn.h', 'server/pch.hpp', 'server/libaquery.h', 'server/monetdb_conn.h', 'server/duckdb_conn.h',
'server/table.h', 'server/threading.h', 'server/types.h', 'server/utils.h', 'server/pch.hpp', 'server/table.h', 'server/threading.h',
'server/winhelper.h', 'server/gc.h', 'server/vector_type.hpp', 'server/types.h', 'server/utils.h', 'server/winhelper.h',
'server/table_ext_monetdb.hpp' 'server/gc.h', 'server/vector_type.hpp', 'server/table_ext_monetdb.hpp'
] ]
class DriverBase: class DriverBase:

@ -229,6 +229,7 @@ class Context:
self.removing_scan = False self.removing_scan = False
def __init__(self): def __init__(self):
from prompt import PromptState
self.tables:list[TableInfo] = [] self.tables:list[TableInfo] = []
self.tables_byname = dict() self.tables_byname = dict()
self.ccols_byname = dict() self.ccols_byname = dict()
@ -252,6 +253,9 @@ class Context:
self.ds_stack = [] self.ds_stack = []
self.scans = [] self.scans = []
self.removing_scan = False self.removing_scan = False
self.force_compiled = True
self.system_state: Optional[PromptState] = None
def add_table(self, table_name, cols): def add_table(self, table_name, cols):
tbl = TableInfo(table_name, cols, self) tbl = TableInfo(table_name, cols, self)
self.tables.append(tbl) self.tables.append(tbl)

@ -31,7 +31,7 @@ class Types:
self.name = name self.name = name
self.cname = defval(cname, name.lower() + '_t') self.cname = defval(cname, name.lower() + '_t')
self.sqlname = defval(sqlname, name.upper()) self.sqlname = defval(sqlname, name.upper())
self.ctype_name = defval(ctype_name, f'types::{name.upper()}') self.ctype_name = defval(ctype_name, f'types::A{name.upper()}')
self.null_value = defval(null_value, 0) self.null_value = defval(null_value, 0)
self.cast_to_dict = defval(cast_to, dict()) self.cast_to_dict = defval(cast_to, dict())
self.cast_from_dict = defval(cast_from, dict()) self.cast_from_dict = defval(cast_from, dict())
@ -102,7 +102,7 @@ LongT = Types(4, name = 'int64', sqlname = 'BIGINT', fp_type = DoubleT)
BoolT = Types(0, name = 'bool', cname='bool', sqlname = 'BOOL', long_type=LongT, fp_type=FloatT) BoolT = Types(0, name = 'bool', cname='bool', sqlname = 'BOOL', long_type=LongT, fp_type=FloatT)
ByteT = Types(1, name = 'int8', sqlname = 'TINYINT', long_type=LongT, fp_type=FloatT) ByteT = Types(1, name = 'int8', sqlname = 'TINYINT', long_type=LongT, fp_type=FloatT)
ShortT = Types(2, name = 'int16', sqlname='SMALLINT', long_type=LongT, fp_type=FloatT) ShortT = Types(2, name = 'int16', sqlname='SMALLINT', long_type=LongT, fp_type=FloatT)
IntT = Types(3, name = 'int', cname = 'int', long_type=LongT, fp_type=FloatT) IntT = Types(3, name = 'int', cname = 'int', long_type=LongT, ctype_name = 'types::AINT32', fp_type=FloatT)
ULongT = Types(8, name = 'uint64', sqlname = 'UINT64', fp_type=DoubleT) ULongT = Types(8, name = 'uint64', sqlname = 'UINT64', fp_type=DoubleT)
UIntT = Types(7, name = 'uint32', sqlname = 'UINT32', long_type=ULongT, fp_type=FloatT) UIntT = Types(7, name = 'uint32', sqlname = 'UINT32', long_type=ULongT, fp_type=FloatT)
UShortT = Types(6, name = 'uint16', sqlname = 'UINT16', long_type=ULongT, fp_type=FloatT) UShortT = Types(6, name = 'uint16', sqlname = 'UINT16', long_type=ULongT, fp_type=FloatT)

@ -20,7 +20,7 @@ __AQEXPORT__(int) action(Context* cxt) {
if (fit_inc == nullptr) if (fit_inc == nullptr)
fit_inc = (decltype(fit_inc))(cxt->get_module_function("fit_inc")); fit_inc = (decltype(fit_inc))(cxt->get_module_function("fit_inc"));
auto server = static_cast<Server*>(cxt->alt_server); auto server = static_cast<DataSource*>(cxt->alt_server);
auto len = uint32_t(monetdbe_get_size(*((void**)server->server), "source")); auto len = uint32_t(monetdbe_get_size(*((void**)server->server), "source"));
auto x_1bN = ColRef<vector_type<double>>(len, monetdbe_get_col(*((void**)(server->server)), "source", 0)); auto x_1bN = ColRef<vector_type<double>>(len, monetdbe_get_col(*((void**)(server->server)), "source", 0));
auto y_6uX = ColRef<int64_t>(len, monetdbe_get_col(*((void**)(server->server)), "source", 1)); auto y_6uX = ColRef<int64_t>(len, monetdbe_get_col(*((void**)(server->server)), "source", 1));

@ -23,7 +23,7 @@ __AQEXPORT__(int) ld(Context* cxt) {
else else
++cnt; ++cnt;
char data_name[] = "data/electricity/electricity "; char data_name[] = "data/electricity/electricity ";
auto server = static_cast<Server*>(cxt->alt_server); auto server = static_cast<DataSource*>(cxt->alt_server);
const char* names_fZrv[] = {"x", "y"}; const char* names_fZrv[] = {"x", "y"};
auto tbl_6erF = new TableInfo<vector_type<double>,int64_t>("source", names_fZrv); auto tbl_6erF = new TableInfo<vector_type<double>,int64_t>("source", names_fZrv);
decltype(auto) c_31ju0e = tbl_6erF->get_col<0>(); decltype(auto) c_31ju0e = tbl_6erF->get_col<0>();

@ -16,7 +16,7 @@ __AQEXPORT__(void) __AQ_Init_GC__(Context* cxt) {
__AQEXPORT__(int) query(Context* cxt) { __AQEXPORT__(int) query(Context* cxt) {
using namespace std; using namespace std;
using namespace types; using namespace types;
auto server = static_cast<Server*>(cxt->alt_server); auto server = static_cast<DataSource*>(cxt->alt_server);
static uint32_t old_sz = 0; static uint32_t old_sz = 0;
constexpr static uint32_t min_delta = 200; constexpr static uint32_t min_delta = 200;
auto newsz = monetdbe_get_size(*(void**) server->server, "source"); auto newsz = monetdbe_get_size(*(void**) server->server, "source");

@ -1,4 +1,4 @@
from reconstruct.ast import Context, ast_node from engine.ast import Context, ast_node
saved_cxt = None saved_cxt = None

@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Set, Tuple, Union
from common.types import * from common.types import *
from common.utils import (base62alp, base62uuid, enlist, from common.utils import (base62alp, base62uuid, enlist,
get_innermost, get_legal_name) get_innermost, get_legal_name)
from reconstruct.storage import ColRef, Context, TableInfo from engine.storage import ColRef, Context, TableInfo
class ast_node: class ast_node:
header = [] header = []
@ -51,7 +51,7 @@ class ast_node:
self.emit(self.sql+';\n') self.emit(self.sql+';\n')
self.context.sql_end() self.context.sql_end()
from reconstruct.expr import expr, fastscan from engine.expr import expr, fastscan
class SubqType(Enum): class SubqType(Enum):
WITH = auto() WITH = auto()
FROM = auto() FROM = auto()
@ -328,7 +328,7 @@ class projection(ast_node):
for v, idx in self.var_table.items(): for v, idx in self.var_table.items():
vname = get_legal_name(v) + '_' + base62uuid(3) vname = get_legal_name(v) + '_' + base62uuid(3)
self.pyname2cname[v] = vname self.pyname2cname[v] = vname
self.context.emitc(f'auto {vname} = ColRef<{typenames[idx].cname}>({length_name}, server->getCol({idx}));') self.context.emitc(f'auto {vname} = ColRef<{typenames[idx].cname}>({length_name}, server->getCol({idx}, {typenames[idx].ctype_name}));')
vid2cname[idx] = vname vid2cname[idx] = vname
# Create table into context # Create table into context
out_typenames = [None] * len(proj_map) out_typenames = [None] * len(proj_map)
@ -463,7 +463,7 @@ class select_into(ast_node):
raise Exception('No out_table found.') raise Exception('No out_table found.')
else: else:
self.context.headers.add('"./server/table_ext_monetdb.hpp"') self.context.headers.add('"./server/table_ext_monetdb.hpp"')
self.ccode = f'{self.parent.out_table.contextname_cpp}->monetdb_append_table(cxt->alt_server, \"{node.lower()}\");' self.ccode = f'{self.parent.out_table.contextname_cpp}->monetdb_append_table(cxt->curr_server, \"{node.lower()}\");'
def produce_sql(self, node): def produce_sql(self, node):
self.context.sql = self.context.sql.replace( self.context.sql = self.context.sql.replace(
@ -1252,6 +1252,7 @@ class load(ast_node):
name="load" name="load"
first_order = name first_order = name
def init(self, node): def init(self, node):
from prompt import Backend_Type
self.module = False self.module = False
if node['load']['file_type'] == 'module': if node['load']['file_type'] == 'module':
self.produce = self.produce_module self.produce = self.produce_module
@ -1259,8 +1260,10 @@ class load(ast_node):
elif 'complex' in node['load']: elif 'complex' in node['load']:
self.produce = self.produce_cpp self.produce = self.produce_cpp
self.consume = lambda *_: None self.consume = lambda *_: None
elif self.context.dialect == 'MonetDB': elif self.context.system_state.cfg.backend_type == Backend_Type.BACKEND_MonetDB.value:
self.produce = self.produce_monetdb self.produce = self.produce_monetdb
elif self.context.system_state.cfg.backend_type == Backend_Type.BACKEND_DuckDB.value:
self.produce = self.produce_duckdb
else: else:
self.produce = self.produce_aq self.produce = self.produce_aq
if self.parent is None: if self.parent is None:
@ -1327,7 +1330,16 @@ class load(ast_node):
self.sql = f'{s1} \'{p}\' {s2} ' self.sql = f'{s1} \'{p}\' {s2} '
if 'term' in node: if 'term' in node:
self.sql += f' {s3} \'{node["term"]["literal"]}\'' self.sql += f' {s3} \'{node["term"]["literal"]}\''
def produce_duckdb(self, node):
node = node['load']
s1 = f'COPY {node["table"]} FROM '
import os
p = os.path.abspath(node['file']['literal']).replace('\\', '/')
s2 = f" DELIMITER '{node['term']['literal']}', " if 'term' in node else ''
self.sql = f'{s1} \'{p}\' ( {s2}HEADER )'
def produce_cpp(self, node): def produce_cpp(self, node):
self.context.has_dll = True self.context.has_dll = True
self.context.headers.add('"csv.h"') self.context.headers.add('"csv.h"')
@ -1374,7 +1386,7 @@ class load(ast_node):
self.context.emitc('}') self.context.emitc('}')
# self.context.emitc(f'print(*{self.out_table});') # self.context.emitc(f'print(*{self.out_table});')
self.context.emitc(f'{self.out_table}->monetdb_append_table(cxt->alt_server, "{table.table_name}");') self.context.emitc(f'{self.out_table}->monetdb_append_table(cxt->curr_server, "{table.table_name}");')
self.context.postproc_end(self.postproc_fname) self.context.postproc_end(self.postproc_fname)
@ -1424,7 +1436,11 @@ class outfile(ast_node):
file_pointer = 'fp_' + base62uuid(6) file_pointer = 'fp_' + base62uuid(6)
self.addc(f'FILE* {file_pointer} = fopen("{filename}", "wb");') self.addc(f'FILE* {file_pointer} = fopen("{filename}", "wb");')
self.addc(f'{self.parent.out_table.contextname_cpp}->printall("{sep}", "\\n", nullptr, {file_pointer});') self.addc(f'{self.parent.out_table.contextname_cpp}->printall("{sep}", "\\n", nullptr, {file_pointer});')
self.addc(f'fclose({file_pointer});') if self.context.use_gc:
self.addc(f'GC::gc_handle->reg({file_pointer}, 65536, [](void* fp){{fclose((FILE*)fp);}});')
else:
self.addc(f'fclose({file_pointer});')
self.context.ccode += self.ccode self.context.ccode += self.ccode
class udf(ast_node): class udf(ast_node):

@ -1,8 +1,8 @@
from typing import Optional, Set from typing import Optional, Set
from common.types import * from common.types import *
from reconstruct.ast import ast_node from engine.ast import ast_node
from reconstruct.storage import ColRef, Context from engine.storage import ColRef, Context
# TODO: Decouple expr and upgrade architecture # TODO: Decouple expr and upgrade architecture
# C_CODE : get ccode/sql code? # C_CODE : get ccode/sql code?
@ -31,7 +31,7 @@ class expr(ast_node):
return self._udf_decltypecall is not None return self._udf_decltypecall is not None
def __init__(self, parent, node, *, c_code = None, supress_undefined = False): def __init__(self, parent, node, *, c_code = None, supress_undefined = False):
from reconstruct.ast import projection, udf from engine.ast import projection, udf
# gen2 expr have multi-passes # gen2 expr have multi-passes
# first pass parse json into expr tree # first pass parse json into expr tree
@ -80,7 +80,7 @@ class expr(ast_node):
ast_node.__init__(self, parent, node, None) ast_node.__init__(self, parent, node, None)
def init(self, _): def init(self, _):
from reconstruct.ast import _tmp_join_union, projection from engine.ast import _tmp_join_union, projection
parent = self.parent parent = self.parent
self.is_compound = parent.is_compound if type(parent) is expr else False self.is_compound = parent.is_compound if type(parent) is expr else False
if type(parent) in [projection, expr, _tmp_join_union]: if type(parent) in [projection, expr, _tmp_join_union]:
@ -96,7 +96,7 @@ class expr(ast_node):
def produce(self, node): def produce(self, node):
from common.utils import enlist from common.utils import enlist
from reconstruct.ast import udf, projection from engine.ast import udf, projection
if type(node) is dict: if type(node) is dict:
if 'literal' in node: if 'literal' in node:
@ -349,7 +349,7 @@ class expr(ast_node):
self.sql = f'{{"CAST({node} AS DOUBLE)" if not c_code else "{node}f"}}' self.sql = f'{{"CAST({node} AS DOUBLE)" if not c_code else "{node}f"}}'
def finalize(self, override = False): def finalize(self, override = False):
from reconstruct.ast import udf from engine.ast import udf
if self.codebuf is None or override: if self.codebuf is None or override:
self.codebuf = '' self.codebuf = ''
for c in self.codlets: for c in self.codlets:

@ -1,7 +1,7 @@
import abc import abc
from reconstruct.ast import ast_node from engine.ast import ast_node
from typing import Optional from typing import Optional
from reconstruct.storage import Context, ColRef from engine.storage import Context, ColRef
from common.utils import enlist from common.utils import enlist
from common.types import builtin_func, user_module_func, builtin_operators from common.types import builtin_func, user_module_func, builtin_operators
@ -47,7 +47,7 @@ class expr_base(ast_node, metaclass = abc.ABCMeta):
pass pass
def produce(self, node): def produce(self, node):
from reconstruct.ast import udf from engine.ast import udf
if node and type(node) is dict: if node and type(node) is dict:
if 'litral' in node: if 'litral' in node:
self.get_literal(node['literal']) self.get_literal(node['literal'])

@ -1,4 +1,4 @@
from typing import Dict, List, Set from typing import Dict, List, Optional, Set
from common.types import * from common.types import *
from common.utils import CaseInsensitiveDict, base62uuid, enlist from common.utils import CaseInsensitiveDict, base62uuid, enlist
@ -64,7 +64,7 @@ class ColRef:
class TableInfo: class TableInfo:
def __init__(self, table_name, cols, cxt:'Context'): def __init__(self, table_name, cols, cxt:'Context'):
from reconstruct.ast import create_trigger from engine.ast import create_trigger
# statics # statics
self.table_name : str = table_name self.table_name : str = table_name
self.contextname_cpp : str = '' self.contextname_cpp : str = ''
@ -161,8 +161,10 @@ class Context:
self.has_dll = False self.has_dll = False
self.triggers_active.clear() self.triggers_active.clear()
def __init__(self): def __init__(self, state = None):
from prompt import PromptState
from .ast import create_trigger from .ast import create_trigger
from aquery_config import compile_use_gc
self.tables_byname : Dict[str, TableInfo] = dict() self.tables_byname : Dict[str, TableInfo] = dict()
self.col_byname = dict() self.col_byname = dict()
self.tables : Set[TableInfo] = set() self.tables : Set[TableInfo] = set()
@ -181,6 +183,10 @@ class Context:
self.triggers : Dict[str, create_trigger] = dict() self.triggers : Dict[str, create_trigger] = dict()
self.triggers_active = set() self.triggers_active = set()
self.stored_proceudres = dict() self.stored_proceudres = dict()
self.force_compiled = False
self.use_gc = compile_use_gc
self.system_state: Optional[PromptState] = state
# self.new() called everytime new query batch is started # self.new() called everytime new query batch is started
def get_scan_var(self): def get_scan_var(self):
@ -206,7 +212,7 @@ class Context:
function_head = ('(Context* cxt) {\n' + function_head = ('(Context* cxt) {\n' +
'\tusing namespace std;\n' + '\tusing namespace std;\n' +
'\tusing namespace types;\n' + '\tusing namespace types;\n' +
'\tauto server = static_cast<Server*>(cxt->alt_server);\n') '\tauto server = static_cast<DataSource*>(cxt->curr_server);\n')
udf_head = ('#pragma once\n' udf_head = ('#pragma once\n'
'#include \"./server/libaquery.h\"\n' '#include \"./server/libaquery.h\"\n'
@ -265,7 +271,7 @@ class Context:
'O' + limit + sep + end) 'O' + limit + sep + end)
def remove_trigger(self, name : str): def remove_trigger(self, name : str):
from reconstruct.ast import create_trigger from engine.ast import create_trigger
val = self.triggers.pop(name, None) val = self.triggers.pop(name, None)
if val.type == create_trigger.Type.Callback: if val.type == create_trigger.Type.Callback:
val.table.triggers.remove(val) val.table.triggers.remove(val)

@ -17,7 +17,7 @@ __AQEXPORT__(void) __AQ_Init_GC__(Context* cxt) {
__AQEXPORT__(int) dll_2Cxoox(Context* cxt) { __AQEXPORT__(int) dll_2Cxoox(Context* cxt) {
using namespace std; using namespace std;
using namespace types; using namespace types;
auto server = static_cast<Server*>(cxt->alt_server); auto server = static_cast<DataSource*>(cxt->alt_server);
auto len_4ycjiV = server->cnt; auto len_4ycjiV = server->cnt;
auto mont_8AE = ColRef<const char*>(len_4ycjiV, server->getCol(0)); auto mont_8AE = ColRef<const char*>(len_4ycjiV, server->getCol(0));
auto sales_2RB = ColRef<int>(len_4ycjiV, server->getCol(1)); auto sales_2RB = ColRef<int>(len_4ycjiV, server->getCol(1));

@ -345,6 +345,7 @@
<ItemGroup> <ItemGroup>
<ClInclude Include="..\csv.h" /> <ClInclude Include="..\csv.h" />
<ClInclude Include="..\server\aggregations.h" /> <ClInclude Include="..\server\aggregations.h" />
<ClInclude Include="..\server\DataSource_conn.h" />
<ClInclude Include="..\server\duckdb_conn.h" /> <ClInclude Include="..\server\duckdb_conn.h" />
<ClInclude Include="..\server\gc.h" /> <ClInclude Include="..\server\gc.h" />
<ClInclude Include="..\server\hasher.h" /> <ClInclude Include="..\server\hasher.h" />

@ -80,7 +80,6 @@ if __name__ == '__main__':
if check_param(['-h', '--help'], True): if check_param(['-h', '--help'], True):
print(help_message) print(help_message)
exit() exit()
import atexit import atexit
@ -95,7 +94,7 @@ import sys
import threading import threading
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, List, Optional from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
from mo_parsing import ParseException from mo_parsing import ParseException
@ -104,10 +103,10 @@ import aquery_parser as parser
import common import common
import common.ddl import common.ddl
import common.projection import common.projection
import reconstruct as xengine import engine as xengine
from build import build_manager from build import build_manager
from common.utils import add_dll_dir, base62uuid, nullstream, ws from common.utils import add_dll_dir, base62uuid, nullstream, ws
from enum import auto
## CLASSES BEGIN ## CLASSES BEGIN
class RunType(enum.Enum): class RunType(enum.Enum):
@ -115,9 +114,19 @@ class RunType(enum.Enum):
IPC = 1 IPC = 1
class Backend_Type(enum.Enum): class Backend_Type(enum.Enum):
BACKEND_AQuery = 0 BACKEND_AQuery = 0
BACKEND_MonetDB = 1 BACKEND_MonetDB = 1
BACKEND_MariaDB = 2 BACKEND_MariaDB = 2
BACKEND_DuckDB = 3
BACKEND_SQLite = 4
BACKEND_TOTAL = 5
backend_strings = {
'aquery': Backend_Type.BACKEND_AQuery,
'monetdb': Backend_Type.BACKEND_MonetDB,
'mariadb': Backend_Type.BACKEND_MariaDB,
'duckdb': Backend_Type.BACKEND_DuckDB,
'sqlite': Backend_Type.BACKEND_SQLite,
}
class StoredProcedure(ctypes.Structure): class StoredProcedure(ctypes.Structure):
_fields_ = [ _fields_ = [
@ -242,7 +251,7 @@ class PromptState():
set_ready = lambda: None set_ready = lambda: None
get_ready = lambda: None get_ready = lambda: None
server_status = lambda: False server_status = lambda: False
cfg : Config = None cfg : Optional[Config] = None
shm : str = '' shm : str = ''
server : subprocess.Popen = None server : subprocess.Popen = None
basecmd : List[str] = None basecmd : List[str] = None
@ -257,6 +266,26 @@ class PromptState():
currstats : Optional[QueryStats] = None currstats : Optional[QueryStats] = None
buildmgr : Optional[build_manager]= None buildmgr : Optional[build_manager]= None
current_procedure : Optional[str] = None current_procedure : Optional[str] = None
_force_compiled : bool = False
_cxt : Optional[Union[xengine.Context, common.Context]] = None
@property
def force_compiled(self):
return self._force_compiled
@force_compiled.setter
def force_compiled(self, new_val):
self.cxt.force_compiled = new_val
self._force_compiled = new_val
@property
def cxt(self):
return self._cxt
@cxt.setter
def cxt(self, cxt):
cxt.force_compiled = self.force_compiled
self._cxt = cxt
self._cxt.system_state = self
## CLASSES END ## CLASSES END
## FUNCTIONS BEGIN ## FUNCTIONS BEGIN
@ -412,7 +441,8 @@ def prompt(running = lambda:True, next = lambda:input('> '), state : Optional[Pr
q = '' q = ''
payload = None payload = None
keep = True keep = True
cxt = common.initialize()
state.cxt = cxt = xengine.initialize()
parser.parse('SELECT "**** WELCOME TO AQUERY++! ****";') parser.parse('SELECT "**** WELCOME TO AQUERY++! ****";')
# state.currstats = QueryStats() # state.currstats = QueryStats()
@ -442,7 +472,7 @@ def prompt(running = lambda:True, next = lambda:input('> '), state : Optional[Pr
continue continue
if False and q == 'exec': # generate build and run (AQuery Engine) if False and q == 'exec': # generate build and run (AQuery Engine)
state.cfg.backend_type = Backend_Type.BACKEND_AQuery.value state.cfg.backend_type = Backend_Type.BACKEND_AQuery.value
cxt = common.exec(state.stmts, cxt, keep) state.cxt = cxt = common.exec(state.stmts, cxt, keep)
if state.buildmgr.build_dll() == 0: if state.buildmgr.build_dll() == 0:
state.set_ready() state.set_ready()
continue continue
@ -466,8 +496,8 @@ def prompt(running = lambda:True, next = lambda:input('> '), state : Optional[Pr
print(prompt_help) print(prompt_help)
continue continue
elif q.startswith('xexec') or q.startswith('exec'): # generate build and run (MonetDB Engine) elif q.startswith('xexec') or q.startswith('exec'): # generate build and run (MonetDB Engine)
state.cfg.backend_type = Backend_Type.BACKEND_MonetDB.value #state.cfg.backend_type = Backend_Type.BACKEND_MonetDB.value
cxt = xengine.exec(state.stmts, cxt, keep, parser.parse) state.cxt = cxt = xengine.exec(state.stmts, cxt, keep, parser.parse)
this_udf = cxt.finalize_udf() this_udf = cxt.finalize_udf()
if this_udf: if this_udf:
@ -659,6 +689,20 @@ def prompt(running = lambda:True, next = lambda:input('> '), state : Optional[Pr
else: else:
print(procedure_help) print(procedure_help)
continue continue
elif q.startswith('force'):
splits = q.split()
if len(splits > 1) and splits[1] == 'compiled':
state.force_compiled = True
cxt.force_compiled = True
continue
elif q.startswith('backend'):
splits = q.split()
if len(splits) > 1 and splits[1] in backend_strings:
state.cfg.backend_type = backend_strings[splits[1]].value
else:
cxt.Error('Not a valid backend type.')
print('External Engine is set to', Backend_Type(state.cfg.backend_type).name)
continue
trimed = ws.sub(' ', og_q).split(' ') trimed = ws.sub(' ', og_q).split(' ')
if len(trimed) > 1 and trimed[0].lower().startswith('fi') or trimed[0].lower() == 'f': if len(trimed) > 1 and trimed[0].lower().startswith('fi') or trimed[0].lower() == 'f':
fn = 'stock.a' if len(trimed) <= 1 or len(trimed[1]) == 0 \ fn = 'stock.a' if len(trimed) <= 1 or len(trimed[1]) == 0 \

@ -2,24 +2,32 @@
#define __DATASOURCE_CONN_H__ #define __DATASOURCE_CONN_H__
struct Context; struct Context;
#ifndef __AQQueryResult__
#define __AQQueryResult__ 1
struct AQQueryResult { struct AQQueryResult {
void* res; void* res;
unsigned ref; unsigned ref;
}; };
enum DataSourceType { #endif
Invalid,
MonetDB, #ifndef __AQBACKEND_TYPE__
MariaDB, #define __AQBACKEND_TYPE__ 1
DuckDB, enum Backend_Type {
SQLite BACKEND_AQuery,
BACKEND_MonetDB,
BACKEND_MariaDB,
BACKEND_DuckDB,
BACKEND_SQLite,
BACKEND_TOTAL
}; };
#endif
struct DataSource { struct DataSource {
void* server = nullptr; void* server = nullptr;
Context* cxt = nullptr; Context* cxt = nullptr;
bool status = false; bool status = false;
char* query = nullptr; char* query = nullptr;
DataSourceType type = Invalid; Backend_Type DataSourceType = BACKEND_AQuery;
void* res = nullptr; void* res = nullptr;
void* ret_col = nullptr; void* ret_col = nullptr;
@ -29,7 +37,7 @@ struct DataSource {
void* handle; void* handle;
DataSource() = default; DataSource() = default;
explicit DataSource(Context* cxt = nullptr) = delete; explicit DataSource(Context* cxt) = delete;
virtual void connect(Context* cxt) = 0; virtual void connect(Context* cxt) = 0;
virtual void exec(const char* q) = 0; virtual void exec(const char* q) = 0;
@ -38,6 +46,10 @@ struct DataSource {
virtual void close() = 0; virtual void close() = 0;
virtual bool haserror() = 0; virtual bool haserror() = 0;
// virtual void print_results(const char* sep = " ", const char* end = "\n"); // virtual void print_results(const char* sep = " ", const char* end = "\n");
virtual ~DataSource() = 0; virtual ~DataSource() {};
}; };
// TODO: replace with super class
//typedef DataSource* (*create_server_t)(Context* cxt);
typedef void* (*create_server_t)(Context* cxt);
void* CreateNULLServer(Context*);
#endif //__DATASOURCE_CONN_H__ #endif //__DATASOURCE_CONN_H__

@ -11,7 +11,8 @@ void DuckdbServer::connect(Context* cxt) {
static_cast<duckdb_database*>(malloc(sizeof(duckdb_database))); static_cast<duckdb_database*>(malloc(sizeof(duckdb_database)));
this->handle = db_handle; this->handle = db_handle;
bool status = duckdb_open(nullptr, db_handle); bool status = duckdb_open(nullptr, db_handle);
duckdb_connection* conn_handle; duckdb_connection* conn_handle =
static_cast<duckdb_connection*>(malloc(sizeof(duckdb_connection)));;
status = status || duckdb_connect(*db_handle, conn_handle); status = status || duckdb_connect(*db_handle, conn_handle);
this->server = conn_handle; this->server = conn_handle;
if (status != 0) { if (status != 0) {
@ -20,6 +21,7 @@ void DuckdbServer::connect(Context* cxt) {
} }
DuckdbServer::DuckdbServer(Context* cxt) { DuckdbServer::DuckdbServer(Context* cxt) {
this->DataSourceType = BACKEND_DuckDB;
this->cxt = cxt; this->cxt = cxt;
connect(cxt); connect(cxt);
} }

@ -3,7 +3,7 @@
#include "DataSource_conn.h" #include "DataSource_conn.h"
struct DuckdbServer : DataSource { struct DuckdbServer : DataSource {
explicit DuckdbServer(Context* cxt = nullptr); explicit DuckdbServer(Context* cxt);
void connect(Context* cxt); void connect(Context* cxt);
void exec(const char* q); void exec(const char* q);
void* getCol(int col_idx, int type); void* getCol(int col_idx, int type);

@ -633,3 +633,5 @@ get_procedure(Context* cxt, const char* name) {
}; };
return res->second; return res->second;
} }
void* CreateNULLServer(Context*) { return nullptr; }

@ -63,11 +63,17 @@ enum Log_level {
LOG_SILENT LOG_SILENT
}; };
#ifndef __AQBACKEND_TYPE__
#define __AQBACKEND_TYPE__ 1
enum Backend_Type { enum Backend_Type {
BACKEND_AQuery, BACKEND_AQuery,
BACKEND_MonetDB, BACKEND_MonetDB,
BACKEND_MariaDB BACKEND_MariaDB,
BACKEND_DuckDB,
BACKEND_SQLite,
BACKEND_TOTAL
}; };
#endif
struct QueryStats{ struct QueryStats{
long long monet_time; long long monet_time;
@ -81,10 +87,14 @@ struct Config{
int buffer_sizes[]; int buffer_sizes[];
}; };
#ifndef __AQQueryResult__
#define __AQQueryResult__ 1
struct AQQueryResult { struct AQQueryResult {
void* res; void* res;
uint32_t ref; unsigned ref;
}; };
#endif
struct Session{ struct Session{
struct Statistic{ struct Statistic{
@ -114,7 +124,8 @@ struct Context {
int n_buffers, *sz_bufs; int n_buffers, *sz_bufs;
void **buffers; void **buffers;
void* alt_server = nullptr; void* curr_server;
void* alt_server[BACKEND_TOTAL] = {nullptr};
Log_level log_level = LOG_INFO; Log_level log_level = LOG_INFO;
Session current; Session current;

@ -9,7 +9,7 @@ inline size_t my_strlen(const char* str){
return ret; return ret;
} }
void Server::connect( void MariadbServer::connect(
Context* cxt, const char* host, const char* user, const char* passwd, Context* cxt, const char* host, const char* user, const char* passwd,
const char* db_name, const unsigned int port, const char* db_name, const unsigned int port,
const char* unix_socket, const unsigned long client_flag const char* unix_socket, const unsigned long client_flag
@ -35,12 +35,12 @@ void Server::connect(
this->status = true; this->status = true;
} }
void Server::exec(const char*q){ void MariadbServer::exec(const char*q){
auto res = mysql_real_query(server, q, my_strlen(q)); auto res = mysql_real_query(server, q, my_strlen(q));
if(res) printf("Execution Error: %d, %s\n", res, mysql_error(server)); if(res) printf("Execution Error: %d, %s\n", res, mysql_error(server));
} }
void Server::close(){ void MariadbServer::close(){
if(this->status && this->server){ if(this->status && this->server){
mysql_close(server); mysql_close(server);
server = 0; server = 0;

@ -5,7 +5,7 @@
#endif #endif
struct Context; struct Context;
struct Server{ struct MariadbServer{
MYSQL *server = nullptr; MYSQL *server = nullptr;
Context *cxt = nullptr; Context *cxt = nullptr;
bool status = false; bool status = false;
@ -20,5 +20,5 @@ struct Server{
); );
void exec(const char* q); void exec(const char* q);
void close(); void close();
~Server(); ~MariadbServer();
}; };

@ -71,16 +71,17 @@ namespace types{
}; };
} }
Server::Server(Context* cxt){ MonetdbServer::MonetdbServer(Context* cxt) {
this->DataSourceType = BACKEND_MonetDB;
if (cxt){ if (cxt){
connect(cxt); connect(cxt);
} }
} }
void Server::connect(Context *cxt){ void MonetdbServer::connect(Context *cxt){
auto server = static_cast<monetdbe_database*>(this->server); auto server = static_cast<monetdbe_database*>(this->server);
if (cxt){ if (cxt){
cxt->alt_server = this; cxt->alt_server[DataSourceType] = this;
this->cxt = cxt; this->cxt = cxt;
} }
else{ else{
@ -89,7 +90,7 @@ void Server::connect(Context *cxt){
} }
if (server){ if (server){
printf("Error: Server %p already connected. Restart? (Y/n). \n", server); printf("Error: MonetdbServer %p already connected. Restart? (Y/n). \n", server);
char c[50]; char c[50];
std::cin.getline(c, 49); std::cin.getline(c, 49);
for(int i = 0; i < 50; ++i) { for(int i = 0; i < 50; ++i) {
@ -122,7 +123,7 @@ void Server::connect(Context *cxt){
} }
} }
void Server::exec(const char* q){ void MonetdbServer::exec(const char* q){
auto server = static_cast<monetdbe_database*>(this->server); auto server = static_cast<monetdbe_database*>(this->server);
auto _res = static_cast<monetdbe_result*>(this->res); auto _res = static_cast<monetdbe_result*>(this->res);
monetdbe_cnt _cnt = 0; monetdbe_cnt _cnt = 0;
@ -137,7 +138,7 @@ void Server::exec(const char* q){
} }
} }
bool Server::haserror(){ bool MonetdbServer::haserror(){
if (last_error){ if (last_error){
puts(last_error); puts(last_error);
last_error = nullptr; last_error = nullptr;
@ -149,7 +150,7 @@ bool Server::haserror(){
} }
void Server::print_results(const char* sep, const char* end){ void MonetdbServer::print_results(const char* sep, const char* end){
if (!haserror()){ if (!haserror()){
auto _res = static_cast<monetdbe_result*> (res); auto _res = static_cast<monetdbe_result*> (res);
@ -190,7 +191,7 @@ void Server::print_results(const char* sep, const char* end){
} }
} }
void Server::close(){ void MonetdbServer::close(){
if(this->server){ if(this->server){
auto server = static_cast<monetdbe_database*>(this->server); auto server = static_cast<monetdbe_database*>(this->server);
monetdbe_close(*server); monetdbe_close(*server);
@ -199,7 +200,7 @@ void Server::close(){
} }
} }
void* Server::getCol(int col_idx){ void* MonetdbServer::getCol(int col_idx, int){
if(res){ if(res){
auto _res = static_cast<monetdbe_result*>(this->res); auto _res = static_cast<monetdbe_result*>(this->res);
auto err_msg = monetdbe_result_fetch(_res, auto err_msg = monetdbe_result_fetch(_res,
@ -224,7 +225,7 @@ void* Server::getCol(int col_idx){
#define AQ_MONETDB_FETCH(X) case monetdbe_##X: \ #define AQ_MONETDB_FETCH(X) case monetdbe_##X: \
return (long long)((X *)(_ret_col->data))[0]; return (long long)((X *)(_ret_col->data))[0];
long long Server::getFirstElement() { long long MonetdbServer::getFirstElement() {
if(!this->haserror() && res) { if(!this->haserror() && res) {
auto _res = static_cast<monetdbe_result*>(this->res); auto _res = static_cast<monetdbe_result*>(this->res);
auto err_msg = monetdbe_result_fetch(_res, auto err_msg = monetdbe_result_fetch(_res,
@ -266,11 +267,11 @@ long long Server::getFirstElement() {
return 0; return 0;
} }
Server::~Server(){ MonetdbServer::~MonetdbServer(){
close(); close();
} }
bool Server::havehge() { bool MonetdbServer::havehge() {
#if defined(_MONETDBE_LIB_) and defined(HAVE_HGE) #if defined(_MONETDBE_LIB_) and defined(HAVE_HGE)
// puts("true"); // puts("true");
return HAVE_HGE; return HAVE_HGE;
@ -299,7 +300,7 @@ constexpr prt_fn_t monetdbe_prtfns[] = {
constexpr uint32_t output_buffer_size = 65536; constexpr uint32_t output_buffer_size = 65536;
void print_monetdb_results(void* _srv, const char* sep = " ", const char* end = "\n", void print_monetdb_results(void* _srv, const char* sep = " ", const char* end = "\n",
uint32_t limit = std::numeric_limits<uint32_t>::max()) { uint32_t limit = std::numeric_limits<uint32_t>::max()) {
auto srv = static_cast<Server *>(_srv); auto srv = static_cast<MonetdbServer *>(_srv);
if (!srv->haserror() && srv->cnt && limit) { if (!srv->haserror() && srv->cnt && limit) {
char buffer[output_buffer_size]; char buffer[output_buffer_size];
auto _res = static_cast<monetdbe_result*> (srv->res); auto _res = static_cast<monetdbe_result*> (srv->res);
@ -360,7 +361,7 @@ cleanup:
int ExecuteStoredProcedureEx(const StoredProcedure *p, Context* cxt){ int ExecuteStoredProcedureEx(const StoredProcedure *p, Context* cxt){
auto server = static_cast<Server*>(cxt->alt_server); auto server = static_cast<MonetdbServer*>(cxt->alt_server[BACKEND_MonetDB]);
int ret = 0; int ret = 0;
bool return_from_procedure = false; bool return_from_procedure = false;
void* handle = nullptr; void* handle = nullptr;

@ -1,31 +1,19 @@
#ifndef __MONETDB_CONN_H__ #ifndef __MONETDB_CONN_H__
#define __MONETDB_CONN_H__ #define __MONETDB_CONN_H__
#include "DataSource_conn.h"
struct Context; struct MonetdbServer : DataSource {
explicit MonetdbServer(Context* cxt);
struct Server{ void connect(Context* cxt) override;
void *server = nullptr; void exec(const char* q) override;
Context *cxt = nullptr; void *getCol(int col_idx, int) override;
bool status = false;
char* query = nullptr;
int type = 1;
void* res = nullptr;
void* ret_col = nullptr;
long long cnt = 0;
char* last_error = nullptr;
explicit Server(Context* cxt = nullptr);
void connect(Context* cxt);
void exec(const char* q);
void *getCol(int col_idx);
long long getFirstElement(); long long getFirstElement();
void close(); void close() override;
bool haserror(); bool haserror() override;
static bool havehge(); static bool havehge();
void print_results(const char* sep = " ", const char* end = "\n"); void print_results(const char* sep = " ", const char* end = "\n");
friend void print_monetdb_results(void* _srv, const char* sep, const char* end, int limit); friend void print_monetdb_results(void* _srv, const char* sep, const char* end, int limit);
~Server(); ~MonetdbServer() override;
}; };
struct monetdbe_table_data{ struct monetdbe_table_data{

@ -7,6 +7,16 @@
#include "libaquery.h" #include "libaquery.h"
#include "monetdb_conn.h" #include "monetdb_conn.h"
#include "duckdb_conn.h"
constexpr create_server_t get_server[] = {
CreateNULLServer,
[](Context* cxt) -> void*{ return new MonetdbServer(cxt); },
CreateNULLServer,
[](Context* cxt) -> void*{ return new DuckdbServer(cxt); },
CreateNULLServer,
};
#pragma region misc #pragma region misc
#ifdef THREADING #ifdef THREADING
#include "threading.h" #include "threading.h"
@ -89,7 +99,7 @@ extern "C" int __DLLEXPORT__ binary_info() {
__AQEXPORT__(bool) __AQEXPORT__(bool)
have_hge() { have_hge() {
#if defined(__MONETDB_CONN_H__) #if defined(__MONETDB_CONN_H__)
return Server::havehge(); return MonetdbServer::havehge();
#else #else
return false; return false;
#endif #endif
@ -205,13 +215,20 @@ int dll_main(int argc, char** argv, Context* cxt){
cxt->cfg = cfg; cxt->cfg = cfg;
cxt->n_buffers = cfg->n_buffers; cxt->n_buffers = cfg->n_buffers;
cxt->sz_bufs = buf_szs; cxt->sz_bufs = buf_szs;
if (cfg->backend_type == BACKEND_MonetDB && cxt->alt_server == nullptr)
{
auto alt_server = new Server(cxt); const auto& update_backend = [&cxt, &cfg](){
alt_server->exec("SELECT '**** WELCOME TO AQUERY++! ****';"); auto& curr_server = cxt->alt_server[cfg->backend_type];
puts(*(const char**)(alt_server->getCol(0))); if (curr_server == nullptr) {
cxt->alt_server = alt_server; curr_server = get_server[cfg->backend_type](cxt);
} cxt->alt_server[cfg->backend_type] = curr_server;
static_cast<DataSource*>(curr_server)->exec("SELECT '**** WELCOME TO AQUERY++! ****';");
puts(*(const char**)(static_cast<DataSource*>(curr_server)->getCol(0, types::Types<const char*>::getType())));
}
cxt->curr_server = curr_server;
};
update_backend();
while(cfg->running){ while(cfg->running){
ENGINE_ACQUIRE(); ENGINE_ACQUIRE();
if (cfg->new_query) { if (cfg->new_query) {
@ -221,10 +238,11 @@ start:
void *handle = nullptr; void *handle = nullptr;
void *user_module_handle = nullptr; void *user_module_handle = nullptr;
if (cfg->backend_type == BACKEND_MonetDB){ if (cfg->backend_type == BACKEND_MonetDB||
if (cxt->alt_server == nullptr) cfg->backend_type == BACKEND_DuckDB
cxt->alt_server = new Server(cxt); ) {
Server* server = reinterpret_cast<Server*>(cxt->alt_server); update_backend();
auto server = reinterpret_cast<DataSource*>(cxt->curr_server);
if(n_recv > 0){ if(n_recv > 0){
if (cfg->backend_type == BACKEND_AQuery || cfg->has_dll) { if (cfg->backend_type == BACKEND_AQuery || cfg->has_dll) {
const char* proc_name = "./dll.so"; const char* proc_name = "./dll.so";

@ -67,7 +67,7 @@ void TableInfo<Ts ...>::monetdb_append_table(void* srv, const char* alt_name) {
auto last_comma = create_table_str.find_last_of(','); auto last_comma = create_table_str.find_last_of(',');
if (last_comma != static_cast<decltype(last_comma)>(-1)) { if (last_comma != static_cast<decltype(last_comma)>(-1)) {
create_table_str[last_comma] = ')'; create_table_str[last_comma] = ')';
Server* server = (Server*)srv; MonetdbServer* server = (MonetdbServer*)srv;
// puts("create table..."); // puts("create table...");
// puts(create_table_str.c_str()); // puts(create_table_str.c_str());
server->exec(create_table_str.c_str()); server->exec(create_table_str.c_str());

@ -169,7 +169,7 @@ public:
return distinct_copy(); return distinct_copy();
} }
// TODO: think of situations where this is a temp!! (copy on write!!!) // TODO: think of situations where this is a temp!! (copy on write!!!)
template <bool _grow = true> template <bool _grow = true, bool _resize = false>
inline void grow(uint32_t sz = 0) { inline void grow(uint32_t sz = 0) {
if constexpr (_grow) if constexpr (_grow)
sz = this->size; sz = this->size;
@ -192,6 +192,8 @@ public:
n_container = (_Ty*)malloc(new_capacity * sizeof(_Ty)); n_container = (_Ty*)malloc(new_capacity * sizeof(_Ty));
memcpy(n_container, container, sizeof(_Ty) * size); memcpy(n_container, container, sizeof(_Ty) * size);
} }
if constexpr(_resize)
size = sz;
memset(n_container + size, 0, sizeof(_Ty) * (new_capacity - size)); memset(n_container + size, 0, sizeof(_Ty) * (new_capacity - size));
// if (capacity) // if (capacity)
// free(container); // free(container);
@ -200,8 +202,7 @@ public:
} }
} }
inline void resize(const uint32_t sz){ inline void resize(const uint32_t sz){
size = sz; grow<false, true>(sz);
grow<false>(sz);
} }
inline void reserve(const uint32_t sz){ inline void reserve(const uint32_t sz){
grow<false>(sz); grow<false>(sz);

Loading…
Cancel
Save