vendor/tomotopy/src/Utils/TruncMultiNormal.hpp in tomoto-0.1.3 vs vendor/tomotopy/src/Utils/TruncMultiNormal.hpp in tomoto-0.1.4

- old
+ new

@@ -12,60 +12,58 @@ _Out ret, const MultiNormalDistribution<_Ty>& multiNormal, const Eigen::Matrix<_Ty, -1, 1>& lowerBound, const Eigen::Matrix<_Ty, -1, 1>& upperBound, _Rng& rng, - size_t iteration) + size_t burnIn + ) { - constexpr _Ty epsilon = 1e-6; const size_t K = ret.size(); - Eigen::Matrix<_Ty, -1, 1> bias = Eigen::Matrix<_Ty, -1, 1>::Zero(K), lowers, uppers; - auto& l = multiNormal.getCovL(); - ret.setZero(); - - std::vector<size_t> ks(K); - std::iota(ks.begin(), ks.end(), 0); - for (size_t i = 0; i < iteration; ++i) + Eigen::Matrix<_Ty, -1, -1> l = multiNormal.getCovL(); + ret = (lowerBound + upperBound) / 2; + Eigen::Matrix<_Ty, -1, 1> z = l.template triangularView<Eigen::Lower>().solve(ret - multiNormal.mean), + a = lowerBound - multiNormal.mean, + b = upperBound - multiNormal.mean, + t, at, bt; + for (size_t i = 0; i < burnIn; ++i) { - // shuffle sampling orders except during initialization - if (i) std::shuffle(ks.begin(), ks.end(), rng); - for (size_t kx = 0; kx < K; ++kx) + for (size_t j = 0; j < K; ++j) { - size_t k = ks[kx]; - ret[k] = 0; - //bias = multiNormal.mean + l * ret; - //bias.tail(K - k) = multiNormal.mean.tail(K - k) + l.block(k, 0, K - k, K) * ret; - bias.tail(K - k) = multiNormal.mean.tail(K - k); - bias.tail(K - k).noalias() += l.block(k, 0, K - k, K) * ret; - lowers = (lowerBound - bias).tail(K - k).array() / l.col(k).tail(K - k).array(); - uppers = (upperBound - bias).tail(K - k).array() / l.col(k).tail(K - k).array(); - _Ty nLower = lowers[0], nUpper = uppers[0]; - if (l(k, k) < 0) std::swap(nLower, nUpper); - if (i) + auto lj = l.col(j); + z[j] = 0; + t = l * z; + _Ty lower_pos = -INFINITY, upper_pos = INFINITY, + lower_neg = -INFINITY, upper_neg = INFINITY; + at = ((a - t).array() / lj.array()).matrix(); + bt = ((b - t).array() / lj.array()).matrix(); + for (size_t k = 0; k < K; ++k) { - for (size_t j = 1; j < lowers.size(); ++j) + if (lj[k] > 0) { - if (l.col(k)(j + k) > epsilon) - { - if (lowers[j] > nLower) nLower = lowers[j]; - if (uppers[j] < nUpper) nUpper = uppers[j]; - } - else if (l.col(k)(j + k) < -epsilon) - { - if (uppers[j] > nLower) nLower = uppers[j]; - if (lowers[j] < nUpper) nUpper = lowers[j]; - } + lower_pos = std::max(lower_pos, at[k]); + upper_pos = std::min(upper_pos, bt[k]); } + else if (lj[k] < 0) + { + lower_neg = std::max(lower_neg, bt[k]); + upper_neg = std::min(upper_neg, at[k]); + } } - if (abs(nLower - nUpper) <= 1e-4) ret[k] = (nLower + nUpper) / 2; + lower_pos = std::max(lower_pos, lower_neg); + upper_pos = std::min(upper_pos, upper_neg); + // this is due to numerical instability + if (lower_pos >= upper_pos) + { + std::cerr << __FILE__ << "(" << __LINE__ << "): wrong truncation range [" << lower_pos << ", " << upper_pos << "]" << std::endl; + z[j] = (lower_pos + upper_pos) / 2; + } else { - ret[k] = rtnorm::rtnorm(rng, nLower, nUpper); + z[j] = rtnorm::rtnorm(rng, lower_pos, upper_pos); } } } - ret = l * ret; - ret += multiNormal.mean; + ret = (l * z) + multiNormal.mean; return ret; } template<typename _Ty, typename _Out, typename _Rng> _Out sampleFromTruncatedMultiNormalRejection( \ No newline at end of file