package net.shrine.adapter.query;

import edu.harvard.i2b2.crc.datavo.i2b2message.RequestMessageType;
import edu.harvard.i2b2.crc.datavo.i2b2message.ResponseMessageType;
import edu.harvard.i2b2.crc.datavo.setfinder.query.StatusType;
import net.shrine.adapter.dao.AdapterDAO;
import net.shrine.dao.DAOException;
import net.shrine.filters.LogFilter;
import net.shrine.protocol.*;
import net.shrine.serializers.QueryResultIDList;
import net.shrine.serializers.ShrineHeader;
import net.shrine.serializers.ShrineJAXBUtils;
import net.shrine.serializers.ShrineMessage;
import net.shrine.serializers.ShrineMessageConstants;
import net.shrine.serializers.crc.CRCHttpClient;
import net.shrine.serializers.crc.CRCRequestType;
import net.shrine.serializers.crc.CRCSerializer;
import net.shrine.serializers.hive.HiveCommonSerializer;
import net.shrine.translators.Translator;
import net.shrine.translators.Translators;
import org.apache.log4j.Logger;
import org.apache.log4j.MDC;
import org.spin.node.QueryContext;
import org.spin.node.actions.AbstractQueryAction;
import org.spin.node.actions.QueryException;
import org.spin.query.message.serializer.SerializationException;
import org.spin.tools.NetworkTime;
import org.spin.tools.crypto.signature.Identity;
import org.springframework.transaction.annotation.Transactional;

import javax.xml.bind.JAXBException;
import java.util.List;

import static org.spin.tools.Util.guardNotNull;

/**
 * Default implementation of SPIN/SHRINE query plumbing.
 */
@Transactional
public class AdapterQuery extends AbstractQueryAction<ShrineMessage<RequestMessageType>> {
    private static final Logger log = Logger.getLogger(AdapterQuery.class);

    private static final boolean DEBUG = log.isDebugEnabled();

    final String crcEndpointURL;

    final Translator<RequestMessageType> requestTranslator;

    final Translator<ResponseMessageType> responseTranslator;

    final AdapterDAO dao;

    protected AdapterQuery(final String crcEndpointURL, final AdapterDAO dao) {
        this(crcEndpointURL, dao, Translators.<RequestMessageType>nullTranslator(), Translators.<ResponseMessageType>nullTranslator());
    }

    protected AdapterQuery(final String crcEndpointURL, final AdapterDAO dao, final Translator<RequestMessageType> requestTranslator, final Translator<ResponseMessageType> responseTranslator) {
        super();

        guardNotNull(crcEndpointURL);
        guardNotNull(requestTranslator);
        guardNotNull(responseTranslator);
        guardNotNull(dao);

        this.crcEndpointURL = crcEndpointURL;
        this.requestTranslator = requestTranslator;
        this.responseTranslator = responseTranslator;
        this.dao = dao;
    }

    /**
     * Do pre-query validation
     *
     * @param queryContext
     * @param request
     * @throws Exception
     */
    protected void beforeQuery(final QueryContext queryContext, final ShrineMessage<RequestMessageType> request) throws Exception {
        //NOOP by default
    }

    /**
     * Translate the query from network to local
     *
     * @return the translated request, suitable to be sent to the CRC
     * @throws Exception
     */
    protected RequestMessageType translate(final RequestMessageType request) throws Exception {
        requestTranslator.translate(request);

        //Network IDs => Local IDs
        mapNetworkIDsToLocalIDs(request);

        return request;
    }

    void mapNetworkIDsToLocalIDs(final RequestMessageType request) throws Exception {
        final String localQueryMasterID = getLocalMasterID(request);

        final String localQueryInstanceID = getLocalInstanceID(request);

        final String localResultID = getLocalResultIDs(request);

        LocalIDInserter.forRequest(request).addIDs(localQueryMasterID, localQueryInstanceID, localResultID);
    }

    private String getLocalInstanceID(final RequestMessageType request) throws DAOException, SerializationException {
        guardNotNull(request);

        final CRCRequestType requestType = CRCSerializer.getRequestType(request);

        final Long networkInstanceID;

        if(requestType == CRCRequestType.InstanceRequestType) {
            networkInstanceID = stringToLong(CRCSerializer.getInstanceRequest(request).getQueryInstanceId());

            if(networkInstanceID == null) {
                log.warn("Instance ID from request is null, using null for local instance ID");

                return null;
            }

            return dao.findLocalInstanceID(networkInstanceID);
        }

        return null;
    }

    protected static final Long stringToLong(final String s) {
        if(s == null) {
            return null;
        }
        return Long.valueOf(s);
    }

    private String getLocalMasterID(final RequestMessageType request) throws DAOException, SerializationException {
        guardNotNull(request);

        final CRCRequestType requestType = CRCSerializer.getRequestType(request);

        Long networkMasterID = null;

        switch(requestType) {
            case MasterRequestType:
                networkMasterID = stringToLong(CRCSerializer.getMasterRequest(request).getQueryMasterId());
                break;
            case MasterDeleteRequestType:
                networkMasterID = stringToLong(CRCSerializer.getMasterDeleteRequest(request).getQueryMasterId());
                break;
            case MasterRenameRequestType:
                networkMasterID = stringToLong(CRCSerializer.getMasterRenameRequest(request).getQueryMasterId());
            default:
                break;
        }

        if(networkMasterID == null) {
            log.warn("Master ID from request is null, using null for local master ID");
            return null;
        }

        return dao.findLocalMasterID(networkMasterID);
    }

