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 2 years ago committed by root
parent d5382c36e9
commit 259d9ef566

@ -5,14 +5,13 @@
struct minEval{ struct minEval{
double value; double value;
double values; int* values;
double eval; double eval;
long left; // how many on its left long left; // how many on its left
double* record; double* record;
long max; long max;
long** count; long** count;
long* sorted; // sorted d
}; };
minEval giniSparse(double** data, long* result, long* d, long size, long col, long classes, long* totalT){ minEval giniSparse(double** data, long* result, long* d, long size, long col, long classes, long* totalT){

@ -4,9 +4,6 @@
#include "../server/table.h" #include "../server/table.h"
DecisionTree* dt = nullptr; DecisionTree* dt = nullptr;
long pt = 0;
double** data = nullptr;
long* result = nullptr;
__AQEXPORT__(bool) newtree(int height, long f, ColRef<int> sparse, double forget, long maxf, long noclasses, Evaluation e, long r, long rb){ __AQEXPORT__(bool) newtree(int height, long f, ColRef<int> sparse, double forget, long maxf, long noclasses, Evaluation e, long r, long rb){
if(sparse.size!=f)return 0; if(sparse.size!=f)return 0;
@ -19,36 +16,29 @@ __AQEXPORT__(bool) newtree(int height, long f, ColRef<int> sparse, double forget
return 1; return 1;
} }
__AQEXPORT__(bool) additem(ColRef<double>X, long y, long size){ __AQEXPORT__(bool) fit(ColRef<ColRef<double>> X, ColRef<int> y){
long j = 0; if(X.size != y.size)return 0;
if(size>0){ double** data = (double**)malloc(X.size*sizeof(double*));
free(data); long* result = (long*)malloc(y.size*sizeof(long));
free(result); for(long i=0; i<X.size; i++){
pt = 0; data[i] = X.container[i].container;
data=(double**)malloc(size*sizeof(double*)); result[i] = y.container[i];
result=(long*)malloc(size*sizeof(long));
} }
data[pt] = (double*)malloc(X.size*sizeof(double)); dt->fit(data, result, X.size);
for(j=0; j<X.size; j++){
data[pt][j]=X.container[j];
}
result[pt] = y;
pt ++;
return 1;
}
__AQEXPORT__(bool) fit(){
if(pt<=0)return 0;
dt->fit(data, result, pt);
return 1; return 1;
} }
__AQEXPORT__(ColRef_storage) predict(){ __AQEXPORT__(ColRef_storage) predict(ColRef<ColRef<double>> X){
int* result = (int*)malloc(pt*sizeof(int)); double** data = (double**)malloc(X.size*sizeof(double*));
for(long i=0; i<pt; i++){ int* result = (int*)malloc(X.size*sizeof(int));
for(long i=0; i<X.size; i++){
data[i] = X.container[i].container;
}
for(long i=0; i<X.size; i++){
result[i]=dt->Test(data[i], dt->DTree); result[i]=dt->Test(data[i], dt->DTree);
} }
return ColRef_storage(new ColRef_storage(result, pt, 0, "prediction", 0), 1, 0, "prediction", 0); return ColRef_storage(new ColRef_storage(result, X.size, 0, "prediction", 0), 1, 0, "prediction", 0);
} }

@ -1,21 +1,21 @@
LOAD MODULE FROM "./libirf.so" LOAD MODULE FROM "./libirf.so"
FUNCTIONS ( FUNCTIONS (
newtree(height:int, f:int64, sparse:vecint, forget:double, maxf:int64, noclasses:int64, e:int, r:int64, rb:int64) -> bool, newtree(height:int, f:int64, sparse:vecint, forget:double, maxf:int64, noclasses:int64, e:int, r:int64, rb:int64) -> bool,
additem(X:vecdouble, y:int64, size:int64) -> bool, fit(X:vecvecdouble, y:vecint) -> bool,
fit() -> bool, predict(X:vecvecdouble) -> vecint
predict() -> vecint
); );
create table tb(x int);
create table tb2(x double, y double, z double); create table source(x1 double, x2 double, x3 double, x4 double, x5 int);
insert into tb values (0); load data infile "data/benchmark" into table source fields terminated by ",";
insert into tb values (0);
insert into tb values (0); create table sparse(x int);
select newtree(5, 3, tb.x, 0, 3, 2, 0, 100, 1) from tb; insert into sparse values (1);
insert into tb2 values (1, 0, 1); insert into sparse values (1);
insert into tb2 values (0, 1, 1); insert into sparse values (1);
insert into tb2 values (1, 1, 1); insert into sparse values (1);
select additem(tb2.x, 1, 3) from tb2;
select additem(tb2.y, 0, -1) from tb2; select newtree(6, 4, sparse.x, 0, 4, 2, 0, 400, 2147483647) from sparse;
select additem(tb2.z, 1, -1) from tb2;
select fit(); select fit(pack(x1, x2, x3, x4), x5) from source;
select predict();
select predict(pack(x1, x2, x3, x4)) from source;

Loading…
Cancel
Save