lib/xgb/dmatrix.rb in xgb-0.1.0 vs lib/xgb/dmatrix.rb in xgb-0.1.1

- old
+ new

@@ -1,37 +1,118 @@ module Xgb class DMatrix - attr_reader :data, :label, :weight + attr_reader :data def initialize(data, label: nil, weight: nil, missing: Float::NAN) @data = data - @label = label - @weight = weight - c_data = ::FFI::MemoryPointer.new(:float, data.count * data.first.count) - c_data.put_array_of_float(0, data.flatten) @handle = ::FFI::MemoryPointer.new(:pointer) - check_result FFI.XGDMatrixCreateFromMat(c_data, data.count, data.first.count, missing, @handle) + if data + if matrix?(data) + nrow = data.row_count + ncol = data.column_count + flat_data = data.to_a.flatten + elsif daru?(data) + nrow, ncol = data.shape + flat_data = data.each_vector.map(&:to_a).flatten + elsif narray?(data) + nrow, ncol = data.shape + flat_data = data.flatten.to_a + else + nrow = data.count + ncol = data.first.count + flat_data = data.flatten + end + + c_data = ::FFI::MemoryPointer.new(:float, nrow * ncol) + c_data.put_array_of_float(0, flat_data) + check_result FFI.XGDMatrixCreateFromMat(c_data, nrow, ncol, missing, @handle) + end + set_float_info("label", label) if label + set_float_info("weight", weight) if weight end + def label + float_info("label") + end + + def weight + float_info("weight") + end + + def num_row + out = ::FFI::MemoryPointer.new(:ulong) + check_result FFI.XGDMatrixNumRow(handle_pointer, out) + out.read_ulong + end + def num_col - out = ::FFI::MemoryPointer.new(:long) - FFI.XGDMatrixNumCol(handle_pointer, out) - out.read_long + out = ::FFI::MemoryPointer.new(:ulong) + check_result FFI.XGDMatrixNumCol(handle_pointer, out) + out.read_ulong end + def slice(rindex) + res = DMatrix.new(nil) + idxset = ::FFI::MemoryPointer.new(:int, rindex.count) + idxset.put_array_of_int(0, rindex) + check_result FFI.XGDMatrixSliceDMatrix(handle_pointer, idxset, rindex.size, res.handle) + res + end + + def save_binary(fname, silent: true) + check_result FFI.XGDMatrixSaveBinary(handle_pointer, fname, silent ? 1 : 0) + end + + def handle + @handle + end + def handle_pointer @handle.read_pointer end private def set_float_info(field, data) - c_data = ::FFI::MemoryPointer.new(:float, data.count) + data = + if matrix?(data) + data.to_a[0] + elsif daru_vector?(data) || narray?(data) + data.to_a + else + data + end + + c_data = ::FFI::MemoryPointer.new(:float, data.size) c_data.put_array_of_float(0, data) check_result FFI.XGDMatrixSetFloatInfo(handle_pointer, field.to_s, c_data, data.size) + end + + def float_info(field) + num_row ||= num_row() + out_len = ::FFI::MemoryPointer.new(:int) + out_dptr = ::FFI::MemoryPointer.new(:float, num_row) + check_result FFI.XGDMatrixGetFloatInfo(handle_pointer, field, out_len, out_dptr) + out_dptr.read_pointer.read_array_of_float(num_row) + end + + def matrix?(data) + defined?(Matrix) && data.is_a?(Matrix) + end + + def daru?(data) + defined?(Daru::DataFrame) && data.is_a?(Daru::DataFrame) + end + + def daru_vector?(data) + defined?(Daru::Vector) && data.is_a?(Daru::Vector) + end + + def narray?(data) + defined?(Numo::NArray) && data.is_a?(Numo::NArray) end include Utils end end