/*
 * Scan matrices for closed invertible additive groups.
 *
 */

#ifndef ALPHABET_H
#define ALPHABET_H

#include "matrix.h"
#include <stdlib.h>
#include <stdio.h>
#include <stdint.h>
#include <vector>
#include <string>
#include <getopt.h>
#include <x86intrin.h>
#include <algorithm>

// =============================================================================
// Alphabet class
// =============================================================================

class Alphabet {
private:
    int m;                                                                      // Size of matrix
    int N;                                                                      // Number of elements/dimensions
    TupleR t;                                                                   // Tuple of matrices
    int n;                                                                      // Number of valid elements
    inline bool findOrInsert(const MatrixR &X);                                 // Try to insert a cnddt if it passes initial screening of an empty slot, return false on failure
public:
    Alphabet(int m, TupleR seed = {});                                          // Initialize to a given tuple, or return a stub if seed is incosistent
    static Alphabet getRandom(int m, uint64_t maxAttempts = 0xFFFFFFFF);        // Construct a random alphabet, or give up after maxAttempts attempts
    bool complete(uint64_t maxAttempts = 0xFFFFFFFF);                           // Try to complete the alphabet by randomly searching for a primitive, try no more than maxAttempts times
    Alphabet upsample(bool andComplete = true);                                 // Generate a 2m alphabet that nests this, and complete it unless instructed not to
    void sortNest(int q0 = 1);                                                  // Sort alphabet to make it ready for nesting, q0 is atom bit resolution
    Alphabet sortAlpha();                                                       // Sort as powers of a primitive and return a new alphabet, leave this intact
    void shuffle(int i0 = 2);                                                   // Shuffle the elements from i0 on, default 2 leaves 0 and I in place
    int getm() { return m; }
    int getn() { return n; };
    inline bool operator+=(const MatrixR &cnddt);                               // Try to add cnddt to tuple, and return false if it fails
    inline bool operator==(const Alphabet &other) const;                        // Equality test
    operator bool() const { return n == N; };                                   // Validity test
    operator TupleR() const {
        return TupleR(std::begin(t), std::begin(t) + n);
    }
    TupleR primitives() const;                                                  // Return sorted list of primitives
    TupleR generators(int mOut = 64);                                           // Construct generator matrices from alphabet
    MatrixR getCap(int i, int mOut = 64);                                       // Returns an IBlock of the triangular cap of the ith generator
};

// -----------------------------------------------------------------------------
// Create an alphabet framework and insert invariant elements
// -----------------------------------------------------------------------------

Alphabet::Alphabet(int m, TupleR seed): m(m), N(1 << m), n(2) {                               // Initialize to 0 and I, and mark their slots
    t.resize(N);
    t[0] = MatrixR(m);
    t[1] = MatrixR::I(m);
    bool allValid = true;                                                       // Validity check for seed
    for (int i = 0; i < seed.size(); i++) {
        allValid = allValid && (*this += seed[i]);
    }
    if (!allValid) { n = 2; return; }                                           // Return stub.
    if (seed.size() && seed[0] == t[0] && seed[1] == t[1]) {                    // Restore the order of seed
        for (int i = 2; i < seed.size(); i++) {
            for (int j = 0; j < n && t[i] != seed[i]; j++) {
                if (t[j] == seed[i]) {
                    std::swap(t[i], t[j]);
                }
            }
        }
    }
}

// -----------------------------------------------------------------------------
// Construct a random alphabet, try as much as 'attempts' times
// -----------------------------------------------------------------------------

Alphabet Alphabet::getRandom(int m, uint64_t maxAttempts) {
    for (uint64_t attempt = 0; attempt < maxAttempts; attempt++) {
        Alphabet alphabet{m, {MatrixR::createInvertible(m)}};
        if (alphabet) {
            //fprintf(stderr, "Hit a primitive in %u attempts\n", attempt);       // Comment out as needed
            return alphabet;
        }
    }
    fprintf(stderr, "Failed to create the required alphabet; Aborting!\n");
    exit(1);
}

// -----------------------------------------------------------------------------
// Try to complete by randomly searching for a primitive
// -----------------------------------------------------------------------------

bool Alphabet::complete(uint64_t maxAttempts) {
    for (uint64_t attempt = 0; attempt < maxAttempts; attempt++) {
        Alphabet alphabet = *this;
        alphabet += MatrixR::createInvertible(m);
        if (alphabet) {
            //fprintf(stderr, "Hit a primitive in %u attempts\n", attempt);       // Comment out as needed
            *this = alphabet;
            return true;
        }
    }
    fprintf(stderr, "Alphabet completion failed\n");                            // Comment out as needed
    return false;
}


// -----------------------------------------------------------------------------
// Try to insert a candidate if it passes initial screening of and empty slot
// Please note that any row may be used for indexing the slots
// -----------------------------------------------------------------------------

