#pragma once #include #include #include "math.h" namespace Eigen { namespace internal { template struct to_int_packet { typedef PacketType type; }; template struct to_float_packet { typedef PacketType type; }; } } #ifdef EIGEN_VECTORIZE_AVX #include #include "avx_gamma.h" namespace Eigen { namespace internal { template<> struct to_int_packet { typedef Packet8i type; }; template<> struct to_float_packet { typedef Packet8f type; }; EIGEN_STRONG_INLINE Packet8f p_to_f32(const Packet8i& a) { return _mm256_cvtepi32_ps(a); } EIGEN_STRONG_INLINE Packet8f p_bool2float(const Packet8f& a) { return _mm256_and_ps(_mm256_cmp_ps(a, _mm256_set1_ps(0), _CMP_NEQ_OQ), _mm256_set1_ps(1)); } EIGEN_STRONG_INLINE Packet8f p_bool2float(const Packet8i& a) { return p_bool2float(_mm256_castsi256_ps(a)); } } } #endif #ifdef EIGEN_VECTORIZE_SSE2 #include #include "sse_gamma.h" namespace Eigen { namespace internal { template<> struct to_int_packet { typedef Packet4i type; }; template<> struct to_float_packet { typedef Packet4f type; }; EIGEN_STRONG_INLINE Packet4f p_to_f32(const Packet4i& a) { return _mm_cvtepi32_ps(a); } EIGEN_STRONG_INLINE Packet4f p_bool2float(const Packet4f& a) { return _mm_and_ps(_mm_cmpneq_ps(a, _mm_set1_ps(0)), _mm_set1_ps(1)); } EIGEN_STRONG_INLINE Packet4f p_bool2float(const Packet4i& a) { return p_bool2float(_mm_castsi128_ps(a)); } } } #endif #ifdef EIGEN_VECTORIZE_NEON #include #include "neon_gamma.h" namespace Eigen { namespace internal { template<> struct to_int_packet { typedef Packet4i type; }; template<> struct to_float_packet { typedef Packet4f type; }; EIGEN_STRONG_INLINE Packet4f p_to_f32(const Packet4i& a) { return vcvtq_f32_s32(a); } EIGEN_STRONG_INLINE Packet4f p_bool2float(const Packet4f& a) { return vcvtq_f32_s32(vandq_s32((Packet4i)a, vdupq_n_s32(1))); } EIGEN_STRONG_INLINE Packet4f p_bool2float(const Packet4i& a) { return p_bool2float((Packet4f)vreinterpretq_f32_s32((int32x4_t)a)); } } } #endif namespace Eigen { namespace internal { template struct scalar_lgamma_subt_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_lgamma_subt_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (const Scalar& z, const Scalar2& a) const { return tomoto::math::lgammaSubt(z, a); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& z, const Packet& a) const { return lgamma_subt(z, a); } }; template struct functor_traits > { enum { Cost = HugeCost, PacketAccess = 1 }; }; template<> struct scalar_cast_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) typedef float result_type; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const float operator() (const int32_t& a) const { return cast(a); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const typename to_int_packet::type>::type& a) const { return p_to_f32(a); } }; template<> struct functor_traits > { enum { Cost = NumTraits::AddCost, PacketAccess = 1 }; }; template struct unary_evaluator, ArgType>, IndexBased > : evaluator_base, ArgType> > { typedef CwiseUnaryOp, ArgType> XprType; enum { CoeffReadCost = evaluator::CoeffReadCost + functor_traits>::Cost, Flags = evaluator::Flags & (HereditaryBits | LinearAccessBit | (functor_traits>::PacketAccess ? PacketAccessBit : 0)), Alignment = evaluator::Alignment }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit unary_evaluator(const XprType& op) : m_functor(op.functor()), m_argImpl(op.nestedExpression()) { EIGEN_INTERNAL_CHECK_COST_VALUE(NumTraits::AddCost); EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); } typedef typename XprType::CoeffReturnType CoeffReturnType; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index row, Index col) const { return m_functor(m_argImpl.coeff(row, col)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { return m_functor(m_argImpl.coeff(index)); } template EIGEN_STRONG_INLINE PacketType packet(Index row, Index col) const { return m_functor.packetOp(m_argImpl.template packet::type>(row, col)); } template EIGEN_STRONG_INLINE PacketType packet(Index index) const { return m_functor.packetOp(m_argImpl.template packet::type>(index)); } protected: const scalar_cast_op m_functor; evaluator m_argImpl; }; struct scalar_bool2float { EIGEN_EMPTY_STRUCT_CTOR(scalar_bool2float) typedef float result_type; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const float operator() (const int32_t& a) const { return a ? 1.f : 0.f; } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename to_float_packet::type packetOp(const Packet& a) const { return p_bool2float(a); } }; template<> struct functor_traits { enum { Cost = NumTraits::AddCost, PacketAccess = 1 }; }; template struct unary_evaluator, IndexBased > : evaluator_base > { typedef CwiseUnaryOp XprType; enum { CoeffReadCost = evaluator::CoeffReadCost + functor_traits::Cost, Flags = evaluator::Flags & (HereditaryBits | LinearAccessBit | (functor_traits::PacketAccess ? PacketAccessBit : 0)), Alignment = evaluator::Alignment }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit unary_evaluator(const XprType& op) : m_functor(op.functor()), m_argImpl(op.nestedExpression()) { EIGEN_INTERNAL_CHECK_COST_VALUE(NumTraits::AddCost); EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); } typedef typename XprType::CoeffReturnType CoeffReturnType; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index row, Index col) const { return m_functor(m_argImpl.coeff(row, col)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { return m_functor(m_argImpl.coeff(index)); } template EIGEN_STRONG_INLINE PacketType packet(Index row, Index col) const { return m_functor.packetOp(m_argImpl.template packet::type>(row, col)); } template EIGEN_STRONG_INLINE PacketType packet(Index index) const { return m_functor.packetOp(m_argImpl.template packet::type>(index)); } protected: const scalar_bool2float m_functor; evaluator m_argImpl; }; } template EIGEN_DEVICE_FUNC inline const CwiseBinaryOp::Scalar, T >, const Derived, const typename internal::plain_constant_type::type> lgamma_subt(const Eigen::ArrayBase& x, const T& scalar) { return CwiseBinaryOp::Scalar, T >, const Derived, const typename internal::plain_constant_type::type>(x.derived(), typename internal::plain_constant_type::type(x.derived().rows(), x.derived().cols(), internal::scalar_constant_op(scalar)) ); } template inline const CwiseBinaryOp, const Derived, const Derived2> lgamma_subt(const Eigen::ArrayBase& x, const Eigen::ArrayBase& y) { return CwiseBinaryOp, const Derived, const Derived2>( x.derived(), y.derived() ); } template inline const CwiseUnaryOp bool2float(const Eigen::ArrayBase& x) { return CwiseUnaryOp( x.derived() ); } }