lib/prophet/stan_backend.rb in prophet-rb-0.4.0 vs lib/prophet/stan_backend.rb in prophet-rb-0.4.1

- old
+ new

@@ -11,10 +11,15 @@ CmdStan::Model.new(exe_file: model_file) end def fit(stan_init, stan_data, **kwargs) stan_init, stan_data = prepare_data(stan_init, stan_data) + + if !kwargs[:inits] && kwargs[:init] + kwargs[:inits] = prepare_data(kwargs.delete(:init), stan_data)[0] + end + kwargs[:algorithm] ||= stan_data["T"] < 100 ? "Newton" : "LBFGS" iterations = 10000 stan_fit = nil begin @@ -47,10 +52,14 @@ end def sampling(stan_init, stan_data, samples, **kwargs) stan_init, stan_data = prepare_data(stan_init, stan_data) + if !kwargs[:inits] && kwargs[:init] + kwargs[:inits] = prepare_data(kwargs.delete(:init), stan_data)[0] + end + kwargs[:chains] ||= 4 kwargs[:warmup_iters] ||= samples / 2 stan_fit = @model.sample( data: stan_data, @@ -126,10 +135,10 @@ stan_data["t"] = stan_data["t"].to_a stan_data["cap"] = stan_data["cap"].to_a stan_data["t_change"] = stan_data["t_change"].to_a stan_data["s_a"] = stan_data["s_a"].to_a stan_data["s_m"] = stan_data["s_m"].to_a - stan_data["X"] = stan_data["X"].to_numo.to_a + stan_data["X"] = stan_data["X"].respond_to?(:to_numo) ? stan_data["X"].to_numo.to_a : stan_data["X"].to_a stan_init["delta"] = stan_init["delta"].to_a stan_init["beta"] = stan_init["beta"].to_a [stan_init, stan_data] end