import oracle.jdbc.driver
from Normalization import initializeNPanel
import InfectionEfficiency
import MainHandle
import PlateSet
import java.awt.Color as Color
import java.sql.Types as Types
import java.sql
import time

# get the masterObject for global variables:
masterObject = MainHandle.masterObject

def loadPlateSetFromDb(event):
        # action performed when the "Load Plate Set Data" button is pressed
        # After the batchSelect_callback the button all information about which plates to load is in 
        # the global object (masterObject)

        # make the plate loading progress bar visible, and indicate that infection plates are being loaded
        loadingProgressBar = masterObject.getField('loadingProgressBar')
        loadingProgressBar.setLabelText('Loading infection plates:')
        loadingProgressBar.setVisibility(1)

        # Change the color of the "Load Plate Set Data" button to red while data is loading
        bgOrig = masterObject.getField('loadPlateSetButton').getForeground()
        masterObject.getField('loadPlateSetButton').setForeground(Color(0.5, 0, 0))
        masterObject.getField('loadPlateSetButton').repaint()

        # maps (for convenience):
        vplateId2VPlateName = {}   # virus plate ID to virus plate name map
        batchId2BatchName = {}     # batch ID to batch name map
        screenId2ScreenName = {}   # screen ID to screen name map

        # get the database connection
        conn = masterObject.getField('conn')

        # plate info queries are keyed by the batchIDs:
        batchIdList = masterObject.getField('batchIdList')

        # make the batch list string for the subsequent queries:
        batchList = None
        for batchID in batchIdList:
                if batchList == None:
                        batchList = "%s" % batchID
                elif batchList.find(batchID) == -1:
                        batchList = "%s, %s" % (batchList, batchID)

        # get the plate info for the selected batches:
        plateQryString = '''select distinct i.plateid, s.screenid, s.screenname, i.infectionbatchid,
                                ib.infectionbatchname, p.screenerplatename, p.barcode, p.plateformat, 
                                ib.selection, i.selectionstatus, i.replicate, i.virusplateid 
                                from infection i, infectionbatch ib, screen s, plate p, assayreadout ar 
                                where i.infectionbatchid in (%s) 
                                and i.infectionbatchid=ib.infectionbatchid 
                                and ib.screenid=s.screenid 
                                and p.plateid=i.plateid
                                and i.infectionid = ar.infectionid''' % batchList
        plateQry = conn.prepareStatement(plateQryString)
        # execute the plate query
        rs = plateQry.executeQuery()

        plateSet = {}      # plateSet contains handles to all of the infection plate info
        vpCounts = {}      # contains counts of infection plates that used this virus plate (separated by screen batch, etc.)
        batchIdList = []   # list of unique batch IDs
        while rs.next():
                plateId = rs.getString(1)                             # get the plate ID
                plateSet.setdefault(plateId, PlateSet.Plate(plateId)) # make a new Plate object for this plate
                screenId = rs.getString(2)                            # get the screen ID
                screenName = rs.getString(3)                          # get the screen name
                screenId2ScreenName[screenId] = screenName            # map the screen name to the screenID
                batchId = rs.getString(4)                             # get the batch ID
                if batchIdList.count(batchId) == 0:
                        # add this batch ID to the list if it hasn't been seen before
                        batchIdList.append(batchId)
                batchName = rs.getString(5)                           # get the batch name
                batchId2BatchName[batchId] = batchName                # map the batch name to th batchID
                plateName = rs.getString(6)                           # get the plate name
                plateName = plateName.replace(':', '_')               # for plate name display
                barcode = rs.getString(7)                             # each plate has a unique barcode
                plateFormat = rs.getString(8)                         # 96- or 384-well plate
                selection = rs.getString(9)                           # reagent used for selection (generally puromycin)
                selectionStatus = rs.getString(10)    # whether or not the selection reagent was applied to this infection plate
                replicate = rs.getString(11)          # should this plate be used for infection efficiency computation?
                virusPlateId = rs.getString(12)       # virus plate used to infect this infection plate
                # add the informaton for this plate to the Plate object:
                plateSet[plateId].addPlateQry(screenId, screenName, batchId, batchName,
                                              plateName, barcode, plateFormat,
                                              selection, selectionStatus, replicate, virusPlateId)

        # construct a list of batch IDs
        bIdList = None 
        for b in batchIdList:
                if bIdList == None:
                        bIdList = '%s' % b
                else:
                        bIdList = '%s, %s' % (bIdList, b)  
        print 'Infection batches: ', bIdList

        # NOTE: The read ID is a unique identifier for the data for an individual assay performed on this plate. 
        # Several independent assays could be performed on a single plate. For example, the same plate might be assayed 
        # at several timepoints, or it may be assayed multiple times for independent measurements at different 
        # flourescence wavelengths, etc.

        # get the conditions associated with each plate:
        condQryString = '''
select infectioncondition, assaydescription, 
decode(infectioncondition, 
null, 
assaydescription, 
infectioncondition || decode(assaydescription, null, '', '; ' || assaydescription))
from assayreadout ar
join infection i using (infectionid)
join assay a using (assayid)
where ar.assayreadoutid = ?
'''
        condQry = conn.prepareStatement(condQryString)

        featureListQueryString = """
select
replace(assayFeatureName, ':', '_') || ':' || decode(scoreType, 'DECIMAL', '%f', 'INTEGER', '%d')
from assayreadout ar
join assay a using (assayid)
join assayfeature af using (assayid)
where ar.assayReadoutId = ?
order by af.assayfeatureindex
"""
        featureListQuery = conn.prepareStatement(featureListQueryString)

        defaultFeatureQueryString = """
select replace(assayFeatureName, ':', '_') || ':' || decode(scoreType, 'DECIMAL', '%f', 'INTEGER', '%d') 
from assayreadout ar
join assay a using (assayid)
join assayfeature af using (assayid)
where ar.assayReadoutId = ?
and af.isdefaultfeature = 1
"""
        defaultFeatureQuery = conn.prepareStatement(defaultFeatureQueryString)

        dataArrayQueryString = """
select
row_,
col,
numericscore,
decode(scoreType, 'DECIMAL', '%f', 'INTEGER', '%d') as featureType
from assayreadout ar
join assay a using (assayid)
join assayfeature af using (assayid)
join screenscorecolumn ssc using (assayfeatureid)
join screenscore ss using (screenscorecolumnid)
join screenscorerow ssr using (screenscorerowid)
join infection i using (infectionid)
join infectionwell iw on (i.plateid = iw.infectionplateid)
where ar.assayReadoutId = ?
and iw.infectionwellid = ssr.infectionwellid
order by row_, col, af.assayfeatureindex
"""
        dataArrayQuery = conn.prepareStatement(dataArrayQueryString)
        
        # Increment per plate for the progress bar
        nAssayQuery = conn.prepareStatement("select count(*) from infection i join assayreadout ar using (infectionid) where i.infectionbatchid in (%s)" % batchList)
        rs = nAssayQuery.executeQuery()
        rs.next()
        nAssay = rs.getInt(1)
        plateInc = 1.0 / nAssay
        platePct = 0
        t0 = time.time()   # measure the plate loading time
        
        readQryString = '''select i.plateid, ar.assayid, ar.assayReadoutId,
                           decode((select max(iscellviabilityfeature) 
                                   from assayfeature 
                                   where assayid = a.assayid),1,'YES','NO') as cellViability
                           from infection i
                           join assayreadout ar using (infectionid)
                           join assay a on ar.assayid = a.assayid
                           where i.infectionbatchid in (%s)''' % batchList
        
        readQry = conn.prepareStatement(readQryString)
        rs = readQry.executeQuery()

        while rs.next():
                plateId = rs.getString(1)         # plate ID for this assay
                readId = rs.getString(2)          # assay identifier
                assayReadoutId = rs.getString(3)          # assay identifier
                cellViability = rs.getString(4)   # whether or not this assay measures cell viability (important for inffection efficiency computation)
                plateSet.setdefault(plateId, PlateSet.Plate(plateId)) # create a Plate object if this plate hasn't been seen before (This shouldn't happen???)
                plateSet[plateId].addReadId(readId, cellViability)    # add this assay ID to the list of assays available for this plate

                iPlate = plateSet[plateId]

                # get condition and dataType
                condQry.setString(1, assayReadoutId)
                rsInner = condQry.executeQuery()
                rsInner.next()
                dataType = rsInner.getString(2)
                condition = rsInner.getString(3)
                iPlate.addCondition(readId, condition)

                # get dataFields
                featureListQuery.setString(1, assayReadoutId)
                rsInner = featureListQuery.executeQuery()
                dataFields = ""
                while rsInner.next():
                        dataFields += rsInner.getString(1)
                        dataFields += ","
                dataFields = dataFields[0:-1]

                # get defField
                defaultFeatureQuery.setString(1, assayReadoutId)
                rsInner = defaultFeatureQuery.executeQuery()
                rsInner.next()
                defField = rsInner.getString(1)

                # get dataArray
                dataArrayQuery.setString(1, assayReadoutId)
                rsInner = dataArrayQuery.executeQuery()
                dataArray = {}
                while rsInner.next():
                        row = rsInner.getString(1)
                        col = rsInner.getString(2)
                        score = rsInner.getString(3)
                        featureType = rsInner.getString(4)
                        if not dataArray.has_key(row):
                                dataArray[row] = {}
                        if not dataArray[row].has_key(col):
                                dataArray[row][col] = {}
                                dataArray[row][col]['score'] = []
                                dataArray[row][col]['type'] = []
                        dataArray[row][col]['score'].append(score)
                        dataArray[row][col]['type'].append(featureType)

                plateSet[plateId].setAssayRawData(readId, dataType, dataFields, defField, dataArray)

                virusPlateId = iPlate.get_virusPlateId()     # get the virus plate ID used for this infection plate
                vpCounts.setdefault(virusPlateId, PlateSet.VirusPlateInfo(virusPlateId))  # keep a count of the number of plates infected by this virus plate
                vpCounts[virusPlateId].updateCount('%s:%s:%s' % (iPlate.screenName, iPlate.batchName,
                                                                 iPlate.get_condition(readId)),
                                                   iPlate.selectionStatus, readId, plateId, iPlate.get_use(readId))

                # increment the plate loading progres bar
                platePct += plateInc
                loadingProgressBar.setProgress(platePct)

        # compute and display the plate loading statistics:
        t1 = time.time()
        nPlates = len(plateSet.keys())
        if (t1 - t0) != 0:
                print 'Plate data loaded: %6.4f sec to load %d plates (%6.4f plates/sec)' % (t1 - t0, nPlates, nPlates / (t1 - t0))
        else:
                print 'Plate data loaded: %d plates loaded infinitely fast!' % nPlates

        # load virus plate data:
        vNameQry = conn.prepareStatement('select p.plateName, p.plateformat from plate p where p.plateid=?')

        vidQry = conn.prepareCall("{? = call UPDATE_VP_INFO(?)}")
        updateVtype = conn.prepareCall("{? = call VP_INFO_VTYPE(?)}")

        # the VP_INFO table is essentially a materialized view of all of the relevant information about each hairpin
        # on the virus plate (hairpin ID, gene symbol, genomic location, etc.)
        # AD: the '*' in this query should be replaced with the list of all columns in the VP_INFO table. If columns are
        #     added/subtracted/moved in VP_INFO, the following code must be changed to account for table changes.
        vpQry = conn.prepareStatement('select * from vp_info vp where vp.virusplateid=?')
        vPlateData = {}

        # set the progress bar to indicate tat the virus plate info is being loaded
        loadingProgressBar.setLabelText('Loading virus plate data:')
        loadingProgressBar.setVisibility(1)
        # virus plate increment for progress bar:
        plateInc = 1.0 / len(vpCounts.keys())
        platePct = 0
        loadingProgressBar.setProgress(platePct)

        cloneMap = {}      # "clones" refer to hairpin clones

        # load the virus info for all wells of every virus plate used in any infection being analyzed
        for vPlateId in vpCounts.keys():
                # create a VirusPlate object for each new virus plate ID
                vPlateData.setdefault(vPlateId, PlateSet.VirusPlate(vPlateId))

                # get the virus plate name:
                vNameQry.setString(1, vPlateId)
                # and execute the name query
                rs = vNameQry.executeQuery()
                vPlateName = None

                # print the virus plate name
                while rs.next():
                        vPlateName = rs.getString(1)
                        nWells = int(rs.getString(2))
                        print 'virus plate: %s, nWells: %d' % (vPlateName, nWells)
                        
                if vPlateName == None:
                        # if unable to find one of the virus plates, delete the reference to it and continue:
                        del vPlateData[vPlateId]
                        print 'unable to find virus plate %s. skipping' % vPlateId
                        continue
                else:
                        vPlateName = vPlateName.replace(':', '_')        # for virus plate name display
                        vPlateData[vPlateId].set_plateName(vPlateName)   # assign the virus plate name

                vplateId2VPlateName[vPlateId] = vPlateName           # virus plate ID to virus plate name map

                # make sure the VP_INFO table is up to date here
                vidQry.registerOutParameter(1, Types.INTEGER)
                vidQry.setString(2, vPlateId)
                vidQry.execute()
                addedP = vidQry.getInt(1)
                if addedP:
                        print 'Added virusplate %s to VP_INFO table' % vPlateId
                        updateVtype.registerOutParameter(1, Types.INTEGER)
                        updateVtype.setString(2, vPlateId)
                        updateVtype.execute()
                        nUntyped = updateVtype.getInt(1)
                        if nUntyped > 0:
                                print ' Note: %d wells with vtype=null for virus plate %s' % (nUntyped, vPlateId)

                # load the virus plate info for each well:
                vpQry.setString(1, vPlateId)   # insert the virus plate ID into the database query
                rs = vpQry.executeQuery()
                while rs.next():
                        # get all of the fields from the VP_INFO table
                        row = rs.getString(2)                # plate row
                        col = rs.getString(3)                # plate column
                        virusId = rs.getString(4)            # virus ID
                        virusName = rs.getString(5)          # virus name
                        sourceContig = rs.getString(6)       # genomic contig of the target gene
                        sourceStart = rs.getString(7)        # genomic start location of the target gene
                        sourceEnd = rs.getString(8)          # genomic end location of the target gene
                        sourceStrand = rs.getString(9)       # genomic strand of the target gene
                        symbol = rs.getString(10)            # NCBI symbol of target gene 
                        prefName = rs.getString(11)          # preferred name of target gene 
                        targetSeq = rs.getString(12)         # hairpin sequence
                        sourcePlateName = rs.getString(13)   # source plate name for this virus
                        sourcePlateRow = rs.getString(14)    # source plate row
                        soucePlateCol = rs.getString(15)     # source plate column
                        quad = rs.getString(16)              # quadrant of the source virus plate
                        geneId = rs.getString(17)            # NCBI gene ID
                        cloneId = rs.getString(18)           # TRC clone ID
                        cloneName = rs.getString(19)         # TRC clone name
                        taxon = rs.getString(20)             # target gene taxon
                        vtype = rs.getInt(21)                # virus type (-2=pgw, -1=control, 0=EMPTY, 1=gene targeting hairpin

                        # set the information fields for each well as it is loaded
                        vPlateData[vPlateId].addWell(row, col, virusId, virusName,
                                                     sourceContig, sourceStart, sourceEnd, sourceStrand,
                                                     symbol, prefName, targetSeq,
                                                     sourcePlateName, sourcePlateRow, soucePlateCol, quad,
                                                     geneId, cloneId, cloneName, taxon, vtype, nWells)
                        # set the information for the clone in the clone map: 
                        cloneMap.setdefault(cloneId, PlateSet.CloneInfo(sourceContig, sourceStart, sourceEnd, sourceStrand,
                                                                        symbol, prefName, targetSeq, geneId, cloneName, taxon, vtype))
                # increment the progress bar
                platePct += plateInc
                loadingProgressBar.setProgress(platePct)

        # set global cloneMap:
        masterObject.setField('cloneMap', cloneMap)
        print 'Total number of unique hairpins: %d' % len(cloneMap.keys())

        # ieData contains information about how to group replicates of puro+ and puro- plates to be used for 
        # infection efficiency calculation 
        ieData = {}

        # set the progress bar to indicate that infection efficiency is being calculated
        loadingProgressBar.setLabelText('Computing IE:')
        loadingProgressBar.setVisibility(1)    # make it visible
        # infection efficiency progress increment:
        plateInc = 1.0 / len(vpCounts.keys())
        platePct = 0
        loadingProgressBar.setProgress(platePct)
        for vpId in vpCounts.keys():
                # compute the infection efficiency for each well of each virus plate
                ieData = vpCounts[vpId].updateIeData(ieData, plateSet, vPlateData)
                platePct += plateInc
                loadingProgressBar.setProgress(platePct)

        # done with IE computation, analyze data with default settings:
        # set the progress bar to indicate that the data is being analyzed (normalized)
        # AD: this doesn't look like it is actually used
        loadingProgressBar.setLabelText('Analyzing assay data:')
        loadingProgressBar.setProgress(0)

        for plateId in plateSet.keys():
                # update the each plate in the plate set to point it to the correct virus plate
                plateSet[plateId].updateVirusPlate(vPlateData)

        # reset the vpListBox table:
        vpListBox = masterObject.getField('vpListBox')
        vpTableModel = vpListBox.table.getSelectionModel()
        vpTable = masterObject.getField('vpListBoxTable')

        # remove listeners while updating the table:
        listeners = vpTableModel.getListSelectionListeners()
        for listener in listeners:
                vpTableModel.removeListSelectionListener(listener)

        # clear the virus plate list box if it has any data to start with
        vpRowN = vpTable.getRowCount()
        if vpRowN > 0:
                vpListBox.clearList()

        # list all virus plates, with the number of puro+ and puro- plates used in the averages:
        vIePlateIdList = ieData.keys()
        vPlateStringList = []
        vPlateIdList = []
        vpName2IDmap = {}
        # get the vp list box from the Infection Efficiency tab:
        listBox = masterObject.getField('vpListPanel').getComp('vpListBox')

        for vPlateId in vIePlateIdList:
                for cond in ieData[vPlateId].keys():
                        # gather the components for each row of the virus plate table:
                        vpName = vPlateData[vPlateId].plateName      # virus plate name
                        vpName2IDmap[vpName] = vPlateId              # fill in the virus plate name to ID map
                        vScreen = ieData[vPlateId][cond].screenName  # screen name
                        vBatch = ieData[vPlateId][cond].batchName    # batch name
                        nPlus = len(ieData[vPlateId][cond].selPlusPlateList)     # number of puro+ plates infected with this plate
                        nMinus = len(ieData[vPlateId][cond].selMinusPlateList)   # number of puro- plates infected with this plate
                        condition = cond.split(':')[2].replace('None', '')       # infection condition
                        vPlateString = [vpName, vScreen, vBatch, condition, nPlus, nMinus, '', '']   # construct the table row
                        vPlateStringList.append(vPlateString)        # keep the display info in a list
                        vPlateIdList.append(vPlateId)                # parallel list of virus plate IDs 

                        # add the row to the list box
                        listBox.addRow(vPlateString, vPlateId)

        # set the globally-accessible fields in the master object
        masterObject.setField('vpListBoxRows', vPlateStringList)
        masterObject.setField('vpListvpIdList', vPlateIdList)
        masterObject.setField('vpName2IDmap', vpName2IDmap)

        masterObject.setField('plateSet', plateSet)
        masterObject.setField('vPlateSet', vPlateData)
        masterObject.setField('ieData', ieData)

        masterObject.setField('vplateId2VPlateName', vplateId2VPlateName)
        masterObject.setField('batchId2BatchName', batchId2BatchName)
        masterObject.setField('screenId2ScreenName', screenId2ScreenName)

        # replace listeners after updating the table:
        for listener in listeners:
                vpTableModel.addListSelectionListener(listener)

        # set up data points for IE plotting:
        InfectionEfficiency.initializeIEPlot()
        initializeNPanel()

        loadingProgressBar.setProgress(0)      # reset the progress bar
        loadingProgressBar.setVisibility(0)    # ...and hide it
        masterObject.getField('loadPlateSetButton').setForeground(bgOrig)  # reset the "Load Plate Set Data" button color
        masterObject.getField('loadPlateSetButton').repaint()              # ...and repaint it
