/* devoicecaptcha.c, jochem, http://vorm.net/captchas */

#define PI 3.141592653
#define EXP 2.718281828

#include<string.h>
#include<stdlib.h>
#include<stdio.h>
#include<math.h>
#include<fftw3.h>

typedef struct {
    int samplerate;
    int byterate;
    int winsize;
    int band_cnt;
    int word_length;
    int word_overlap;
    int threshold_energy;
    int file_offset;
    char trainfile[255];
} configtype;

static configtype msn    = {  8000, 1, 256, 6, 10, 8, 8000, 46, "msn.txt"};
static configtype google = {  8000, 2, 512, 12, 8, 14, 2000000, 25500, "google.txt"};

/* Hamming window */
void hamming(double* data, int winsz)
{
    for(int i=0; i < winsz; ++i)
    *(data+i) = (0.54-0.46*cos(2.0*PI*i/winsz))*(*(data+i));
}

/* Precalc 'semi-mel' band boundaries in samples */
void setup_mel(int* bands, int cnt, int ws, int fs)
{
    double max_mel = 1127.01048*log(1.0+(fs/700.0));
    for(int i=1; i <= cnt; ++i)
    *(bands+i-1) = (int) (700*ws*(pow(EXP,((i*max_mel)/cnt)/1127.01)-1))/fs;
}

/* Sum over band */
void sum_over_bands(double* data, int* bands, double *sum, int cnt)
{
    int j = 1;
    for(int i=0; i < cnt; i++) {
    *(sum+i) = 0;
    while(j < *(bands+i)) {
        *(sum+i) += (double) abs(*(data+j));
        j++;
    }
    /* Normalize */
    if (i == 0) *(sum+i) = *(sum+i)/(*(bands+i));
    else *(sum+i) = *(sum+i)/(*(bands+i)/(*(bands+i-1)));
    }
}

/* Simple compare for qsort */
int compare (const void * a, const void * b) { return ( *(int*)a - *(int*)b ); }

/* Naive peak detection, by comparing energy with threshold, max first */
/* using only info from band_idx */
int detect_peaks(double* values, int values_cnt, int** peaks, int width,
                 int threshold, int band_idx, int band_cnt)
{
    int detected = 0;
    int max = 0;
    int max_loc;
    /* Copy only band_idx data to prevent distorting values */
    double* band = (double*) malloc(values_cnt*sizeof(double));
    if (band == NULL) {
        perror("Out of mem\n");
        exit(1);
    }
    for(int i=0; i < values_cnt; i++)
        *(band+i) = *(values+i*band_cnt+band_idx);
    while(1) {
        for (int i=0; i < values_cnt; i++) {
            if (*(band+i) > max) {
                max = *(band+i);
                max_loc =i;
            }
        }
        if (max < threshold) break;
        max = 0;
        detected++;
        *peaks = realloc(*peaks, sizeof(int)*detected);
        if (*peaks == NULL) {
            perror("Out of mem\n");
            exit(1);
        }
        *(*peaks+detected-1) = max_loc;
        for(int j=max_loc-width/2; j < max_loc+width/2; j++)
            if (j >= 0 && j < values_cnt)
                *(band+j) = 0;
    }
    qsort(*peaks, detected, sizeof(int), compare);
    free(band);
    return detected;
}

/* Basic usage information */
void show_usage(char *s)
{
    printf("Start with %s name.wav\n"
    "to guess the words in name.wav.\n\n"
    "Start with %s name.wav --print\n"
    "to print a profile for name.wav thatcan be used as "
    "training data\n", s,s);
    exit(1);
}

/* Lazy open wav: - no header parsing */
void sniff_wavfile(char *s, configtype **cfg)
{
    FILE *fp = fopen(s, "rb");
    if (fp == NULL) {
        perror("Could not open input file\n");
        exit(1);
    }
    unsigned char buf[50];
    if (! fread(buf, sizeof(buf), 1, fp)) {
        perror("File to small\n");
        exit(1);
    }
    if (buf[34] == 8)
        *cfg = &msn;
    else if (buf[34] == 16)
        *cfg = &google;
    else {
        perror("Unknown input file\n");
        exit(1);
    }
    fclose(fp);
}

