lib/lightgbm/dataset.rb in lightgbm-0.1.3 vs lib/lightgbm/dataset.rb in lightgbm-0.1.4
- old
+ new
@@ -1,26 +1,26 @@
module LightGBM
class Dataset
attr_reader :data, :params
- def initialize(data, label: nil, weight: nil, params: nil, reference: nil, used_indices: nil, categorical_feature: "auto")
+ def initialize(data, label: nil, weight: nil, group: nil, params: nil, reference: nil, used_indices: nil, categorical_feature: "auto")
@data = data
# TODO stringify params
params ||= {}
params["categorical_feature"] ||= categorical_feature.join(",") if categorical_feature != "auto"
set_verbosity(params)
@handle = ::FFI::MemoryPointer.new(:pointer)
parameters = params_str(params)
reference = reference.handle_pointer if reference
- if data.is_a?(String)
- check_result FFI.LGBM_DatasetCreateFromFile(data, parameters, reference, @handle)
- elsif used_indices
+ if used_indices
used_row_indices = ::FFI::MemoryPointer.new(:int32, used_indices.count)
used_row_indices.put_array_of_int32(0, used_indices)
check_result FFI.LGBM_DatasetGetSubset(reference, used_row_indices, used_indices.count, parameters, @handle)
+ elsif data.is_a?(String)
+ check_result FFI.LGBM_DatasetCreateFromFile(data, parameters, reference, @handle)
else
if matrix?(data)
nrow = data.row_count
ncol = data.column_count
flat_data = data.to_a.flatten
@@ -38,25 +38,37 @@
c_data = ::FFI::MemoryPointer.new(:float, nrow * ncol)
c_data.put_array_of_float(0, flat_data)
check_result FFI.LGBM_DatasetCreateFromMat(c_data, 0, nrow, ncol, 1, parameters, reference, @handle)
end
- # causes "Stack consistency error"
- # ObjectSpace.define_finalizer(self, self.class.finalize(handle_pointer))
+ ObjectSpace.define_finalizer(self, self.class.finalize(handle_pointer)) unless used_indices
- set_field("label", label) if label
- set_field("weight", weight) if weight
+ self.label = label if label
+ self.weight = weight if weight
+ self.group = group if group
end
def label
field("label")
end
def weight
field("weight")
end
+ def label=(label)
+ set_field("label", label)
+ end
+
+ def weight=(weight)
+ set_field("weight", weight)
+ end
+
+ def group=(group)
+ set_field("group", group, type: :int32)
+ end
+
def num_data
out = ::FFI::MemoryPointer.new(:int)
check_result FFI.LGBM_DatasetGetNumData(handle_pointer, out)
out.read_int
end
@@ -69,13 +81,14 @@
def save_binary(filename)
check_result FFI.LGBM_DatasetSaveBinary(handle_pointer, filename)
end
- def dump_text(filename)
- check_result FFI.LGBM_DatasetDumpText(handle_pointer, filename)
- end
+ # not released yet
+ # def dump_text(filename)
+ # check_result FFI.LGBM_DatasetDumpText(handle_pointer, filename)
+ # end
def subset(used_indices, params: nil)
# categorical_feature passed via params
params ||= self.params
Dataset.new(nil,
@@ -83,18 +96,19 @@
reference: self,
used_indices: used_indices
)
end
- def self.finalize(pointer)
- -> { FFI.LGBM_DatasetFree(pointer) }
- end
-
def handle_pointer
@handle.read_pointer
end
+ def self.finalize(pointer)
+ # must use proc instead of stabby lambda
+ proc { FFI.LGBM_DatasetFree(pointer) }
+ end
+
private
def field(field_name)
num_data = self.num_data
out_len = ::FFI::MemoryPointer.new(:int)
@@ -102,14 +116,20 @@
out_type = ::FFI::MemoryPointer.new(:int)
check_result FFI.LGBM_DatasetGetField(handle_pointer, field_name, out_len, out_ptr, out_type)
out_ptr.read_pointer.read_array_of_float(num_data)
end
- def set_field(field_name, data)
+ def set_field(field_name, data, type: :float)
data = data.to_a unless data.is_a?(Array)
- c_data = ::FFI::MemoryPointer.new(:float, data.count)
- c_data.put_array_of_float(0, data)
- check_result FFI.LGBM_DatasetSetField(handle_pointer, field_name, c_data, data.count, 0)
+ if type == :int32
+ c_data = ::FFI::MemoryPointer.new(:int32, data.count)
+ c_data.put_array_of_int32(0, data)
+ check_result FFI.LGBM_DatasetSetField(handle_pointer, field_name, c_data, data.count, 2)
+ else
+ c_data = ::FFI::MemoryPointer.new(:float, data.count)
+ c_data.put_array_of_float(0, data)
+ check_result FFI.LGBM_DatasetSetField(handle_pointer, field_name, c_data, data.count, 0)
+ end
end
def matrix?(data)
defined?(Matrix) && data.is_a?(Matrix)
end