vendor/tomotopy/src/Utils/TruncMultiNormal.hpp in tomoto-0.1.3 vs vendor/tomotopy/src/Utils/TruncMultiNormal.hpp in tomoto-0.1.4
- old
+ new
@@ -12,60 +12,58 @@
_Out ret,
const MultiNormalDistribution<_Ty>& multiNormal,
const Eigen::Matrix<_Ty, -1, 1>& lowerBound,
const Eigen::Matrix<_Ty, -1, 1>& upperBound,
_Rng& rng,
- size_t iteration)
+ size_t burnIn
+ )
{
- constexpr _Ty epsilon = 1e-6;
const size_t K = ret.size();
- Eigen::Matrix<_Ty, -1, 1> bias = Eigen::Matrix<_Ty, -1, 1>::Zero(K), lowers, uppers;
- auto& l = multiNormal.getCovL();
- ret.setZero();
-
- std::vector<size_t> ks(K);
- std::iota(ks.begin(), ks.end(), 0);
- for (size_t i = 0; i < iteration; ++i)
+ Eigen::Matrix<_Ty, -1, -1> l = multiNormal.getCovL();
+ ret = (lowerBound + upperBound) / 2;
+ Eigen::Matrix<_Ty, -1, 1> z = l.template triangularView<Eigen::Lower>().solve(ret - multiNormal.mean),
+ a = lowerBound - multiNormal.mean,
+ b = upperBound - multiNormal.mean,
+ t, at, bt;
+ for (size_t i = 0; i < burnIn; ++i)
{
- // shuffle sampling orders except during initialization
- if (i) std::shuffle(ks.begin(), ks.end(), rng);
- for (size_t kx = 0; kx < K; ++kx)
+ for (size_t j = 0; j < K; ++j)
{
- size_t k = ks[kx];
- ret[k] = 0;
- //bias = multiNormal.mean + l * ret;
- //bias.tail(K - k) = multiNormal.mean.tail(K - k) + l.block(k, 0, K - k, K) * ret;
- bias.tail(K - k) = multiNormal.mean.tail(K - k);
- bias.tail(K - k).noalias() += l.block(k, 0, K - k, K) * ret;
- lowers = (lowerBound - bias).tail(K - k).array() / l.col(k).tail(K - k).array();
- uppers = (upperBound - bias).tail(K - k).array() / l.col(k).tail(K - k).array();
- _Ty nLower = lowers[0], nUpper = uppers[0];
- if (l(k, k) < 0) std::swap(nLower, nUpper);
- if (i)
+ auto lj = l.col(j);
+ z[j] = 0;
+ t = l * z;
+ _Ty lower_pos = -INFINITY, upper_pos = INFINITY,
+ lower_neg = -INFINITY, upper_neg = INFINITY;
+ at = ((a - t).array() / lj.array()).matrix();
+ bt = ((b - t).array() / lj.array()).matrix();
+ for (size_t k = 0; k < K; ++k)
{
- for (size_t j = 1; j < lowers.size(); ++j)
+ if (lj[k] > 0)
{
- if (l.col(k)(j + k) > epsilon)
- {
- if (lowers[j] > nLower) nLower = lowers[j];
- if (uppers[j] < nUpper) nUpper = uppers[j];
- }
- else if (l.col(k)(j + k) < -epsilon)
- {
- if (uppers[j] > nLower) nLower = uppers[j];
- if (lowers[j] < nUpper) nUpper = lowers[j];
- }
+ lower_pos = std::max(lower_pos, at[k]);
+ upper_pos = std::min(upper_pos, bt[k]);
}
+ else if (lj[k] < 0)
+ {
+ lower_neg = std::max(lower_neg, bt[k]);
+ upper_neg = std::min(upper_neg, at[k]);
+ }
}
- if (abs(nLower - nUpper) <= 1e-4) ret[k] = (nLower + nUpper) / 2;
+ lower_pos = std::max(lower_pos, lower_neg);
+ upper_pos = std::min(upper_pos, upper_neg);
+ // this is due to numerical instability
+ if (lower_pos >= upper_pos)
+ {
+ std::cerr << __FILE__ << "(" << __LINE__ << "): wrong truncation range [" << lower_pos << ", " << upper_pos << "]" << std::endl;
+ z[j] = (lower_pos + upper_pos) / 2;
+ }
else
{
- ret[k] = rtnorm::rtnorm(rng, nLower, nUpper);
+ z[j] = rtnorm::rtnorm(rng, lower_pos, upper_pos);
}
}
}
- ret = l * ret;
- ret += multiNormal.mean;
+ ret = (l * z) + multiNormal.mean;
return ret;
}
template<typename _Ty, typename _Out, typename _Rng>
_Out sampleFromTruncatedMultiNormalRejection(
\ No newline at end of file