/* Read file, window, dft and calc params */
int read_blocks(char* in, double** params, configtype *cfg)
{
    int samples_read=0;
    int blocks_read=0;

    FILE *fp = fopen(in, "rb");
    if (fp == NULL) {
        perror("Could not open input file\n");
        exit(1);
    }
    fseek(fp, cfg->file_offset, SEEK_SET);
    /* Setup fftw and mel */
    double* audio_t = (double*) malloc(cfg->winsize*sizeof(double));
    double* audio_f = (double*) malloc(cfg->winsize*sizeof(double));
    int *bands = (int*) malloc(cfg->band_cnt*sizeof(int));
    if (audio_t == NULL || audio_f == NULL || bands == NULL) {
        perror("Out of mem\n");
        exit(1);
    }
    fftw_plan p = fftw_plan_r2r_1d(cfg->winsize, audio_t, audio_f, FFTW_R2HC,
                                   FFTW_ESTIMATE);
    setup_mel(bands,cfg->band_cnt,cfg->winsize,cfg->samplerate);

    /* Loop over file */
    short buf;
    while(! feof(fp)) {
        if (cfg->byterate == 1) {
            buf = fgetc(fp);
            *(audio_t+(samples_read++ % cfg->winsize)) = (double) buf;
        } else if (cfg->byterate == 2) {
            if (fread(&buf,1, sizeof(short), fp) < sizeof(short))
                break;
            *(audio_t+(samples_read++ % cfg->winsize)) = (double) buf;
        }
        if((samples_read % cfg->winsize) == 0) {
            *params = (double*) realloc(*params, (blocks_read+1)*
                                        cfg->band_cnt*sizeof(double));
            if (*params == NULL) {
                perror("Out of mem\n");
                exit(1);
            }
            hamming(audio_t,cfg->winsize);
            fftw_execute(p);
            sum_over_bands(audio_f, bands, (*(params)+blocks_read*
                           cfg->band_cnt), cfg->band_cnt);
            blocks_read++;
        }
    } /* todo (perhaps ;-)) zeropad and handle last block */

    /* Free/close */
    fclose(fp);
    free(audio_t);
    free(audio_f);
    free(bands);
    fftw_destroy_plan(p);
    return blocks_read;
}

/* Load trained set (messy) */
int load_trained(char* filename, int** profiles, int** words, int n_bands)
{
    int digit;
    int n_digits = 0;
    char buf[255];
    int lines = 0;
    FILE* trained = fopen(filename, "r");
    if (trained == NULL) {
        perror("Could not open trained set\n");
        exit(1);
    } else {
        while(fgets(buf, 255, trained) != NULL) {
            if (strstr(buf, ":") != NULL) {
                if (sscanf(buf, "%d",&digit)) {
                    n_digits++;
                    *words = (int*) realloc(*words,n_digits*sizeof(int));
                    if (*words == NULL) {
                        perror("Out of mem\n");
                        exit(1);
                    }
                    *(*words+n_digits-1) = digit;
                }
            } else {
                lines++;
                *profiles = (int*) realloc(*profiles,lines*n_bands*sizeof(int));
                if (*profiles == NULL) {
                    perror("Out of mem\n");
                    exit(1);
                }
                char* t = strtok(buf, " ");
                int cnt = 0;
                while(t!= NULL) {
                    if (++cnt > n_bands) break;
                    *(*profiles+cnt+(lines-1)*n_bands-1) = atoi(t);
                    t = strtok(NULL, " ");
                }
                if (cnt < n_bands+1) {
                    perror("Not enough columns in input\n");
                    exit(1);
                }
            }
        }
        fclose(trained);
    }
    return n_digits;
}

/* Get params around peaks only */
void separate_word_data(int* p, double* all, int* peaks, int n, int word_cnt,
                        configtype* cfg)
{
    int *p_ptr = p;
    for(int i=0; i < word_cnt; i++) {
        int loc = peaks[i];
        for(int i=loc-cfg->word_length/2; i < loc+cfg->word_length/2; i++) {
            if (i > 0 && i < n) {
                for(int j=0; j < cfg->band_cnt; j++) {
                    *(p_ptr) = (int) *(all+i*cfg->band_cnt+j);
                    p_ptr++;
                }
            } else if(i < 0) {
                /* fixme: this is hack for digits starting too early*/
                for(int j=0; j < cfg->band_cnt; j++) {
                    *(p_ptr) = (int) *(all+j);
                    p_ptr++;
                }
            } else {
                /* or late */
                for(int j=0; j < cfg->band_cnt; j++) {
                    *(p_ptr) = (int) *(all+j+n-1);
                    p_ptr++;
                }
            }
        }
    }
}

