
#pragma once

#include "cherrypi.h"
#include "utils.h"
#include "common/rand.h"

#include <fmt/printf.h>

#include <cctype>
#include <cstring>

namespace cherrypi {


struct Bandit {

  std::unordered_map<std::string_view, float> weights;
  std::list<std::string> stringContainer;

  float& weight(std::string_view n) {
    auto i = weights.find(n);
    if (i != weights.end()) return i->second;
    VLOG(0) << " weight " << n << " was not found";
    stringContainer.emplace_back(n);
    return weights[stringContainer.back()];
  }

  std::string escapeFilename(std::string_view str) {
    std::string r;
    const char* c = str.data();
    const char* e = c + str.size();
    const char* lc = c;
    auto flush = [&]() {
      r.append(lc, c - lc);
      lc = c;
    };
    while (c != e) {
      if (*c == '-' || *c == '.' || (*c >= 'A' && *c <= 'Z') || (*c >= 'a' && *c <= 'z') || (*c >= '0' && *c <= '9')) ++c;
      else {
        flush();
        ++lc;
        char buf[3];
        buf[0] = '_';
        buf[1] = *c >> 4 & 0xf;
        if (buf[1] < 10) buf[1] += '0';
        else buf[1] += -10 + 'a';
        buf[2] = *c & 0xf;
        if (buf[2] < 10) buf[2] += '0';
        else buf[2] += -10 + 'a';
        r.append(buf, 3);
        ++c;
      }
    }
    flush();
    return r;
  }

  std::string escape(std::string_view str) {
    bool quote = false;
    std::string r = "\"";
    const char* c = str.data();
    const char* e = c + str.size();
    const char* lc = c;
    auto flush = [&]() {
      r.append(lc, c - lc);
      lc = c;
    };
    while (c != e) {
      if (*c >= 32 && *c < 127 && *c != '"') {
        if (*c == ' ') {
          quote = true;
        }
        ++c;
      } else {
        quote = true;
        flush();
        ++lc;
        char buf[3];
        buf[0] = '\\x';
        buf[1] = *c >> 4 & 0xf;
        if (buf[1] < 10) buf[1] += '0';
        else buf[1] += -10 + 'a';
        buf[2] = *c & 0xf;
        if (buf[2] < 10) buf[2] += '0';
        else buf[2] += -10 + 'a';
        r.append(buf, 3);
        ++c;
      }
    }
    if (!quote) return std::string(str);
    flush();
    return r + '"';
  }

  std::vector<std::string> tokenlist;

  void write(const State* state, bool won) {
    std::string filename = "bwapi-data/write/dragon-vs-" + escapeFilename(state->board()->get<std::string>(Blackboard::kEnemyNameKey, "unknown") + ".txt");
    FILE* f = fopen(filename.c_str(), "wb");
    if (!f) return;
    tokenlist.push_back(won ? "won" : "lost");
    for (auto& v : tokenlist) {
      std::string s = escape(v);
      fwrite(s.data(), s.size(), 1, f);
      if (v == "won" || v == "lost") {
        fprintf(f, "\n");
      } else {
        fprintf(f, " ");
      }
    }
    fclose(f);
  }

  void read(const State* state) {
    std::string filename = "bwapi-data/read/dragon-vs-" + escapeFilename(state->board()->get<std::string>(Blackboard::kEnemyNameKey, "unknown") + ".txt");
    FILE* f = fopen(filename.c_str(), "rb");
    if (!f) return;
    std::string s;
    size_t read = 0;
    s.resize(0x1000);
    size_t len = 0;
    while ((read = fread(s.data() + len, 1, s.size() - len, f)) > 0) {
      len += read;
      s.resize(len + 0x1000);
    }
    s.resize(len);
    fclose(f);

    const char* p = s.data();
    auto token = [&]() {
      if (!*p) return std::string();
      while (std::isspace(*p)) ++p;
      if (*p == '"') {
        ++p;
        std::string r;
        const char* s = p;
        while (*p && *p != '"') {
          if (*p == '\\') {
            r += std::string_view(s, p - s);
            ++p;
            if (*p == 'x' && p[1] && p[2]) {
              int n = std::strtoul(p + 1, (char**)&p, 16);
              r += char(n);
            }
            s = p;
          } else ++p;
        }
        r += std::string_view(s, p - s);
        if (*p == '"') ++p;
        s = p;
        tokenlist.push_back(r);
        return r;
      }
      const char* s = p;
      while (*p && !std::isspace(*p)) {
        ++p;
      }
      std::string r(s, p - s);
      tokenlist.push_back(r);
      return r;
    };

    std::vector<std::string> q;
    auto reward = [&](float val) {
      for (auto& v : q) {
        weight(v) = weight(v) * 0.75f + val * 0.25f;
        VLOG(0) << fmt::sprintf("weight for '%s' is now %g\n", v.c_str(), weight(v));
      }
      //VLOG(0) << " def is now " << weight("def");
      q.clear();
    };

    while (true) {
      auto n = token();
      if (n.empty()) break;
      VLOG(0) << " got token '" << n << "'";
      if (n == "won") reward(1.0f);
      else if (n == "lost") reward(-1.0f);
      else q.push_back(n);
    }
  }

  template<typename... T>
  std::string_view choose(T... choices) {
    return choose({choices...});
  }

  std::string_view choose(std::initializer_list<std::string_view> choices) {
    std::string_view r;
    float bestValue = -kfInfty;
    for (auto& v : choices) {
      float w = common::Rand::sample(std::uniform_real_distribution<float>(0.0f, std::exp(weight(v) * 6.0f)));
      VLOG(0) << " value for '" << v << "' is " << w << " (weight " << weight(v) << ")";
      if (w > bestValue) {
        r = v;
        bestValue = w;
      }
    }
    tokenlist.emplace_back(r);
    return r;
  }

};

}
