lib/torch.rb in torch-rb-0.2.0 vs lib/torch.rb in torch-rb-0.2.1

- old
+ new

@@ -315,16 +315,15 @@ def device(str) Device.new(str) end def save(obj, f) - raise NotImplementedYet unless obj.is_a?(Tensor) - File.binwrite(f, _save(obj)) + File.binwrite(f, _save(to_ivalue(obj))) end def load(f) - raise NotImplementedYet + to_ruby(_load(File.binread(f))) end # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html --- def arange(start, finish = nil, step = 1, **options) @@ -444,9 +443,97 @@ def zeros_like(input, **options) zeros(input.size, **like_options(input, options)) end private + + def to_ivalue(obj) + case obj + when String + IValue.from_string(obj) + when Integer + IValue.from_int(obj) + when Tensor + IValue.from_tensor(obj) + when Float + IValue.from_double(obj) + when Hash + dict = {} + obj.each do |k, v| + dict[to_ivalue(k)] = to_ivalue(v) + end + IValue.from_dict(dict) + when true, false + IValue.from_bool(obj) + when nil + IValue.new + else + raise Error, "Unknown type: #{obj.class.name}" + end + end + + def to_ruby(ivalue) + if ivalue.bool? + ivalue.to_bool + elsif ivalue.double? + ivalue.to_double + elsif ivalue.int? + ivalue.to_int + elsif ivalue.none? + nil + elsif ivalue.string? + ivalue.to_string_ref + elsif ivalue.tensor? + ivalue.to_tensor + elsif ivalue.generic_dict? + dict = {} + ivalue.to_generic_dict.each do |k, v| + dict[to_ruby(k)] = to_ruby(v) + end + dict + else + type = + if ivalue.capsule? + "Capsule" + elsif ivalue.custom_class? + "CustomClass" + elsif ivalue.tuple? + "Tuple" + elsif ivalue.future? + "Future" + elsif ivalue.r_ref? + "RRef" + elsif ivalue.int_list? + "IntList" + elsif ivalue.double_list? + "DoubleList" + elsif ivalue.bool_list? + "BoolList" + elsif ivalue.tensor_list? + "TensorList" + elsif ivalue.list? + "List" + elsif ivalue.object? + "Object" + elsif ivalue.module? + "Module" + elsif ivalue.py_object? + "PyObject" + elsif ivalue.scalar? + "Scalar" + elsif ivalue.device? + "Device" + # elsif ivalue.generator? + # "Generator" + elsif ivalue.ptr_type? + "PtrType" + else + "Unknown" + end + + raise Error, "Unsupported type: #{type}" + end + end def tensor_size(size) size.flatten end