lib/easy_ml/predict.rb in easy_ml-0.2.0.pre.rc41 vs lib/easy_ml/predict.rb in easy_ml-0.2.0.pre.rc43
- old
+ new
@@ -8,26 +8,41 @@
def initialize
@models = {}
end
- def self.predict(model_name, df)
+ def self.predict(model_name, df, serialize: false)
if df.is_a?(Hash)
df = Polars::DataFrame.new(df)
end
- raw_input = df.to_hashes&.first
+ raw_input = df.to_hashes
df = instance.normalize(model_name, df)
+ normalized_input = df.to_hashes
preds = instance.predict(model_name, df)
current_version = instance.get_model(model_name)
- EasyML::Prediction.create!(
- model: current_version.model,
- model_history: current_version,
- prediction_type: current_version.model.task,
- prediction_value: preds.first,
- raw_input: raw_input,
- normalized_input: df.to_hashes&.first,
- )
+ output = preds.zip(raw_input, normalized_input).map do |pred, raw, norm|
+ EasyML::Prediction.create!(
+ model: current_version.model,
+ model_history: current_version,
+ prediction_type: current_version.model.task,
+ prediction_value: pred,
+ raw_input: raw,
+ normalized_input: norm,
+ )
+ end
+
+ output = if output.is_a?(Array) && output.count == 1
+ output.first
+ else
+ output
+ end
+
+ if serialize
+ EasyML::PredictionSerializer.new(output).serializable_hash
+ else
+ output
+ end
end
def self.train(model_name, tuner: nil, evaluator: nil)
instance.train(model_name, tuner: tuner, evaluator: evaluator)
end