package net.shrine.monitor

import edu.harvard.i2b2.crc.datavo.i2b2message.PasswordType
import edu.harvard.i2b2.crc.datavo.i2b2message.RequestMessageType
import edu.harvard.i2b2.crc.datavo.i2b2message.SecurityType
import groovy.sql.Sql
import groovy.util.slurpersupport.GPathResult
import net.shrine.monitor.scanner.FlatFileIterator
import net.shrine.serializers.crc.CRCHttpClient
import net.shrine.serializers.crc.CRCSerializer
import net.shrine.serializers.crc.QueryDefBuilder
import org.apache.log4j.Logger
import org.spin.tools.PKITool

/**
 * @author Bill Simons
 * @date Dec 17, 2010
 * @link http://cbmi.med.harvard.edu
 * @link http://chip.org
 *       <p/>
 *       NOTICE: This software comes with NO guarantees whatsoever and is
 *       licensed as Lgpl Open Source
 * @link http://www.gnu.org/licenses/lgpl.html
 */
class Heartbeat
{
  static Logger logger = Logger.getLogger(Heartbeat.class);
  def config
  Sql sql;

  Heartbeat(config, Sql sql)
  {
    this.config = config
    this.sql = sql
    bootstrap()
  }

  def bootstrap(configFileLocation)
  {
    PKITool.getInstance()
  }

  def run()
  {
    def alertMessage = ""
    boolean normalExecution = true
    def alertingNodes = []
    try
    {
      queryDefinitions().each {
        def response = queryShrine(it)

        def report = generateReport(response)
        if(isFailure(report))
        {
          normalExecution = false
          alertMessage += generateAlertMessage(report)
          alertMessage += '\n\n'
          alertingNodes.addAll report.alertingNodes
        }
      }
    }
    catch (Exception e)
    {
      normalExecution = false
      alertMessage += e.getMessage()
    }


    try
    {
      updateDbNodeFailureCount alertingNodes
      if(!normalExecution)
      {
        handleAlert alertMessage, alertingNodes
      }
    }
    catch (Exception e)
    {
      logger.error("Failed to complete", e)
    }

  }

  boolean shouldAlert()
  {
    thresholdExceeded() && !emailSentToday()
  }

  boolean emailSentToday()
  {
    def row = sql.firstRow('select LAST_NOTIFICATION from EMAIL_NOTIFICATION')
    row.last_notification

    def now = new Date()
    now.clearTime()
    def lastEmailDate = (row.last_notification) ? row.last_notification : (now - 1).format("MM/dd/yyyy")

    !now.after(lastEmailDate)
  }

  def updateLastNotificationTime()
  {
    sql.executeUpdate "update EMAIL_NOTIFICATION set LAST_NOTIFICATION=CURRENT_TIMESTAMP()"
  }

  boolean thresholdExceeded()
  {
    def row = sql.firstRow("select 1 from NODE where FAILURE_COUNT >= ${config.failureThreshold}")
    row
  }

  def updateDbNodeFailureCount(alertingNodes)
  {
    if(alertingNodes == null || alertingNodes.size < 1)
    {
      sql.executeUpdate "update NODE set FAILURE_COUNT = 0"
    }
    else
    {
      def nodeNameList = alertingNodes.collect { "'${it}'" }.join(',')
      sql.withTransaction {
        sql.executeUpdate "update NODE set FAILURE_COUNT = FAILURE_COUNT + 1 where NAME in (${nodeNameList})".toString()
        sql.executeUpdate "update NODE set FAILURE_COUNT = 0 where NAME not in (${nodeNameList})".toString()
      }
    }
  }

  boolean isFailure(report)
  {
    !report.queryCompleted || (report.alertingNodes.size() > 0) || (report.nodeResults.size() != expectedNumberOfNodes())
  }

  int expectedNumberOfNodes()
  {
    def row = sql.firstRow("select count(*) as NODE_NUMBER from NODE")
    row.node_number
  }

  def generateReport(response)
  {
    def analysis = [:]
    analysis.queryCompleted = isQueryCompleted(response)
    analysis.numberOfResults = determineNumberOfResults(response)
    analysis.nodeResults = []
    response?.message_body?.response?.query_result_instance?.each {
      if("aggregated" != it.description.toString())
      {
        analysis.nodeResults << [name: it.description.toString(), status: it.query_status_type?.name.toString(), description: it.query_status_type?.description.toString()]
      }
    }

    analysis.nonReportingNodes = identifyNonReportingNodes(response)
    analysis.errorNodes = identifyErrorNodes(response)
    analysis.alertingNodes = []
    analysis.alertingNodes.addAll(analysis.errorNodes)
    analysis.alertingNodes.addAll(analysis.nonReportingNodes)

    analysis
  }

