lib/prophet/stan_backend.rb in prophet-rb-0.4.0 vs lib/prophet/stan_backend.rb in prophet-rb-0.4.1
- old
+ new
@@ -11,10 +11,15 @@
CmdStan::Model.new(exe_file: model_file)
end
def fit(stan_init, stan_data, **kwargs)
stan_init, stan_data = prepare_data(stan_init, stan_data)
+
+ if !kwargs[:inits] && kwargs[:init]
+ kwargs[:inits] = prepare_data(kwargs.delete(:init), stan_data)[0]
+ end
+
kwargs[:algorithm] ||= stan_data["T"] < 100 ? "Newton" : "LBFGS"
iterations = 10000
stan_fit = nil
begin
@@ -47,10 +52,14 @@
end
def sampling(stan_init, stan_data, samples, **kwargs)
stan_init, stan_data = prepare_data(stan_init, stan_data)
+ if !kwargs[:inits] && kwargs[:init]
+ kwargs[:inits] = prepare_data(kwargs.delete(:init), stan_data)[0]
+ end
+
kwargs[:chains] ||= 4
kwargs[:warmup_iters] ||= samples / 2
stan_fit = @model.sample(
data: stan_data,
@@ -126,10 +135,10 @@
stan_data["t"] = stan_data["t"].to_a
stan_data["cap"] = stan_data["cap"].to_a
stan_data["t_change"] = stan_data["t_change"].to_a
stan_data["s_a"] = stan_data["s_a"].to_a
stan_data["s_m"] = stan_data["s_m"].to_a
- stan_data["X"] = stan_data["X"].to_numo.to_a
+ stan_data["X"] = stan_data["X"].respond_to?(:to_numo) ? stan_data["X"].to_numo.to_a : stan_data["X"].to_a
stan_init["delta"] = stan_init["delta"].to_a
stan_init["beta"] = stan_init["beta"].to_a
[stan_init, stan_data]
end