lib/bayesnet/factor.rb in bayesnet-0.1.0 vs lib/bayesnet/factor.rb in bayesnet-0.6.0
- old
+ new
@@ -1,91 +1,176 @@
# frozen_string_literal: true
module Bayesnet
- # Factor if a function of sevaral variables (A, B, ...) each defined on values from finite set
+ # Factor if a function of several variables (A, B, ...), where
+ # every variable cold take values from some finite set
class Factor
+ # +++ Factor DSL +++
+ #
+ # Factor DSL entry point:
def self.build(&block)
factor = new
factor.instance_eval(&block)
factor
end
- # Specifies variable name together with its values
- def scope(var_name_to_values)
- @scope.merge!(var_name_to_values)
+ # Factor DSL
+ # Defining variable with list of its possible values looks like:
+ # ```
+ # Bayesnet::Factor.build do
+ # scope weather: %i[sunny cloudy]
+ # scope mood: %i[bad good]
+ # ...
+ # ```
+ # ^ this code defines to variables `weather` and `mood`, where
+ # `weather` could be :sunny or :cloudy, and
+ # `mood` could be :bad or :good
+ def scope(var_name_to_values = nil)
+ if var_name_to_values
+ @scope.merge!(var_name_to_values)
+ else
+ @scope
+ end
end
- # Specifies value for a scope context. Value is the last element in `context_and_val`
+ # Factor DSL
+ # Specifies factor value for some set of variable values, i.e.
+ # ```
+ # Bayesnet::Factor.build do
+ # scope weather: %i[sunny cloudy]
+ # scope mood: %i[bad good]
+ # val :sunny, :bad, 0.1
+ # ...
+ # ```
+ # ^ this code says the value of factor for [weather == :sunny, mood == :bad] is 0.1
def val(*context_and_val)
context_and_val = context_and_val[0] if context_and_val.size == 1 && context_and_val[0].is_a?(Array)
@vals[context_and_val[0..-2]] = context_and_val[-1]
end
+ # --- Factor DSL ---
+ # List of variable names
def var_names
@scope.keys
end
+ # accessor factor value, i.e
+ # ```
+ # factor = Bayesnet::Factor.build do
+ # scope weather: %i[sunny cloudy]
+ # scope mood: %i[bad good]
+ # val :sunny, :bad, 0.1
+ # ...
+ # end
+ # factor[:sunny, :bad] # 0.1
+ # ```
def [](*context)
key = if context.size == 1 && context[0].is_a?(Hash)
context[0].slice(*var_names).values
else
context
end
@vals[key]
end
- def self.from_distribution(var_distribution)
- self.class.new(var_distribution.keys, var_distribution.values.map(&:to_a))
- end
-
+ # returns all combinations of values of `var_names`
def contextes(*var_names)
return [] if var_names.empty?
@scope[var_names[0]].product(*var_names[1..].map { |var_name| @scope[var_name] })
end
+ # returns all possible values
def values
@vals.values
end
+ # returns new normalized factor, i.e. where sum of all values is 1.0
def normalize
vals = @vals.clone
norm_factor = vals.map(&:last).sum * 1.0
vals.each { |k, _v| vals[k] /= norm_factor }
self.class.new(@scope.clone, vals)
end
+ # Returns factor built as follows:
+ # 1. Original factor gets filtered out by variables having values compatible with `context`
+ # 2. Returned factor does not have any variables from `context` (because they have
+ # same values, after step 1)
+ # The `context` argument supposed to be an evidence, somewhat like
+ # `{weather: :sunny}`
def reduce_to(context)
- # TODO: use Hash#except when Ruby 2.6 support no longer needed
- context_keys_set = context.keys.to_set
- scope = @scope.reject { |k, _| context_keys_set.include?(k) }
+ limited_context = context.slice(*scope.keys)
+ return self.class.new(@scope, @vals) if limited_context.empty?
+ limited_scope = @scope.slice(*(@scope.keys - limited_context.keys))
- context_vals = context.values
- indices = context.keys.map { |k| index_by_var_name[k] }
+ context_vals = limited_context.values
+ indices = limited_context.keys.map { |k| index_by_var_name[k] }
vals = @vals.select { |k, _v| indices.map { |i| k[i] } == context_vals }
vals.transform_keys! { |k| delete_by_indices(k, indices) }
- self.class.new(scope, vals)
+ self.class.new(limited_scope, vals)
end
- def delete_by_indices(array, indices)
- result = array.dup
- indices.map { |i| result[i] = nil }
- result.compact
- end
-
- # groups by `var_names` having same context and sum out values.
+ # Returns new context defined over `var_names`, all other variables
+ # get eliminated. For every combination of `var_names`'s values
+ # the value of new factor is defined by summing up values in original factor
+ # having compatible value
def marginalize(var_names)
scope = @scope.slice(*var_names)
indices = scope.keys.map { |k| index_by_var_name[k] }
vals = @vals.group_by { |context, _val| indices.map { |i| context[i] } }
vals.transform_values! { |v| v.map(&:last).sum }
self.class.new(scope, vals)
end
+ def eliminate(var_name)
+ keep_var_names = var_names
+ keep_var_names.delete(var_name)
+ marginalize(keep_var_names)
+ end
+
+ def select(subcontext)
+ @vals.select do |context, _|
+ var_names.zip(context).slice(subcontext.keys) == subcontext
+ end
+ end
+
+ def *(other)
+ common_scope = @scope.keys & other.scope.keys
+ new_scope = scope.merge(other.scope)
+ new_vals = {}
+ group1 = group_by_scope_values(common_scope)
+ group2 = other.group_by_scope_values(common_scope)
+ group1.each do |scope, vals1|
+ combo = vals1.product(group2[scope])
+ combo.each do |(val1, val2)|
+ # values in scope must match variables order in new_scope, i.e.
+ # they must match `new_scope.var_names`
+ # The code bellow ensures it by merging two hashes in the same
+ # wasy as `new_scope`` is constructed above
+ val_by_name1 = var_names.zip(val1.first).to_h
+ val_by_name2 = other.var_names.zip(val2.first).to_h
+ new_vals[val_by_name1.merge(val_by_name2).values] = val1.last*val2.last
+ end
+ end
+ Factor.new(new_scope, new_vals)
+ end
+
+ def group_by_scope_values(scope_keys)
+ indices = scope_keys.map { |k| index_by_var_name[k] }
+ @vals.group_by { |context, _val| indices.map { |i| context[i] } }
+ end
+
private
+
+ def delete_by_indices(array, indices)
+ result = array.dup
+ indices.map { |i| result[i] = nil }
+ result.compact
+ end
def initialize(scope = {}, vals = {})
@scope = scope
@vals = vals
end