lib/xlearn/dmatrix.rb in xlearn-0.1.0 vs lib/xlearn/dmatrix.rb in xlearn-0.1.1
- old
+ new
@@ -3,18 +3,34 @@
include Utils
def initialize(data, label: nil)
@handle = ::FFI::MemoryPointer.new(:pointer)
- nrow = data.count
- ncol = data.first.count
+ if matrix?(data)
+ nrow = data.row_count
+ ncol = data.column_count
+ flat_data = data.to_a.flatten
+ elsif daru?(data)
+ nrow, ncol = data.shape
+ flat_data = data.map_rows(&:to_a).flatten
+ elsif narray?(data)
+ nrow, ncol = data.shape
+ # TODO convert to SFloat and pass pointer
+ # for better performance
+ flat_data = data.flatten.to_a
+ else
+ nrow = data.count
+ ncol = data.first.count
+ flat_data = data.flatten
+ end
- c_data = ::FFI::MemoryPointer.new(:float, nrow * ncol)
- c_data.put_array_of_float(0, data.flatten)
+ c_data = ::FFI::MemoryPointer.new(:float, flat_data.size)
+ c_data.put_array_of_float(0, flat_data)
if label
- c_label = ::FFI::MemoryPointer.new(:float, nrow)
+ label = label.to_a
+ c_label = ::FFI::MemoryPointer.new(:float, label.size)
c_label.put_array_of_float(0, label)
end
# TODO support this
field_map = nil
@@ -28,8 +44,22 @@
end
def self.finalize(pointer)
# must use proc instead of stabby lambda
proc { FFI.XlearnDataFree(pointer) }
+ end
+
+ private
+
+ def matrix?(data)
+ defined?(Matrix) && data.is_a?(Matrix)
+ end
+
+ def daru?(data)
+ defined?(Daru::DataFrame) && data.is_a?(Daru::DataFrame)
+ end
+
+ def narray?(data)
+ defined?(Numo::NArray) && data.is_a?(Numo::NArray)
end
end
end