#include "mex.h"
#include "matrix.h"
#include <stdint.h>
#include <math.h>
#include <string.h>
#include <time.h>  // ׼Cʱͷļ
#include <intrin.h>   // VSӲλ
#include <stdio.h>

//mwSizejĶתΪָĽ
mwSize binary_to_quaternary_mwsize(mwSize j) {
    if (j == 0) {
        return 0; // 0תΪ0
    }

    // 1ٶλλƽ̨λѡָ
    unsigned long highest_bit;
#ifdef _WIN64
    // 64λƽ̨mwSize64λʹ64λǰ
    unsigned long leading_zeros = __lzcnt64(j);
    highest_bit = 63 - leading_zeros; // 64λΪ63
#else
    // 32λƽ̨mwSize32λʹ32λǰ
    unsigned long leading_zeros = __lzcnt(j);
    highest_bit = 31 - leading_zeros; // 32λΪ31
#endif

    // 2λλĽƽλ˷
    mwSize result = 0;
    for (unsigned long k = highest_bit; k != (unsigned long)-1; --k) {
        // ȡkλ01
        mwSize bit = (j >> k) & 1;
        // ȼ result = result * 4 + bit2λ4
        result = (result << 2) + bit;
    }

    return result;
}

// Ķڻ
int binary_vector_dot_mwsize(mwSize i, mwSize j) {
    // 1. λӦλΪ1λ
    mwSize common_ones = i & j;
    
    // 2. ƽ̨λѡӦӲٺ
#ifdef _WIN64
    // 64λMATLABmwSize64λʹ64λλָ
    return _mm_popcnt_u64((uint64_t)common_ones);
#else
    // 32λMATLABmwSize32λʹ32λλָ
    return _mm_popcnt_u32((uint32_t)common_ones);
#endif
}

// nǷΪ2ݴ
static int is_pow2(int n) {
    return (n > 0) && ((n & (n - 1)) == 0);
}

// Hadamard任ԭز
static void fht(double *x, int n) {
    for (int m = 1; m < n; m *= 2) {
        for (int i = 0; i < n; i += 2 * m) {
            for (int j = 0; j < m; j++) {
                double a = x[i + j];
                double b = x[i + j + m];
                x[i + j] = a + b;
                x[i + j + m] = a - b;
            }
        }
    }
}

// static void get_bandwidth(const mxArray *mat, int *Lb, int *Ub) {
//     mwSize m = mxGetM(mat);
//     mwSize n = mxGetN(mat);
//     *Lb = m;  // ʼΪܵĸֵ
//     *Ub = 0;  // ʼΪСֵܵ
// 
//     if (mxIsSparse(mat)) {
//         mwIndex *irs = mxGetIr(mat);
//         mwIndex *jcs = mxGetJc(mat);
//         mwIndex nnz = mxGetNzmax(mat);
// 
//         for (mwIndex j = 0; j < n; j++) {
//             for (mwIndex k = jcs[j]; k < jcs[j+1]; k++) {
//                 mwIndex i = irs[k];
//                 int di = i - j;
//                 if (di < *Lb) *Lb = di;
//                 if (di > *Ub) *Ub = di;
//             }
//         }
//     } else {
//         const double *pr = mxGetPr(mat);
//         for (mwIndex j = 0; j < n; j++) {
//             for (mwIndex i = 0; i < m; i++) {
//                 if (fabs(pr[j*m + i]) > 1e-13) {
//                     int di = i - j;
//                     if (di < *Lb) *Lb = di;
//                     if (di > *Ub) *Ub = di;
//                 }
//             }
//         }
//     }
// }


static void get_bandwidth(const mxArray *mat, int *Lb, int *Ub) {
    mwSize m = mxGetM(mat);  // ȡ
    mwSize n = mxGetN(mat);  // ȡ

    *Ub = 0;  // ʼϴΪ0

    // ֱֻܾӻȡָ
    const double *pr = mxGetPr(mat);
    
    for (mwIndex j = 0; j < n; j++) {
        // һеǰǷзԪ
        if (fabs(pr[j*m + (m-1)]) > 1e-13) {
            *Ub = m - 1 - j;
            *Lb = *Ub;  // Գƾ´ϴ
            return;
        }

        // 鵱ǰезԪأϴ
        for (mwIndex i = 0; i < m; i++) {
            if (fabs(pr[j*m + i]) > 1e-13) {
                int di = i - j;
                if (di > *Ub) {
                    *Ub = di;
                }
            }
        }
    }

    // ڶԳƾ´ϴ
    *Lb = *Ub;
}

