lib/lightgbm/dataset.rb in lightgbm-0.1.2 vs lib/lightgbm/dataset.rb in lightgbm-0.1.3

- old
+ new

@@ -18,13 +18,29 @@ elsif used_indices used_row_indices = ::FFI::MemoryPointer.new(:int32, used_indices.count) used_row_indices.put_array_of_int32(0, used_indices) check_result FFI.LGBM_DatasetGetSubset(reference, used_row_indices, used_indices.count, parameters, @handle) else - c_data = ::FFI::MemoryPointer.new(:float, data.count * data.first.count) - c_data.put_array_of_float(0, data.flatten) - check_result FFI.LGBM_DatasetCreateFromMat(c_data, 0, data.count, data.first.count, 1, parameters, reference, @handle) + if matrix?(data) + nrow = data.row_count + ncol = data.column_count + flat_data = data.to_a.flatten + elsif daru?(data) + nrow, ncol = data.shape + flat_data = data.each_vector.map(&:to_a).flatten + elsif narray?(data) + nrow, ncol = data.shape + flat_data = data.flatten.to_a + else + nrow = data.count + ncol = data.first.count + flat_data = data.flatten + end + + c_data = ::FFI::MemoryPointer.new(:float, nrow * ncol) + c_data.put_array_of_float(0, flat_data) + check_result FFI.LGBM_DatasetCreateFromMat(c_data, 0, nrow, ncol, 1, parameters, reference, @handle) end # causes "Stack consistency error" # ObjectSpace.define_finalizer(self, self.class.finalize(handle_pointer)) set_field("label", label) if label @@ -87,12 +103,25 @@ check_result FFI.LGBM_DatasetGetField(handle_pointer, field_name, out_len, out_ptr, out_type) out_ptr.read_pointer.read_array_of_float(num_data) end def set_field(field_name, data) + data = data.to_a unless data.is_a?(Array) c_data = ::FFI::MemoryPointer.new(:float, data.count) c_data.put_array_of_float(0, data) check_result FFI.LGBM_DatasetSetField(handle_pointer, field_name, c_data, data.count, 0) + end + + def matrix?(data) + defined?(Matrix) && data.is_a?(Matrix) + end + + def daru?(data) + defined?(Daru::DataFrame) && data.is_a?(Daru::DataFrame) + end + + def narray?(data) + defined?(Numo::NArray) && data.is_a?(Numo::NArray) end include Utils end end