import oracle.jdbc.driver
from Normalization import initializeNPanel
import InfectionEfficiency
import MainHandle
import PlateSet
import java.awt.Color as Color
import java.sql
import time

# get the masterObject for global variables:
masterObject = MainHandle.masterObject

def loadPlateSetFromDb(event):
	# action performed when the "Load Plate Set Data" button is pressed
	# After the batchSelect_callback the button all information about which plates to load is in 
	# the global object (masterObject)

	# make the plate loading progress bar visible, and indicate that infection plates are being loaded
	loadingProgressBar = masterObject.getField('loadingProgressBar')
	loadingProgressBar.setLabelText('Loading infection plates:')
	loadingProgressBar.setVisibility(1)

	# Change the color of the "Load Plate Set Data" button to red while data is loading
	bgOrig = masterObject.getField('loadPlateSetButton').getForeground()
	masterObject.getField('loadPlateSetButton').setForeground(Color(0.5, 0, 0))
	masterObject.getField('loadPlateSetButton').repaint()

	# maps (for convenience):
	vplateId2VPlateName = {}   # virus plate ID to virus plate name map
	batchId2BatchName = {}     # batch ID to batch name map
	screenId2ScreenName = {}   # screen ID to screen name map

	# get the database connection
	conn = masterObject.getField('conn')

	# plate info queries are keyed by the batchIDs:
	batchIdList = masterObject.getField('batchIdList')

	# make the batch list string for the subsequent queries:
	batchList = None
	for batchID in batchIdList:
		if batchList == None:
			batchList = "'%s'" % batchID
		elif batchList.find(batchID) == -1:
			batchList = "%s,'%s'" % (batchList, batchID)

	# get the plate info for the selected batches:
	plateQryString = 'select distinct i.plateid, s.screenid, s.screenname, i.infectionbatchid, ib.infectionbatchname, p.screenerplatename, p.barcode, p.plateformat, ib.selection, i.selectionstatus, i.replicate, i.virusplateid from infection i, infectionbatch ib, screen s, plate p, plate2assay p2a, well2value w2v where i.infectionbatchid in (%s) and i.infectionbatchid=ib.infectionbatchid and ib.screenid=s.screenid and p.plateid=i.plateid and p2a.plateid=p.plateid and p2a.assayid=w2v.readid' % batchList
	plateQry = conn.prepareStatement(plateQryString)
	# execute the plate query
	rs = plateQry.executeQuery()

	plateSet = {}      # plateSet contains handles to all of the infection plate info
	vpCounts = {}      # contains counts of infection plates that used this virus plate (separated by screen batch, etc.)
	batchIdList = []   # list of unique batch IDs
	while rs.next():
		plateId = rs.getString(1)                             # get the plate ID
		plateSet.setdefault(plateId, PlateSet.Plate(plateId)) # make a new Plate object for this plate
		screenId = rs.getString(2)                            # get the screen ID
		screenName = rs.getString(3)                          # get the screen name
		screenId2ScreenName[screenId] = screenName            # map the screen name to the screenID
		batchId = rs.getString(4)                             # get the batch ID
		if batchIdList.count(batchId) == 0:
			# add this batch ID to the list if it hasn't been seen before
			batchIdList.append(batchId)
		batchName = rs.getString(5)                           # get the batch name
		batchId2BatchName[batchId] = batchName                # map the batch name to th batchID
		plateName = rs.getString(6)                           # get the plate name
		plateName = plateName.replace(':', '_')               # for plate name display
		barcode = rs.getString(7)                             # each plate has a unique barcode
		plateFormat = rs.getString(8)                         # 96- or 384-well plate
		selection = rs.getString(9)                           # reagent used for selection (generally puromycin)
		selectionStatus = rs.getString(10)    # whether or not the selection reagent was applied to this infection plate
		replicate = rs.getString(11)          # should this plate be used for infection efficiency computation?
		virusPlateId = rs.getString(12)       # virus plate used to infect this infection plate
		# add the informaton for this plate to the Plate object:
		plateSet[plateId].addPlateQry(screenId, screenName, batchId, batchName,
									  plateName, barcode, plateFormat,
									  selection, selectionStatus, replicate, virusPlateId)

	# construct a list of batch IDs
	bIdList = None 
	for b in batchIdList:
		if bIdList == None:
			bIdList = '%s' % b
		else:
			bIdList = '%s, %s' % (bIdList, b)  
	print 'Infection batches: ', bIdList

	# NOTE: The read ID is a unique identifier for the data for an individual assay performed on this plate. 
	# Several independent assays could be performed on a single plate. For example, the same plate might be assayed 
	# at several timepoints, or it may be assayed multiple times for independent measurements at different 
	# flourescence wavelengths, etc.

	# get the list of readIDs for the selected batches:
	readQryString = 'select distinct p2a.plateid, p2a.assayid, p2a.cellViability from plate2assay p2a, well2value w2v, infection i where p2a.infectionbatchid in (%s) and p2a.assayid=w2v.readid and i.plateid=p2a.plateid and i.infectionbatchid=p2a.infectionbatchid' % batchList
	readQry = conn.prepareStatement(readQryString)
	# execute the read query
	rs = readQry.executeQuery()

	while rs.next():
		plateId = rs.getString(1)         # plate ID for this assay
		readId = rs.getString(2)          # assay identifier
		cellViability = rs.getString(3)   # whether or not this assay measures cell viability (important for inffection efficiency computation)
		plateSet.setdefault(plateId, PlateSet.Plate(plateId)) # create a Plate object if this plate hasn't been seen before (This shouldn't happen???)
		plateSet[plateId].addReadId(readId, cellViability)    # add this assay ID to the list of assays available for this plate

	# get the conditions associated with each plate:
	condQryString = 'select c.conditiondescription from plate2assay p2a, plate2condition c where c.readid=p2a.assayid and p2a.assayid=?'
	condQry = conn.prepareStatement(condQryString)

	# Now, read the assays for all plates:
	# GET_ASSAYREAD is a PL/SQL procedure that reads the data from an assay and returns the data as an array object
	pQry384 = conn.prepareCall("{call GET_ASSAYREAD(?,?,?,?,?,?)}")
	pQry384.registerOutParameter(2, java.sql.Types.VARCHAR)     # output: plateID
	pQry384.registerOutParameter(3, java.sql.Types.VARCHAR)     # output: assay data type
	pQry384.registerOutParameter(4, java.sql.Types.VARCHAR)     # output: data fields for this data type
	pQry384.registerOutParameter(5, java.sql.Types.VARCHAR)     # output: default data field
	pQry384.registerOutParameter(6, oracle.jdbc.driver.OracleTypes.ARRAY, 'PLATEARRAYCLOB')  # output: main assay data array object (PLATEARRAYCLOB datatype)

	dataArray = None
	# compute the total number of ASSAYS to be read in:
	nAssay = 0
	for plateId in plateSet.keys():
		nAssay += len(plateSet[plateId].get_readIdList())

	# Increment per plate for the progress bar
	plateInc = 1.0 / nAssay
	platePct = 0
	t0 = time.time()   # measure the plate loading time

	# loop over all plates, and read the assay data and assay condition description for each plate
	for plateId in plateSet.keys():
		iPlate = plateSet[plateId]                        # get the plate object
		readIdList = plateSet[plateId].get_readIdList()   # get the list of assay identifiers for this plate

		# get the conditions associated with each plate:
		for readId in readIdList:
			condQry.setString(1, readId)
			rs = condQry.executeQuery()

			# condition is a description of the infection condition, assay description, or both. One assay can have multiple conditions
			while rs.next():
				condition = rs.getString(1)
				plateSet.setdefault(plateId, PlateSet.Plate(plateId))
				plateSet[plateId].addCondition(readId, condition)
				#print 'Plate %s (%s), condition: %s' % (plateSet[plateId].plateName, plateId, condition)

			pQry384.setString(1, readId)  # set the input parameter (readID) for the GET_ASSAYREAD procedure
			dataArray = []   # clear the data array

			try:
				pQry384.executeUpdate()             # execute the PL/SQL procedure in the database
				aPlateId = pQry384.getString(2)     # get the plate ID
				dataType = pQry384.getString(3)     # get the data type
				dataFields = pQry384.getString(4)   # get the data fields
				defField = pQry384.getString(5)     # get the default data field
				dataArray = pQry384.getArray(6).getArray()      # get the array of assay data from the DB
				#print '	%s assay read successful (nWells=%d).' % (readId, len(dataArray)-1)
			except:
				# print out an error message on failure, but continue
				print 'GET_ASSAYREAD failure on readId %s (384 well plate)' % readId
				print 'message from DB: ', pQry384.getString(1)     # print the error message from the DB

			virusPlateId = iPlate.get_virusPlateId()     # get the virus plate ID used for this infection plate
			vpCounts.setdefault(virusPlateId, PlateSet.VirusPlateInfo(virusPlateId))  # keep a count of the number of plates infected by this virus plate
			vpCounts[virusPlateId].updateCount('%s:%s:%s' % (iPlate.screenName, iPlate.batchName,
															 iPlate.get_condition(readId)),
											   iPlate.selectionStatus, readId, aPlateId, iPlate.get_use(readId))
			# insert the raw (unnormalized) assay data for this read ID into the Plate object: 
			plateSet[plateId].setAssayRawData(readId, dataType, dataFields, defField, dataArray)

			# increment the plate loading progres bar
			platePct += plateInc
			loadingProgressBar.setProgress(platePct)

	# compute and display the plate loading statistics:
	t1 = time.time()
	nPlates = len(plateSet.keys())
	if (t1 - t0) != 0:
		print 'Plate data loaded: %6.4f sec to load %d plates (%6.4f plates/sec)' % (t1 - t0, nPlates, nPlates / (t1 - t0))
	else:
		print 'Plate data loaded: %d plates loaded infinitely fast!' % nPlates

	# load virus plate data:
	vNameQry = conn.prepareStatement('select p.plateName, p.plateformat from plate p where p.plateid=?')
	# the VP_INFO table is essentially a materialized view of all of the relevant information about each hairpin
	# on the virus plate (hairpin ID, gene symbol, genomic location, etc.)
	# AD: the '*' in this query should be replaced with the list of all columns in the VP_INFO table. If columns are
	#     added/subtracted/moved in VP_INFO, the following code must be changed to account for table changes.
	vpQry = conn.prepareStatement('select * from vp_info vp where vp.virusplateid=?')
	vPlateData = {}

	# set the progress bar to indicate tat the virus plate info is being loaded
	loadingProgressBar.setLabelText('Loading virus plate data:')
	loadingProgressBar.setVisibility(1)
	# virus plate increment for progress bar:
	plateInc = 1.0 / len(vpCounts.keys())
	platePct = 0
	loadingProgressBar.setProgress(platePct)

	cloneMap = {}      # "clones" refer to hairpin clones

	# load the virus info for all wells of every virus plate used in any infection being analyzed
	for vPlateId in vpCounts.keys():
		# create a VirusPlate object for each new virus plate ID
		vPlateData.setdefault(vPlateId, PlateSet.VirusPlate(vPlateId))

		# get the virus plate name:
		vNameQry.setString(1, vPlateId)
		# and execute the name query
		rs = vNameQry.executeQuery()
		vPlateName = None

		# print the virus plate name
		while rs.next():
			vPlateName = rs.getString(1)
			nWells = int(rs.getString(2))
			print 'virus plate: %s, nWells: %d' % (vPlateName, nWells)
			
		if vPlateName == None:
			# if unable to find one of the virus plates, delete the reference to it and continue:
			del vPlateData[vPlateId]
			print 'unable to find virus plate %s. skipping' % vPlateId
			continue
		else:
			vPlateName = vPlateName.replace(':', '_')        # for virus plate name display
			vPlateData[vPlateId].set_plateName(vPlateName)   # assign the virus plate name

		vplateId2VPlateName[vPlateId] = vPlateName           # virus plate ID to virus plate name map

		# load the virus plate info for each well:
		vpQry.setString(1, vPlateId)   # insert the virus plate ID into the database query
		rs = vpQry.executeQuery()
		while rs.next():
			# get all of the fields from the VP_INFO table
			row = rs.getString(2)                # plate row
			col = rs.getString(3)                # plate column
			virusId = rs.getString(4)            # virus ID
			virusName = rs.getString(5)          # virus name
			sourceContig = rs.getString(6)       # genomic contig of the target gene
			sourceStart = rs.getString(7)        # genomic start location of the target gene
			sourceEnd = rs.getString(8)          # genomic end location of the target gene
			sourceStrand = rs.getString(9)       # genomic strand of the target gene
			symbol = rs.getString(10)            # NCBI symbol of target gene 
			prefName = rs.getString(11)          # preferred name of target gene 
			targetSeq = rs.getString(12)         # hairpin sequence
			sourcePlateName = rs.getString(13)   # source plate name for this virus
			sourcePlateRow = rs.getString(14)    # source plate row
			soucePlateCol = rs.getString(15)     # source plate column
			quad = rs.getString(16)              # quadrant of the source virus plate
			geneId = rs.getString(17)            # NCBI gene ID
			cloneId = rs.getString(18)           # TRC clone ID
			cloneName = rs.getString(19)         # TRC clone name
			taxon = rs.getString(20)             # target gene taxon
			vtype = rs.getInt(21)                # virus type (-2=pgw, -1=control, 0=EMPTY, 1=gene targeting hairpin

			# set the information fields for each well as it is loaded
			vPlateData[vPlateId].addWell(row, col, virusId, virusName,
										 sourceContig, sourceStart, sourceEnd, sourceStrand,
										 symbol, prefName, targetSeq,
										 sourcePlateName, sourcePlateRow, soucePlateCol, quad,
										 geneId, cloneId, cloneName, taxon, vtype, nWells)
			# set the information for the clone in the clone map: 
			cloneMap.setdefault(cloneId, PlateSet.CloneInfo(sourceContig, sourceStart, sourceEnd, sourceStrand,
												  symbol, prefName, targetSeq, geneId, cloneName, taxon, vtype))
		# increment the progress bar
		platePct += plateInc
		loadingProgressBar.setProgress(platePct)

	# set global cloneMap:
	masterObject.setField('cloneMap', cloneMap)
	print 'Total number of unique hairpins: %d' % len(cloneMap.keys())

	# ieData contains information about how to group replicates of puro+ and puro- plates to be used for 
	# infection efficiency calculation 
	ieData = {}

	# set the progress bar to indicate that infection efficiency is being calculated
	loadingProgressBar.setLabelText('Computing IE:')
	loadingProgressBar.setVisibility(1)    # make it visible
	# infection efficiency progress increment:
	plateInc = 1.0 / len(vpCounts.keys())
	platePct = 0
	loadingProgressBar.setProgress(platePct)
	for vpId in vpCounts.keys():
		# compute the infection efficiency for each well of each virus plate
		ieData = vpCounts[vpId].updateIeData(ieData, plateSet, vPlateData)
		platePct += plateInc
		loadingProgressBar.setProgress(platePct)

	# done with IE computation, analyze data with default settings:
	# set the progress bar to indicate that the data is being analyzed (normalized)
	# AD: this doesn't look like it is actually used
	loadingProgressBar.setLabelText('Analyzing assay data:')
	loadingProgressBar.setProgress(0)

	for plateId in plateSet.keys():
		# update the each plate in the plate set to point it to the correct virus plate
		plateSet[plateId].updateVirusPlate(vPlateData)

	# reset the vpListBox table:
	vpListBox = masterObject.getField('vpListBox')
	vpTableModel = vpListBox.table.getSelectionModel()
	vpTable = masterObject.getField('vpListBoxTable')

	# remove listeners while updating the table:
	listeners = vpTableModel.getListSelectionListeners()
	for listener in listeners:
		vpTableModel.removeListSelectionListener(listener)

	# clear the virus plate list box if it has any data to start with
	vpRowN = vpTable.getRowCount()
	if vpRowN > 0:
		vpListBox.clearList()

	# list all virus plates, with the number of puro+ and puro- plates used in the averages:
	vIePlateIdList = ieData.keys()
	vPlateStringList = []
	vPlateIdList = []
	vpName2IDmap = {}
	# get the vp list box from the Infection Efficiency tab:
	listBox = masterObject.getField('vpListPanel').getComp('vpListBox')

	for vPlateId in vIePlateIdList:
		for cond in ieData[vPlateId].keys():
			# gather the components for each row of the virus plate table:
			vpName = vPlateData[vPlateId].plateName      # virus plate name
			vpName2IDmap[vpName] = vPlateId              # fill in the virus plate name to ID map
			vScreen = ieData[vPlateId][cond].screenName  # screen name
			vBatch = ieData[vPlateId][cond].batchName    # batch name
			nPlus = len(ieData[vPlateId][cond].selPlusPlateList)     # number of puro+ plates infected with this plate
			nMinus = len(ieData[vPlateId][cond].selMinusPlateList)   # number of puro- plates infected with this plate
			condition = cond.split(':')[2].replace('None', '')       # infection condition
			vPlateString = [vpName, vScreen, vBatch, condition, nPlus, nMinus, '', '']   # construct the table row
			vPlateStringList.append(vPlateString)        # keep the display info in a list
			vPlateIdList.append(vPlateId)                # parallel list of virus plate IDs 

			# add the row to the list box
			listBox.addRow(vPlateString, vPlateId)

	# set the globally-accessible fields in the master object
	masterObject.setField('vpListBoxRows', vPlateStringList)
	masterObject.setField('vpListvpIdList', vPlateIdList)
	masterObject.setField('vpName2IDmap', vpName2IDmap)

	masterObject.setField('plateSet', plateSet)
	masterObject.setField('vPlateSet', vPlateData)
	masterObject.setField('ieData', ieData)

	masterObject.setField('vplateId2VPlateName', vplateId2VPlateName)
	masterObject.setField('batchId2BatchName', batchId2BatchName)
	masterObject.setField('screenId2ScreenName', screenId2ScreenName)

	# replace listeners after updating the table:
	for listener in listeners:
		vpTableModel.addListSelectionListener(listener)

	# set up data points for IE plotting:
	InfectionEfficiency.initializeIEPlot()
	initializeNPanel()

	loadingProgressBar.setProgress(0)      # reset the progress bar
	loadingProgressBar.setVisibility(0)    # ...and hide it
	masterObject.getField('loadPlateSetButton').setForeground(bgOrig)  # reset the "Load Plate Set Data" button color
	masterObject.getField('loadPlateSetButton').repaint()              # ...and repaint it


