lib/cerebrum/cerebrum.rb in cerebrum-0.1.2 vs lib/cerebrum/cerebrum.rb in cerebrum-0.1.3

- old
+ new

@@ -1,7 +1,8 @@ require_relative "data_scrubber" require_relative "cerebrum_helper" +require 'json' class Cerebrum include CerebrumHelper include DataScrubber @@ -23,12 +24,12 @@ adjust_weights(learning_rate) mean_squared_error(@errors[@layers]) end def train(training_set, options = Hash.new) - @input_lookup_table ||= get_input_lookup_table(training_set) - @output_lookup_table ||= get_output_lookup_table(training_set) + @input_lookup_table ||= get_input_lookup_table(training_set) + @output_lookup_table ||= get_output_lookup_table(training_set) training_set = scrub_dataset(training_set) iterations = options[:iterations] || 20000 error_threshold = options[:error_threshold] || 0.005 log = options[:log] || false @@ -58,9 +59,47 @@ def run(input) input = to_vector_given_features(input, @input_lookup_table) if @input_lookup_table output = run_input(input) @output_lookup_table ? to_features_given_vector(output, @output_lookup_table) : output + end + + def save_state + { + biases: @biases, + binary_thresh: @binary_thresh, + changes: @changes, + deltas: @deltas, + errors: @errors, + hidden_layers: @hidden_layers, + input_lookup_table: @input_lookup_table, + layer_sizes: @layer_sizes, + layers: @layers, + learning_rate: @learning_rate, + momentum: @momentum, + output_lookup_table: @output_lookup_table, + outputs: @outputs, + weights: @weights + }.to_json + end + + def load_state(saved_state) + state = JSON.parse(saved_state, symbolize_names: true) + + @biases = state[:biases] + @binary_thresh = state[:binary_thresh] + @changes = state[:changes] + @deltas = state[:deltas] + @errors = state[:errors] + @hidden_layers = state[:hidden_layers] + @input_lookup_table = state[:input_lookup_table] + @layer_sizes = state[:layer_sizes] + @layers = state[:layers] + @learning_rate = state[:learning_rate] + @momentum = state[:momentum] + @output_lookup_table = state[:output_lookup_table] + @outputs = state[:outputs] + @weights = state[:weights] end private def construct_network(layer_sizes)