#!/usr/bin/env python
import csv
import random
import string

# -------------------------------------------------------------------------------
#  Constants
# -------------------------------------------------------------------------------

bases = ['A', 'C', 'G', 'T']

qual = [chr(i) for i in range(35, 75)]

# -------------------------------------------------------------------------------
#  Utility functions
# -------------------------------------------------------------------------------
def generate_wells():
    wells = []
    for r in [chr(i) for i in range(ord('A'), ord('I'))]:
        for i in range(1, 13):
            wells.append(r + '{:02d}'.format(i))
        return wells


def generate_brdn(n):
    return 'BRDN' + '{:010d}'.format(n)


def read_barcodes(filename):
    with open(filename, 'rU') as fh:
        reader = csv.reader(fh, delimiter=',')
        return [row[0] for row in reader]
    

def random_string(alphabet, length):
    return string.join([random.choice(alphabet) for i in range(0, length)], '')


def random_seq(length):
    return random_string(bases, length)


def random_length_seq(min_length, max_length):
    length = random.randint(min_length, max_length)
    return random_seq(length)


def random_qual(length):
    return random_string(qual, length)


def fastq_read_id(number):
    return '@HWUSI-EAS100R:6:23:398:3989#' + str(number)


def gaussian_qual(length, mean, var):
    '''
    returns a quality score sampled from a Gaussian distribution
    with the provided mean and standard deviation
    '''
    def trim(x):
        rounded = round(x)
        if rounded < 35:
            return 35
        elif rounded > 74:
            return 74
        else:
            return int(rounded)

    return [trim(random.gauss(mean, var)) for i in range(0, length)]


def random_barcodes(length, number):
    barcodes = {}
    while len(barcodes.keys()) < number:
        barcode = random_seq(length)
        barcodes[barcode] = 0
    return sorted(barcodes.keys())


def maybe_mutate_seq(seq, quals):
    chars = list(seq)
    for i in range(0, len(seq)):
        if quals[i] < 36 and random.random() < 0.7:
            chars[i] = 'N'
    return string.join(chars, '')


def single_fastq(barcode, seq, stuffer, qual_mean, qual_stddev, number):
    readid = fastq_read_id(number)
    read_seq = barcode + stuffer + seq
    quals = gaussian_qual(len(read_seq), qual_mean, qual_stddev)
    mutated_seq = maybe_mutate_seq(read_seq, quals)
    return (readid, mutated_seq, quals)


def print_fastq_record(file, fastq):
    (readid, seq, quality) = fastq
    file.write(readid + '\n')
    file.write(seq + '\n')
    file.write('+\n')
    file.write(string.join([chr(q) for q in quality], '') + '\n')


def print_fastq_record2(file, fastq):
    (readid, seq, quality) = fastq
    file.write(readid + '\n')
    file.write(seq + '\n')
    file.write('+\n')
    file.write(quality + '\n')


def write_csv(filename, data):
    with open(filename, 'w') as file:
        for (barcode, payload) in data:
            file.write(barcode + "," + payload + '\n')


def print_scenario1_fastq_file(filename, data):
    stuffer = random_seq(17) + 'CACCG'
    with open(filename, 'w') as file:
        for (i, (barcode, seq)) in enumerate(data):
            fastq = single_fastq(barcode, seq, stuffer, 55, 5, i + 1)
            print_fastq_record(file, fastq)


def print_scenario2_fastq_file(filename, data):
    stuffer = random_seq(10) + 'CACCG'
    with open(filename, 'w') as file:
        for (i, (barcode, seq)) in enumerate(data):
            stagger = random_length_seq(0, 7)
            stagger_len = len(stagger)
            fill_len = 7 - stagger_len
            fill = random_seq(fill_len)
            fastq = single_fastq(barcode, seq + fill, stagger + stuffer, 55, 5, i + 1)
            print_fastq_record(file, fastq)

            
def print_scenario3_fastq_file(filename1, filename2, data):
    # print the construct file
    with open(filename1, 'w') as file:
        stuffer = random_seq(24) + 'CACCG'
        for (i, (_, seq)) in enumerate(data):
            fastq = single_fastq('', seq, stuffer, 55, 5, i + 1)
            print_fastq_record(file, fastq)

    # print the barcode file
    with open(filename2, 'w') as file:
        for(i, (barcode, _)) in enumerate(data):
            fastq = single_fastq(barcode, '', '', 55, 5, i + 1)
            print_fastq_record(file, fastq)


