#include "mex.h"
#include "matrix.h"
#include <stdint.h>
#include <math.h>
#include <string.h>
#include <time.h>  
#include <intrin.h>   
#include <stdio.h>

//mwSizejĶתΪָĽ
mwSize binary_to_quaternary_mwsize(mwSize j) {
    if (j == 0) {
        return 0; 
    }

    unsigned long highest_bit;
#ifdef _WIN64
    unsigned long leading_zeros = __lzcnt64(j);
    highest_bit = 63 - leading_zeros; 
#else
    unsigned long leading_zeros = __lzcnt(j);
    highest_bit = 31 - leading_zeros; 
#endif

    mwSize result = 0;
    for (unsigned long k = highest_bit; k != (unsigned long)-1; --k) {
        mwSize bit = (j >> k) & 1;
        result = (result << 2) + bit;
    }

    return result;
}

// Ķڻ
int binary_vector_dot_mwsize(mwSize i, mwSize j) {
    mwSize common_ones = i & j;
    
#ifdef _WIN64
    return _mm_popcnt_u64((uint64_t)common_ones);
#else
    return _mm_popcnt_u32((uint32_t)common_ones);
#endif
}

// nǷΪ2ݴ
static int is_pow2(int n) {
    return (n > 0) && ((n & (n - 1)) == 0);
}

// Hadamard任ԭزʵʵ֣踴FHTɺչ
static void fht(double *x, int n) {
    for (int m = 1; m < n; m *= 2) {
        for (int i = 0; i < n; i += 2 * m) {
            for (int j = 0; j < m; j++) {
                double a = x[i + j];
                double b = x[i + j + m];
                x[i + j] = a + b;
                x[i + j + m] = a - b;
            }
        }
    }
}

// static void get_bandwidth(const mxArray *mat, int *Lb, int *Ub) {
//     mwSize m = mxGetM(mat);
//     mwSize n = mxGetN(mat);
//     *Lb = m;  
//     *Ub = 0;  
// 
//     if (mxIsSparse(mat)) {
//         mwIndex *irs = mxGetIr(mat);
//         mwIndex *jcs = mxGetJc(mat);
//         mwIndex nnz = mxGetNzmax(mat);
// 
//         for (mwIndex j = 0; j < n; j++) {
//             for (mwIndex k = jcs[j]; k < jcs[j+1]; k++) {
//                 mwIndex i = irs[k];
//                 int di = i - j;
//                 if (di < *Lb) *Lb = di;
//                 if (di > *Ub) *Ub = di;
//             }
//         }
//     } else {
//         const double *pr = mxGetPr(mat);
//         for (mwIndex j = 0; j < n; j++) {
//             for (mwIndex i = 0; i < m; i++) {
//                 if (fabs(pr[j*m + i]) > 1e-13) {
//                     int di = i - j;
//                     if (di < *Lb) *Lb = di;
//                     if (di > *Ub) *Ub = di;
//                 }
//             }
//         }
//     }
// }

static void get_bandwidth(const mxArray *mat, int *Lb, int *Ub) {
    mwSize m = mxGetM(mat);  // ȡ
    mwSize n = mxGetN(mat);  // ȡ
    *Lb = m;  // ʼΪܵĸֵΪʼޣ
    *Ub = 0;  // ʼΪСֵܵ

    // ֱܾӻȡָ
    const double *pr = mxGetPr(mat);
    
    // һеһԪأi=0, j=n-1
    if (n > 0) {  // ȷ
        mwIndex last_col_first_row_idx = (n-1)*m + 0;  // һһԪ
        if (fabs(pr[last_col_first_row_idx]) > 1e-13) {
            *Ub = m - 1;  // ֱϴΪm-1
        }
    }
    
    // һеһԪأi=m-1, j=0
    if (m > 0) {  // ȷ
        mwIndex last_row_first_col_idx = 0*m + (m-1);  // һеһԪ
        if (fabs(pr[last_row_first_col_idx]) > 1e-13) {
            *Lb = n - 1;  // ֱ´Ϊn-1
        }
    }
    
    // 㣬ֱӷ
    if (*Ub == m - 1 && *Lb == n - 1) {
        return;
    }
    
    // Ԫؼ
    for (mwIndex j = 0; j < n; j++) {
        // ϴȷΪֵеļ
        if (*Ub == m - 1 && j == n - 1) {
            continue;
        }
        
        for (mwIndex i = 0; i < m; i++) {
            // ´ȷΪֵеļ
            if (*Lb == n - 1 && i == m - 1) {
                continue;
            }
            
            // ԪǷΪ㣨Ǹ㾫
            if (fabs(pr[j*m + i]) > 1e-13) {
                int di = i - j;  // Ĳֵ
                if (di < *Lb) {
                    *Lb = di;  // ´Сֵ
                }
                if (di > *Ub) {
                    *Ub = di;  // ϴֵ
                }
            }
        }
    }
}



