//===-- High Precision Decimal ----------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See httpss//llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LIBC_SRC_SUPPORT_HIGH_PRECISION_DECIMAL_H
#define LIBC_SRC_SUPPORT_HIGH_PRECISION_DECIMAL_H

#include "src/__support/ctype_utils.h"
#include "src/__support/str_conv_utils.h"
#include <stdint.h>

namespace __llvm_libc {
namespace internal {

struct LShiftTableEntry {
  uint32_t newDigits;
  char const *powerOfFive;
};

// This is based on the HPD data structure described as part of the Simple
// Decimal Conversion algorithm by Nigel Tao, described at this link:
// https://nigeltao.github.io/blog/2020/parse-number-f64-simple.html
class HighPrecsisionDecimal {

  // This precomputed table speeds up left shifts by having the number of new
  // digits that will be added by multiplying 5^i by 2^i. If the number is less
  // than 5^i then it will add one fewer digit. There are only 60 entries since
  // that's the max shift amount.
  // This table was generated by the script at
  // libc/utils/mathtools/GenerateHPDConstants.py
  static constexpr LShiftTableEntry LEFT_SHIFT_DIGIT_TABLE[] = {
      {0, ""},
      {1, "5"},
      {1, "25"},
      {1, "125"},
      {2, "625"},
      {2, "3125"},
      {2, "15625"},
      {3, "78125"},
      {3, "390625"},
      {3, "1953125"},
      {4, "9765625"},
      {4, "48828125"},
      {4, "244140625"},
      {4, "1220703125"},
      {5, "6103515625"},
      {5, "30517578125"},
      {5, "152587890625"},
      {6, "762939453125"},
      {6, "3814697265625"},
      {6, "19073486328125"},
      {7, "95367431640625"},
      {7, "476837158203125"},
      {7, "2384185791015625"},
      {7, "11920928955078125"},
      {8, "59604644775390625"},
      {8, "298023223876953125"},
      {8, "1490116119384765625"},
      {9, "7450580596923828125"},
      {9, "37252902984619140625"},
      {9, "186264514923095703125"},
      {10, "931322574615478515625"},
      {10, "4656612873077392578125"},
      {10, "23283064365386962890625"},
      {10, "116415321826934814453125"},
      {11, "582076609134674072265625"},
      {11, "2910383045673370361328125"},
      {11, "14551915228366851806640625"},
      {12, "72759576141834259033203125"},
      {12, "363797880709171295166015625"},
      {12, "1818989403545856475830078125"},
      {13, "9094947017729282379150390625"},
      {13, "45474735088646411895751953125"},
      {13, "227373675443232059478759765625"},
      {13, "1136868377216160297393798828125"},
      {14, "5684341886080801486968994140625"},
      {14, "28421709430404007434844970703125"},
      {14, "142108547152020037174224853515625"},
      {15, "710542735760100185871124267578125"},
      {15, "3552713678800500929355621337890625"},
      {15, "17763568394002504646778106689453125"},
      {16, "88817841970012523233890533447265625"},
      {16, "444089209850062616169452667236328125"},
      {16, "2220446049250313080847263336181640625"},
      {16, "11102230246251565404236316680908203125"},
      {17, "55511151231257827021181583404541015625"},
      {17, "277555756156289135105907917022705078125"},
      {17, "1387778780781445675529539585113525390625"},
      {18, "6938893903907228377647697925567626953125"},
      {18, "34694469519536141888238489627838134765625"},
      {18, "173472347597680709441192448139190673828125"},
      {19, "867361737988403547205962240695953369140625"},
  };

  // The maximum amount we can shift is the number of bits used in the
  // accumulator, minus the number of bits needed to represent the base (in this
  // case 4).
  static constexpr uint32_t MAX_SHIFT_AMOUNT = sizeof(uint64_t) - 4;

  // 800 is an arbitrary number of digits, but should be
  // large enough for any practical number.
  static constexpr uint32_t MAX_NUM_DIGITS = 800;

  uint32_t numDigits = 0;
  int32_t decimalPoint = 0;
  bool truncated = false;
  uint8_t digits[MAX_NUM_DIGITS];

private:
  bool shouldRoundUp(uint32_t roundToDigit) {
    if (roundToDigit < 0 || roundToDigit >= this->numDigits) {
      return false;
    }

    // If we're right in the middle and there are no extra digits
    if (this->digits[roundToDigit] == 5 &&
        roundToDigit + 1 == this->numDigits) {

      // Round up if we've truncated (since that means the result is slightly
      // higher than what's represented.)
      if (this->truncated) {
        return true;
      }

      // If this exactly halfway, round to even.
      return this->digits[roundToDigit - 1] % 2 != 0;
    }
    // If there are digits after roundToDigit, they must be non-zero since we
    // trim trailing zeroes after all operations that change digits.
    return this->digits[roundToDigit] >= 5;
  }