// ʮתĽ
static void t2f(const uint32_t *D, mwSize m, mwSize n, uint8_t *M) {
    uint8_t shifts[16];
    for (mwSize j = 0; j < n; j++) {
        shifts[j] = 2 * (n - 1 - j);
    }
    
    for (mwSize j = 0; j < n; j++) {
        uint8_t shift = shifts[j];
        for (mwSize i = 0; i < m; i++) {
            M[j * m + i] = (uint8_t)((D[i] >> shift) & 0x3);
        }
    }
}

// ʱ룩ĸ
static double get_elapsed_ms(clock_t start, clock_t end) {
    return ((double)(end - start) / CLOCKS_PER_SEC) * 1000.0;
}

// MEXں
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
    // 
    if (nrhs != 2) {
        mexErrMsgIdAndTxt("sppaolidec_w1:inputCount", "Ҫ2: Kttype");
    }
    if (nlhs != 2) {
        mexErrMsgIdAndTxt("sppaolidec_w1:outputCount", "Ҫ2: aDD");
    }
    
    if (mxIsComplex(prhs[0])) {
        mexErrMsgIdAndTxt("sppaolidec_w1:complexInput", "ֻ֧ʵ");
    }
    
    const mxArray *Kt = prhs[0];
    mwSize N = mxGetM(Kt);
    mwSize N_cols = mxGetN(Kt);
    if (N_cols != N) {
        mexErrMsgIdAndTxt("sppaolidec_w1:notSquare", "Ƿ");
    }
    
    int type = (int)mxGetScalar(prhs[1]);
    if (type < 0 || type > 1) {
        mexErrMsgIdAndTxt("sppaolidec_w1:invalidType", "typeΪ01");
    }

    // n=log2(N)
    int n = 0;
    mwSize temp = N;
    while (temp > 1) {
        temp >>= 1;
        n++;
    }
    if ((1 << n) != N) {
        mexErrMsgIdAndTxt("sppaolidec_w1:notPowerOf2", "2ݴ");
    }
    
    // 1. ʱ
    // clock_t t_start, t_end;  // ʹñ׼Cclock_t
    int Lb, Ub;
    // t_start = clock();
    get_bandwidth(Kt, &Lb, &Ub);
    // t_end = clock();
    int h = (abs(Lb) > abs(Ub)) ? abs(Lb) : abs(Ub);
    // mexPrintf("[ʱ] ʱ: %.3f\n", get_elapsed_ms(t_start, t_end));

    // 2. K
    // uint8_t *K = (uint8_t *)mxMalloc(N * n * sizeof(uint8_t));
    // for (mwSize s = 0; s < N; s++) {
    //     for (int i = 0; i < n; i++) {
    //         int bit_pos = n - 1 - i;
    //         K[s * n + i] = (s >> bit_pos) & 0x1;   //ȡ
    //     }
    // }

    // 3. NS3󣨼ʱ
    // t_start = clock();
    mwSize *NS3 = (mwSize *)mxMalloc(N * N * sizeof(mwSize));
    for (mwSize i = 0; i < N; i++) {
        for (mwSize j = 0; j < N; j++) {
            // int sum = 0;
            // for (int k = 0; k < n; k++) {
            //     sum += K[i * n + k] * K[j * n + k];
            // }
            NS3[i * N + j] = binary_vector_dot_mwsize(i, j);
        }
    }
    // t_end = clock();
    // mexPrintf("[ʱ] NS3ɺʱ: %.3f\n", get_elapsed_ms(t_start, t_end));

    // 4. n4Kn4
    // uint32_t *n4 = (uint32_t *)mxMalloc(n * sizeof(uint32_t));
    // for (int i = 0; i < n; i++) {
    //     n4[i] = 1U << (2*(n-1-i));
    // }

    mwSize *Kn4 = (mwSize *)mxMalloc(N * sizeof(mwSize));
    for (mwSize s = 0; s < N; s++) {
        Kn4[s] = binary_to_quaternary_mwsize(s);
        // Kn4[s] = 0;
        // for (int i = 0; i < n; i++) {
        //     Kn4[s] += K[s * n + i] * n4[i];
        // }
    }

    // 5. in_k飨ʱ
    // t_start = clock();
    mwSize bi = 0;
    mwSize *in_k = (mwSize *)mxMalloc(N * sizeof(mwSize));
    
    for (int i = -1; i <= (int)n - 1; i++) {
        mwSize pow2i = (i >= 0) ? (1 << i) : 1;
        mwSize pow2j = (i >= 0) ? (1 << i) : 0;
        mwSize start_j = (pow2i > h) ? (pow2j - h + 1) : 1;
        mwSize end_j = pow2i;

        if (end_j < start_j) continue;
        
        for (mwSize j = start_j; j <= end_j; j++) {
            mwSize k = j + pow2j - 1;
            if (k >= N) continue;
            in_k[bi] = k;
            bi++;
            if (bi > N) goto end_in_k_loop;
        }
    }
