#pragma once

#ifdef isfinite
#undef isfinite
#endif

#include <rice/Array.hpp>
#include <rice/Object.hpp>

using namespace Rice;

using torch::Device;
using torch::Scalar;
using torch::ScalarType;
using torch::Tensor;
using torch::QScheme;
using torch::Generator;
using torch::TensorOptions;
using torch::Layout;
using torch::MemoryFormat;
using torch::IntArrayRef;
using torch::TensorList;
using torch::Storage;

#define HANDLE_TH_ERRORS                                             \
  try {

#define END_HANDLE_TH_ERRORS                                         \
  } catch (const torch::Error& ex) {                                 \
    rb_raise(rb_eRuntimeError, "%s", ex.what_without_backtrace());   \
  } catch (const Rice::Exception& ex) {                              \
    rb_raise(ex.class_of(), "%s", ex.what());                        \
  } catch (const std::exception& ex) {                               \
    rb_raise(rb_eRuntimeError, "%s", ex.what());                     \
  }

#define RETURN_NIL                                                   \
  return Qnil;

template<>
inline
std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
{
  Array a = Array(x);
  std::vector<int64_t> vec(a.size());
  for (size_t i = 0; i < a.size(); i++) {
    vec[i] = from_ruby<int64_t>(a[i]);
  }
  return vec;
}

template<>
inline
std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
{
  Array a = Array(x);
  std::vector<Tensor> vec(a.size());
  for (size_t i = 0; i < a.size(); i++) {
    vec[i] = from_ruby<Tensor>(a[i]);
  }
  return vec;
}

class FanModeType {
  std::string s;
  public:
    FanModeType(Object o) {
      s = String(o).str();
    }
    operator torch::nn::init::FanModeType() {
      if (s == "fan_in") {
        return torch::kFanIn;
      } else if (s == "fan_out") {
        return torch::kFanOut;
      } else {
        throw std::runtime_error("Unsupported nonlinearity type: " + s);
      }
    }
};

template<>
inline
FanModeType from_ruby<FanModeType>(Object x)
{
  return FanModeType(x);
}

class NonlinearityType {
  std::string s;
  public:
    NonlinearityType(Object o) {
      s = String(o).str();
    }
    operator torch::nn::init::NonlinearityType() {
      if (s == "linear") {
        return torch::kLinear;
      } else if (s == "conv1d") {
        return torch::kConv1D;
      } else if (s == "conv2d") {
        return torch::kConv2D;
      } else if (s == "conv3d") {
        return torch::kConv3D;
      } else if (s == "conv_transpose1d") {
        return torch::kConvTranspose1D;
      } else if (s == "conv_transpose2d") {
        return torch::kConvTranspose2D;
      } else if (s == "conv_transpose3d") {
        return torch::kConvTranspose3D;
      } else if (s == "sigmoid") {
        return torch::kSigmoid;
      } else if (s == "tanh") {
        return torch::kTanh;
      } else if (s == "relu") {
        return torch::kReLU;
      } else if (s == "leaky_relu") {
        return torch::kLeakyReLU;
      } else {
        throw std::runtime_error("Unsupported nonlinearity type: " + s);
      }
    }
};

template<>
inline
NonlinearityType from_ruby<NonlinearityType>(Object x)
{
  return NonlinearityType(x);
}

class OptionalTensor {
  torch::Tensor value;
  public:
    OptionalTensor(Object o) {
      if (o.is_nil()) {
        value = {};
      } else {
        value = from_ruby<torch::Tensor>(o);
      }
    }
    OptionalTensor(torch::Tensor o) {
      value = o;
    }
    operator torch::Tensor() const {
      return value;
    }
};

template<>
inline
Scalar from_ruby<Scalar>(Object x)
{
  if (x.rb_type() == T_FIXNUM) {
    return torch::Scalar(from_ruby<int64_t>(x));
  } else {
    return torch::Scalar(from_ruby<double>(x));
  }
}

template<>
inline
OptionalTensor from_ruby<OptionalTensor>(Object x)
{
  return OptionalTensor(x);
}

template<>
inline
torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
{
  if (x.is_nil()) {
    return torch::nullopt;
  } else {
    return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
  }
}

template<>
inline
torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
{
  if (x.is_nil()) {
    return torch::nullopt;
  } else {
    return torch::optional<int64_t>{from_ruby<int64_t>(x)};
  }
}

template<>
inline
torch::optional<double> from_ruby<torch::optional<double>>(Object x)
{
  if (x.is_nil()) {
    return torch::nullopt;
  } else {
    return torch::optional<double>{from_ruby<double>(x)};
  }
}

template<>
inline
torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
{
  if (x.is_nil()) {
    return torch::nullopt;
  } else {
    return torch::optional<bool>{from_ruby<bool>(x)};
  }
}

template<>
inline
torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
{
  if (x.is_nil()) {
    return torch::nullopt;
  } else {
    return torch::optional<Scalar>{from_ruby<Scalar>(x)};
  }
}