  // Takes an amount to left shift and returns the number of new digits needed
  // to store the result based on LEFT_SHIFT_DIGIT_TABLE.
  uint32_t getNumNewDigits(uint32_t lShiftAmount) {
    const char *powerOfFive = LEFT_SHIFT_DIGIT_TABLE[lShiftAmount].powerOfFive;
    uint32_t newDigits = LEFT_SHIFT_DIGIT_TABLE[lShiftAmount].newDigits;
    uint32_t digitIndex = 0;
    while (powerOfFive[digitIndex] != 0) {
      if (digitIndex >= this->numDigits) {
        return newDigits - 1;
      }
      if (this->digits[digitIndex] != powerOfFive[digitIndex] - '0') {
        return newDigits -
               ((this->digits[digitIndex] < powerOfFive[digitIndex] - '0') ? 1
                                                                           : 0);
      }
      ++digitIndex;
    }
    return newDigits;
  }

  // Trim all trailing 0s
  void trimTrailingZeroes() {
    while (this->numDigits > 0 && this->digits[this->numDigits - 1] == 0) {
      --this->numDigits;
    }
    if (this->numDigits == 0) {
      this->decimalPoint = 0;
    }
  }

  // Perform a digitwise binary non-rounding right shift on this value by
  // shiftAmount. The shiftAmount can't be more than MAX_SHIFT_AMOUNT to prevent
  // overflow.
  void rightShift(uint32_t shiftAmount) {
    uint32_t readIndex = 0;
    uint32_t writeIndex = 0;

    uint64_t accumulator = 0;

    const uint64_t shiftMask = (uint64_t(1) << shiftAmount) - 1;

    // Warm Up phase: we don't have enough digits to start writing, so just
    // read them into the accumulator.
    while (accumulator >> shiftAmount == 0) {
      uint64_t readDigit = 0;
      // If there are still digits to read, read the next one, else the digit is
      // assumed to be 0.
      if (readIndex < this->numDigits) {
        readDigit = this->digits[readIndex];
      }
      accumulator = accumulator * 10 + readDigit;
      ++readIndex;
    }

    // Shift the decimal point by the number of digits it took to fill the
    // accumulator.
    this->decimalPoint -= readIndex - 1;

    // Middle phase: we have enough digits to write, as well as more digits to
    // read. Keep reading until we run out of digits.
    while (readIndex < this->numDigits) {
      uint64_t readDigit = this->digits[readIndex];
      uint64_t writeDigit = accumulator >> shiftAmount;
      accumulator &= shiftMask;
      this->digits[writeIndex] = static_cast<uint8_t>(writeDigit);
      accumulator = accumulator * 10 + readDigit;
      ++readIndex;
      ++writeIndex;
    }

    // Cool Down phase: All of the readable digits have been read, so just write
    // the remainder, while treating any more digits as 0.
    while (accumulator > 0) {
      uint64_t writeDigit = accumulator >> shiftAmount;
      accumulator &= shiftMask;
      if (writeIndex < MAX_NUM_DIGITS) {
        this->digits[writeIndex] = static_cast<uint8_t>(writeDigit);
        ++writeIndex;
      } else if (writeDigit > 0) {
        this->truncated = true;
      }
      accumulator = accumulator * 10;
    }
    this->numDigits = writeIndex;
    this->trimTrailingZeroes();
  }

