lib/lightgbm/dataset.rb in lightgbm-0.1.2 vs lib/lightgbm/dataset.rb in lightgbm-0.1.3
- old
+ new
@@ -18,13 +18,29 @@
elsif used_indices
used_row_indices = ::FFI::MemoryPointer.new(:int32, used_indices.count)
used_row_indices.put_array_of_int32(0, used_indices)
check_result FFI.LGBM_DatasetGetSubset(reference, used_row_indices, used_indices.count, parameters, @handle)
else
- c_data = ::FFI::MemoryPointer.new(:float, data.count * data.first.count)
- c_data.put_array_of_float(0, data.flatten)
- check_result FFI.LGBM_DatasetCreateFromMat(c_data, 0, data.count, data.first.count, 1, parameters, reference, @handle)
+ if matrix?(data)
+ nrow = data.row_count
+ ncol = data.column_count
+ flat_data = data.to_a.flatten
+ elsif daru?(data)
+ nrow, ncol = data.shape
+ flat_data = data.each_vector.map(&:to_a).flatten
+ elsif narray?(data)
+ nrow, ncol = data.shape
+ flat_data = data.flatten.to_a
+ else
+ nrow = data.count
+ ncol = data.first.count
+ flat_data = data.flatten
+ end
+
+ c_data = ::FFI::MemoryPointer.new(:float, nrow * ncol)
+ c_data.put_array_of_float(0, flat_data)
+ check_result FFI.LGBM_DatasetCreateFromMat(c_data, 0, nrow, ncol, 1, parameters, reference, @handle)
end
# causes "Stack consistency error"
# ObjectSpace.define_finalizer(self, self.class.finalize(handle_pointer))
set_field("label", label) if label
@@ -87,12 +103,25 @@
check_result FFI.LGBM_DatasetGetField(handle_pointer, field_name, out_len, out_ptr, out_type)
out_ptr.read_pointer.read_array_of_float(num_data)
end
def set_field(field_name, data)
+ data = data.to_a unless data.is_a?(Array)
c_data = ::FFI::MemoryPointer.new(:float, data.count)
c_data.put_array_of_float(0, data)
check_result FFI.LGBM_DatasetSetField(handle_pointer, field_name, c_data, data.count, 0)
+ end
+
+ def matrix?(data)
+ defined?(Matrix) && data.is_a?(Matrix)
+ end
+
+ def daru?(data)
+ defined?(Daru::DataFrame) && data.is_a?(Daru::DataFrame)
+ end
+
+ def narray?(data)
+ defined?(Numo::NArray) && data.is_a?(Numo::NArray)
end
include Utils
end
end