vendor/tomotopy/src/Utils/Dictionary.h in tomoto-0.1.2 vs vendor/tomotopy/src/Utils/Dictionary.h in tomoto-0.1.3

- old
+ new

@@ -9,13 +9,19 @@ #include "serializer.hpp" namespace tomoto { using Vid = uint32_t; + static constexpr Vid non_vocab_id = (Vid)-1; using Tid = uint16_t; using Float = float; + struct VidPair : public std::pair<Vid, Vid> + { + using std::pair<Vid, Vid>::pair; + }; + class Dictionary { protected: std::unordered_map<std::string, Vid> dict; std::vector<std::string> id2word; @@ -32,20 +38,20 @@ return it->second; } size_t size() const { return dict.size(); } - std::string toWord(Vid vid) const + const std::string& toWord(Vid vid) const { assert(vid < id2word.size()); return id2word[vid]; } Vid toWid(const std::string& word) const { auto it = dict.find(word); - if (it == dict.end()) return (Vid)-1; + if (it == dict.end()) return non_vocab_id; return it->second; } void serializerWrite(std::ostream& writer) const { @@ -73,8 +79,40 @@ { p.second = order[p.second]; id2word[p.second] = p.first; } } + + const std::vector<std::string>& getRaw() const + { + return id2word; + } + + Vid mapToNewDict(Vid v, const Dictionary& newDict) const + { + return newDict.toWid(toWord(v)); + } + + std::vector<Vid> mapToNewDict(const std::vector<Vid>& v, const Dictionary& newDict) const + { + std::vector<Vid> r(v.size()); + for (size_t i = 0; i < v.size(); ++i) + { + r[i] = mapToNewDict(v[i], newDict); + } + return r; + } }; } + +namespace std +{ + template<> + struct hash<tomoto::VidPair> + { + size_t operator()(const tomoto::VidPair& p) const + { + return hash<size_t>{}(p.first) ^ hash<size_t>{}(p.second); + } + }; +} \ No newline at end of file