ext/isotree/ext.cpp in isotree-0.1.3 vs ext/isotree/ext.cpp in isotree-0.1.4

- old
+ new

@@ -3,21 +3,81 @@ // rice #include <rice/Array.hpp> #include <rice/Hash.hpp> #include <rice/Module.hpp> +#include <rice/Object.hpp> #include <rice/String.hpp> #include <rice/Symbol.hpp> using Rice::Array; using Rice::Hash; using Rice::Module; +using Rice::Object; using Rice::String; using Rice::Symbol; using Rice::define_class_under; using Rice::define_module; +template<> +NewCategAction from_ruby<NewCategAction>(Object x) +{ + auto value = x.to_s().str(); + if (value == "weighted") return Weighted; + if (value == "smallest") return Smallest; + if (value == "random") return Random; + throw std::runtime_error("Unknown new categ action: " + value); +} + +template<> +MissingAction from_ruby<MissingAction>(Object x) +{ + auto value = x.to_s().str(); + if (value == "divide") return Divide; + if (value == "impute") return Impute; + if (value == "fail") return Fail; + throw std::runtime_error("Unknown missing action: " + value); +} + +template<> +CategSplit from_ruby<CategSplit>(Object x) +{ + auto value = x.to_s().str(); + if (value == "subset") return SubSet; + if (value == "single_categ") return SingleCateg; + throw std::runtime_error("Unknown categ split: " + value); +} + +template<> +CoefType from_ruby<CoefType>(Object x) +{ + auto value = x.to_s().str(); + if (value == "uniform") return Uniform; + if (value == "normal") return Normal; + throw std::runtime_error("Unknown coef type: " + value); +} + +template<> +UseDepthImp from_ruby<UseDepthImp>(Object x) +{ + auto value = x.to_s().str(); + if (value == "lower") return Lower; + if (value == "higher") return Higher; + if (value == "same") return Same; + throw std::runtime_error("Unknown depth imp: " + value); +} + +template<> +WeighImpRows from_ruby<WeighImpRows>(Object x) +{ + auto value = x.to_s().str(); + if (value == "inverse") return Inverse; + if (value == "prop") return Prop; + if (value == "flat") return Flat; + throw std::runtime_error("Unknown weight imp rows: " + value); +} + extern "C" void Init_ext() { Module rb_mIsoTree = define_module("IsoTree"); @@ -52,45 +112,46 @@ double* Xc = NULL; sparse_ix* Xc_ind = NULL; sparse_ix* Xc_indptr = NULL; // options - CoefType coef_type = Normal; - double* sample_weights = NULL; - bool weight_as_sample = false; - size_t max_depth = 0; - bool limit_depth = true; - bool standardize_dist = false; - double* tmat = NULL; - double* output_depths = NULL; - bool standardize_depth = false; - double* col_weights = NULL; - MissingAction missing_action = Impute; - CategSplit cat_split_type = SubSet; - NewCategAction new_cat_action = Smallest; - Imputer *imputer = NULL; - UseDepthImp depth_imp = Higher; - WeighImpRows weigh_imp_rows = Inverse; - bool impute_at_fit = false; - - // Rice has limit of 14 arguments, so use hash for options + // Rice has limit of 14 arguments, so use hash size_t sample_size = options.get<size_t, Symbol>("sample_size"); size_t ndim = options.get<size_t, Symbol>("ndim"); size_t ntrees = options.get<size_t, Symbol>("ntrees"); size_t ntry = options.get<size_t, Symbol>("ntry"); double prob_pick_by_gain_avg = options.get<double, Symbol>("prob_pick_avg_gain"); double prob_split_by_gain_avg = options.get<double, Symbol>("prob_split_avg_gain"); double prob_pick_by_gain_pl = options.get<double, Symbol>("prob_pick_pooled_gain"); double prob_split_by_gain_pl = options.get<double, Symbol>("prob_split_pooled_gain"); double min_gain = options.get<double, Symbol>("min_gain"); + MissingAction missing_action = options.get<MissingAction, Symbol>("missing_action"); + CategSplit cat_split_type = options.get<CategSplit, Symbol>("categ_split_type"); + NewCategAction new_cat_action = options.get<NewCategAction, Symbol>("new_categ_action"); bool all_perm = options.get<bool, Symbol>("all_perm"); bool coef_by_prop = options.get<bool, Symbol>("coef_by_prop"); bool with_replacement = options.get<bool, Symbol>("sample_with_replacement"); bool penalize_range = options.get<bool, Symbol>("penalize_range"); bool weigh_by_kurt = options.get<bool, Symbol>("weigh_by_kurtosis"); + CoefType coef_type = options.get<CoefType, Symbol>("coefs"); size_t min_imp_obs = options.get<size_t, Symbol>("min_imp_obs"); + UseDepthImp depth_imp = options.get<UseDepthImp, Symbol>("depth_imp"); + WeighImpRows weigh_imp_rows = options.get<WeighImpRows, Symbol>("weigh_imp_rows"); uint64_t random_seed = options.get<uint64_t, Symbol>("random_seed"); int nthreads = options.get<int, Symbol>("nthreads"); + + // TODO options + double* sample_weights = NULL; + bool weight_as_sample = false; + size_t max_depth = 0; + bool limit_depth = true; + bool standardize_dist = false; + double* tmat = NULL; + double* output_depths = NULL; + bool standardize_depth = false; + double* col_weights = NULL; + Imputer *imputer = NULL; + bool impute_at_fit = false; fit_iforest( NULL, &iso, numeric_data,