try:
  import numpy
  from pipeline_display import *
  from pipeline_product import *
  from numpy.polynomial import Polynomial
except ImportError:
  donothing=1


class PlotableReducedArc :
  def __init__(self, fits):
    self.arc     = PipelineProduct(fits)
    self.arcdisp = ImageDisplay()
    self.loadFromFits()

  def loadFromFits(self) :
    self.arc.readImage()

  def plot(self, subplot, title, tooltip):
    self.arcdisp.setLabels('X [pix]', 'Y [pix]')
    self.arcdisp.setZLimits((-100., 1000.))
    self.arcdisp.display(subplot, title, tooltip, self.arc.image)
    
class PlotableFlat(object) :
  def __init__(self, fits_flat, fits_slittrace):
    self.flat      = PipelineProduct(fits_flat)
    self.slittrace = None
    if fits_slittrace is not None:
      self.slittrace = PipelineProduct(fits_slittrace)
    self.flatdisp  = ImageDisplay()
    self.loadFromFits()

  def loadFromFits(self) :
    #Reading the flat image
    self.flat.readImage()
    
    #Reading the polinomial traces
    if self.slittrace is not None:
      ndegree = self.slittrace.getTableNcols(1) - 1
      self.nslits  = self.slittrace.getTableNrows(1) / 2
      degreecols = []
      for deg in range(ndegree):
        colname = 'c%d'%deg
        self.slittrace.readTableColumn(1, colname)
        degreecols.append(self.slittrace.column)
    
      top_trace_polynomials = []
      bottom_trace_polynomials = []
      for slit in range(self.nslits) :
        top_trace_coeff = []
        bottom_trace_coeff = []
        for deg in range(ndegree) :
          top_trace_coeff.append(degreecols[deg][2*slit])
          bottom_trace_coeff.append(degreecols[deg][2*slit + 1])
        
        top_trace_pol = Polynomial(top_trace_coeff)  
        bottom_trace_pol = Polynomial(bottom_trace_coeff)
        top_trace_polynomials.append(top_trace_pol) 
        bottom_trace_polynomials.append(bottom_trace_pol) 

      #Creating the points to plot based on the polynomail traces
      self.xpos_traces = []
      self.ypos_top_traces = []
      self.ypos_bottom_traces = []
      for slit in range(self.nslits) :
        ypos_top = []
        ypos_bottom = []
        xpos = []
        for xpix in range(self.flat.image.shape[1]) :
          xpos.append(xpix+1) 
          ypos_top.append(top_trace_polynomials[slit](xpix)+1) 
          ypos_bottom.append(bottom_trace_polynomials[slit](xpix)+1)
        self.xpos_traces.append(xpos)
        self.ypos_top_traces.append(ypos_top) 
        self.ypos_bottom_traces.append(ypos_bottom)

  def plot(self, subplot, title, tooltip):
    self.flatdisp.setLabels('X [pix]', 'Y [pix]')
    self.flatdisp.display(subplot, title, tooltip, self.flat.image)
    
    if self.slittrace is not None:
      subplot.autoscale(enable=False)
      for slit in range(self.nslits) :
        subplot.plot(self.xpos_traces[slit], self.ypos_top_traces[slit],
                     linestyle='solid',color='red')
        subplot.plot(self.xpos_traces[slit], self.ypos_bottom_traces[slit],
                     linestyle='solid',color='darkred')

class PlotableNormFlat (PlotableFlat) :
  def __init__(self, fits_flat, fits_slittrace):
    super(PlotableNormFlat, self).__init__(fits_flat, fits_slittrace)

  def plot(self, subplot, title, tooltip):
    self.flatdisp.setZLimits((0.9, 1.1))
    self.flat.image[self.flat.image > 5.] = 0
    super(PlotableNormFlat, self).plot(subplot, title, tooltip)

class PlotableSpatialMap :
  def __init__(self, fits_spatialmap):
    self.spatialmap      = PipelineProduct(fits_spatialmap)
    self.spatialmapdisp  = ImageDisplay()
    self.loadFromFits()

  def loadFromFits(self) :
    #Reading the flat image
    self.spatialmap.readImage()

  def plot(self, subplot, title, tooltip):
    self.spatialmapdisp.setLabels('X', 'Y')
    self.spatialmapdisp.setZLimits((0., 100))
    self.spatialmapdisp.display(subplot, title, tooltip, self.spatialmap.image)