    private final String getLocalResultIDs(final RequestMessageType request) throws DAOException, SerializationException {
        guardNotNull(request);

        final CRCRequestType requestType = CRCSerializer.getRequestType(request);

        if(requestType == CRCRequestType.ResultRequestType) {
            final Long networkResultID = stringToLong(CRCSerializer.getResultRequest(request).getQueryResultInstanceId());

            if(networkResultID == null) {
                log.warn("Result ID from request is null, using null for local result ID");

                return null;
            }

            return dao.findLocalResultID(networkResultID);
        }

        return null;
    }

    /**
     * Make the call to the local CRC
     *
     * @return the CRC response
     * @throws Exception
     */
    protected ResponseMessageType callCRC(final Identity identity, final RequestMessageType request) throws Exception {
        String requestString = CRCSerializer.toXMLString(request);
        log.debug(String.format("Calling CRC with request: %s", requestString));
        return CRCHttpClient.sendRequestToCRC(crcEndpointURL, request);
    }

    /**
     * Do any post-query cleanup, Log the returned response
     *
     * @param response
     * @throws Exception
     */
    protected ResponseMessageType afterQuery(final QueryContext queryContext, final ShrineMessage<RequestMessageType> shrineMessage, final ResponseMessageType response, final long duration) throws Exception {
        responseTranslator.translate(response);

        mapLocalIDsToNetworkIDs(shrineMessage, response);

        return response;
    }

    @Override
    @Transactional(rollbackFor = Exception.class)
    public final String perform(final QueryContext queryContext, final ShrineMessage<RequestMessageType> shrineMessage) throws QueryException {
        // read in and set the requestId for logging
        readAndSetRequestId(shrineMessage);

        final RequestMessageType request = shrineMessage.getPayload();

        try {
            beforeQuery(queryContext, shrineMessage);

            // create a copy of the request, to retain the original within shrineMessage
            RequestMessageType copy = ShrineJAXBUtils.copy(shrineMessage.getPayload());

            final RequestMessageType translated = translate(copy);

            final long start = new NetworkTime().toMilliSeconds();
            final ResponseMessageType crcResponse = callCRC(queryContext.getQueryInfo().getIdentity(), translated);
            final long duration = new NetworkTime().toMilliSeconds() - start;

            final ResponseMessageType shrineResponse = afterQuery(queryContext, shrineMessage, crcResponse, duration);

            return HiveCommonSerializer.toXMLString(shrineResponse);
        }
        catch(final Exception e) {
            try {
                log.error("Query Failed", e);

                final ResponseMessageType error = HiveCommonSerializer.getTemplateResponseMessageTypeError(request, e.getMessage());

                return HiveCommonSerializer.toXMLString(error);
            }
            catch(final Exception ee) {
                log.error("Failed to build error response", ee);

                return "Query Failed: " + e.getMessage();
            }
        }
    }

    private void readAndSetRequestId(ShrineMessage<RequestMessageType> msg) {
        Object rid = msg.getHeader().getHeader(ShrineMessageConstants.REQUEST_ID);
        if(rid != null) {
            MDC.put(LogFilter.GRID, rid.toString());
        }
    }

    void mapLocalIDsToNetworkIDs(final ShrineMessage<RequestMessageType> shrineMessage, final ResponseMessageType crcResponse) throws Exception {
        LocalToNetworkIDMapper.forResponse(shrineMessage, crcResponse, dao).map();
    }

    @Override
    public ShrineMessage<RequestMessageType> unmarshal(final String xml) throws SerializationException {
        try {

            BroadcastMessage message = (BroadcastMessage) BroadcastMessage$.MODULE$.fromXml(xml);

            RequestMessageType request = HiveCommonSerializer.getRequest(message.request().toI2b2().toString());
            ShrineHeader header = new ShrineHeader();
            header.putHeader(ShrineMessageConstants.REQUEST_ID, message.requestId());
            header.putHeader(ShrineMessageConstants.QUERY_MASTER_ID, message.masterId().isDefined() ? message.masterId().get() : null);
            header.putHeader(ShrineMessageConstants.QUERY_INSTANCE_ID, message.instanceId().isDefined() ? message.instanceId().get() : null);
            QueryResultIDList resultIds = new QueryResultIDList(message.getResultIdsJava());
            header.putHeader(ShrineMessageConstants.QUERY_RESULT_IDS, resultIds);

            ShrineMessage<RequestMessageType> shrineMessage = new ShrineMessage<RequestMessageType>(header, request);

            return shrineMessage;
        }
        catch(final JAXBException e) {
            throw new SerializationException(e);
        }
    }

    /**
     * TODO: factor out this code and the version from AggregationStrategy into a common utils class.
     * I duplicated this class beacuse I'm in a hurry, and this module doesn't have access to the
     * BroadcasterAggregator module where AggregationStrategy lives.  Sorry. :( -Clint
     */
    protected static StatusType makeDoneStatus() {
        final StatusType status = new StatusType();

        final StatusType.Condition done = new StatusType.Condition();

        done.setType("DONE");
        done.setValue("DONE");

        status.getCondition().add(done);

        return status;
    }
}
