# pattern-match.rb
#
# Copyright (C) 2012-2013 Kazuki Tsujimoto, All rights reserved.

require 'pattern-match/version'

module PatternMatch
  module Deconstructable
    def call(*subpatterns)
      if Object == self
        raise MalformedPatternError unless subpatterns.length == 1
        PatternObject.new(subpatterns[0])
      elsif Hash == self
        raise MalformedPatternError unless subpatterns.length == 1
        PatternHash.new(subpatterns[0])
      else
        PatternDeconstructor.new(self, *subpatterns)
      end
    end
  end

  class Pattern
    attr_accessor :parent, :next, :prev

    def initialize(*subpatterns)
      @parent = nil
      @next = nil
      @prev = nil
      @subpatterns = subpatterns.map {|i| i.kind_of?(Pattern) ? i : PatternValue.new(i) }
      set_subpatterns_relation
    end

    def vars
      @subpatterns.map(&:vars).flatten
    end

    def binding
      vars.each_with_object({}) {|v, h| h[v.name] = v.val }
    end

    def &(pattern)
      PatternAnd.new(self, pattern)
    end

    def |(pattern)
      PatternOr.new(self, pattern)
    end

    def !@
      PatternNot.new(self)
    end

    def to_a
      [self, PatternQuantifier.new(0)]
    end

    def quantified?
      @next.kind_of?(PatternQuantifier) || (root? ? false : @parent.quantified?)
    end

    def root?
      @parent == nil
    end

    def validate
      if root?
        dup_vars = vars - vars.uniq {|i| i.name }
        raise MalformedPatternError, "duplicate variables: #{dup_vars.map(&:name).join(', ')}" unless dup_vars.empty?
      end
      raise MalformedPatternError if @subpatterns.count {|i| i.kind_of?(PatternQuantifier) } > 1
      @subpatterns.each(&:validate)
    end

    private

    def set_subpatterns_relation
      @subpatterns.each do |i|
        i.parent = self
      end
      @subpatterns.each_cons(2) do |a, b|
        a.next = b
        b.prev = a
      end
    end

    def ancestors
      ary = []
      pat = self
      until pat == nil
        ary << pat
        pat = pat.parent
      end
      ary
    end
  end

  class PatternObject < Pattern
    def initialize(spec)
      super(*spec.values)
      @spec = spec.map {|k, pat| [k.to_proc, pat] }
    rescue
      raise MalformedPatternError
    end

    def match(val)
      @spec.all? {|k, pat| pat.match(k.(val)) rescue raise PatternNotMatch }
    end
  end

  class PatternHash < Pattern
    def initialize(spec)
      super(*spec.values)
      @spec = spec
    end

    def match(val)
      raise PatternNotMatch unless val.kind_of?(Hash)
      raise PatternNotMatch unless @spec.keys.all? {|k| val.has_key?(k) }
      @spec.all? {|k, pat| pat.match(val[k]) rescue raise PatternNotMatch }
    end
  end

  class PatternDeconstructor < Pattern
    def initialize(deconstructor, *subpatterns)
      super(*subpatterns)
      @deconstructor = deconstructor
    end

    def match(val)
      deconstructed_vals = @deconstructor.deconstruct(val)
      k = deconstructed_vals.length - (@subpatterns.length - 2)
      quantifier = @subpatterns.find {|i| i.kind_of?(PatternQuantifier) }
      if quantifier
        return false unless quantifier.min_k <= k
      else
        return false unless @subpatterns.length == deconstructed_vals.length
      end
      @subpatterns.flat_map do |pat|
        case
        when pat.next.kind_of?(PatternQuantifier)
          []
        when pat.kind_of?(PatternQuantifier)
          pat.prev.vars.each {|v| v.set_bind_to(pat) }
          Array.new(k, pat.prev)
        else
          [pat]
        end
      end.zip(deconstructed_vals).all? do |pat, v|
        pat.match(v)
      end
    end
  end

  class PatternQuantifier < Pattern
    attr_reader :min_k

    def initialize(min_k)
      super()
      @min_k = min_k
    end

    def match(val)
      raise PatternMatchError, 'must not happen'
    end

    def validate
      super
      raise MalformedPatternError unless @prev
      raise MalformedPatternError unless @parent.kind_of?(PatternDeconstructor)
    end
  end

  class PatternVariable < Pattern
    attr_reader :name, :val

    def initialize(name)
      super()
      @name = name
      @val = nil
      @bind_to = nil
    end

    def match(val)
      bind(val)
      true
    end

    def vars
      [self]
    end

    def set_bind_to(quantifier)
      if @val
        outer = @val
        (nest_level(quantifier) - 1).times do
          outer = outer[-1]
        end
        @bind_to = []
        outer << @bind_to
      else
        @val = @bind_to = []
      end
    end

    private

    def bind(val)
      if quantified?
        @bind_to << val
      else
        @val = val
      end
    end

    def nest_level(quantifier)
      qs = ancestors.map {|i| i.next.kind_of?(PatternQuantifier) ? i.next : nil }.find_all {|i| i }.reverse
      qs.index(quantifier) || (raise PatternMatchError)
    end
  end

  class PatternValue < Pattern
    def initialize(val, compare_by = :===)
      super()
      @val = val
      @compare_by = compare_by
    end

    def match(val)
      @val.__send__(@compare_by, val)
    end
  end

  class PatternAnd < Pattern
    def match(val)
      @subpatterns.all? {|i| i.match(val) }
    end
  end

  class PatternOr < Pattern
    def match(val)
      @subpatterns.find do |i|
        begin
          i.match(val)
        rescue PatternNotMatch
          false
        end
      end
    end

    def validate
      super
      raise MalformedPatternError unless vars.length == 0
    end
  end

  class PatternNot < Pattern
    def match(val)
      ! @subpatterns[0].match(val)
    rescue PatternNotMatch
      true
    end

    def validate
      super
      raise MalformedPatternError unless vars.length == 0
    end
  end

  class Env < BasicObject
    def initialize(ctx, val)
      @ctx = ctx
      @val = val
    end

    private

    def with(pat_or_val, guard_proc = nil, &block)
      pat = pat_or_val.kind_of?(Pattern) ? pat_or_val : PatternValue.new(pat_or_val)
      pat.validate
      if pat.match(@val) and (guard_proc ? with_tmpbinding(@ctx, pat.binding, &guard_proc) : true)
        ret = with_tmpbinding(@ctx, pat.binding, &block)
        ::Kernel.throw(:exit_match, ret)
      else
        nil
      end
    rescue PatternNotMatch
    end

    def guard(&block)
      block
    end

    def method_missing(name, *)
      case name.to_s
      when '___'
        PatternQuantifier.new(0)
      when /\A__(\d+)\z/
        PatternQuantifier.new($1.to_i)
      else
        PatternVariable.new(name)
      end
    end

    def _(*vals)
      case vals.length
      when 0
        uscore = PatternVariable.new(:_)
        class << uscore
          def [](*args)
            Array.call(*args)
          end

          def match(val)
            true
          end

          def vars
            []
          end
        end
        uscore
      when 1
        PatternValue.new(vals[0])
      when 2
        PatternValue.new(vals[0], vals[1])
      else
        raise MalformedPatternError
      end
    end

    alias __ _
    alias _l _

    def with_tmpbinding(obj, binding, &block)
      tmpbinding_module(obj).instance_eval do
        begin
          binding.each do |name, val|
            stack = @stacks[name]
            if stack.empty?
              define_method(name) { stack[-1] }
              private name
            end
            stack.push(val)
          end
          obj.instance_eval(&block)
        ensure
          binding.each do |name, _|
            if @stacks[name].tap(&:pop).empty?
              remove_method(name)
            end
          end
        end
      end
    end

    class TmpBindingModule < ::Module
    end

    def tmpbinding_module(obj)
      m = obj.singleton_class.ancestors.find {|i| i.kind_of?(TmpBindingModule) }
      unless m
        m = TmpBindingModule.new
        m.instance_eval do
          @stacks = ::Hash.new {|h, k| h[k] = [] }
        end
        obj.singleton_class.class_eval do
          if respond_to?(:prepend, true)
            prepend m
          else
            include m
          end
        end
      end
      m
    end
  end

  class PatternNotMatch < Exception; end
  class PatternMatchError < StandardError; end
  class NoMatchingPatternError < PatternMatchError; end
  class MalformedPatternError < PatternMatchError; end

  # Make Pattern and its subclasses/Env private.
  if respond_to?(:private_constant)
    constants.each do |c|
      klass = const_get(c)
      next unless klass.kind_of?(Class)
      if klass <= Pattern
        private_constant c
      end
    end
    private_constant :Env
  end
