/*
 * Test boosting speed performance by optimizing matrix-vector multipilication
 *
 * 24-29x dim<16, 15-19x dim <256
 */


#define GENERATOR_BITs 64
#include "sz.h"

// =============================================================================
// Global parameters
// =============================================================================

bool printMatrices = false;


// =============================================================================
// Diadonal matrix multipilication elements
// =============================================================================

// -----------------------------------------------------------------------------
// Bit reversal, needed to reflect Z diagonals for JZJ
// -----------------------------------------------------------------------------

inline uint64_t bitReverse(uint64_t x) {
    x = ((x & 0x5555555555555555ull) << 1)|((x >> 1) & 0x5555555555555555ull);
    x = ((x & 0x3333333333333333ull) << 2)|((x >> 2) & 0x3333333333333333ull);
    x = ((x & 0x0f0f0f0f0f0f0f0full) << 4)|((x >> 4) & 0x0f0f0f0f0f0f0f0full);
    x = ((x & 0x00ff00ff00ff00ffull) << 8)|((x >> 8) & 0x00ff00ff00ff00ffull);
    x = ((x & 0x0000ffff0000ffffull) <<16)|((x >>16) & 0x0000ffff0000ffffull);
    x = (x << 32) | (x >> 32);  // ✅ final 32-bit swap, no mask needed
    return x;
}

// -----------------------------------------------------------------------------
// Multiplication by 4x4-block Pascal matrix with bit-reversal to give PJ
// -----------------------------------------------------------------------------

inline uint64_t PJ4(uint64_t i) {
    uint64_t j;
    j= i & 0xFFFFFFFF00000000; i= j^ (j>>32)^ ((i<<32)                     );
    j= i & 0xFFFF0000FFFF0000; i= j^ (j>>16)^ ((i<<16) & 0xFFFF0000FFFF0000);
    j= i & 0xFF00FF00FF00FF00; i= j^ (j>> 8)^ ((i<< 8) & 0xFF00FF00FF00FF00);
    j= i & 0xF0F0F0F0F0F0F0F0; i= j^ (j>> 4)^ ((i<< 4) & 0xF0F0F0F0F0F0F0F0);
    // -------------------------------------------------------------------------
    // Last two stages are just bit reversal.
    // -------------------------------------------------------------------------
    i = ((i & 0x3333333333333333) << 2)|((i >> 2) & 0x3333333333333333);
    i = ((i & 0x5555555555555555) << 1)|((i >> 1) & 0x5555555555555555);
    return i;
}

// -----------------------------------------------------------------------------
// Multiplication by 8x8-block Pascal matrix with bit-reversal to give PJ
// -----------------------------------------------------------------------------

inline uint64_t PJ8(uint64_t i) {                                               // 4x4 block Pascal with bit reversal
    uint64_t j;
    j= i & 0xFFFFFFFF00000000; i= j^ (j>>32)^ ((i<<32)                     );
    j= i & 0xFFFF0000FFFF0000; i= j^ (j>>16)^ ((i<<16) & 0xFFFF0000FFFF0000);
    j= i & 0xFF00FF00FF00FF00; i= j^ (j>> 8)^ ((i<< 8) & 0xFF00FF00FF00FF00);
    // -------------------------------------------------------------------------
    // Last three stages are just bit reversal.
    // -------------------------------------------------------------------------
    i = ((i & 0x0F0F0F0F0F0F0F0F) << 4)|((i >> 4) & 0x0F0F0F0F0F0F0F0F);
    i = ((i & 0x3333333333333333) << 2)|((i >> 2) & 0x3333333333333333);
    i = ((i & 0x5555555555555555) << 1)|((i >> 1) & 0x5555555555555555);
    return i;
}

// -----------------------------------------------------------------------------
// actual diagonal matrix multipilication for 4x4 and 8x8 block-diagonal matrix
// -----------------------------------------------------------------------------

