lib/easy_ml/predict.rb in easy_ml-0.2.0.pre.rc41 vs lib/easy_ml/predict.rb in easy_ml-0.2.0.pre.rc43

- old
+ new

@@ -8,26 +8,41 @@ def initialize @models = {} end - def self.predict(model_name, df) + def self.predict(model_name, df, serialize: false) if df.is_a?(Hash) df = Polars::DataFrame.new(df) end - raw_input = df.to_hashes&.first + raw_input = df.to_hashes df = instance.normalize(model_name, df) + normalized_input = df.to_hashes preds = instance.predict(model_name, df) current_version = instance.get_model(model_name) - EasyML::Prediction.create!( - model: current_version.model, - model_history: current_version, - prediction_type: current_version.model.task, - prediction_value: preds.first, - raw_input: raw_input, - normalized_input: df.to_hashes&.first, - ) + output = preds.zip(raw_input, normalized_input).map do |pred, raw, norm| + EasyML::Prediction.create!( + model: current_version.model, + model_history: current_version, + prediction_type: current_version.model.task, + prediction_value: pred, + raw_input: raw, + normalized_input: norm, + ) + end + + output = if output.is_a?(Array) && output.count == 1 + output.first + else + output + end + + if serialize + EasyML::PredictionSerializer.new(output).serializable_hash + else + output + end end def self.train(model_name, tuner: nil, evaluator: nil) instance.train(model_name, tuner: tuner, evaluator: evaluator) end