lib/svmkit/dataset.rb in svmkit-0.7.0 vs lib/svmkit/dataset.rb in svmkit-0.7.1
- old
+ new
@@ -1,7 +1,9 @@
# frozen_string_literal: true
+require 'csv'
+
module SVMKit
# Module for loading and saving a dataset file.
module Dataset
class << self
# Load a dataset with the libsvm file format into Numo::NArray.
@@ -14,15 +16,15 @@
# and (n_samples) vector for labels or target values.
def load_libsvm_file(filename, zero_based: false)
ftvecs = []
labels = []
n_features = 0
- File.read(filename).split("\n").each do |line|
+ CSV.foreach(filename, col_sep: "\s", headers: false) do |line|
label, ftvec, max_idx = parse_libsvm_line(line, zero_based)
labels.push(label)
ftvecs.push(ftvec)
- n_features = [n_features, max_idx].max
+ n_features = max_idx if n_features < max_idx
end
[convert_to_matrix(ftvecs, n_features), Numo::NArray.asarray(labels)]
end
# Dump the dataset with the libsvm file format.
@@ -46,19 +48,20 @@
end
private
def parse_libsvm_line(line, zero_based)
- tokens = line.split
- label = parse_label(tokens.shift)
- ftvec = tokens.map do |el|
+ label = parse_label(line.shift)
+ adj_idx = zero_based == false ? 1 : 0
+ max_idx = -1
+ ftvec = []
+ while (el = line.shift)
idx, val = el.split(':')
- idx = idx.to_i - (zero_based == false ? 1 : 0)
+ idx = idx.to_i - adj_idx
val = val.to_i.to_s == val ? val.to_i : val.to_f
- [idx, val]
+ max_idx = idx if max_idx < idx
+ ftvec.push([idx, val])
end
- max_idx = ftvec.map { |el| el[0] }.max
- max_idx ||= 0
[label, ftvec, max_idx]
end
def parse_label(label)
lbl_arr = label.split(',').map { |lbl| lbl.to_i.to_s == lbl ? lbl.to_i : lbl.to_f }