class PlotableMappedScience :
  def __init__(self, fits_mappedscience):
    self.mappedscience      = PipelineProduct(fits_mappedscience)
    self.mappedsciencedisp  = ImageDisplay()
    self.loadFromFits()

  def loadFromFits(self) :
    #Reading the flat image
    self.mappedscience.readImage()

  def plot(self, subplot, title, tooltip):
    self.mappedsciencedisp.setLabels('X', 'Y')
    self.mappedsciencedisp.setZLimits((0., 0.9))
    self.mappedsciencedisp.display(subplot, title, tooltip, self.mappedscience.image)

class PlotableDispResiduals :
  def __init__(self, fits_dispresiduals):
    self.dispresiduals = PipelineProduct(fits_dispresiduals)
    self.resdisplay  = ScatterDisplay()
    self.loadFromFits()

  def loadFromFits(self) :
    #Reading the residuals table
    self.dispresiduals.readTableColumn(1, 'wavelength')
    self.wave = self.dispresiduals.column
    nwave    = self.dispresiduals.getTableNrows(1)
    ncolumns = self.dispresiduals.getTableNcols(1)
    nselectedrows = (ncolumns - 1) // 3
    self.residuals = []
    self.allwave = []
    self.allypos = []
    self.allresiduals = []
    for i in range(nselectedrows) :
      #TODO: Currently the residuals are computed every 10 rows. 
      #This is hard-coded in the pipeline. It would be better just to detect the
      #columns whose name start with 'r' 
      colname = 'r%d'%(i*10) 
      self.dispresiduals.readTableColumn(1, colname)
      row_residuals = self.dispresiduals.column
      self.residuals.append(row_residuals)
      self.allwave.extend(self.wave)
      self.allresiduals.extend(row_residuals)
      ypos = i*10.
      self.allypos.extend([ypos] * nwave)

  def plotResVsWave(self, subplot, title, tooltip):
    self.resdisplay.setLabels('Wavelength [Ang]','Residual [pix]')
    self.resdisplay.display(subplot, title, tooltip, self.allwave,
                            self.allresiduals)

  def plotResVsY(self, subplot, title, tooltip):
    self.resdisplay.setLabels('Ypos [pix]','Residual [pix]')
    self.resdisplay.display(subplot, title, tooltip, self.allypos,
                            self.allresiduals)

class PlotableDetectedLines :
  def __init__(self, fits_detectedlines):
    self.detectedlines = PipelineProduct(fits_detectedlines)
    self.xydisplay     = ScatterDisplay()
    self.resdisplay    = ScatterDisplay()
    self.loadFromFits()

  def loadFromFits(self) :
    #Reading the residuals table
    try :
      self.detectedlines.readTableColumn(1, 'xpos_rectified')
      self.x_pix = self.detectedlines.column
      self.detectedlines.readTableColumn(1, 'ypos_rectified')
      self.y_pix = self.detectedlines.column
    except KeyError:
      self.detectedlines.readTableColumn(1, 'xpos')
      self.x_pix = self.detectedlines.column
      self.detectedlines.readTableColumn(1, 'ypos')
      self.y_pix = self.detectedlines.column
      
    self.detectedlines.readTableColumn(1, 'wave_ident')
    self.wave  = self.detectedlines.column
    self.detectedlines.readTableColumn(1, 'res_xpos')
    self.res_xpos  = self.detectedlines.column

  def plotXVsY(self, subplot, title, tooltip):
    self.xydisplay.setLabels('Xpos [pix]','Ypos [pix]')
    self.xydisplay.setColor('black')
    self.xydisplay.display(subplot, title, tooltip, self.x_pix,
                           self.y_pix)
    self.xydisplay.setColor('green')
    self.xydisplay.display(subplot, title, tooltip, 
                           self.x_pix[numpy.isfinite(self.wave)],
                           self.y_pix[numpy.isfinite(self.wave)])

  def plotResVsWave(self, subplot, title, tooltip):
    self.resdisplay.setLabels('Wavelength [Ang]','Residual [pix]')
    self.resdisplay.setColor('black')
    self.resdisplay.display(subplot, title, tooltip, 
                            self.wave[numpy.isfinite(self.res_xpos)],
                            self.res_xpos[numpy.isfinite(self.res_xpos)])