end

module Kernel
  private

  def match(*vals, &block)
    do_match = Proc.new do |val|
      env = PatternMatch.const_get(:Env).new(self, val)
      catch(:exit_match) do
        env.instance_eval(&block)
        raise ::PatternMatch::NoMatchingPatternError
      end
    end
    case vals.length
    when 0
      do_match
    when 1
      do_match.(vals[0])
    else
      raise ArgumentError, "wrong number of arguments (#{vals.length} for 0..1)"
    end
  end
end

class Class
  include PatternMatch::Deconstructable

  def deconstruct(val)
    raise NotImplementedError, "need to define `#{__method__}'"
  end

  private

  def accept_self_instance_only(val)
    raise PatternMatch::PatternNotMatch unless val.kind_of?(self)
  end
end

class << Array
  def deconstruct(val)
    accept_self_instance_only(val)
    val
  end
end

class << Struct
  def deconstruct(val)
    accept_self_instance_only(val)
    val.values
  end
end

class << Complex
  def deconstruct(val)
    accept_self_instance_only(val)
    val.rect
  end
end

class << Rational
  def deconstruct(val)
    accept_self_instance_only(val)
    [val.numerator, val.denominator]
  end
end

class << MatchData
  def deconstruct(val)
    accept_self_instance_only(val)
    val.captures.empty? ? [val[0]] : val.captures
  end
end

class Regexp
  include PatternMatch::Deconstructable

  def deconstruct(val)
    m = Regexp.new("\\A#{source}\\z", options).match(val.to_s)
    raise PatternMatch::PatternNotMatch unless m
    m.captures.empty? ? [m[0]] : m.captures
  end
end

class Symbol
  def call(*args)
    Proc.new {|obj| obj.__send__(self, *args) }
  end
end