lib/torch.rb in torch-rb-0.2.0 vs lib/torch.rb in torch-rb-0.2.1
- old
+ new
@@ -315,16 +315,15 @@
def device(str)
Device.new(str)
end
def save(obj, f)
- raise NotImplementedYet unless obj.is_a?(Tensor)
- File.binwrite(f, _save(obj))
+ File.binwrite(f, _save(to_ivalue(obj)))
end
def load(f)
- raise NotImplementedYet
+ to_ruby(_load(File.binread(f)))
end
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
def arange(start, finish = nil, step = 1, **options)
@@ -444,9 +443,97 @@
def zeros_like(input, **options)
zeros(input.size, **like_options(input, options))
end
private
+
+ def to_ivalue(obj)
+ case obj
+ when String
+ IValue.from_string(obj)
+ when Integer
+ IValue.from_int(obj)
+ when Tensor
+ IValue.from_tensor(obj)
+ when Float
+ IValue.from_double(obj)
+ when Hash
+ dict = {}
+ obj.each do |k, v|
+ dict[to_ivalue(k)] = to_ivalue(v)
+ end
+ IValue.from_dict(dict)
+ when true, false
+ IValue.from_bool(obj)
+ when nil
+ IValue.new
+ else
+ raise Error, "Unknown type: #{obj.class.name}"
+ end
+ end
+
+ def to_ruby(ivalue)
+ if ivalue.bool?
+ ivalue.to_bool
+ elsif ivalue.double?
+ ivalue.to_double
+ elsif ivalue.int?
+ ivalue.to_int
+ elsif ivalue.none?
+ nil
+ elsif ivalue.string?
+ ivalue.to_string_ref
+ elsif ivalue.tensor?
+ ivalue.to_tensor
+ elsif ivalue.generic_dict?
+ dict = {}
+ ivalue.to_generic_dict.each do |k, v|
+ dict[to_ruby(k)] = to_ruby(v)
+ end
+ dict
+ else
+ type =
+ if ivalue.capsule?
+ "Capsule"
+ elsif ivalue.custom_class?
+ "CustomClass"
+ elsif ivalue.tuple?
+ "Tuple"
+ elsif ivalue.future?
+ "Future"
+ elsif ivalue.r_ref?
+ "RRef"
+ elsif ivalue.int_list?
+ "IntList"
+ elsif ivalue.double_list?
+ "DoubleList"
+ elsif ivalue.bool_list?
+ "BoolList"
+ elsif ivalue.tensor_list?
+ "TensorList"
+ elsif ivalue.list?
+ "List"
+ elsif ivalue.object?
+ "Object"
+ elsif ivalue.module?
+ "Module"
+ elsif ivalue.py_object?
+ "PyObject"
+ elsif ivalue.scalar?
+ "Scalar"
+ elsif ivalue.device?
+ "Device"
+ # elsif ivalue.generator?
+ # "Generator"
+ elsif ivalue.ptr_type?
+ "PtrType"
+ else
+ "Unknown"
+ end
+
+ raise Error, "Unsupported type: #{type}"
+ end
+ end
def tensor_size(size)
size.flatten
end