You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
AQuery/sdk/irf.cpp

94 lines
2.5 KiB

2 years ago
#include "DecisionTree.h"
2 years ago
#include "RF.h"
2 years ago
// __AQ_NO_SESSION__
#include "../server/table.h"
2 years ago
#include "aquery.h"
2 years ago
2 years ago
#include "./server/gc.h"
__AQEXPORT__(void) __AQ_Init_GC__(Context* cxt) {
GC::gc_handle = static_cast<GC*>(cxt->gc);
GC::scratch_space = nullptr;
}
DecisionTree *dt = nullptr;
2 years ago
RandomForest *rf = nullptr;
2 years ago
__AQEXPORT__(bool)
2 years ago
newtree(int height, long f, ColRef<int> X, double forget, long maxf, long noclasses, Evaluation e, long r, long rb)
{
2 years ago
if (X.size != f)
return false;
2 years ago
int *X_cpy = (int *)malloc(f * sizeof(int));
memcpy(X_cpy, X.container, f);
2 years ago
if (maxf < 0)
maxf = f;
2 years ago
dt = new DecisionTree(f, X_cpy, forget, maxf, noclasses, e);
2 years ago
rf = new RandomForest(height, f, X_cpy, forget, noclasses, e);
return true;
2 years ago
}
// size_t pt = 0;
// __AQEXPORT__(bool) fit(ColRef<ColRef<double>> X, ColRef<int> y){
// if(X.size != y.size)return 0;
// double** data = (double**)malloc(X.size*sizeof(double*));
// long* result = (long*)malloc(y.size*sizeof(long));
// for(long i=0; i<X.size; i++){
// data[i] = X.container[i].container;
// result[i] = y.container[i];
// }
// data[pt] = (double*)malloc(X.size*sizeof(double));
// for(uint32_t j=0; j<X.size; j++){
// data[pt][j]=X.container[j];
// }
// result[pt] = y;
// pt ++;
// return 1;
// }
2 years ago
__AQEXPORT__(bool)
fit_inc(vector_type<vector_type<double>> v, vector_type<long> res)
{
static uint32_t last_offset = 0;
double **data = (double **)malloc(v.size * sizeof(double *));
if(last_offset >= v.size)
last_offset = 0;
for (int i = last_offset; i < v.size; ++i)
data[i] = v.container[i].container;
rf->fit(data, res.container, v.size);
free(data);
return true;
}
__AQEXPORT__(bool)
fit(vector_type<vector_type<double>> v, vector_type<long> res)
{
double **data = (double **)malloc(v.size * sizeof(double *));
for (int i = 0; i < v.size; ++i)
data[i] = v.container[i].container;
2 years ago
// dt->fit(data, res.container, v.size);
rf->fit(data, res.container, v.size);
return true;
2 years ago
}
2 years ago
__AQEXPORT__(vectortype_cstorage)
predict(vector_type<vector_type<double>> v)
{
int *result = (int *)malloc(v.size * sizeof(int));
2 years ago
for (uint32_t i = 0; i < v.size; i++)
2 years ago
//result[i] = dt->Test(v.container[i].container, dt->DTree);
2 years ago
result[i] = int(rf->Test(v[i].container));
auto container = (vector_type<int> *)malloc(sizeof(vector_type<int>));
container->size = v.size;
container->capacity = 0;
container->container = result;
auto ret = vectortype_cstorage{.container = container, .size = 1, .capacity = 0};
return ret;
2 years ago
}