lib/svmkit/dataset.rb in svmkit-0.7.0 vs lib/svmkit/dataset.rb in svmkit-0.7.1

- old
+ new

@@ -1,7 +1,9 @@ # frozen_string_literal: true +require 'csv' + module SVMKit # Module for loading and saving a dataset file. module Dataset class << self # Load a dataset with the libsvm file format into Numo::NArray. @@ -14,15 +16,15 @@ # and (n_samples) vector for labels or target values. def load_libsvm_file(filename, zero_based: false) ftvecs = [] labels = [] n_features = 0 - File.read(filename).split("\n").each do |line| + CSV.foreach(filename, col_sep: "\s", headers: false) do |line| label, ftvec, max_idx = parse_libsvm_line(line, zero_based) labels.push(label) ftvecs.push(ftvec) - n_features = [n_features, max_idx].max + n_features = max_idx if n_features < max_idx end [convert_to_matrix(ftvecs, n_features), Numo::NArray.asarray(labels)] end # Dump the dataset with the libsvm file format. @@ -46,19 +48,20 @@ end private def parse_libsvm_line(line, zero_based) - tokens = line.split - label = parse_label(tokens.shift) - ftvec = tokens.map do |el| + label = parse_label(line.shift) + adj_idx = zero_based == false ? 1 : 0 + max_idx = -1 + ftvec = [] + while (el = line.shift) idx, val = el.split(':') - idx = idx.to_i - (zero_based == false ? 1 : 0) + idx = idx.to_i - adj_idx val = val.to_i.to_s == val ? val.to_i : val.to_f - [idx, val] + max_idx = idx if max_idx < idx + ftvec.push([idx, val]) end - max_idx = ftvec.map { |el| el[0] }.max - max_idx ||= 0 [label, ftvec, max_idx] end def parse_label(label) lbl_arr = label.split(',').map { |lbl| lbl.to_i.to_s == lbl ? lbl.to_i : lbl.to_f }