// ʮתĽ
static void t2f(const uint32_t *D, mwSize m, mwSize n, uint8_t *M) {
    uint8_t shifts[16];
    for (mwSize j = 0; j < n; j++) {
        shifts[j] = 2 * (n - 1 - j);
    }
    
    for (mwSize j = 0; j < n; j++) {
        uint8_t shift = shifts[j];
        for (mwSize i = 0; i < m; i++) {
            M[j * m + i] = (uint8_t)((D[i] >> shift) & 0x3);
        }
    }
}

// ʱ룩ĸ
static double get_elapsed_ms(clock_t start, clock_t end) {
    return ((double)(end - start) / CLOCKS_PER_SEC) * 1000.0;
}

// MEXں
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
    // 
    if (nrhs != 2) {
        mexErrMsgIdAndTxt("sppaolidec_w1:inputCount", "Ҫ2: Kttype");
    }
    if (nlhs != 2) {
        mexErrMsgIdAndTxt("sppaolidec_w1:outputCount", "Ҫ2: aDD");
    }
    
    if (mxIsComplex(prhs[0])) {
        mexErrMsgIdAndTxt("sppaolidec_w1:complexInput", "ֻ֧ʵ");
    }
    
    const mxArray *Kt = prhs[0];
    mwSize N = mxGetM(Kt);
    mwSize N_cols = mxGetN(Kt);
    if (N_cols != N) {
        mexErrMsgIdAndTxt("sppaolidec_w1:notSquare", "Ƿ");
    }
    
    int type = (int)mxGetScalar(prhs[1]);
    if (type < 0 || type > 1) {
        mexErrMsgIdAndTxt("sppaolidec_w1:invalidType", "typeΪ01");
    }

    // n=log2(N)
    int n = 0;
    mwSize temp = N;
    while (temp > 1) {
        temp >>= 1;
        n++;
    }
    if ((1 << n) != N) {
        mexErrMsgIdAndTxt("sppaolidec_w1:notPowerOf2", "2ݴ");
    }
    
    // 1. ʱ
    // clock_t t_start, t_end;  
    int Lb, Ub;
    // t_start = clock();
    get_bandwidth(Kt, &Lb, &Ub);
    // t_end = clock();
    int h = (abs(Lb) > abs(Ub)) ? abs(Lb) : abs(Ub);
    // mexPrintf("[ʱ] ʱ: %.3f\n", get_elapsed_ms(t_start, t_end));

    // 3. NS3󣨼ʱ
    // t_start = clock();
    mwSize *NS3 = (mwSize *)mxMalloc(N * N * sizeof(mwSize));
    for (mwSize i = 0; i < N; i++) {
        for (mwSize j = 0; j < N; j++) {
            NS3[i * N + j] = binary_vector_dot_mwsize(i, j);
        }
    }
    // t_end = clock();
    // mexPrintf("[ʱ] NS3ɺʱ: %.3f\n", get_elapsed_ms(t_start, t_end));

    // 4. Kn4
    mwSize *Kn4 = (mwSize *)mxMalloc(N * sizeof(mwSize));
    for (mwSize s = 0; s < N; s++) {
        Kn4[s] = binary_to_quaternary_mwsize(s);
    }

    // 5. in_k飨ʱ
    // t_start = clock();
    mwSize bi = 0;
    mwSize *in_k = (mwSize *)mxMalloc(N * sizeof(mwSize));
    
    for (int i = -1; i <= (int)n - 1; i++) {
        mwSize pow2i = (i >= 0) ? (1 << i) : 1;
        mwSize pow2j = (i >= 0) ? (1 << i) : 0;
        mwSize start_j = (pow2i > h) ? (pow2j - h + 1) : 1;
        mwSize end_j = pow2i;

        if (end_j < start_j) continue;
        
        for (mwSize j = start_j; j <= end_j; j++) {
            mwSize k = j + pow2j - 1;
            if (k >= N) continue;
            in_k[bi] = k;
            bi++;
            if (bi > N) goto end_in_k_loop;
        }
    }