  // Perform a digitwise binary non-rounding left shift on this value by
  // shiftAmount. The shiftAmount can't be more than MAX_SHIFT_AMOUNT to prevent
  // overflow.
  void leftShift(uint32_t shiftAmount) {
    uint32_t newDigits = this->getNumNewDigits(shiftAmount);

    int32_t readIndex = this->numDigits - 1;
    uint32_t writeIndex = this->numDigits + newDigits;

    uint64_t accumulator = 0;

    // No Warm Up phase. Since we're putting digits in at the top and taking
    // digits from the bottom we don't have to wait for the accumulator to fill.

    // Middle phase: while we have more digits to read, keep reading as well as
    // writing.
    while (readIndex >= 0) {
      accumulator += static_cast<uint64_t>(this->digits[readIndex])
                     << shiftAmount;
      uint64_t nextAccumulator = accumulator / 10;
      uint64_t writeDigit = accumulator - (10 * nextAccumulator);
      --writeIndex;
      if (writeIndex < MAX_NUM_DIGITS) {
        this->digits[writeIndex] = static_cast<uint8_t>(writeDigit);
      } else if (writeDigit != 0) {
        this->truncated = true;
      }
      accumulator = nextAccumulator;
      --readIndex;
    }

    // Cool Down phase: there are no more digits to read, so just write the
    // remaining digits in the accumulator.
    while (accumulator > 0) {
      uint64_t nextAccumulator = accumulator / 10;
      uint64_t writeDigit = accumulator - (10 * nextAccumulator);
      --writeIndex;
      if (writeIndex < MAX_NUM_DIGITS) {
        this->digits[writeIndex] = static_cast<uint8_t>(writeDigit);
      } else if (writeDigit != 0) {
        this->truncated = true;
      }
      accumulator = nextAccumulator;
    }

    this->numDigits += newDigits;
    if (this->numDigits > MAX_NUM_DIGITS) {
      this->numDigits = MAX_NUM_DIGITS;
    }
    this->decimalPoint += newDigits;
    this->trimTrailingZeroes();
  }

public:
  // numString is assumed to be a string of numeric characters. It doesn't
  // handle leading spaces.
  HighPrecsisionDecimal(const char *__restrict numString) {
    bool sawDot = false;
    bool sawDigit = false;
    while (isdigit(*numString) || *numString == '.') {
      if (*numString == '.') {
        if (sawDot) {
          break;
        }
        this->decimalPoint = this->numDigits;
        sawDot = true;
      } else {
        sawDigit = true;
        if (*numString == '0' && this->numDigits == 0) {
          --this->decimalPoint;
          ++numString;
          continue;
        }
        if (this->numDigits < MAX_NUM_DIGITS) {
          this->digits[this->numDigits] = *numString - '0';
          ++this->numDigits;
        } else if (*numString != '0') {
          this->truncated = true;
        }
      }
      ++numString;
    }

    if (!sawDot) {
      this->decimalPoint = this->numDigits;
    }

    if ((*numString | 32) == 'e') {
      ++numString;
      if (isdigit(*numString) || *numString == '+' || *numString == '-') {
        int32_t addToExp = strtointeger<int32_t>(numString, nullptr, 10);
        if (addToExp > 100000) {
          addToExp = 100000;
        } else if (addToExp < -100000) {
          addToExp = -100000;
        }
        this->decimalPoint += addToExp;
      }
    }

    this->trimTrailingZeroes();
  }

  // Binary shift left (shiftAmount > 0) or right (shiftAmount < 0)
  void shift(int shiftAmount) {
    if (shiftAmount == 0) {
      return;
    }
    // Left
    else if (shiftAmount > 0) {
      while (static_cast<uint32_t>(shiftAmount) > MAX_SHIFT_AMOUNT) {
        this->leftShift(MAX_SHIFT_AMOUNT);
        shiftAmount -= MAX_SHIFT_AMOUNT;
      }
      this->leftShift(shiftAmount);
    }
    // Right
    else {
      while (static_cast<uint32_t>(shiftAmount) < -MAX_SHIFT_AMOUNT) {
        this->rightShift(MAX_SHIFT_AMOUNT);
        shiftAmount += MAX_SHIFT_AMOUNT;
      }
      this->rightShift(-shiftAmount);
    }
  }

  // Round the number represented to the closest value of unsigned int type T.
  // This is done ignoring overflow.
  template <class T> T roundToIntegerType() {
    T result = 0;
    uint32_t curDigit = 0;

    while (static_cast<int32_t>(curDigit) < this->decimalPoint &&
           curDigit < this->numDigits) {
      result = result * 10 + (this->digits[curDigit]);
      ++curDigit;
    }

    // If there are implicit 0s at the end of the number, include those.
    while (static_cast<int32_t>(curDigit) < this->decimalPoint) {
      result *= 10;
      ++curDigit;
    }
    if (this->shouldRoundUp(this->decimalPoint)) {
      ++result;
    }
    return result;
  }

  // Extra functions for testing.

  uint8_t *getDigits() { return this->digits; }
  uint32_t getNumDigits() { return this->numDigits; }
  int32_t getDecimalPoint() { return this->decimalPoint; }
  void setTruncated(bool trunc) { this->truncated = trunc; }
};

} // namespace internal
} // namespace __llvm_libc

#endif // LIBC_SRC_SUPPORT_HIGH_PRECISION_DECIMAL_H