end_in_k_loop:
    in_k = (mwSize *)mxRealloc(in_k, bi * sizeof(mwSize));
    // t_end = clock();
    // mexPrintf("[ʱ] in_kɺʱ: %.3f\n", get_elapsed_ms(t_start, t_end));

    // 6. DSD󣨼ʱ
    // t_start = clock();
    double *DSD = (double *)mxMalloc(N * bi * sizeof(double));
    const double *Kt_pr = mxGetPr(Kt);
    
    for (mwSize idx = 0; idx < bi; idx++) {
        for (mwSize s = 0; s < N; s++) {

            mwSize k = in_k[idx];
            mwSize id2 = (k ^ s);
            mwSize qd2 = id2 * N + s;
            DSD[idx * N + s] = Kt_pr[qd2];
        }

    }
    // t_end = clock();
    // mexPrintf("[ʱ] DSD󹹽ʱ: %.3f\n", get_elapsed_ms(t_start, t_end));
// 
    // 7. FHT任ϵ㣨ʱ
    // t_start = clock();
    double *a = (double *)mxMalloc(bi * N * sizeof(double));
    uint32_t *DD = (uint32_t *)mxMalloc(bi * N * sizeof(uint32_t));
    
    // DD
    for (mwSize idx = 0; idx < bi; idx++) {
        mwSize k = in_k[idx];
        for (mwSize s = 0; s < N; s++) {
            DD[idx * N + s] = Kn4[s] + 2 * Kn4[k];
        }
    }
    
    // FHTa
    for (mwSize idx = 0; idx < bi; idx++) {
        double *row = &DSD[idx * N];
        fht(row, N);
        
        for (mwSize s = 0; s < N; s++) {
            int exponent = (int)NS3[in_k[idx] * N + s];
            int sign = 1 - 2 * ((exponent >> 1) & 1);
            a[idx * N + s] = row[s] * sign / N;
        }
    }
    // t_end = clock();
    // mexPrintf("[ʱ] FHT任ϵʱ: %.3f\n", get_elapsed_ms(t_start, t_end));

    // 8. ɸѡԪز
    mwSize valid_count = 0;
    for (mwSize i = 0; i < bi * N; i++) {
        if (fabs(a[i]) >= 1e-13) valid_count++;
    }

    plhs[0] = mxCreateDoubleMatrix(1, valid_count, mxREAL);
    double *a_out = mxGetPr(plhs[0]);
    
    mxArray *DD_out;
    if (type == 0) {
        DD_out = mxCreateNumericMatrix(n, valid_count, mxUINT8_CLASS, mxREAL);
    } else {
        DD_out = mxCreateNumericMatrix(1, valid_count, mxUINT32_CLASS, mxREAL);
    }
    plhs[1] = DD_out;
    
    mwSize ptr = 0;
    for (mwSize i = 0; i < bi * N; i++) {
        if (fabs(a[i]) >= 1e-13) {
            a_out[ptr] = a[i];
            
            if (type == 0) {
                uint8_t *dd_ptr = (uint8_t *)mxGetData(DD_out);
                t2f(&DD[i], 1, n, &dd_ptr[ptr * n]);
            } else {
                uint32_t *dd_ptr = (uint32_t *)mxGetData(DD_out);
                dd_ptr[ptr] = DD[i];
            }
            ptr++;
        }
    }

    // ͷڴ
    // mxFree(K);
    mxFree(NS3);
    // mxFree(n4);
    mxFree(Kn4);
    mxFree(in_k);
    mxFree(DSD);
    mxFree(a);
    mxFree(DD);
}
