// Copyright (C) 2012 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_FIND_MAX_PaRSE_CKY_Hh_ #define DLIB_FIND_MAX_PaRSE_CKY_Hh_ #include "find_max_parse_cky_abstract.h" #include <vector> #include <string> #include <sstream> #include "../serialize.h" #include "../array2d.h" namespace dlib { // ----------------------------------------------------------------------------------------- template <typename T> struct constituent { unsigned long begin, end, k; T left_tag; T right_tag; }; template <typename T> void serialize( const constituent<T>& item, std::ostream& out ) { serialize(item.begin, out); serialize(item.end, out); serialize(item.k, out); serialize(item.left_tag, out); serialize(item.right_tag, out); } template <typename T> void deserialize( constituent<T>& item, std::istream& in ) { deserialize(item.begin, in); deserialize(item.end, in); deserialize(item.k, in); deserialize(item.left_tag, in); deserialize(item.right_tag, in); } // ----------------------------------------------------------------------------------------- const unsigned long END_OF_TREE = 0xFFFFFFFF; // ----------------------------------------------------------------------------------------- template <typename T> struct parse_tree_element { constituent<T> c; T tag; // id for the constituent corresponding to this level of the tree unsigned long left; unsigned long right; double score; }; template <typename T> void serialize ( const parse_tree_element<T>& item, std::ostream& out ) { serialize(item.c, out); serialize(item.tag, out); serialize(item.left, out); serialize(item.right, out); serialize(item.score, out); } template <typename T> void deserialize ( parse_tree_element<T>& item, std::istream& in ) { deserialize(item.c, in); deserialize(item.tag, in); deserialize(item.left, in); deserialize(item.right, in); deserialize(item.score, in); } // ----------------------------------------------------------------------------------------- namespace impl { template <typename T> unsigned long fill_parse_tree( std::vector<parse_tree_element<T> >& parse_tree, const T& tag, const array2d<std::map<T, parse_tree_element<T> > >& back, long r, long c ) /*! requires - back[r][c].size() == 0 || back[r][c].count(tag) != 0 !*/ { // base case of the recursion if (back[r][c].size() == 0) { return END_OF_TREE; } const unsigned long idx = parse_tree.size(); const parse_tree_element<T>& item = back[r][c].find(tag)->second; parse_tree.push_back(item); const long k = item.c.k; const unsigned long idx_left = fill_parse_tree(parse_tree, item.c.left_tag, back, r, k-1); const unsigned long idx_right = fill_parse_tree(parse_tree, item.c.right_tag, back, k, c); parse_tree[idx].left = idx_left; parse_tree[idx].right = idx_right; return idx; } } template <typename T, typename production_rule_function> void find_max_parse_cky ( const std::vector<T>& sequence, const production_rule_function& production_rules, std::vector<parse_tree_element<T> >& parse_tree ) { parse_tree.clear(); if (sequence.size() == 0) return; array2d<std::map<T,double> > table(sequence.size(), sequence.size()); array2d<std::map<T,parse_tree_element<T> > > back(sequence.size(), sequence.size()); typedef typename std::map<T,double>::iterator itr; typedef typename std::map<T,parse_tree_element<T> >::iterator itr_b; for (long r = 0; r < table.nr(); ++r) table[r][r][sequence[r]] = 0; std::vector<std::pair<T,double> > possible_tags; for (long r = table.nr()-2; r >= 0; --r) { for (long c = r+1; c < table.nc(); ++c) { for (long k = r; k < c; ++k) { for (itr i = table[k+1][c].begin(); i != table[k+1][c].end(); ++i) { for (itr j = table[r][k].begin(); j != table[r][k].end(); ++j) { constituent<T> con; con.begin = r; con.end = c+1; con.k = k+1; con.left_tag = j->first; con.right_tag = i->first; possible_tags.clear(); production_rules(sequence, con, possible_tags); for (unsigned long m = 0; m < possible_tags.size(); ++m) { const double score = possible_tags[m].second + i->second + j->second; itr match = table[r][c].find(possible_tags[m].first); if (match == table[r][c].end() || score > match->second) { table[r][c][possible_tags[m].first] = score; parse_tree_element<T> item; item.c = con; item.score = score; item.tag = possible_tags[m].first; item.left = END_OF_TREE; item.right = END_OF_TREE; back[r][c][possible_tags[m].first] = item; } } } } } } } // now use back pointers to build the parse trees const long r = 0; const long c = back.nc()-1; if (back[r][c].size() != 0) { // find the max scoring element in back[r][c] itr_b max_i = back[r][c].begin(); itr_b i = max_i; ++i; for (; i != back[r][c].end(); ++i) { if (i->second.score > max_i->second.score) max_i = i; } parse_tree.reserve(c); impl::fill_parse_tree(parse_tree, max_i->second.tag, back, r, c); } } // ----------------------------------------------------------------------------------------- class parse_tree_to_string_error : public error { public: parse_tree_to_string_error(const std::string& str): error(str) {} }; namespace impl { template <bool enabled, typename T> typename enable_if_c<enabled>::type conditional_print( const T& item, std::ostream& out ) { out << item << " "; } template <bool enabled, typename T> typename disable_if_c<enabled>::type conditional_print( const T& , std::ostream& ) { } template <bool print_tag, bool skip_tag, typename T, typename U > void print_parse_tree_helper ( const std::vector<parse_tree_element<T> >& tree, const std::vector<U>& words, unsigned long i, const T& tag_to_skip, std::ostream& out ) { if (!skip_tag || tree[i].tag != tag_to_skip) out << "["; bool left_recurse = false; // Only print if we are supposed to. Doing it this funny way avoids compiler // errors in parse_tree_to_string() for the case where tag isn't // printable. if (!skip_tag || tree[i].tag != tag_to_skip) conditional_print<print_tag>(tree[i].tag, out); if (tree[i].left < tree.size()) { left_recurse = true; print_parse_tree_helper<print_tag,skip_tag>(tree, words, tree[i].left, tag_to_skip, out); } else { if ((tree[i].c.begin) < words.size()) { out << words[tree[i].c.begin] << " "; } else { std::ostringstream sout; sout << "Parse tree refers to element " << tree[i].c.begin << " of sequence which is only of size " << words.size() << "."; throw parse_tree_to_string_error(sout.str()); } } if (left_recurse == true) out << " "; if (tree[i].right < tree.size()) { print_parse_tree_helper<print_tag,skip_tag>(tree, words, tree[i].right, tag_to_skip, out); } else { if (tree[i].c.k < words.size()) { out << words[tree[i].c.k]; } else { std::ostringstream sout; sout << "Parse tree refers to element " << tree[i].c.k << " of sequence which is only of size " << words.size() << "."; throw parse_tree_to_string_error(sout.str()); } } if (!skip_tag || tree[i].tag != tag_to_skip) out << "]"; } } // ----------------------------------------------------------------------------------------- template <typename T, typename U> std::string parse_tree_to_string ( const std::vector<parse_tree_element<T> >& tree, const std::vector<U>& words, const unsigned long root_idx = 0 ) { if (root_idx >= tree.size()) return ""; std::ostringstream sout; impl::print_parse_tree_helper<false,false>(tree, words, root_idx, tree[root_idx].tag, sout); return sout.str(); } // ----------------------------------------------------------------------------------------- template <typename T, typename U> std::string parse_tree_to_string_tagged ( const std::vector<parse_tree_element<T> >& tree, const std::vector<U>& words, const unsigned long root_idx = 0 ) { if (root_idx >= tree.size()) return ""; std::ostringstream sout; impl::print_parse_tree_helper<true,false>(tree, words, root_idx, tree[root_idx].tag, sout); return sout.str(); } // ----------------------------------------------------------------------------------------- template <typename T, typename U> std::string parse_trees_to_string ( const std::vector<parse_tree_element<T> >& tree, const std::vector<U>& words, const T& tag_to_skip ) { if (tree.size() == 0) return ""; std::ostringstream sout; impl::print_parse_tree_helper<false,true>(tree, words, 0, tag_to_skip, sout); return sout.str(); } // ----------------------------------------------------------------------------------------- template <typename T, typename U> std::string parse_trees_to_string_tagged ( const std::vector<parse_tree_element<T> >& tree, const std::vector<U>& words, const T& tag_to_skip ) { if (tree.size() == 0) return ""; std::ostringstream sout; impl::print_parse_tree_helper<true,true>(tree, words, 0, tag_to_skip, sout); return sout.str(); } // ----------------------------------------------------------------------------------------- namespace impl { template <typename T> void helper_find_trees_without_tag ( const std::vector<parse_tree_element<T> >& tree, const T& tag, std::vector<unsigned long>& tree_roots, unsigned long idx ) { if (idx < tree.size()) { if (tree[idx].tag != tag) { tree_roots.push_back(idx); } else { helper_find_trees_without_tag(tree, tag, tree_roots, tree[idx].left); helper_find_trees_without_tag(tree, tag, tree_roots, tree[idx].right); } } } } template <typename T> void find_trees_not_rooted_with_tag ( const std::vector<parse_tree_element<T> >& tree, const T& tag, std::vector<unsigned long>& tree_roots ) { tree_roots.clear(); impl::helper_find_trees_without_tag(tree, tag, tree_roots, 0); } // ----------------------------------------------------------------------------------------- } #endif // DLIB_FIND_MAX_PaRSE_CKY_Hh_