inline uint64_t MxV4(uint64_t *d, uint64_t x) {
    return (
        ((x >> 3) & d[0]) ^
        ((x >> 2) & d[1]) ^
        ((x >> 1) & d[2]) ^
        ( x       & d[3]) ^
        ((x << 1) & d[4]) ^
        ((x << 2) & d[5]) ^
        ((x << 3) & d[6])
    );
}

inline uint64_t MxV8(uint64_t *d, uint64_t x) {
    return (
        ((x >> 7) & d[ 0]) ^
        ((x >> 6) & d[ 1]) ^
        ((x >> 5) & d[ 2]) ^
        ((x >> 4) & d[ 3]) ^
        ((x >> 3) & d[ 4]) ^
        ((x >> 2) & d[ 5]) ^
        ((x >> 1) & d[ 6]) ^
        ( x       & d[ 7]) ^
        ((x << 1) & d[ 8]) ^
        ((x << 2) & d[ 9]) ^
        ((x << 3) & d[10]) ^
        ((x << 4) & d[11]) ^
        ((x << 5) & d[12]) ^
        ((x << 6) & d[13]) ^
        ((x << 7) & d[14])
    );
}

// -----------------------------------------------------------------------------
// Complete mutiplication by SPZ decomposed matrix
// -----------------------------------------------------------------------------

inline uint64_t mulFast4(uint64_t *d, uint64_t seqNo) {
    return MxV4(&d[7], PJ4(MxV4(&d[0], seqNo)));
}

inline uint64_t mulFast8(uint64_t *d, uint64_t seqNo) {
    return MxV8(&d[15], PJ8(MxV8(&d[0], seqNo)));
}

template <int m>
inline uint64_t mulFast(uint64_t* d, uint64_t seqNo) {
    if constexpr (m == 4) {
        return mulFast4(d, seqNo);
    } else if constexpr (m == 8) {
        return mulFast8(d, seqNo);
    } else {
        static_assert(m == 4 || m == 8, "Only m=4 or m=8 are supported.");
    }
}

// =============================================================================
// SPZ decomposition of an SZ generator matrix
// =============================================================================

// -----------------------------------------------------------------------------
// Decomposition
// -----------------------------------------------------------------------------

TupleR SPZ(const MatrixR &A, int mBlock) {
    int m = A.getm();
    MatrixR S = MatrixR::IBlock(m, A.getBlock(0, 0, mBlock));
    MatrixR P = S.inverse() * A;
    MatrixR Z(m), tmp(m);
    for (int i = 0; i < m; i += mBlock) {
        Z = Z + tmp.putBlock(i, i, P.getBlock(0, i, mBlock));
    }
    P = Z * P * Z.inverse();
    S = S * Z.inverse();
    return {S, P, Z};
}

// -----------------------------------------------------------------------------
// Extract diagonals from decomposition
// -----------------------------------------------------------------------------

std::vector<uint64_t> szDiagonals(const TupleR &spz, int mBlock) {
    std::vector<uint64_t> result;
    result.reserve(2 * (2 * mBlock - 1));
    // -------------------------------------------------------------------------
    // Z is multiplied first, hence put it first
    // It is before bit reversal, hence we actually deal with JZJ
    // so that actual bit reversal is done in the P factor
    // -------------------------------------------------------------------------
    for (int i = mBlock - 1; i >= 1 - mBlock; i--) {                                      // Scan diagonals right to left to account for right-hand-side J
        result.push_back( bitReverse( spz[2].getDiagonal(i) ) );                  // And bottom to top to account for left-hand-side J
    }
    // -------------------------------------------------------------------------
    // Now insert S diagonals in their normal ordering
    // -------------------------------------------------------------------------
    for (int i = 1 - mBlock; i <= mBlock - 1; i++) {
        result.push_back(spz[0].getDiagonal(i));
    }
    return result;
}

// -----------------------------------------------------------------------------
// Print diagonals table if needed decomposition
// -----------------------------------------------------------------------------

