lib/cerebrum/cerebrum.rb in cerebrum-0.1.2 vs lib/cerebrum/cerebrum.rb in cerebrum-0.1.3
- old
+ new
@@ -1,7 +1,8 @@
require_relative "data_scrubber"
require_relative "cerebrum_helper"
+require 'json'
class Cerebrum
include CerebrumHelper
include DataScrubber
@@ -23,12 +24,12 @@
adjust_weights(learning_rate)
mean_squared_error(@errors[@layers])
end
def train(training_set, options = Hash.new)
- @input_lookup_table ||= get_input_lookup_table(training_set)
- @output_lookup_table ||= get_output_lookup_table(training_set)
+ @input_lookup_table ||= get_input_lookup_table(training_set)
+ @output_lookup_table ||= get_output_lookup_table(training_set)
training_set = scrub_dataset(training_set)
iterations = options[:iterations] || 20000
error_threshold = options[:error_threshold] || 0.005
log = options[:log] || false
@@ -58,9 +59,47 @@
def run(input)
input = to_vector_given_features(input, @input_lookup_table) if @input_lookup_table
output = run_input(input)
@output_lookup_table ? to_features_given_vector(output, @output_lookup_table) : output
+ end
+
+ def save_state
+ {
+ biases: @biases,
+ binary_thresh: @binary_thresh,
+ changes: @changes,
+ deltas: @deltas,
+ errors: @errors,
+ hidden_layers: @hidden_layers,
+ input_lookup_table: @input_lookup_table,
+ layer_sizes: @layer_sizes,
+ layers: @layers,
+ learning_rate: @learning_rate,
+ momentum: @momentum,
+ output_lookup_table: @output_lookup_table,
+ outputs: @outputs,
+ weights: @weights
+ }.to_json
+ end
+
+ def load_state(saved_state)
+ state = JSON.parse(saved_state, symbolize_names: true)
+
+ @biases = state[:biases]
+ @binary_thresh = state[:binary_thresh]
+ @changes = state[:changes]
+ @deltas = state[:deltas]
+ @errors = state[:errors]
+ @hidden_layers = state[:hidden_layers]
+ @input_lookup_table = state[:input_lookup_table]
+ @layer_sizes = state[:layer_sizes]
+ @layers = state[:layers]
+ @learning_rate = state[:learning_rate]
+ @momentum = state[:momentum]
+ @output_lookup_table = state[:output_lookup_table]
+ @outputs = state[:outputs]
+ @weights = state[:weights]
end
private
def construct_network(layer_sizes)