lib/svmkit/dataset.rb in svmkit-0.3.1 vs lib/svmkit/dataset.rb in svmkit-0.3.2
- old
+ new
@@ -31,26 +31,27 @@
# @param labels [Numo::NArray] (shape: [n_samples]) matrix consisting of labels or target values.
# @param filename [String] A path to the output libsvm file.
# @param zero_based [Boolean] Whether the column index starts from 0 (true) or 1 (false).
def dump_libsvm_file(data, labels, filename, zero_based: false)
n_samples = [data.shape[0], labels.shape[0]].min
+ single_label = labels.shape[1].nil?
label_type = detect_dtype(labels)
value_type = detect_dtype(data)
File.open(filename, 'w') do |file|
n_samples.times do |n|
- file.puts(dump_libsvm_line(labels[n], data[n, true],
+ label = single_label ? labels[n] : labels[n, true].to_a
+ file.puts(dump_libsvm_line(label, data[n, true],
label_type, value_type, zero_based))
end
end
end
private
def parse_libsvm_line(line, zero_based)
tokens = line.split
- label = tokens.shift
- label = label.to_i.to_s == label ? label.to_i : label.to_f
+ label = parse_label(tokens.shift)
ftvec = tokens.map do |el|
idx, val = el.split(':')
idx = idx.to_i - (zero_based == false ? 1 : 0)
val = val.to_i.to_s == val ? val.to_i : val.to_f
[idx, val]
@@ -58,10 +59,15 @@
max_idx = ftvec.map { |el| el[0] }.max
max_idx ||= 0
[label, ftvec, max_idx]
end
+ def parse_label(label)
+ lbl_arr = label.split(',').map { |lbl| lbl.to_i.to_s == lbl ? lbl.to_i : lbl.to_f }
+ lbl_arr.size > 1 ? lbl_arr : lbl_arr[0]
+ end
+
def convert_to_matrix(data, n_features)
mat = []
data.each do |ft|
vec = Array.new(n_features) { 0 }
ft.each { |el| vec[el[0]] = el[1] }
@@ -78,15 +84,23 @@
type = '%.10g' if ['Numo::SFloat', 'Numo::DFloat'].include?(arr_type_str)
type
end
def dump_libsvm_line(label, ftvec, label_type, value_type, zero_based)
- line = format(label_type.to_s, label)
+ line = dump_label(label, label_type.to_s)
ftvec.to_a.each_with_index do |val, n|
idx = n + (zero_based == false ? 1 : 0)
line += format(" %d:#{value_type}", idx, val) if val != 0.0
end
line
+ end
+
+ def dump_label(label, label_type_str)
+ if label.is_a?(Array)
+ label.map { |lbl| format(label_type_str, lbl) }.join(',')
+ else
+ format(label_type_str, label)
+ end
end
end
end
end