/* Print param profiles for seperated digits */
void print_params(int* params, int maxx, int maxy, int maxz)
{
    for(int i=0; i < maxx; i++) {
        for(int j=0; j < maxy; j++) {
            for(int k=0; k < maxz; k++) {
                printf("%d ", *(params+i*maxy*maxz+j*maxz+k));
            }
            printf("\n");
        }
        printf("\n");
    }
}

int closed_match(int* word_data, int word_idx, int* trained_data,
                  int trained_sz, configtype *cfg)
{
    int sum;
    int pos;
    int it_lowest=999999;

    int* t_band = malloc(sizeof(int)*cfg->band_cnt);
    int* t_mean = malloc(sizeof(int)*cfg->band_cnt);
    if (t_mean == NULL || t_band ==NULL) {
        perror("Out of mem\n");
        exit(1);
    }
    for (int i=0; i < trained_sz; i++) {
        memset(t_band, 0, sizeof(int)*cfg->band_cnt);
        memset(t_mean, 0, sizeof(int)*cfg->band_cnt);
        for (int j=0; j < cfg->word_length; j++) {
            for (int k=0; k < cfg->band_cnt; k++) {
                int orig = *(trained_data+i*cfg->word_length*cfg->band_cnt+
                             j*cfg->band_cnt+k);
                int test = *(word_data+j*cfg->band_cnt+k+word_idx*
                            cfg->word_length*cfg->band_cnt);
                *(t_band+k) += abs(orig-test);
                *(t_mean+k) += orig;
            }
        }
        sum = 0;
        int borked = 0;
        for (int l=0; l < cfg->band_cnt; l++) {
            if (*(t_mean+l) != 0)
                sum += (*(t_band+l)*cfg->word_length)/(*(t_mean+l));
            else
                borked = 1;
        }
        if (sum < it_lowest && ! borked) {
            it_lowest = sum;
            pos = i;
        }
    }
    free(t_mean);
    free(t_band);
    return pos;
}

int main(int argc, char** argv)
{
    if (argc != 2 && argc !=3)
        show_usage(argv[0]);

    /* Sniff filetype and set config for that type */
    configtype *cfg;
    sniff_wavfile(argv[1],&cfg);

    /* Read file and give parames per winsize blocks */
    double* params = malloc(0);
    int blocks_read = read_blocks(argv[1], &params, cfg);

    /* Calculate positions of separate words */
    int* peak_locations = (int*) malloc(0);
    int word_cnt = detect_peaks(params, blocks_read, &peak_locations, cfg->word_length+
                        cfg->word_overlap, cfg->threshold_energy,cfg->band_cnt-1,cfg->band_cnt);

    /* Only interested in data near peaks, put in words */
    int* words = (int*) malloc(word_cnt*cfg->word_length*cfg->band_cnt*
                               sizeof(int));
    if (words == NULL) {
        perror("Out of mem\n");
        exit(1);
    }
    separate_word_data(words, params, peak_locations, blocks_read, word_cnt, cfg);
    free(params);

    /* Print profile for training or lookup in trained set */
    if (argc == 2) {
        printf("Use training data from: %s\n", cfg->trainfile);
        printf("Detected %d words\n", word_cnt);
        printf("Guess is: ");

        /* Load trained datafile */
        int* trained_params = (int*) malloc(0);
        int* trained_ids = (int*) malloc(0);
        int t_size = load_trained(cfg->trainfile, &trained_params, &trained_ids,
                                  cfg->band_cnt);

        /* Now compare params with trained data per word*/
        for (int m=0; m < word_cnt; m++) {
            int match_idx = closed_match(words, m, trained_params, t_size, cfg);
            printf("%d", *(trained_ids+match_idx));
        }
        printf("\n");

        /* Freeing */
        free(trained_params);
        free(trained_ids);
    } else if (argc == 3)  {
        print_params(words, word_cnt, cfg->word_length, cfg->band_cnt);
    } else show_usage(argv[0]);

    /* Free and close stuff */
    free(words);
    exit(0);
}