bool Alphabet::findOrInsert(const MatrixR &X) {
    if (X.getm() != m) { return false; }                                        // Unequally dimension matrix: no way!
    for (int i = 0; i < n; i++) {
        if (X == t[i]) return true;                                             // Found, so insertion is acceptable
        if (X.getRow(0) == t[i].getRow(0)) return false;                        // Not found but slot taken by a different symbol
    }
    // -------------------------------------------------------------------------
    // Not found and slot empty:
    // -------------------------------------------------------------------------
    t[n++] = X;
    return true;
}

// -----------------------------------------------------------------------------
// Try to add candidate to tuple, and return false if it fails
// -----------------------------------------------------------------------------

bool Alphabet::operator+=(const MatrixR &cnddt) {
    if (cnddt.getm() != m) { return false; }                                    // Unequally dimension matrix: no way!
    if (n == N) {                                                               // Full occupacy, then cnddt must already be there to be accepted
        for (int i = 0; i < n; i++) {
            if (cnddt == t[i]) { return true; }                                 // Consider addition successful if the element is already there
        }
        return false;                                                           // Otherwise, fail
    }
    int nFallBack(n);                                                           // Save number of elements for fall back
    bool pass = findOrInsert(cnddt);                                            // Place on top of list if not occupied
    for (int i = nFallBack; (i < n) && pass; i++) {                             // Iterate through newly added matrices to validate them
        MatrixR X = t[i];                                                       // Retrieve a matrix
        pass = (
            X.isInvertible()                                                    // Minimum requirement after screening, more tests apply to derivatives
            && findOrInsert(X.inverse())                                        // First derivative is invert itself
        );
        for (int j = 0; (j < i) && pass; j++) {                                 // Iterate through all preceding entries
            pass = (
                pass
                && findOrInsert(X + t[j])                                       // Sum to each preceding element
                && findOrInsert(X * t[j])                                       // Right multiplication by each preceding element
                && findOrInsert(t[i] * X)                                       // Left multiplication by each preceding element
            );
        }
    }
    if (!pass) {                                                                // If it fails at any step
        n = nFallBack;                                                          // Just ignore the inserted elements and return counter to initial set
    }
    return pass;
}

// -----------------------------------------------------------------------------
// Return sorted list of primitives
// -----------------------------------------------------------------------------

TupleR Alphabet::primitives() const {
    if (n != N) return {};                                                      // We only deal with complete alphabets
    if (N == 2) return {t[1]};                                                  // Identity is the primitive for alphabet<1>
    TupleR result;
    for (int i = 2; i < n; i++) {
        int minRootIndex(1);
        for (MatrixR Power = t[i]; Power != t[1]; Power = t[i] * Power) {
            minRootIndex++;
        }
        if (minRootIndex == N - 1) {
            result.push_back(t[i]);
        }
    }
    std::sort(result.begin(), result.end());
    return result;
}

// -----------------------------------------------------------------------------
// Equality test
// -----------------------------------------------------------------------------

bool Alphabet::operator==(const Alphabet &other) const {
    if (n != other.n) {
        return false;
    }
    return primitives()[0] == other.primitives()[0];
}

// -----------------------------------------------------------------------------
// Generate a 2m alphabet that nests this
// -----------------------------------------------------------------------------

Alphabet Alphabet::upsample(bool andComplete) {
    TupleR seed;
    for (int i = 0; i < n; i++) {
        MatrixR aSq = t[i] * t[i];
        seed.push_back(
            MatrixR(2 * m)
            .putBlock(0, 0, aSq)                                                // Insert nested symbol in top-left
            .putBlock(m, m, aSq)                                                // and bottom-right
        );
    }
    Alphabet base(2 * m);
    base.n = N;
    for (int i = 2; i < N; i++) {
        MatrixR aSq = t[i] * t[i];
        base.t[i] = MatrixR(2 * m).putBlock(0, 0, aSq).putBlock(m, m, aSq);
    }
    if (!andComplete) { return base; }                                          // That's it
    int mask = msk(m);
    for (int attempt = 0; attempt < 1000000; attempt++) {                       // Supposedly a million attempt would be enough
        uint64_t i = rnd();
        uint64_t i0(i & mask), i1((i >> m) & mask), i2((i >> (2*m)) & mask);         // Three random indices over the nested alphabet
        if (!i0 || !i1 || !i2) continue;                                        // A 0 in one of the three quadrants won't work
        MatrixR cnddt = (                                                       // Create a candidate primitive
            MatrixR(2 * m)
            .putBlock(0, 0, t[i0])                                              // Top-left quadrant
            .putBlock(0, m, t[i1])                                              // Top-right quadrant
            .putBlock(m, 0, t[i2])                                              // Bottom-left quadrant
        );
        Alphabet result = base;
        result += cnddt;
        if (result) {                                                           // Completed successfully?
//             fprintf(
//                 stderr, "Upsampling %dto%d succeeded in %d attempts\n",
//                 m, 2 * m, attempt
//             );
            //result.shuffle(N);
            //result.sortNest();
            result.sortNest(m);
            return result;
        }
    }
    fprintf(stderr, "Upsampling %d-to-%d failed\n", m, 2 * m);
    exit(1);
}


