/**
 * Copyright (c) 2016-present, Facebook, Inc.
 * All rights reserved.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#pragma once

#include <istream>
#include <memory>
#include <random>
#include <thread>
#include <vector>

#include "args.h"
#include "fasttext.h"

namespace fasttext {

class AutotuneStrategy {
 private:
  Args bestArgs_;
  int maxDuration_;
  std::minstd_rand rng_;
  int trials_;
  int bestMinnIndex_;
  int bestDsubExponent_;
  int bestNonzeroBucket_;
  int originalBucket_;
  std::vector<int> minnChoices_;
  int getIndex(int val, const std::vector<int>& choices);

 public:
  explicit AutotuneStrategy(
      const Args& args,
      std::minstd_rand::result_type seed);
  Args ask(double elapsed);
  void updateBest(const Args& args);
};

class Autotune {
 protected:
  std::shared_ptr<FastText> fastText_;
  double elapsed_;
  double bestScore_;
  int32_t trials_;
  int32_t sizeConstraintFailed_;
  std::atomic<bool> continueTraining_;
  std::unique_ptr<AutotuneStrategy> strategy_;
  std::thread timer_;

  bool keepTraining(double maxDuration) const;
  void printInfo(double maxDuration);
  void timer(
      const std::chrono::steady_clock::time_point& start,
      double maxDuration);
  void abort();
  void startTimer(const Args& args);
  double getMetricScore(
      Meter& meter,
      const metric_name& metricName,
      const double metricValue,
      const std::string& metricLabel) const;
  void printArgs(const Args& args, const Args& autotuneArgs);
  void printSkippedArgs(const Args& autotuneArgs);
  bool quantize(Args& args, const Args& autotuneArgs);
  int getCutoffForFileSize(bool qout, bool qnorm, int dsub, int64_t fileSize)
      const;

  class TimeoutError : public std::runtime_error {
   public:
    TimeoutError() : std::runtime_error("Autotune timed out.") {}
  };

 public:
  Autotune() = delete;
  explicit Autotune(const std::shared_ptr<FastText>& fastText);
  Autotune(const Autotune&) = delete;
  Autotune(Autotune&&) = delete;
  Autotune& operator=(const Autotune&) = delete;
  Autotune& operator=(Autotune&&) = delete;
  ~Autotune() noexcept = default;

  void train(const Args& args);
};

} // namespace fasttext