/* Copyright (c) 2012-2017 The ANTLR Project. All rights reserved. * Use of this file is governed by the BSD 3-clause license that * can be found in the LICENSE.txt file in the project root. */ #include "tree/pattern/ParseTreePattern.h" #include "tree/pattern/ParseTreeMatch.h" #include "tree/TerminalNode.h" #include "CommonTokenStream.h" #include "ParserInterpreter.h" #include "tree/pattern/TokenTagToken.h" #include "ParserRuleContext.h" #include "tree/pattern/RuleTagToken.h" #include "tree/pattern/TagChunk.h" #include "atn/ATN.h" #include "Lexer.h" #include "BailErrorStrategy.h" #include "ListTokenSource.h" #include "tree/pattern/TextChunk.h" #include "ANTLRInputStream.h" #include "support/Arrays.h" #include "Exceptions.h" #include "support/CPPUtils.h" #include "tree/pattern/ParseTreePatternMatcher.h" using namespace antlr4; using namespace antlr4::tree; using namespace antlr4::tree::pattern; using namespace antlrcpp; ParseTreePatternMatcher::CannotInvokeStartRule::CannotInvokeStartRule(const RuntimeException &e) : RuntimeException(e.what()) { } ParseTreePatternMatcher::CannotInvokeStartRule::~CannotInvokeStartRule() { } ParseTreePatternMatcher::StartRuleDoesNotConsumeFullPattern::~StartRuleDoesNotConsumeFullPattern() { } ParseTreePatternMatcher::ParseTreePatternMatcher(Lexer *lexer, Parser *parser) : _lexer(lexer), _parser(parser) { InitializeInstanceFields(); } ParseTreePatternMatcher::~ParseTreePatternMatcher() { } void ParseTreePatternMatcher::setDelimiters(const std::string &start, const std::string &stop, const std::string &escapeLeft) { if (start.empty()) { throw IllegalArgumentException("start cannot be null or empty"); } if (stop.empty()) { throw IllegalArgumentException("stop cannot be null or empty"); } _start = start; _stop = stop; _escape = escapeLeft; } bool ParseTreePatternMatcher::matches(ParseTree *tree, const std::string &pattern, int patternRuleIndex) { ParseTreePattern p = compile(pattern, patternRuleIndex); return matches(tree, p); } bool ParseTreePatternMatcher::matches(ParseTree *tree, const ParseTreePattern &pattern) { std::map> labels; ParseTree *mismatchedNode = matchImpl(tree, pattern.getPatternTree(), labels); return mismatchedNode == nullptr; } ParseTreeMatch ParseTreePatternMatcher::match(ParseTree *tree, const std::string &pattern, int patternRuleIndex) { ParseTreePattern p = compile(pattern, patternRuleIndex); return match(tree, p); } ParseTreeMatch ParseTreePatternMatcher::match(ParseTree *tree, const ParseTreePattern &pattern) { std::map> labels; tree::ParseTree *mismatchedNode = matchImpl(tree, pattern.getPatternTree(), labels); return ParseTreeMatch(tree, pattern, labels, mismatchedNode); } ParseTreePattern ParseTreePatternMatcher::compile(const std::string &pattern, int patternRuleIndex) { ListTokenSource tokenSrc(tokenize(pattern)); CommonTokenStream tokens(&tokenSrc); ParserInterpreter parserInterp(_parser->getGrammarFileName(), _parser->getVocabulary(), _parser->getRuleNames(), _parser->getATNWithBypassAlts(), &tokens); ParserRuleContext *tree = nullptr; try { parserInterp.setErrorHandler(std::make_shared()); tree = parserInterp.parse(patternRuleIndex); } catch (ParseCancellationException &e) { #if defined(_MSC_FULL_VER) && _MSC_FULL_VER < 190023026 // rethrow_if_nested is not available before VS 2015. throw e; #else std::rethrow_if_nested(e); // Unwrap the nested exception. #endif } catch (RecognitionException &re) { throw re; #if defined(_MSC_FULL_VER) && _MSC_FULL_VER < 190023026 } catch (std::exception &e) { // throw_with_nested is not available before VS 2015. throw e; #else } catch (std::exception & /*e*/) { std::throw_with_nested(RuntimeException("Cannot invoke start rule")); // Wrap any other exception. #endif } // Make sure tree pattern compilation checks for a complete parse if (tokens.LA(1) != Token::EOF) { throw StartRuleDoesNotConsumeFullPattern(); } return ParseTreePattern(this, pattern, patternRuleIndex, tree); } Lexer* ParseTreePatternMatcher::getLexer() { return _lexer; } Parser* ParseTreePatternMatcher::getParser() { return _parser; } ParseTree* ParseTreePatternMatcher::matchImpl(ParseTree *tree, ParseTree *patternTree, std::map> &labels) { if (tree == nullptr) { throw IllegalArgumentException("tree cannot be nul"); } if (patternTree == nullptr) { throw IllegalArgumentException("patternTree cannot be nul"); } // x and , x and y, or x and x; or could be mismatched types if (is(tree) && is(patternTree)) { TerminalNode *t1 = dynamic_cast(tree); TerminalNode *t2 = dynamic_cast(patternTree); ParseTree *mismatchedNode = nullptr; // both are tokens and they have same type if (t1->getSymbol()->getType() == t2->getSymbol()->getType()) { if (is(t2->getSymbol())) { // x and TokenTagToken *tokenTagToken = dynamic_cast(t2->getSymbol()); // track label->list-of-nodes for both token name and label (if any) labels[tokenTagToken->getTokenName()].push_back(tree); if (tokenTagToken->getLabel() != "") { labels[tokenTagToken->getLabel()].push_back(tree); } } else if (t1->getText() == t2->getText()) { // x and x } else { // x and y if (mismatchedNode == nullptr) { mismatchedNode = t1; } } } else { if (mismatchedNode == nullptr) { mismatchedNode = t1; } } return mismatchedNode; } if (is(tree) && is(patternTree)) { ParserRuleContext *r1 = dynamic_cast(tree); ParserRuleContext *r2 = dynamic_cast(patternTree); ParseTree *mismatchedNode = nullptr; // (expr ...) and RuleTagToken *ruleTagToken = getRuleTagToken(r2); if (ruleTagToken != nullptr) { //ParseTreeMatch *m = nullptr; // unused? if (r1->getRuleIndex() == r2->getRuleIndex()) { // track label->list-of-nodes for both rule name and label (if any) labels[ruleTagToken->getRuleName()].push_back(tree); if (ruleTagToken->getLabel() != "") { labels[ruleTagToken->getLabel()].push_back(tree); } } else { if (!mismatchedNode) { mismatchedNode = r1; } } return mismatchedNode; } // (expr ...) and (expr ...) if (r1->children.size() != r2->children.size()) { if (mismatchedNode == nullptr) { mismatchedNode = r1; } return mismatchedNode; } std::size_t n = r1->children.size(); for (size_t i = 0; i < n; i++) { ParseTree *childMatch = matchImpl(r1->children[i], patternTree->children[i], labels); if (childMatch) { return childMatch; } } return mismatchedNode; } // if nodes aren't both tokens or both rule nodes, can't match return tree; } RuleTagToken* ParseTreePatternMatcher::getRuleTagToken(ParseTree *t) { if (t->children.size() == 1 && is(t->children[0])) { TerminalNode *c = dynamic_cast(t->children[0]); if (is(c->getSymbol())) { return dynamic_cast(c->getSymbol()); } } return nullptr; } std::vector> ParseTreePatternMatcher::tokenize(const std::string &pattern) { // split pattern into chunks: sea (raw input) and islands (, ) std::vector chunks = split(pattern); // create token stream from text and tags std::vector> tokens; for (auto chunk : chunks) { if (is(&chunk)) { TagChunk &tagChunk = (TagChunk&)chunk; // add special rule token or conjure up new token from name if (isupper(tagChunk.getTag()[0])) { size_t ttype = _parser->getTokenType(tagChunk.getTag()); if (ttype == Token::INVALID_TYPE) { throw IllegalArgumentException("Unknown token " + tagChunk.getTag() + " in pattern: " + pattern); } tokens.emplace_back(new TokenTagToken(tagChunk.getTag(), (int)ttype, tagChunk.getLabel())); } else if (islower(tagChunk.getTag()[0])) { size_t ruleIndex = _parser->getRuleIndex(tagChunk.getTag()); if (ruleIndex == INVALID_INDEX) { throw IllegalArgumentException("Unknown rule " + tagChunk.getTag() + " in pattern: " + pattern); } size_t ruleImaginaryTokenType = _parser->getATNWithBypassAlts().ruleToTokenType[ruleIndex]; tokens.emplace_back(new RuleTagToken(tagChunk.getTag(), ruleImaginaryTokenType, tagChunk.getLabel())); } else { throw IllegalArgumentException("invalid tag: " + tagChunk.getTag() + " in pattern: " + pattern); } } else { TextChunk &textChunk = (TextChunk&)chunk; ANTLRInputStream input(textChunk.getText()); _lexer->setInputStream(&input); std::unique_ptr t(_lexer->nextToken()); while (t->getType() != Token::EOF) { tokens.push_back(std::move(t)); t = _lexer->nextToken(); } _lexer->setInputStream(nullptr); } } return tokens; } std::vector ParseTreePatternMatcher::split(const std::string &pattern) { size_t p = 0; size_t n = pattern.length(); std::vector chunks; // find all start and stop indexes first, then collect std::vector starts; std::vector stops; while (p < n) { if (p == pattern.find(_escape + _start,p)) { p += _escape.length() + _start.length(); } else if (p == pattern.find(_escape + _stop,p)) { p += _escape.length() + _stop.length(); } else if (p == pattern.find(_start,p)) { starts.push_back(p); p += _start.length(); } else if (p == pattern.find(_stop,p)) { stops.push_back(p); p += _stop.length(); } else { p++; } } if (starts.size() > stops.size()) { throw IllegalArgumentException("unterminated tag in pattern: " + pattern); } if (starts.size() < stops.size()) { throw IllegalArgumentException("missing start tag in pattern: " + pattern); } size_t ntags = starts.size(); for (size_t i = 0; i < ntags; i++) { if (starts[i] >= stops[i]) { throw IllegalArgumentException("tag delimiters out of order in pattern: " + pattern); } } // collect into chunks now if (ntags == 0) { std::string text = pattern.substr(0, n); chunks.push_back(TextChunk(text)); } if (ntags > 0 && starts[0] > 0) { // copy text up to first tag into chunks std::string text = pattern.substr(0, starts[0]); chunks.push_back(TextChunk(text)); } for (size_t i = 0; i < ntags; i++) { // copy inside of std::string tag = pattern.substr(starts[i] + _start.length(), stops[i] - (starts[i] + _start.length())); std::string ruleOrToken = tag; std::string label = ""; size_t colon = tag.find(':'); if (colon != std::string::npos) { label = tag.substr(0,colon); ruleOrToken = tag.substr(colon + 1, tag.length() - (colon + 1)); } chunks.push_back(TagChunk(label, ruleOrToken)); if (i + 1 < ntags) { // copy from end of to start of next std::string text = pattern.substr(stops[i] + _stop.length(), starts[i + 1] - (stops[i] + _stop.length())); chunks.push_back(TextChunk(text)); } } if (ntags > 0) { size_t afterLastTag = stops[ntags - 1] + _stop.length(); if (afterLastTag < n) { // copy text from end of last tag to end std::string text = pattern.substr(afterLastTag, n - afterLastTag); chunks.push_back(TextChunk(text)); } } // strip out all backslashes from text chunks but not tags for (size_t i = 0; i < chunks.size(); i++) { Chunk &c = chunks[i]; if (is(&c)) { TextChunk &tc = (TextChunk&)c; std::string unescaped = tc.getText(); unescaped.erase(std::remove(unescaped.begin(), unescaped.end(), '\\'), unescaped.end()); if (unescaped.length() < tc.getText().length()) { chunks[i] = TextChunk(unescaped); } } } return chunks; } void ParseTreePatternMatcher::InitializeInstanceFields() { _start = "<"; _stop = ">"; _escape = "\\"; }