/*
 * Discrepancy assessment
 *
 */

//#define USE_DOUBLE

#include "sz.h"
#include <vector>
#include <string>
#include <getopt.h>

// =============================================================================
// Gloabal parameters
// =============================================================================

bool doOwenScramble = false;

// =============================================================================
// Compute discrepancy of a given sequence
// =============================================================================

double discrepancy2D(int N, const MatrixR &C0, const MatrixR &C1) {
    int m = __builtin_ctz(N);
    int dmax = 0;
    int shift = BITs - m;
    for (int i = 0; i < N; i++) {                                               // Iterate through the points to obtain list of Y's
        int YRef = C1[i] >> shift;                                              // This should be an m-bit number
        for (int j = 0; j < N; j++) {
            int XRef = C0[j] >> shift;
            int vol = YRef * XRef;
            int in(0), on(0);
            //fprintf(stderr, "%d, %d\n", XRef, YRef);
            for (int seqNo = 0; seqNo < N; seqNo++) {
                int X = C0[seqNo] >> shift, Y = C1[seqNo] >> shift;
                if (X < XRef && Y < YRef) {
                    ++in;
                }
                if ((X == XRef && Y <= YRef) || (Y == YRef && X <= XRef)) {
                    ++on;
                }
            }
            int d = std::max(
                std::abs(in * N - vol), std::abs((in+on) * N - vol)
            );
            dmax = std::max(dmax, d);
        }
    }
    return dmax / double(N * N);
}

/*template<int dims>
double discrepancy(
    int N, const std::vector<MatrixR> &C
) {
    int m = __builtin_ctz(N);
    if (dims * m > 32) {
        fprintf(stderr, "Too large space\n");
        exit(1);
    }
    int dmax = 0;
    int shift = BITs - m;
    uint64_t gridIndexMax = 1llu << (dims * m);
    const int mask = N - 1;
    const int scl = 1 << (m * (dims - 1));
    int boundary[dims];
    for (uint64_t gridIndex = 0; gridIndex < gridIndexMax; gridIndex++) {       // Iterate through grid points
        int vol = 1;                                                            // Bounded volume
        for (int dim = 0; dim < dims; dim++) {                                  // Iterate through dimensions to define boundary
            boundary[dim] = C[dim][(gridIndex >> (dim * m)) & mask] >> shift;   // Translate gridIndex in this dimension to an actual coordinate of a point
            vol *= boundary[dim];                                               // Include dimension in volume
        }
        int in(0), on(0);                                                       // Number of points in or on edges
        for (int i = 0; i < N; i++) {                                           // Iterate through points to test their inclusion wrt box
            int isOn(0), isOut(0);
            for (int dim = 0; dim < dims; dim++) {
                int X = C[dim][i] >> shift;
                isOn |= (X == boundary[dim]) & 1;
                isOut |= (X > boundary[dim]) & 1;
            }
            in += (isOut | isOn) ^ 1;                                           // Not on or outside any edge
            on += isOn & (isOut ^ 1);                                           // On one an edge but not outside another
        }
        int d = std::max(
            std::abs(in * scl - vol), std::abs((in+on) * scl - vol)
        );
            dmax = std::max(dmax, d);
    }
    return dmax / double(gridIndexMax);
}//*/

