lib/lightgbm/dataset.rb in lightgbm-0.1.7 vs lib/lightgbm/dataset.rb in lightgbm-0.1.8

- old
+ new

@@ -2,93 +2,71 @@ class Dataset attr_reader :data, :params def initialize(data, label: nil, weight: nil, group: nil, params: nil, reference: nil, used_indices: nil, categorical_feature: "auto", feature_names: nil) @data = data + @label = label + @weight = weight + @group = group + @params = params + @reference = reference + @used_indices = used_indices + @categorical_feature = categorical_feature + @feature_names = feature_names - # TODO stringify params - params ||= {} - if categorical_feature != "auto" && categorical_feature.any? - params["categorical_feature"] ||= categorical_feature.join(",") - end - set_verbosity(params) - - @handle = ::FFI::MemoryPointer.new(:pointer) - parameters = params_str(params) - reference = reference.handle_pointer if reference - if used_indices - used_row_indices = ::FFI::MemoryPointer.new(:int32, used_indices.count) - used_row_indices.write_array_of_int32(used_indices) - check_result FFI.LGBM_DatasetGetSubset(reference, used_row_indices, used_indices.count, parameters, @handle) - elsif data.is_a?(String) - check_result FFI.LGBM_DatasetCreateFromFile(data, parameters, reference, @handle) - else - if matrix?(data) - nrow = data.row_count - ncol = data.column_count - flat_data = data.to_a.flatten - elsif daru?(data) - nrow, ncol = data.shape - flat_data = data.map_rows(&:to_a).flatten - elsif narray?(data) - nrow, ncol = data.shape - flat_data = data.flatten.to_a - else - nrow = data.count - ncol = data.first.count - flat_data = data.flatten - end - - handle_missing(flat_data) - c_data = ::FFI::MemoryPointer.new(:double, nrow * ncol) - c_data.write_array_of_double(flat_data) - check_result FFI.LGBM_DatasetCreateFromMat(c_data, 1, nrow, ncol, 1, parameters, reference, @handle) - end - ObjectSpace.define_finalizer(self, self.class.finalize(handle_pointer)) unless used_indices - - self.label = label if label - self.weight = weight if weight - self.group = group if group - self.feature_names = feature_names if feature_names + construct end def label field("label") end def weight field("weight") end - def label=(label) - set_field("label", label) - end - def feature_names # must preallocate space num_feature_names = ::FFI::MemoryPointer.new(:int) out_strs = ::FFI::MemoryPointer.new(:pointer, 1000) str_ptrs = 1000.times.map { ::FFI::MemoryPointer.new(:char, 255) } out_strs.write_array_of_pointer(str_ptrs) check_result FFI.LGBM_DatasetGetFeatureNames(handle_pointer, out_strs, num_feature_names) str_ptrs[0, num_feature_names.read_int].map(&:read_string) end + def label=(label) + @label = label + set_field("label", label) + end + def weight=(weight) + @weight = weight set_field("weight", weight) end def group=(group) + @group = group set_field("group", group, type: :int32) end def feature_names=(feature_names) + @feature_names = feature_names c_feature_names = ::FFI::MemoryPointer.new(:pointer, feature_names.size) c_feature_names.write_array_of_pointer(feature_names.map { |v| ::FFI::MemoryPointer.from_string(v) }) check_result FFI.LGBM_DatasetSetFeatureNames(handle_pointer, c_feature_names, feature_names.size) end + # TODO only update reference if not in chain + def reference=(reference) + if reference != @reference + @reference = reference + free_handle + construct + end + end + def num_data out = ::FFI::MemoryPointer.new(:int) check_result FFI.LGBM_DatasetGetNumData(handle_pointer, out) out.read_int end @@ -121,9 +99,64 @@ # must use proc instead of stabby lambda proc { FFI.LGBM_DatasetFree(pointer) } end private + + def construct + data = @data + used_indices = @used_indices + + # TODO stringify params + params = @params || {} + if @categorical_feature != "auto" && @categorical_feature.any? + params["categorical_feature"] ||= @categorical_feature.join(",") + end + set_verbosity(params) + + @handle = ::FFI::MemoryPointer.new(:pointer) + parameters = params_str(params) + reference = @reference.handle_pointer if @reference + if used_indices + used_row_indices = ::FFI::MemoryPointer.new(:int32, used_indices.count) + used_row_indices.write_array_of_int32(used_indices) + check_result FFI.LGBM_DatasetGetSubset(reference, used_row_indices, used_indices.count, parameters, @handle) + elsif data.is_a?(String) + check_result FFI.LGBM_DatasetCreateFromFile(data, parameters, reference, @handle) + else + if matrix?(data) + nrow = data.row_count + ncol = data.column_count + flat_data = data.to_a.flatten + elsif daru?(data) + nrow, ncol = data.shape + flat_data = data.map_rows(&:to_a).flatten + elsif narray?(data) + nrow, ncol = data.shape + flat_data = data.flatten.to_a + else + nrow = data.count + ncol = data.first.count + flat_data = data.flatten + end + + handle_missing(flat_data) + c_data = ::FFI::MemoryPointer.new(:double, nrow * ncol) + c_data.write_array_of_double(flat_data) + check_result FFI.LGBM_DatasetCreateFromMat(c_data, 1, nrow, ncol, 1, parameters, reference, @handle) + end + ObjectSpace.define_finalizer(self, self.class.finalize(handle_pointer)) unless used_indices + + self.label = @label if @label + self.weight = @weight if @weight + self.group = @group if @group + self.feature_names = @feature_names if @feature_names + end + + def free_handle + FFI.LGBM_DatasetFree(handle_pointer) + ObjectSpace.undefine_finalizer(self) + end def dump_text(filename) check_result FFI.LGBM_DatasetDumpText(handle_pointer, filename) end