// Copyright 2021 gRPC authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef GRPC_SRC_CORE_LIB_AVL_AVL_H #define GRPC_SRC_CORE_LIB_AVL_AVL_H #include #include // IWYU pragma: keep #include #include #include #include "src/core/lib/gprpp/ref_counted.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" #include "src/core/util/useful.h" namespace grpc_core { template class AVL { public: AVL() {} AVL Add(K key, V value) const { return AVL(AddKey(root_, std::move(key), std::move(value))); } template AVL Remove(const SomethingLikeK& key) const { return AVL(RemoveKey(root_, key)); } template const V* Lookup(const SomethingLikeK& key) const { NodePtr n = Get(root_, key); return n != nullptr ? &n->kv.second : nullptr; } const std::pair* LookupBelow(const K& key) const { NodePtr n = GetBelow(root_, *key); return n != nullptr ? &n->kv : nullptr; } bool Empty() const { return root_ == nullptr; } template void ForEach(F&& f) const { ForEachImpl(root_.get(), std::forward(f)); } bool SameIdentity(const AVL& avl) const { return root_ == avl.root_; } friend int QsortCompare(const AVL& left, const AVL& right) { if (left.root_.get() == right.root_.get()) return 0; Iterator a(left.root_); Iterator b(right.root_); for (;;) { Node* p = a.current(); Node* q = b.current(); if (p != q) { if (p == nullptr) return -1; if (q == nullptr) return 1; const int kv = QsortCompare(p->kv, q->kv); if (kv != 0) return kv; } else if (p == nullptr) { return 0; } a.MoveNext(); b.MoveNext(); } } bool operator==(const AVL& other) const { return QsortCompare(*this, other) == 0; } bool operator<(const AVL& other) const { return QsortCompare(*this, other) < 0; } size_t Height() const { if (root_ == nullptr) return 0; return root_->height; } private: struct Node; typedef RefCountedPtr NodePtr; struct Node : public RefCounted { Node(K k, V v, NodePtr l, NodePtr r, long h) : kv(std::move(k), std::move(v)), left(std::move(l)), right(std::move(r)), height(h) {} const std::pair kv; const NodePtr left; const NodePtr right; const long height; }; NodePtr root_; class IteratorStack { public: void Push(Node* n) { nodes_[depth_] = n; ++depth_; } Node* Pop() { --depth_; return nodes_[depth_]; } Node* Back() const { return nodes_[depth_ - 1]; } bool Empty() const { return depth_ == 0; } private: size_t depth_{0}; // 32 is the maximum depth we can accept, and corresponds to ~4billion nodes // - which ought to suffice our use cases. Node* nodes_[32]; }; class Iterator { public: explicit Iterator(const NodePtr& root) { auto* n = root.get(); while (n != nullptr) { stack_.Push(n); n = n->left.get(); } } Node* current() const { return stack_.Empty() ? nullptr : stack_.Back(); } void MoveNext() { auto* n = stack_.Pop(); if (n->right != nullptr) { n = n->right.get(); while (n != nullptr) { stack_.Push(n); n = n->left.get(); } } } private: IteratorStack stack_; }; explicit AVL(NodePtr root) : root_(std::move(root)) {} template static void ForEachImpl(const Node* n, F&& f) { if (n == nullptr) return; ForEachImpl(n->left.get(), std::forward(f)); f(const_cast(n->kv.first), const_cast(n->kv.second)); ForEachImpl(n->right.get(), std::forward(f)); } static long Height(const NodePtr& n) { return n != nullptr ? n->height : 0; } static NodePtr MakeNode(K key, V value, const NodePtr& left, const NodePtr& right) { return MakeRefCounted(std::move(key), std::move(value), left, right, 1 + std::max(Height(left), Height(right))); } template static NodePtr Get(const NodePtr& node, const SomethingLikeK& key) { if (node == nullptr) { return nullptr; } if (node->kv.first > key) { return Get(node->left, key); } else if (node->kv.first < key) { return Get(node->right, key); } else { return node; } } static NodePtr GetBelow(const NodePtr& node, const K& key) { if (!node) return nullptr; if (node->kv.first > key) { return GetBelow(node->left, key); } else if (node->kv.first < key) { NodePtr n = GetBelow(node->right, key); if (n == nullptr) n = node; return n; } else { return node; } } static NodePtr RotateLeft(K key, V value, const NodePtr& left, const NodePtr& right) { return MakeNode( right->kv.first, right->kv.second, MakeNode(std::move(key), std::move(value), left, right->left), right->right); } static NodePtr RotateRight(K key, V value, const NodePtr& left, const NodePtr& right) { return MakeNode( left->kv.first, left->kv.second, left->left, MakeNode(std::move(key), std::move(value), left->right, right)); } static NodePtr RotateLeftRight(K key, V value, const NodePtr& left, const NodePtr& right) { // rotate_right(..., rotate_left(left), right) return MakeNode( left->right->kv.first, left->right->kv.second, MakeNode(left->kv.first, left->kv.second, left->left, left->right->left), MakeNode(std::move(key), std::move(value), left->right->right, right)); } static NodePtr RotateRightLeft(K key, V value, const NodePtr& left, const NodePtr& right) { // rotate_left(..., left, rotate_right(right)) return MakeNode( right->left->kv.first, right->left->kv.second, MakeNode(std::move(key), std::move(value), left, right->left->left), MakeNode(right->kv.first, right->kv.second, right->left->right, right->right)); } static NodePtr Rebalance(K key, V value, const NodePtr& left, const NodePtr& right) { switch (Height(left) - Height(right)) { case 2: if (Height(left->left) - Height(left->right) == -1) { return RotateLeftRight(std::move(key), std::move(value), left, right); } else { return RotateRight(std::move(key), std::move(value), left, right); } case -2: if (Height(right->left) - Height(right->right) == 1) { return RotateRightLeft(std::move(key), std::move(value), left, right); } else { return RotateLeft(std::move(key), std::move(value), left, right); } default: return MakeNode(key, value, left, right); } } static NodePtr AddKey(const NodePtr& node, K key, V value) { if (node == nullptr) { return MakeNode(std::move(key), std::move(value), nullptr, nullptr); } if (node->kv.first < key) { return Rebalance(node->kv.first, node->kv.second, node->left, AddKey(node->right, std::move(key), std::move(value))); } if (key < node->kv.first) { return Rebalance(node->kv.first, node->kv.second, AddKey(node->left, std::move(key), std::move(value)), node->right); } return MakeNode(std::move(key), std::move(value), node->left, node->right); } static NodePtr InOrderHead(NodePtr node) { while (node->left != nullptr) { node = node->left; } return node; } static NodePtr InOrderTail(NodePtr node) { while (node->right != nullptr) { node = node->right; } return node; } template static NodePtr RemoveKey(const NodePtr& node, const SomethingLikeK& key) { if (node == nullptr) { return nullptr; } if (key < node->kv.first) { return Rebalance(node->kv.first, node->kv.second, RemoveKey(node->left, key), node->right); } else if (node->kv.first < key) { return Rebalance(node->kv.first, node->kv.second, node->left, RemoveKey(node->right, key)); } else { if (node->left == nullptr) { return node->right; } else if (node->right == nullptr) { return node->left; } else if (node->left->height < node->right->height) { NodePtr h = InOrderHead(node->right); return Rebalance(h->kv.first, h->kv.second, node->left, RemoveKey(node->right, h->kv.first)); } else { NodePtr h = InOrderTail(node->left); return Rebalance(h->kv.first, h->kv.second, RemoveKey(node->left, h->kv.first), node->right); } } abort(); } }; } // namespace grpc_core #endif // GRPC_SRC_CORE_LIB_AVL_AVL_H