ext/isotree/ext.cpp in isotree-0.1.3 vs ext/isotree/ext.cpp in isotree-0.1.4
- old
+ new
@@ -3,21 +3,81 @@
// rice
#include <rice/Array.hpp>
#include <rice/Hash.hpp>
#include <rice/Module.hpp>
+#include <rice/Object.hpp>
#include <rice/String.hpp>
#include <rice/Symbol.hpp>
using Rice::Array;
using Rice::Hash;
using Rice::Module;
+using Rice::Object;
using Rice::String;
using Rice::Symbol;
using Rice::define_class_under;
using Rice::define_module;
+template<>
+NewCategAction from_ruby<NewCategAction>(Object x)
+{
+ auto value = x.to_s().str();
+ if (value == "weighted") return Weighted;
+ if (value == "smallest") return Smallest;
+ if (value == "random") return Random;
+ throw std::runtime_error("Unknown new categ action: " + value);
+}
+
+template<>
+MissingAction from_ruby<MissingAction>(Object x)
+{
+ auto value = x.to_s().str();
+ if (value == "divide") return Divide;
+ if (value == "impute") return Impute;
+ if (value == "fail") return Fail;
+ throw std::runtime_error("Unknown missing action: " + value);
+}
+
+template<>
+CategSplit from_ruby<CategSplit>(Object x)
+{
+ auto value = x.to_s().str();
+ if (value == "subset") return SubSet;
+ if (value == "single_categ") return SingleCateg;
+ throw std::runtime_error("Unknown categ split: " + value);
+}
+
+template<>
+CoefType from_ruby<CoefType>(Object x)
+{
+ auto value = x.to_s().str();
+ if (value == "uniform") return Uniform;
+ if (value == "normal") return Normal;
+ throw std::runtime_error("Unknown coef type: " + value);
+}
+
+template<>
+UseDepthImp from_ruby<UseDepthImp>(Object x)
+{
+ auto value = x.to_s().str();
+ if (value == "lower") return Lower;
+ if (value == "higher") return Higher;
+ if (value == "same") return Same;
+ throw std::runtime_error("Unknown depth imp: " + value);
+}
+
+template<>
+WeighImpRows from_ruby<WeighImpRows>(Object x)
+{
+ auto value = x.to_s().str();
+ if (value == "inverse") return Inverse;
+ if (value == "prop") return Prop;
+ if (value == "flat") return Flat;
+ throw std::runtime_error("Unknown weight imp rows: " + value);
+}
+
extern "C"
void Init_ext()
{
Module rb_mIsoTree = define_module("IsoTree");
@@ -52,45 +112,46 @@
double* Xc = NULL;
sparse_ix* Xc_ind = NULL;
sparse_ix* Xc_indptr = NULL;
// options
- CoefType coef_type = Normal;
- double* sample_weights = NULL;
- bool weight_as_sample = false;
- size_t max_depth = 0;
- bool limit_depth = true;
- bool standardize_dist = false;
- double* tmat = NULL;
- double* output_depths = NULL;
- bool standardize_depth = false;
- double* col_weights = NULL;
- MissingAction missing_action = Impute;
- CategSplit cat_split_type = SubSet;
- NewCategAction new_cat_action = Smallest;
- Imputer *imputer = NULL;
- UseDepthImp depth_imp = Higher;
- WeighImpRows weigh_imp_rows = Inverse;
- bool impute_at_fit = false;
-
- // Rice has limit of 14 arguments, so use hash for options
+ // Rice has limit of 14 arguments, so use hash
size_t sample_size = options.get<size_t, Symbol>("sample_size");
size_t ndim = options.get<size_t, Symbol>("ndim");
size_t ntrees = options.get<size_t, Symbol>("ntrees");
size_t ntry = options.get<size_t, Symbol>("ntry");
double prob_pick_by_gain_avg = options.get<double, Symbol>("prob_pick_avg_gain");
double prob_split_by_gain_avg = options.get<double, Symbol>("prob_split_avg_gain");
double prob_pick_by_gain_pl = options.get<double, Symbol>("prob_pick_pooled_gain");
double prob_split_by_gain_pl = options.get<double, Symbol>("prob_split_pooled_gain");
double min_gain = options.get<double, Symbol>("min_gain");
+ MissingAction missing_action = options.get<MissingAction, Symbol>("missing_action");
+ CategSplit cat_split_type = options.get<CategSplit, Symbol>("categ_split_type");
+ NewCategAction new_cat_action = options.get<NewCategAction, Symbol>("new_categ_action");
bool all_perm = options.get<bool, Symbol>("all_perm");
bool coef_by_prop = options.get<bool, Symbol>("coef_by_prop");
bool with_replacement = options.get<bool, Symbol>("sample_with_replacement");
bool penalize_range = options.get<bool, Symbol>("penalize_range");
bool weigh_by_kurt = options.get<bool, Symbol>("weigh_by_kurtosis");
+ CoefType coef_type = options.get<CoefType, Symbol>("coefs");
size_t min_imp_obs = options.get<size_t, Symbol>("min_imp_obs");
+ UseDepthImp depth_imp = options.get<UseDepthImp, Symbol>("depth_imp");
+ WeighImpRows weigh_imp_rows = options.get<WeighImpRows, Symbol>("weigh_imp_rows");
uint64_t random_seed = options.get<uint64_t, Symbol>("random_seed");
int nthreads = options.get<int, Symbol>("nthreads");
+
+ // TODO options
+ double* sample_weights = NULL;
+ bool weight_as_sample = false;
+ size_t max_depth = 0;
+ bool limit_depth = true;
+ bool standardize_dist = false;
+ double* tmat = NULL;
+ double* output_depths = NULL;
+ bool standardize_depth = false;
+ double* col_weights = NULL;
+ Imputer *imputer = NULL;
+ bool impute_at_fit = false;
fit_iforest(
NULL,
&iso,
numeric_data,