vendor/isotree/src/helpers_iforest.cpp in isotree-0.1.4 vs vendor/isotree/src/helpers_iforest.cpp in isotree-0.1.5
- old
+ new
@@ -20,11 +20,11 @@
* [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014.
* [8] Cortes, David. "Distance approximation using Isolation Forests." arXiv preprint arXiv:1910.12362 (2019).
* [9] Cortes, David. "Imputing missing values with unsupervised random trees." arXiv preprint arXiv:1911.06646 (2019).
*
* BSD 2-Clause License
-* Copyright (c) 2019, David Cortes
+* Copyright (c) 2020, David Cortes
* All rights reserved.
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
@@ -89,11 +89,11 @@
/* for use in regular model */
void get_split_range(WorkerMemory &workspace, InputData &input_data, ModelParams &model_params, IsoTree &tree)
{
if (tree.col_type == Numeric)
{
- if (input_data.Xc == NULL)
+ if (input_data.Xc_indptr == NULL)
get_range(workspace.ix_arr.data(), input_data.numeric_data + input_data.nrows * tree.col_num,
workspace.st, workspace.end, model_params.missing_action,
workspace.xmin, workspace.xmax, workspace.unsplittable);
else
get_range(workspace.ix_arr.data(), workspace.st, workspace.end, tree.col_num,
@@ -112,11 +112,11 @@
/* for use in extended model */
void get_split_range(WorkerMemory &workspace, InputData &input_data, ModelParams &model_params)
{
if (workspace.col_type == Numeric)
{
- if (input_data.Xc == NULL)
+ if (input_data.Xc_indptr == NULL)
get_range(workspace.ix_arr.data(), input_data.numeric_data + input_data.nrows * workspace.col_chosen,
workspace.st, workspace.end, model_params.missing_action,
workspace.xmin, workspace.xmax, workspace.unsplittable);
else
get_range(workspace.ix_arr.data(), workspace.st, workspace.end, workspace.col_chosen,
@@ -279,14 +279,23 @@
recursion_state.col_sampler = workspace.col_sampler;
/* for the extended model, it's not necessary to copy everything */
if (!workspace.comb_val.size())
{
- /* TODO: here only need to copy the left half, as the right one is untouched */
- recursion_state.ix_arr = workspace.ix_arr;
- recursion_state.weights_map = workspace.weights_map;
- recursion_state.weights_arr = workspace.weights_arr;
+ recursion_state.ix_arr = std::vector<size_t>(workspace.ix_arr.begin() + workspace.st_NA,
+ workspace.ix_arr.begin() + workspace.end + 1);
+ size_t tot = workspace.end - workspace.st_NA + 1;
+ if (workspace.weights_arr.size() || workspace.weights_map.size())
+ recursion_state.weights_arr = std::unique_ptr<double[]>(new double[tot]);
+ if (workspace.weights_arr.size())
+ for (size_t ix = 0; ix < tot; ix++)
+ recursion_state.weights_arr[ix] = workspace.weights_arr[workspace.ix_arr[ix + workspace.st_NA]];
+ else if (workspace.weights_map.size())
+ for (size_t ix = 0; ix < tot; ix++)
+ recursion_state.weights_arr[ix] = workspace.weights_map[workspace.ix_arr[ix + workspace.st_NA]];
+
+
}
}
void restore_recursion_state(WorkerMemory &workspace, RecursionState &recursion_state)
@@ -299,11 +308,17 @@
workspace.cols_possible = std::move(recursion_state.cols_possible);
workspace.col_sampler = std::move(recursion_state.col_sampler);
if (!workspace.comb_val.size())
{
- /* TODO: here only need to copy the left half, as the right one is untouched */
- workspace.ix_arr = std::move(recursion_state.ix_arr);
- workspace.weights_map = std::move(recursion_state.weights_map);
- workspace.weights_arr = std::move(recursion_state.weights_arr);
+ std::copy(recursion_state.ix_arr.begin(),
+ recursion_state.ix_arr.end(),
+ workspace.ix_arr.begin() + recursion_state.st_NA);
+ size_t tot = workspace.end - workspace.st_NA + 1;
+ if (workspace.weights_arr.size())
+ for (size_t ix = 0; ix < tot; ix++)
+ workspace.weights_arr[workspace.ix_arr[ix + workspace.st_NA]] = recursion_state.weights_arr[ix];
+ else if (workspace.weights_map.size())
+ for (size_t ix = 0; ix < tot; ix++)
+ workspace.weights_map[workspace.ix_arr[ix + workspace.st_NA]] = recursion_state.weights_arr[ix];
}
}