lib/xlearn/dmatrix.rb in xlearn-0.1.0 vs lib/xlearn/dmatrix.rb in xlearn-0.1.1

- old
+ new

@@ -3,18 +3,34 @@ include Utils def initialize(data, label: nil) @handle = ::FFI::MemoryPointer.new(:pointer) - nrow = data.count - ncol = data.first.count + if matrix?(data) + nrow = data.row_count + ncol = data.column_count + flat_data = data.to_a.flatten + elsif daru?(data) + nrow, ncol = data.shape + flat_data = data.map_rows(&:to_a).flatten + elsif narray?(data) + nrow, ncol = data.shape + # TODO convert to SFloat and pass pointer + # for better performance + flat_data = data.flatten.to_a + else + nrow = data.count + ncol = data.first.count + flat_data = data.flatten + end - c_data = ::FFI::MemoryPointer.new(:float, nrow * ncol) - c_data.put_array_of_float(0, data.flatten) + c_data = ::FFI::MemoryPointer.new(:float, flat_data.size) + c_data.put_array_of_float(0, flat_data) if label - c_label = ::FFI::MemoryPointer.new(:float, nrow) + label = label.to_a + c_label = ::FFI::MemoryPointer.new(:float, label.size) c_label.put_array_of_float(0, label) end # TODO support this field_map = nil @@ -28,8 +44,22 @@ end def self.finalize(pointer) # must use proc instead of stabby lambda proc { FFI.XlearnDataFree(pointer) } + end + + private + + def matrix?(data) + defined?(Matrix) && data.is_a?(Matrix) + end + + def daru?(data) + defined?(Daru::DataFrame) && data.is_a?(Daru::DataFrame) + end + + def narray?(data) + defined?(Numo::NArray) && data.is_a?(Numo::NArray) end end end