vendor/isotree/src/helpers_iforest.cpp in isotree-0.1.4 vs vendor/isotree/src/helpers_iforest.cpp in isotree-0.1.5

- old
+ new

@@ -20,11 +20,11 @@ * [7] Quinlan, J. Ross. C4. 5: programs for machine learning. Elsevier, 2014. * [8] Cortes, David. "Distance approximation using Isolation Forests." arXiv preprint arXiv:1910.12362 (2019). * [9] Cortes, David. "Imputing missing values with unsupervised random trees." arXiv preprint arXiv:1911.06646 (2019). * * BSD 2-Clause License -* Copyright (c) 2019, David Cortes +* Copyright (c) 2020, David Cortes * All rights reserved. * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. @@ -89,11 +89,11 @@ /* for use in regular model */ void get_split_range(WorkerMemory &workspace, InputData &input_data, ModelParams &model_params, IsoTree &tree) { if (tree.col_type == Numeric) { - if (input_data.Xc == NULL) + if (input_data.Xc_indptr == NULL) get_range(workspace.ix_arr.data(), input_data.numeric_data + input_data.nrows * tree.col_num, workspace.st, workspace.end, model_params.missing_action, workspace.xmin, workspace.xmax, workspace.unsplittable); else get_range(workspace.ix_arr.data(), workspace.st, workspace.end, tree.col_num, @@ -112,11 +112,11 @@ /* for use in extended model */ void get_split_range(WorkerMemory &workspace, InputData &input_data, ModelParams &model_params) { if (workspace.col_type == Numeric) { - if (input_data.Xc == NULL) + if (input_data.Xc_indptr == NULL) get_range(workspace.ix_arr.data(), input_data.numeric_data + input_data.nrows * workspace.col_chosen, workspace.st, workspace.end, model_params.missing_action, workspace.xmin, workspace.xmax, workspace.unsplittable); else get_range(workspace.ix_arr.data(), workspace.st, workspace.end, workspace.col_chosen, @@ -279,14 +279,23 @@ recursion_state.col_sampler = workspace.col_sampler; /* for the extended model, it's not necessary to copy everything */ if (!workspace.comb_val.size()) { - /* TODO: here only need to copy the left half, as the right one is untouched */ - recursion_state.ix_arr = workspace.ix_arr; - recursion_state.weights_map = workspace.weights_map; - recursion_state.weights_arr = workspace.weights_arr; + recursion_state.ix_arr = std::vector<size_t>(workspace.ix_arr.begin() + workspace.st_NA, + workspace.ix_arr.begin() + workspace.end + 1); + size_t tot = workspace.end - workspace.st_NA + 1; + if (workspace.weights_arr.size() || workspace.weights_map.size()) + recursion_state.weights_arr = std::unique_ptr<double[]>(new double[tot]); + if (workspace.weights_arr.size()) + for (size_t ix = 0; ix < tot; ix++) + recursion_state.weights_arr[ix] = workspace.weights_arr[workspace.ix_arr[ix + workspace.st_NA]]; + else if (workspace.weights_map.size()) + for (size_t ix = 0; ix < tot; ix++) + recursion_state.weights_arr[ix] = workspace.weights_map[workspace.ix_arr[ix + workspace.st_NA]]; + + } } void restore_recursion_state(WorkerMemory &workspace, RecursionState &recursion_state) @@ -299,11 +308,17 @@ workspace.cols_possible = std::move(recursion_state.cols_possible); workspace.col_sampler = std::move(recursion_state.col_sampler); if (!workspace.comb_val.size()) { - /* TODO: here only need to copy the left half, as the right one is untouched */ - workspace.ix_arr = std::move(recursion_state.ix_arr); - workspace.weights_map = std::move(recursion_state.weights_map); - workspace.weights_arr = std::move(recursion_state.weights_arr); + std::copy(recursion_state.ix_arr.begin(), + recursion_state.ix_arr.end(), + workspace.ix_arr.begin() + recursion_state.st_NA); + size_t tot = workspace.end - workspace.st_NA + 1; + if (workspace.weights_arr.size()) + for (size_t ix = 0; ix < tot; ix++) + workspace.weights_arr[workspace.ix_arr[ix + workspace.st_NA]] = recursion_state.weights_arr[ix]; + else if (workspace.weights_map.size()) + for (size_t ix = 0; ix < tot; ix++) + workspace.weights_map[workspace.ix_arr[ix + workspace.st_NA]] = recursion_state.weights_arr[ix]; } }