lib/lightgbm/dataset.rb in lightgbm-0.1.7 vs lib/lightgbm/dataset.rb in lightgbm-0.1.8
- old
+ new
@@ -2,93 +2,71 @@
class Dataset
attr_reader :data, :params
def initialize(data, label: nil, weight: nil, group: nil, params: nil, reference: nil, used_indices: nil, categorical_feature: "auto", feature_names: nil)
@data = data
+ @label = label
+ @weight = weight
+ @group = group
+ @params = params
+ @reference = reference
+ @used_indices = used_indices
+ @categorical_feature = categorical_feature
+ @feature_names = feature_names
- # TODO stringify params
- params ||= {}
- if categorical_feature != "auto" && categorical_feature.any?
- params["categorical_feature"] ||= categorical_feature.join(",")
- end
- set_verbosity(params)
-
- @handle = ::FFI::MemoryPointer.new(:pointer)
- parameters = params_str(params)
- reference = reference.handle_pointer if reference
- if used_indices
- used_row_indices = ::FFI::MemoryPointer.new(:int32, used_indices.count)
- used_row_indices.write_array_of_int32(used_indices)
- check_result FFI.LGBM_DatasetGetSubset(reference, used_row_indices, used_indices.count, parameters, @handle)
- elsif data.is_a?(String)
- check_result FFI.LGBM_DatasetCreateFromFile(data, parameters, reference, @handle)
- else
- if matrix?(data)
- nrow = data.row_count
- ncol = data.column_count
- flat_data = data.to_a.flatten
- elsif daru?(data)
- nrow, ncol = data.shape
- flat_data = data.map_rows(&:to_a).flatten
- elsif narray?(data)
- nrow, ncol = data.shape
- flat_data = data.flatten.to_a
- else
- nrow = data.count
- ncol = data.first.count
- flat_data = data.flatten
- end
-
- handle_missing(flat_data)
- c_data = ::FFI::MemoryPointer.new(:double, nrow * ncol)
- c_data.write_array_of_double(flat_data)
- check_result FFI.LGBM_DatasetCreateFromMat(c_data, 1, nrow, ncol, 1, parameters, reference, @handle)
- end
- ObjectSpace.define_finalizer(self, self.class.finalize(handle_pointer)) unless used_indices
-
- self.label = label if label
- self.weight = weight if weight
- self.group = group if group
- self.feature_names = feature_names if feature_names
+ construct
end
def label
field("label")
end
def weight
field("weight")
end
- def label=(label)
- set_field("label", label)
- end
-
def feature_names
# must preallocate space
num_feature_names = ::FFI::MemoryPointer.new(:int)
out_strs = ::FFI::MemoryPointer.new(:pointer, 1000)
str_ptrs = 1000.times.map { ::FFI::MemoryPointer.new(:char, 255) }
out_strs.write_array_of_pointer(str_ptrs)
check_result FFI.LGBM_DatasetGetFeatureNames(handle_pointer, out_strs, num_feature_names)
str_ptrs[0, num_feature_names.read_int].map(&:read_string)
end
+ def label=(label)
+ @label = label
+ set_field("label", label)
+ end
+
def weight=(weight)
+ @weight = weight
set_field("weight", weight)
end
def group=(group)
+ @group = group
set_field("group", group, type: :int32)
end
def feature_names=(feature_names)
+ @feature_names = feature_names
c_feature_names = ::FFI::MemoryPointer.new(:pointer, feature_names.size)
c_feature_names.write_array_of_pointer(feature_names.map { |v| ::FFI::MemoryPointer.from_string(v) })
check_result FFI.LGBM_DatasetSetFeatureNames(handle_pointer, c_feature_names, feature_names.size)
end
+ # TODO only update reference if not in chain
+ def reference=(reference)
+ if reference != @reference
+ @reference = reference
+ free_handle
+ construct
+ end
+ end
+
def num_data
out = ::FFI::MemoryPointer.new(:int)
check_result FFI.LGBM_DatasetGetNumData(handle_pointer, out)
out.read_int
end
@@ -121,9 +99,64 @@
# must use proc instead of stabby lambda
proc { FFI.LGBM_DatasetFree(pointer) }
end
private
+
+ def construct
+ data = @data
+ used_indices = @used_indices
+
+ # TODO stringify params
+ params = @params || {}
+ if @categorical_feature != "auto" && @categorical_feature.any?
+ params["categorical_feature"] ||= @categorical_feature.join(",")
+ end
+ set_verbosity(params)
+
+ @handle = ::FFI::MemoryPointer.new(:pointer)
+ parameters = params_str(params)
+ reference = @reference.handle_pointer if @reference
+ if used_indices
+ used_row_indices = ::FFI::MemoryPointer.new(:int32, used_indices.count)
+ used_row_indices.write_array_of_int32(used_indices)
+ check_result FFI.LGBM_DatasetGetSubset(reference, used_row_indices, used_indices.count, parameters, @handle)
+ elsif data.is_a?(String)
+ check_result FFI.LGBM_DatasetCreateFromFile(data, parameters, reference, @handle)
+ else
+ if matrix?(data)
+ nrow = data.row_count
+ ncol = data.column_count
+ flat_data = data.to_a.flatten
+ elsif daru?(data)
+ nrow, ncol = data.shape
+ flat_data = data.map_rows(&:to_a).flatten
+ elsif narray?(data)
+ nrow, ncol = data.shape
+ flat_data = data.flatten.to_a
+ else
+ nrow = data.count
+ ncol = data.first.count
+ flat_data = data.flatten
+ end
+
+ handle_missing(flat_data)
+ c_data = ::FFI::MemoryPointer.new(:double, nrow * ncol)
+ c_data.write_array_of_double(flat_data)
+ check_result FFI.LGBM_DatasetCreateFromMat(c_data, 1, nrow, ncol, 1, parameters, reference, @handle)
+ end
+ ObjectSpace.define_finalizer(self, self.class.finalize(handle_pointer)) unless used_indices
+
+ self.label = @label if @label
+ self.weight = @weight if @weight
+ self.group = @group if @group
+ self.feature_names = @feature_names if @feature_names
+ end
+
+ def free_handle
+ FFI.LGBM_DatasetFree(handle_pointer)
+ ObjectSpace.undefine_finalizer(self)
+ end
def dump_text(filename)
check_result FFI.LGBM_DatasetDumpText(handle_pointer, filename)
end