  List identifyErrorNodes(response)
  {
    def errorNodes = response?.message_body?.response?.query_result_instance.findAll {
      it.description != "aggregated" && it.query_status_type?.name != "FINISHED"
    }.collect {
      it.description.toString()
    }

    errorNodes
  }

  int determineNumberOfResults(response)
  {
    response?.message_body?.response?.query_result_instance?.findAll {
      it.description != "aggregated"
    }.size()
  }

  String generateAlertMessage(report)
  {
    """\
    |${report.numberOfResults} of ${expectedNumberOfNodes()} nodes reporting:
    ${ out ->
      report.nodeResults.each { result ->
        out << "|${result.name} reported status ${result.status}"
        if("ERROR" == result.status && result.description != null)
        {
          out << ": ${result.description}\n"
        }
        else
        {
          out << "\n"
        }
      }
    }
    |${expectedNumberOfNodes() - report.numberOfResults} of ${expectedNumberOfNodes()} nodes did not report:
    ${ out ->
      report.nonReportingNodes.each { result ->
        out << "|${result}"
        out << "\n"
      }
    }""".stripMargin().toString()
  }

  def parseResponse(String originalResponse)
  {
    GPathResult responseGPath = new XmlSlurper().parseText(originalResponse)
    obscurePassword(responseGPath)
    responseGPath
  }

  def obscurePassword(GPathResult responseGPath)
  {
    responseGPath.message_header.security.password = '*****'
  }

  def handleAlert(alertMessage, alertingNodes)
  {
    logger.warn alertMessage
    auditAlert alertMessage, alertingNodes

    if(shouldAlert())
    {
      String subject = "SHRINE system alert - ${alertingNodes?.isEmpty() ? "General" : alertingNodes.join(',')}"
      def ant = new AntBuilder()
      ant.mail(mailhost: config.mail.host, mailport: config.mail.port,
              subject: subject) {
        from(address: config.mail.from)
        to(address: config.mail.recipient)
        message("Network in non-functional state on ${new Date()}:\r\n${alertMessage}")
      }

      updateLastNotificationTime()
    }
  }

  def auditAlert(alertMessage, alertingNodes)
  {
    sql.withTransaction {
      def keys = sql.executeInsert("insert into HEARTBEAT_FAILURE (MESSAGE) values(${alertMessage})")
      def failureId = keys[0][0]
      alertingNodes.each {
        def node = sql.firstRow("select ID from NODE where NAME=${it}")
        sql.executeInsert("insert into NODE_FAILURE (FAILURE_ID, NODE_ID) values(${failureId}, ${node.id})")
      }
    }
  }

  boolean isQueryCompleted(response)
  {
    def condition = response?.message_body?.response?.status?.condition?.text()
    if('DONE' != condition)
    {
      return false
    }
    def queryInstanceStatus = response?.message_body?.response?.query_instance?.query_status_type?.name?.text()

    'COMPLETED' == queryInstanceStatus
  }

  def queryShrine(queryDefinition)
  {
    queryDefinition.setQueryName("Heartbeat " + System.currentTimeMillis());
    PasswordType pwd = new PasswordType()
    pwd.value = config.i2b2.password
    SecurityType security = new SecurityType(config.i2b2.domain, config.i2b2.username, pwd)
    RequestMessageType request = QueryDefBuilder.getRequestPSM(queryDefinition, security, 'SHRINE')
    request.getMessageHeader().setProjectId config.i2b2.projectId;
    request.getRequestHeader().setResultWaittimeMs 240000

    CRCSerializer.setQueryTopicId(request, config.sheriff.queryTopicId)

    def responseObject = CRCHttpClient.sendRequestToCRC(config.nodeUrl, request)

    String originalResponse = CRCSerializer.toXMLString(responseObject)
    parseResponse(originalResponse)
  }

  Iterator queryDefinitions()
  {
    new FlatFileIterator(new File(config.queryDefinitions.file))
  }

  List identifyNonReportingNodes(response)
  {
    List reportingNodes = response?.message_body?.response?.query_result_instance?.description.findAll {
      it != "aggregated"
    }.collect {"'${it}'"}

    if(reportingNodes.size > 0)
    {
      return sql.rows("select NAME from NODE where NAME not in (${reportingNodes.join(",")})".toString()).collect {
        it.NAME
      }
    }
    else
    {
      return sql.rows("select NAME from NODE").collect {
        it.NAME
      }
    }
  }

  public static void main(String[] args)
  {
    def config = new ConfigSlurper().parse(new File(args[0]).toURL())
    def sql = Sql.newInstance(config.dbUrl, config.dbUser, config.dbPasswd ?: "", config.dbDriver)
    def heartbeat = new Heartbeat(config, sql)
    heartbeat.run()
    println 'Heartbeat finished'
  }
}