end_in_k_loop:
    in_k = (mwSize *)mxRealloc(in_k, bi * sizeof(mwSize));
    // t_end = clock();
    // mexPrintf("[ʱ] in_kɺʱ: %.3f\n", get_elapsed_ms(t_start, t_end));

    // 6. DSD󣨼ʱ
    // t_start = clock();
    double *DSD = (double *)mxMalloc(N * bi * sizeof(double));
    const double *Kt_pr = mxGetPr(Kt);
    
    for (mwSize idx = 0; idx < bi; idx++) {
        for (mwSize s = 0; s < N; s++) {
            mwSize k = in_k[idx];
            mwSize id2 = (k ^ s);
            mwSize qd2 = id2 * N + s;
            DSD[idx * N + s] = Kt_pr[qd2];
        }
    }
    // t_end = clock();
    // mexPrintf("[ʱ] DSD󹹽ʱ: %.3f\n", get_elapsed_ms(t_start, t_end));

    // 7. FHT任ϵ㣨ʱ
    // t_start = clock();
    // ؼ޸1Ϊaʵ鲿ڴ棨鲿ֱӸֵɣ
    double *a_re = (double *)mxMalloc(bi * N * sizeof(double));  // aʵ
    double *a_im = (double *)mxMalloc(bi * N * sizeof(double));  // a鲿ʼΪ0չΪֵ
    uint32_t *DD = (uint32_t *)mxMalloc(bi * N * sizeof(uint32_t));
    
    // DD
    for (mwSize idx = 0; idx < bi; idx++) {
        mwSize k = in_k[idx];
        for (mwSize s = 0; s < N; s++) {
            DD[idx * N + s] = Kn4[s] + 2 * Kn4[k];
        }
    }
    
    // FHTaʵԭ߼鲿Ϊ0ɸʵ޸ģ
    for (mwSize idx = 0; idx < bi; idx++) {
        double *row = &DSD[idx * N];
        fht(row, N);  // ԭʵFHT踴FHTչú
        
        for (mwSize s = 0; s < N; s++) {
            int exponent = (int)NS3[in_k[idx] * N + s];
            int sign = 1 - 2 * ((exponent >> 1) & 1);
            int y = (exponent & 1); 
            a_re[idx * N + s] =  row[s] * (sign * (1-y)) / N;  // ʵ
            a_im[idx * N + s] = row[s] * (sign * y) / N;      // 鲿ʼΪ0滻Ϊʵʼ߼
        }
    }
    // t_end = clock();
    // mexPrintf("[ʱ] FHT任ϵʱ: %.3f\n", get_elapsed_ms(t_start, t_end));

    // 8. ɸѡԪز临a
    // ؼ޸2жϣģƽ⿪ȶӦ1e-13
    mwSize valid_count = 0;
    for (mwSize i = 0; i < bi * N; i++) {
        if (a_re[i]*a_re[i] + a_im[i]*a_im[i] >= 1e-26) {  // ģ1e-13Ϊ
            valid_count++;
        }
    }

    // ؼ޸3mxCOMPLEX־
    plhs[0] = mxCreateDoubleMatrix(1, valid_count, mxCOMPLEX);
    double *a_out_re = mxGetPr(plhs[0]);  // aʵָ
    double *a_out_im = mxGetPi(plhs[0]);  // a鲿ָ루MATLABԶڴ棩
    
    mxArray *DD_out;
    if (type == 0) {
        DD_out = mxCreateNumericMatrix(n, valid_count, mxUINT8_CLASS, mxREAL);
    } else {
        DD_out = mxCreateNumericMatrix(1, valid_count, mxUINT32_CLASS, mxREAL);
    }
    plhs[1] = DD_out;
    
    mwSize ptr = 0;
    for (mwSize i = 0; i < bi * N; i++) {
        if (a_re[i]*a_re[i] + a_im[i]*a_im[i] >= 1e-26) {
            // ʵ鲿
            a_out_re[ptr] = a_re[i];
            a_out_im[ptr] = a_im[i];
            
            if (type == 0) {
                uint8_t *dd_ptr = (uint8_t *)mxGetData(DD_out);
                t2f(&DD[i], 1, n, &dd_ptr[ptr * n]);
            } else {
                uint32_t *dd_ptr = (uint32_t *)mxGetData(DD_out);
                dd_ptr[ptr] = DD[i];
            }
            ptr++;
        }
    }

    // ͷڴ棨a_rea_im
    mxFree(NS3);
    mxFree(Kn4);
    mxFree(in_k);
    mxFree(DSD);
    mxFree(a_re);  // ͷʵڴ
    mxFree(a_im);  // ͷ鲿ڴ
    mxFree(DD);
}