def print_scenario4_fastq_file(filename1, filename2, data):
    # print the construct file
    with open(filename1, 'w') as file:
        stuffer = random_seq(18) + 'CACCG'
        for (i, (_, seq)) in enumerate(data):
            stagger = random_length_seq(0, 7)
            stagger_len = len(stagger)
            fill_len = 7 - stagger_len
            fill = random_seq(fill_len)
            fastq = single_fastq('', seq + fill, stagger + stuffer, 55, 5, i + 1)
            print_fastq_record(file, fastq)

    # print the barcode file
    with open(filename2, 'w') as file:
        for(i, (barcode, _)) in enumerate(data):
            fastq = single_fastq(barcode, '', '', 55, 5, i + 1)
            print_fastq_record(file, fastq)


def print_template_setup(cond, ref, fq1, fq2, matrix, min_stagger, max_stagger, ref_barcode_length, stuffer_length, max_trailing_bases):
    read_length = min_stagger + ref_barcode_length + stuffer_length + ref_barcode_length + max_trailing_bases
    
    def exactly_once_in_order(data):
        used = {}
        ids = []
        for (_, x) in data:
            if used.get(x, 0) == 0:
                ids.append(x)
                used[x] = 1
        return ids

    # generate random sample barcodes
    barcodes = random_barcodes(8, 6)
    wells = generate_wells()

    # generate conditions and print it
    conditions = [(b, random.choice(wells)) for b in barcodes]
    conditions.sort()
    write_csv(cond, conditions)

    # generate the reference and print it
    rs = random_barcodes(ref_barcode_length, 32)
    reference1 = rs[:ref_barcode_length]
    reference2 = rs[ref_barcode_length:]
    reference = []
    n = 1
    for r1 in reference1:
        for r2 in reference2:
            reference.append((r1 + ';' + r2, generate_brdn(n)))
            n += 1
    reference.sort()
    write_csv(ref, reference)

    col_ids = exactly_once_in_order(conditions)

    # now generate scores and fastq files
    scores = {}
    with open(fq1, 'w') as fastq1:
        with open(fq2, 'w') as fastq2:
            for i in range(0, 2000):
                (col_barcode, col_id) = random.choice(conditions)
                r1 = random.choice(reference1)
                r2 = random.choice(reference2)
                row_barcode = r1 + ';' + r2

                # increment the scores dict
                scores[(col_id, row_barcode)] = scores.get((col_id, row_barcode), 0) + 1

                # generate FASTQ records and print them
                readid = fastq_read_id(i)

                # the sample fastq is easy
                col_qual = random_qual(len(col_barcode))
                print_fastq_record2(fastq1, (readid, col_barcode, col_qual))

                # the other one is harder
                seq = random_length_seq(min_stagger, max_stagger) + 'CACCG' + r1 + random_seq(stuffer_length) + 'TTACA' + r2
                seq += random_seq(read_length - (len(seq)))
                row_qual = random_qual(read_length)
                print_fastq_record2(fastq2, (readid, seq, row_qual))

    # use the dict to write a scores file
    with open(matrix, 'w') as file:
        # write the header
        header = ['Construct Barcode', 'Construct IDs'] + col_ids
        file.write('\t'.join(header) + '\n')

        for (rbc, rid) in reference:
            row = [rbc, rid]
            sc = [str(scores.get((cid, rbc), 0)) for cid in col_ids]
            file.write('\t'.join(row + sc) + '\n')


def print_papi_setup(cond, ref, fq1, fq2, matrix):
    print_template_setup(cond, ref, fq1, fq2, matrix, 12, 15, 16, 9, 8)


def print_long_template_setup(cond, ref, fq1, fq2, matrix):
    print_template_setup(cond, ref, fq1, fq2, matrix, 12, 15, 20, 189, 8)


# -------------------------------------------------------------------------------
#  Main
# -------------------------------------------------------------------------------

if __name__ == '__main__':
    barcodes = read_barcodes('Conditions.csv')
    seqs = read_barcodes('Reference.csv')

    data = [(random.choice(barcodes), random.choice(seqs)) for _ in range(0, 1000)]
    
    #print_scenario1_fastq_file('scenario1/scenario1.fastq', data)
    #print_scenario2_fastq_file('scenario2/scenario2.fastq', data)
    #print_scenario3_fastq_file('scenario3/scenario3.1.fastq',
    #                           'scenario3/scenario3.barcode_1.fastq',
    #                           data)
    #print_scenario4_fastq_file('bits/scenario4.1.fastq',
    #                           'bits/scenario4.barcode_1.fastq',
    #                           data)
    #print_papi_setup('papi/conditions.csv',
    #                 'papi/reference.csv',
    #                 'papi/papi.barcode_1.fastq',
    #                 'papi/papi.fastq',
    #                 'papi/expected-scores.txt')
    print_long_template_setup('long-template/conditions.csv',
                              'long-template/reference.csv',
                              'long-template/long-template.barcode_1.fastq',
                              'long-template/long-template.fastq',
                              'long-template/expected-scores.txt')
