lib/lightgbm/dataset.rb in lightgbm-0.1.3 vs lib/lightgbm/dataset.rb in lightgbm-0.1.4

- old
+ new

@@ -1,26 +1,26 @@ module LightGBM class Dataset attr_reader :data, :params - def initialize(data, label: nil, weight: nil, params: nil, reference: nil, used_indices: nil, categorical_feature: "auto") + def initialize(data, label: nil, weight: nil, group: nil, params: nil, reference: nil, used_indices: nil, categorical_feature: "auto") @data = data # TODO stringify params params ||= {} params["categorical_feature"] ||= categorical_feature.join(",") if categorical_feature != "auto" set_verbosity(params) @handle = ::FFI::MemoryPointer.new(:pointer) parameters = params_str(params) reference = reference.handle_pointer if reference - if data.is_a?(String) - check_result FFI.LGBM_DatasetCreateFromFile(data, parameters, reference, @handle) - elsif used_indices + if used_indices used_row_indices = ::FFI::MemoryPointer.new(:int32, used_indices.count) used_row_indices.put_array_of_int32(0, used_indices) check_result FFI.LGBM_DatasetGetSubset(reference, used_row_indices, used_indices.count, parameters, @handle) + elsif data.is_a?(String) + check_result FFI.LGBM_DatasetCreateFromFile(data, parameters, reference, @handle) else if matrix?(data) nrow = data.row_count ncol = data.column_count flat_data = data.to_a.flatten @@ -38,25 +38,37 @@ c_data = ::FFI::MemoryPointer.new(:float, nrow * ncol) c_data.put_array_of_float(0, flat_data) check_result FFI.LGBM_DatasetCreateFromMat(c_data, 0, nrow, ncol, 1, parameters, reference, @handle) end - # causes "Stack consistency error" - # ObjectSpace.define_finalizer(self, self.class.finalize(handle_pointer)) + ObjectSpace.define_finalizer(self, self.class.finalize(handle_pointer)) unless used_indices - set_field("label", label) if label - set_field("weight", weight) if weight + self.label = label if label + self.weight = weight if weight + self.group = group if group end def label field("label") end def weight field("weight") end + def label=(label) + set_field("label", label) + end + + def weight=(weight) + set_field("weight", weight) + end + + def group=(group) + set_field("group", group, type: :int32) + end + def num_data out = ::FFI::MemoryPointer.new(:int) check_result FFI.LGBM_DatasetGetNumData(handle_pointer, out) out.read_int end @@ -69,13 +81,14 @@ def save_binary(filename) check_result FFI.LGBM_DatasetSaveBinary(handle_pointer, filename) end - def dump_text(filename) - check_result FFI.LGBM_DatasetDumpText(handle_pointer, filename) - end + # not released yet + # def dump_text(filename) + # check_result FFI.LGBM_DatasetDumpText(handle_pointer, filename) + # end def subset(used_indices, params: nil) # categorical_feature passed via params params ||= self.params Dataset.new(nil, @@ -83,18 +96,19 @@ reference: self, used_indices: used_indices ) end - def self.finalize(pointer) - -> { FFI.LGBM_DatasetFree(pointer) } - end - def handle_pointer @handle.read_pointer end + def self.finalize(pointer) + # must use proc instead of stabby lambda + proc { FFI.LGBM_DatasetFree(pointer) } + end + private def field(field_name) num_data = self.num_data out_len = ::FFI::MemoryPointer.new(:int) @@ -102,14 +116,20 @@ out_type = ::FFI::MemoryPointer.new(:int) check_result FFI.LGBM_DatasetGetField(handle_pointer, field_name, out_len, out_ptr, out_type) out_ptr.read_pointer.read_array_of_float(num_data) end - def set_field(field_name, data) + def set_field(field_name, data, type: :float) data = data.to_a unless data.is_a?(Array) - c_data = ::FFI::MemoryPointer.new(:float, data.count) - c_data.put_array_of_float(0, data) - check_result FFI.LGBM_DatasetSetField(handle_pointer, field_name, c_data, data.count, 0) + if type == :int32 + c_data = ::FFI::MemoryPointer.new(:int32, data.count) + c_data.put_array_of_int32(0, data) + check_result FFI.LGBM_DatasetSetField(handle_pointer, field_name, c_data, data.count, 2) + else + c_data = ::FFI::MemoryPointer.new(:float, data.count) + c_data.put_array_of_float(0, data) + check_result FFI.LGBM_DatasetSetField(handle_pointer, field_name, c_data, data.count, 0) + end end def matrix?(data) defined?(Matrix) && data.is_a?(Matrix) end