void printDiagonalsTable(const TupleR &G, FILE *file = stdout) {
    int N = 16 * 14;
    std::vector<uint64_t> table(N);
    for (int i = 0; i < 14; i++) table[i] = 0;                                  // Place holder for dim0
    for (int dim = 1, i = 14; dim < 16; dim++) {
        std::vector<uint64_t> diagonals = szDiagonals(SPZ(G[dim], 4), 4);
        for (int d = 0; d < 14; d++) {
            table[i++] = diagonals[d];
        }
    }
    for (int i = 0; i < N; i++) {
        if (i % 3 == 0) fprintf(file, "\n   ");
        fprintf(file, " 0x%016llx", table[i]);
        if (i < N - 1) fprintf(file, ",");
    }
    fprintf(file, "\n");
}

// =============================================================================
// Column matrix class to implement PBRT multipilication for comparison
// =============================================================================

class MatrixC : public MatrixCommon<MatrixC> {
private:
    int m;
    volatile uint64_t cols[64];                                                 // Col[0] is leftmost, MSB is row 0
public:
    MatrixC(const MatrixR &rowMatrix);                                          // Transpose a row matrix
    // -------------------------------------------------------------------------
    // Querying
    // -------------------------------------------------------------------------
    uint64_t getCol(int j) const { return cols[j]; };                           // Retrieve a column
    uint64_t getBit(int i, int j) const { return (cols[j] >> (63 - i)) & 1; };  // Retrieve bit in ith row, jth column, both indexed from 0.
    int getm() const { return m; };
    // -------------------------------------------------------------------------
    // Drawing samples via linear vector multiplication
    // -------------------------------------------------------------------------
    inline uint64_t operator[](uint64_t seqNo) const;                           // seqNo is normally ordered, not bit reversed
};

// -----------------------------------------------------------------------------
// Constructor
// -----------------------------------------------------------------------------

MatrixC::MatrixC(const MatrixR &rowMatrix) {
    m = rowMatrix.getm();
    for (int j = 0; j < m; j++) {
        cols[j] = 0;                                                            // Clear
        for (int i = 0; i < m; i++) {
            cols[j] |= rowMatrix.getBit(i, j) << (63 - i);
        }
    }
    for (int j = m; j < 64; j++) cols[j] = 0;                                   // Clear remaining rows
}

// -----------------------------------------------------------------------------
// Drawing samples via linear vector multiplication
// -----------------------------------------------------------------------------

inline uint64_t MatrixC::operator[](uint64_t seqNo) const {
    uint64_t result(0);
    for (int i = 0; seqNo; i++, seqNo >>= 1) {
        if (seqNo & 1) result ^= cols[i];                                       // We could use multiplication, but using if as in pbrt
    }
    return result;
}

// =============================================================================
// Speed test comparison of diagonal vs column-wise multipilication
// =============================================================================