// -----------------------------------------------------------------------------
// Sort to make ready for nesting
// -----------------------------------------------------------------------------

void Alphabet::sortNest(int q0) {
    if (n != N) { return; }                                                     // Sorry, we only accept complete alphabets
    for (int q = q0; q < m; q++) {
        int nElements = 1 << q;
        MatrixR keyElement = t[nElements];                                      // First element in second band
        for (int i = 1; i < nElements; i++) {                                   // Iterate through elements of first band
            int slotIndex = i + nElements;                                      // Slots in the second band
            MatrixR slotOwner = keyElement + t[i];                              // Slot should be occupied by this element
            for (int j = slotIndex; j < n; j++) {                               // Scan the slots
                if (t[j] == slotOwner) {
                    std::swap(t[j], t[slotIndex]);
                    break;
                }
            }
        }
    }
}

// -----------------------------------------------------------------------------
// Sort as powers of a primitive to enable fair comparison
// -----------------------------------------------------------------------------

Alphabet Alphabet::sortAlpha() {
    Alphabet result = *this;
    if (n == N) {                                                               // Alphabet complete?
        MatrixR alpha = primitives()[0];
        for (int i = 2; i < N; i++) {
            result.t[i] = alpha * result.t[i - 1];
        }
    }
    return result;
}

// -----------------------------------------------------------------------------
// Construct generator matrices from alphabet
// -----------------------------------------------------------------------------

TupleR Alphabet::generators(int mOut){
    TupleR result;
    for (int i = 0; i < n; i++) {                                               // Iterate through symbols
        MatrixR L = MatrixR::I(mOut).shift(m, 0) + MatrixR::IBlock(mOut, t[i]); // Lower-triangular for recurrence formula
        MatrixR col = MatrixR(mOut).putBlock(0, 0, t[1]);                       // Symbol on first element
        MatrixR C(mOut);                                                        // Start with an empty matrix
        for (int j = 0; j < mOut; j += m) {
            C = C + col.shift(0, j);
            col = L * col;
        }
        result.push_back(C);
    }
    return result;
}

// -----------------------------------------------------------------------------
// Shuffle the elements, leaving 0 and I in place
// -----------------------------------------------------------------------------

void Alphabet::shuffle(int i0) {
    for (int i = n - 1; i > i0; i--) {                                          // Iterate down the list through move-able, and swap each place with another place down, inclusive of same place.
        int r = i0 + rnd() % (i + 1 - i0);                                      // (i + 1) means that the same place is included.
        std::swap(t[i], t[r]);
    }
}

// -----------------------------------------------------------------------------
//
// -----------------------------------------------------------------------------

MatrixR Alphabet::getCap(int i, int mOut){
    return MatrixR::IBlock(mOut, MatrixR::I(2*m).putBlock(0, m, t[i]));
}

// // =============================================================================
// // Print arithmetic table
// // =============================================================================
//
// template <int m>
// void printTable(
//     TupleR t,
//     int tableType = 0,                                                          // 0: addition, 1: multiplication, 2: inversion
//     FILE *file = stdout
// ) {
//     const char *title[] = {"Addition:", "Multiplication:", "Inversion:"};
//     //printf("%s:\n", title[tableType]);
//     int n = t.size();
//     MatrixR::printf(t, file, title[tableType], 6);
//     std::vector<int> table(n * n);
//     for (int i = 0; i < n; i++) {
//         TupleR row;
//         row.push_back(t[i]);                                                    // Key column
//         for (int j = 0; j < n; j++) {
//             MatrixR result;
//             switch(tableType) {
//                 case 0: result = t[i] + t[j]; break;
//                 case 1: result = t[i] * t[j]; break;
//                 case 2: result = (t[i] * t[j] == t[1] ? t[1] : t[0]); break;    // Assuming t[0] is 0 and t[1] is I
//             }
//             row.push_back(result);                                              // Insert actual matrix in row
//             table[i * n + j] = findOrAbort(result, t);                          // Insert id in sympolic table
//         }
//         MatrixR::printf(row, file, "\n");                                  // Display resultant row
//     }
//     fprintf(file, "\nSymbolic Table:\n  ");
//     for (int i = 0; i < n; i++) printf("%3X", i);
//     fprintf(file, "\n");
//     for (int i = 0; i < n; i++) {
//         fprintf(file, "%2X", i);
//         for (int j = 0; j < n; j++) {
//             fprintf(file, "%3X", table[i * n + j]);
//         }
//         fprintf(file, "\n");
//     }
//     fprintf(file, "\n\n");
// }
//
//

#endif                                                                          // #ifndef ALPHABET_H
