/*
 * 2D spectral plots
 */

//#define USE_DOUBLE

#include "sz.h"
#include "spectrum.h"
#include <vector>
#include <string>
#include <getopt.h>

// =============================================================================
// Gloabal parameters
// =============================================================================

bool doOwenScramble = false;
bool doXorScramble = false;

// =============================================================================
// Plot pairwise spectra averaged over a list of sequences, typically
//  instances of the same family of sequences.
// =============================================================================

void plotSpectra(
    const std::vector<Sequence> &seqs,                                          // A list of instances of the sequence
    std::vector<int> dimList,                                                   // List of pairs as a single array
    int N,                                                                      // Number of points in the generated set
    String path = "",                                                           // Path to output the plots, and/or name prefix
    int width = 0
) {
    clock_t t0 = clock();
    int n = seqs.size();                                                        // Number of sets
    if (!n) { return; }                                                         // Nothing to do with nothing!
    if (dimList.empty()) { dimList.push_back(seqs[0].getDims()); }              // All pairs
    if (dimList.size() == 1) {                                                  // A single number passed to indicate highest dim
        int nDim = dimList[0];
        dimList.resize(nDim * (nDim - 1));
        for (int dim1 = 0, i = 0; dim1 < nDim - 1; dim1++) {
            for (int dim2 = dim1 + 1; dim2 < nDim; dim2++) {
                dimList[i++] = dim1;
                dimList[i++] = dim2;
            }
        }
    }
    if (!width) width = 10 * sqrt(N);
    std::vector<PointFloat> pflt(n * N);                                        // Buffer for sets of 2D points to compute the spectrum
    Points p(N);
    PointFloat *p_gpu;
    uint32_t dataSize = n * N * sizeof(PointFloat);
    cudaMalloc(&p_gpu, dataSize);
    xmax = 1;
    for (int k = 0; k < dimList.size(); k += 2) {
        int dim1 = dimList[k    ];
        int dim2 = dimList[k + 1];
        // ---------------------------------------------------------------------
        // Construct the point sets
        // ---------------------------------------------------------------------
        for (int seqIndex = 0, i = 0; seqIndex < n; seqIndex++) {
            const Sequence &seq = seqs[seqIndex];
            const MatrixR &Cx = seq[dim1], &Cy = seq[dim2];
            for (int seqNo = 0; seqNo < N; seqNo++) {
                p[seqNo] = {uint32_t(Cx[seqNo]), uint32_t(Cy[seqNo])};
            }
            if (doOwenScramble) { owenScramble(p); };
            if (doXorScramble) { xorScramble(p); };
            for (int seqNo = 0; seqNo < N; seqNo++) {
                pflt[i++] = {fxd2flt( p[seqNo].x ), fxd2flt( p[seqNo].y )};
            }
        }
        // ---------------------------------------------------------------------
        // Compute the spectrun
        // ---------------------------------------------------------------------
        cudaMemcpy(p_gpu, pflt.data(), dataSize, cudaMemcpyHostToDevice);
        std::vector<Float> spectrum = powerSpectrum(
            p_gpu, NULL, n, N, width
        );
        // ---------------------------------------------------------------------
        // Save files
        // ---------------------------------------------------------------------
        char serialNo[40];
        sprintf(serialNo, "%02dx%02d", dim1, dim2);
        String fileName = path + serialNo;
        plotSpectrum(spectrum, width, (fileName + ".png").c_str());
        //plotRadialPower(
        //    spectrum, width, (fileName + ".tex").c_str(), N, 0.5, true, true
        //);
    }
    cudaFree(p_gpu);
    clock_t t1 = clock();
    Float totalTime = (Float)(t1 - t0) / CLOCKS_PER_SEC;
    fprintf(
        stderr, "Total time = %.6fs\n", totalTime
    );
}


// =============================================================================
// Main
// =============================================================================

const char *USAGE_MESSAGE = "Usage: %s <q> [options] [[dimx dimy] ..]\n"
"Options:\n"
" -P <path>         <Output path>/<fileName prefix>; default none\n"
" -n <instances>    Default 1\n"
" -N <pointcount>   Default 1024\n"
" -O                Owen scramble the sets\n"
" -R                Randomize the matrices; default no\n"
" -V                Verify the construction, default no\n"
" -E                Ensemble\n"
;

int main(int argc,char **argv) {
    int opt;
    String path;
    int N = 1024;
    int n = 1;
    bool randomize = false;
    bool validate = false;
    bool ensemble = false;
    while ((opt = getopt(argc, argv, "P:n:N:OXRVE")) != -1) {
        switch (opt) {
            case 'P': path = optarg; break;
            case 'n': n = atoi(optarg); break;
            case 'N': N = atoi(optarg); break;
            case 'O': doOwenScramble = true; break;
            case 'X': doXorScramble = true; break;
            case 'R': randomize = true; break;
            case 'E': ensemble = true; break;
            case 'V': validate = true; break;
            default: fprintf(stderr, USAGE_MESSAGE, argv[0]); exit(1);
        }
    }
    if (optind > argc - 1) {
        fprintf(stderr, USAGE_MESSAGE, argv[0]);
        exit(1);
    }
    int q = atoi(argv[optind]);
    std::vector<int> dimsPairs;
    for (optind++; optind < argc; optind++) {
        dimsPairs.push_back(atoi(argv[optind]));
        //fprintf(stderr, "adding dim %2d to list\n", dimsPairs.back());
    }
    std::vector<Sequence> seqs;
    //SZ sz = SZ::makeNested(1, 2, ensemble, randomize);
    Sequence szTeaser = Sequence(16, SZTeaserFigColMatrices);
    for (int i = 0; i < n; i++) {
        //seqs.push_back(SZ(q, randomize, validate));
        //seqs.push_back(SZ::makeNested(1, 2, ensemble, randomize));
        //seqs.push_back(sz);
        seqs.push_back(szTeaser);
        //seqs.push_back(Sequence::Sobol());
    }
    //plotSpectra(seqs, {16}, N, path);
    plotSpectra(seqs, dimsPairs, N, path);
//     SZ sz(q);
//     Points p(N);
//     for (int i = 0; i < N; i++) {
//         p[i] = {uint32_t(sz[0][i]), uint32_t(sz[1][i])};
//     }
//     owenScramble(p);
//     printTxt(p, path.c_str());
}
