// Copyright (C) 2011  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#include <sstream>
#include <string>
#include <cstdlib>
#include <ctime>
#include <dlib/optimization.h>
#include <dlib/rand.h>

#include "tester.h"

namespace  
{
    using namespace test;
    using namespace dlib;
    using namespace std;

    logger dlog("test.find_max_factor_graph_viterbi");

// ----------------------------------------------------------------------------------------

    dlib::rand rnd;

// ----------------------------------------------------------------------------------------

    template <
        unsigned long O,
        unsigned long NS,
        unsigned long num_nodes,
        bool all_negative 
        >
    class map_problem
    {
    public:
        unsigned long order() const { return O; }
        unsigned long num_states() const { return NS; }

        map_problem()
        {
            data = randm(number_of_nodes(),(long)std::pow(num_states(),(double)order()+1), rnd);
            if (all_negative)
                data = -data;
        }

        unsigned long number_of_nodes (
        ) const
        {
            return num_nodes;
        }

        template <
            typename EXP 
            >
        double factor_value (
            unsigned long node_id,
            const matrix_exp<EXP>& node_states
        ) const
        {
            if (node_states.size() == 1)
                return data(node_id, node_states(0));
            else if (node_states.size() == 2)
                return data(node_id, node_states(0) + node_states(1)*NS);
            else if (node_states.size() == 3)
                return data(node_id, (node_states(0) + node_states(1)*NS)*NS + node_states(2));
            else 
                return data(node_id, ((node_states(0) + node_states(1)*NS)*NS + node_states(2))*NS + node_states(3));
        }

        matrix<double> data;
    };


// ----------------------------------------------------------------------------------------

    template <
        typename map_problem
        >
    void brute_force_find_max_factor_graph_viterbi (
        const map_problem& prob,
        std::vector<unsigned long>& map_assignment
    )
    {
        using namespace dlib::impl;
        const int order = prob.order();
        const int num_states = prob.num_states();

        map_assignment.resize(prob.number_of_nodes());
        double best_score = -std::numeric_limits<double>::infinity();
        matrix<unsigned long,1,0> node_states;
        node_states.set_size(prob.number_of_nodes());
        node_states = 0;
        do
        {
            double score = 0;
            for (unsigned long i = 0; i < prob.number_of_nodes(); ++i)
            {
                score += prob.factor_value(i, (colm(node_states,range(i,i-std::min<int>(order,i)))));
            }

            if (score > best_score)
            {
                for (unsigned long i = 0; i < map_assignment.size(); ++i)
                    map_assignment[i] = node_states(i);
                best_score = score;
            }

        } while(advance_state(node_states,num_states));

    }

// ----------------------------------------------------------------------------------------

    template <
        unsigned long order,
        unsigned long num_states,
        unsigned long num_nodes,
        bool all_negative
        >
    void do_test_()
    {
        dlog << LINFO << "order: "<< order 
                      << "  num_states:   " << num_states
                      << "  num_nodes:    " << num_nodes
                      << "  all_negative: " << all_negative;

        for (int i = 0; i < 25; ++i)
        {
            print_spinner();
            map_problem<order,num_states,num_nodes,all_negative> prob;
            std::vector<unsigned long> assign, assign2;
            brute_force_find_max_factor_graph_viterbi(prob, assign);
            find_max_factor_graph_viterbi(prob, assign2);

            DLIB_TEST_MSG(mat(assign) == mat(assign2),
                          trans(mat(assign))
                          << trans(mat(assign2))
                          );
        }
    }

    template <
        unsigned long order,
        unsigned long num_states,
        unsigned long num_nodes
        >
    void do_test()
    {
        do_test_<order,num_states,num_nodes,false>();
    }

    template <
        unsigned long order,
        unsigned long num_states,
        unsigned long num_nodes
        >
    void do_test_negative()
    {
        do_test_<order,num_states,num_nodes,true>();
    }

// ----------------------------------------------------------------------------------------

    class test_find_max_factor_graph_viterbi : public tester
    {
    public:
        test_find_max_factor_graph_viterbi (
        ) :
            tester ("test_find_max_factor_graph_viterbi",
                    "Runs tests on the find_max_factor_graph_viterbi routine.")
        {}

        void perform_test (
        )
        {
            do_test<1,3,0>();
            do_test<1,3,1>();
            do_test<1,3,2>();
            do_test<0,3,2>();
            do_test_negative<0,3,2>();

            do_test<1,3,8>();
            do_test<2,3,7>();
            do_test_negative<2,3,7>();
            do_test<3,3,8>();
            do_test<4,3,8>();
            do_test_negative<4,3,8>();
            do_test<0,3,8>();
            do_test<4,3,1>();
            do_test<4,3,0>();

            do_test<3,2,1>();
            do_test<3,2,0>();
            do_test<3,2,2>();
            do_test<2,2,1>();
            do_test_negative<3,2,1>();
            do_test_negative<3,2,0>();
            do_test_negative<3,2,2>();
            do_test_negative<2,2,1>();

            do_test<0,3,0>();
            do_test<1,2,8>();
            do_test<2,2,7>();
            do_test<3,2,8>();
            do_test<0,2,8>();

            do_test<1,1,8>();
            do_test<2,1,8>();
            do_test<3,1,8>();
            do_test<0,1,8>();
        }
    } a;

}