template <int m>
void benchmark(MatrixR R, int log2N) {
    uint64_t N = 1llu << log2N;
    MatrixC C(R);                                                               // Convert to column matrix
    TupleR spz = SPZ(R, m);
    MatrixR product = spz[0] * spz[1] * spz[2];
    if (product != R) {
        printf("SPZ decomposition failed; aborting test\n");
        return;
    }
    std::vector<uint64_t> d = szDiagonals(spz, m);                              // Extract diagonals
    // -------------------------------------------------------------------------
    // Display matrices if needed
    // -------------------------------------------------------------------------
    if (printMatrices) {
        MatrixR::printf(                                                        // Display for visual check
            {R, spz[0], spz[1], spz[2], product, product - R}, stdout,
            "Input, S, P, Z, product, difference from input:"
        );
        C.printf(stdout, "Converted column matrix C:");                         // Print out the matrix for visual checking
    }
    // -------------------------------------------------------------------------
    // Generate the samples
    // -------------------------------------------------------------------------
//     uint64_t M = 1llu << 24;
//     std::vector<uint64_t> list(M), list1(N), list2(N);
//     for (uint64_t i = 0; i < M; i++) list[i] = rnd() & (N - 1);
//     uint64_t checksum(0);                                                       // To enforce actual computation
//     clock_t clock0 = clock();
//     for (uint64_t i = 0; i < N; i++) checksum ^= C[list[i]];// C[i];
//     clock_t clock1 = clock();
//     for (uint64_t i = 0; i < N; i++) checksum ^= mulFast<m>(&d[0], list[i]);// mulFast<m>(&d[0], i);
//     clock_t clock2 = clock();

//     clock_t clock0 = clock();
//     for (uint64_t i = 0; i < M; i++) checksum ^= C[list[i]];// C[i];
//     clock_t clock1 = clock();
//     for (uint64_t i = 0; i < M; i++) checksum ^= mulFast<m>(&d[0], list[i]);// mulFast<m>(&d[0], i);
//     clock_t clock2 = clock();


    std::vector<uint64_t> list1(N), list2(N);
    clock_t clock0 = clock();
    for (uint64_t i = 0; i < N; i++) list1[i] ^= C[i];
    clock_t clock1 = clock();
    for (uint64_t i = 0; i < N; i++) list2[i] ^= mulFast<m>(&d[0], i);
    clock_t clock2 = clock();


    double tPBRT = (double)(clock1 - clock0) / CLOCKS_PER_SEC;
    double tSPZ = (double)(clock2 - clock1) / CLOCKS_PER_SEC;
    // -------------------------------------------------------------------------
    // Perform a random validation test; will also enforce actual computatiion
    // -------------------------------------------------------------------------
    uint64_t seqNo(rnd() % N);
    //uint64_t smplPBRT(C[seqNo]), smpleSPZ(mulFast<m>(&d[0], seqNo));
    uint64_t smplPBRT(list1[seqNo]), smpleSPZ(list2[seqNo]);
    // -------------------------------------------------------------------------
    // Print out result
    // -------------------------------------------------------------------------
    printf(
        "Generated 2^%2d samples; "
        "times PBRT / SPZ = %10.4f / %10.4f = %5.2fx boost; "
        "random validation: sample[%20llu] = %016llx %s %016llx\n",
//         "checksum = %d\n",
        log2N,
        tPBRT, tSPZ, tPBRT / tSPZ,
        seqNo, smplPBRT, smpleSPZ == smplPBRT ? "==" : "!=", smpleSPZ
//         checksum
    );
}

// =============================================================================
// Main
// =============================================================================

const char *USAGE_MESSAGE = "Usage: %s [options]\n"
"Options:\n"
" -v                Verbose: print matrix decomposition\n"
" -m <min>          Log2 of minmum point count, default is 20\n"
" -M <max>          Log2 of maximum point count, default is 32\n"
;

int main(int argc,char **argv) {
    int opt;
    int mMin(20), mMax(32);
    while ((opt = getopt(argc, argv, "vm:M:")) != -1) {
        switch (opt) {
            case 'v': printMatrices = true; break;
            case 'm': mMin = atoi(optarg); break;
            case 'M': mMax = atoi(optarg); break;
            default: fprintf(stderr, USAGE_MESSAGE, argv[0]); exit(1);
        }
    }
    if (optind > argc - 0) {
        fprintf(stderr, USAGE_MESSAGE, argv[0]);
        exit(1);
    }
    TupleR G = SZ::makeNested(1, 3, true, true);                               // Ensembled randomized

    printf("Benchmarking 4x4-block matrices:\n");
    for (int m = mMin; m <= mMax; m++) {
        benchmark<4>(G[7], m);                                                  // Any matrix in {1..15} will do the job
    }
    printf("Benchmarking 8x8-block matrices:\n");
    for (int m = mMin; m <= mMax; m++) {
        benchmark<8>(G[17], m);                                                 // Any matrix in {16..255} will do the job
    }
}


