/** * Copyright (c) 2016-present, Facebook, Inc. * All rights reserved. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include #include #include #include #include "matrix.h" #include "real.h" #include "utils.h" #include "vector.h" namespace fasttext { class Loss; class Model { protected: std::shared_ptr wi_; std::shared_ptr wo_; std::shared_ptr loss_; bool normalizeGradient_; public: Model( std::shared_ptr wi, std::shared_ptr wo, std::shared_ptr loss, bool normalizeGradient); Model(const Model& model) = delete; Model(Model&& model) = delete; Model& operator=(const Model& other) = delete; Model& operator=(Model&& other) = delete; class State { private: real lossValue_; int64_t nexamples_; public: Vector hidden; Vector output; Vector grad; std::minstd_rand rng; State(int32_t hiddenSize, int32_t outputSize, int32_t seed); real getLoss() const; void incrementNExamples(real loss); }; void predict( const std::vector& input, int32_t k, real threshold, Predictions& heap, State& state) const; void update( const std::vector& input, const std::vector& targets, int32_t targetIndex, real lr, State& state); void computeHidden(const std::vector& input, State& state) const; real std_log(real) const; static const int32_t kUnlimitedPredictions = -1; static const int32_t kAllLabelsAsTarget = -1; }; } // namespace fasttext