template<int dims>
double discrepancy(
    int N, const std::vector<MatrixR> &C
) {
    int m = __builtin_ctz(N);
    if (dims * m > 32) {
        fprintf(stderr, "Too large space\n");
        exit(1);
    }
    int64_t dmax = 0;
    int shift = BITs - m;
    int64_t gridIndexMax = 1llu << (dims * m);
    const int mask = N - 1;
    const int64_t scl = 1 << (m * (dims - 1));
    int boundary[dims];
    std::vector<uint32_t> X(dims * N);
    for (int dim = 0; dim < dims; dim++) {
        for (int i = 0; i < N; i++) {
            X[dim * N + i] = C[dim][i];
        }
        if (doOwenScramble) { owenScramble(&X[dim * N], N); };
    }
    for (int i = 0; i < X.size(); i++) X[i] >>= shift;
    for (int64_t gridIndex = 0; gridIndex < gridIndexMax; gridIndex++) {       // Iterate through grid points
        int64_t vol = 1;                                                       // Bounded volume
        for (int dim = 0; dim < dims; dim++) {                                  // Iterate through dimensions to define boundary
            boundary[dim] = X[dim * N + ((gridIndex >> (dim * m)) & mask)];     // Translate gridIndex in this dimension to an actual coordinate of a point
            vol *= boundary[dim];                                               // Include dimension in volume
        }
        int64_t in(0), on(0);                                                       // Number of points in or on edges
        for (int i = 0; i < N; i++) {                                           // Iterate through points to test their inclusion wrt box
            int isOn(0), isOut(0);
            for (int dim = 0; dim < dims; dim++) {
                int Xi = X[dim * N + i];
                isOn |= (Xi == boundary[dim]) & 1;
                isOut |= (Xi > boundary[dim]) & 1;
            }
            in += (isOut | isOn) ^ 1;                                           // Not on or outside any edge
            on += isOn & (isOut ^ 1);                                           // On one an edge but not outside another
        }
        int64_t d = std::max(
            std::abs(in * scl - vol), std::abs((in+on) * scl - vol)
        );
        dmax = std::max(dmax, d);
    }
    return dmax / double(gridIndexMax);
}

inline double getDiscrepancy(int dims, int N, const std::vector<MatrixR> &C) {
    switch (dims) {
        case 2: return discrepancy<2>(N, C);
        case 3: return discrepancy<3>(N, C);
        case 4: return discrepancy<4>(N, C);
        case 5: return discrepancy<5>(N, C);
        case 6: return discrepancy<6>(N, C);
        case 7: return discrepancy<7>(N, C);
        case 8: return discrepancy<8>(N, C);
    }
}

// =============================================================================
// Main
// =============================================================================

const char *USAGE_MESSAGE = "Usage: %s \n"
"Options:\n"
" -N <nPoints>      Default 256\n"
" -D <dimensions    Default 2\n"
" -R                Randomize the matrices; default no\n"
" -V                Verify the construction, default no\n"
" -E                Ensemble, default is nest only\n"
" -n <repetitions>  Default 1"
;

int main(int argc,char **argv) {
    clock_t t0 = clock();
    int opt;
    int N = 256;
    int dims = 2;
    bool randomize = false;
    bool validate = false;
    bool ensemble = false;
    int n = 1;
    int q = 2;
    while ((opt = getopt(argc, argv, "N:n:D:q:RVEO")) != -1) {
        switch (opt) {
            case 'N': N = atoi(optarg); break;
            case 'D': dims = atoi(optarg); break;
            case 'R': randomize = true; break;
            case 'V': validate = true; break;
            case 'E': ensemble = true; break;
            case 'q': q = atoi(optarg); break;
            case 'O': doOwenScramble = true; break;
            case 'n': n = atoi(optarg); break;
            default: fprintf(stderr, USAGE_MESSAGE, argv[0]); exit(1);
        }
    }
    if (optind > argc - 0) {
        fprintf(stderr, USAGE_MESSAGE, argv[0]);
        exit(1);
    }
    printf("Using %d points in %d dims for Sobol, Sz4D, SzN\n", N, dims);
    for (int i = 0; i < n; i++) {
        Sequence Sobol = Sequence::Sobol();
        SZ sz4D(2, randomize, validate);
        SZ szN = SZ::makeNested(1, 2, ensemble, randomize);
        printf(
            "%f %f %f\n",
            getDiscrepancy(dims, N, Sobol),
            getDiscrepancy(dims, N, sz4D),
            getDiscrepancy(dims, N, szN)
        );
    }
    clock_t t1 = clock();
    Float totalTime = (Float)(t1 - t0) / CLOCKS_PER_SEC;
    fprintf(
        stderr, "Total time = %.6fs\n", totalTime
    );
}
