/*
 * DomSearch.c -- select closest homologs of a sequence from a
 * multiple sequence alignment.
 *
 * Morgan N. Price, 2007-2008
 *
 *  Copyright (C) 2007-2008 The Regents of the University of California
 *  All rights reserved.
 * 
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License along
 *  with this program; if not, write to the Free Software Foundation, Inc.,
 *  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 *  Disclaimer
 *
 *  NEITHER THE UNITED STATES NOR THE UNITED STATES DEPARTMENT OF ENERGY,
 *  NOR ANY OF THEIR EMPLOYEES, MAKES ANY WARRANTY, EXPRESS OR IMPLIED,
 *  OR ASSUMES ANY LEGAL LIABILITY OR RESPONSIBILITY FOR THE ACCURACY,
 *  COMPLETENESS, OR USEFULNESS OF ANY INFORMATION, APPARATUS, PRODUCT,
 *  OR PROCESS DISCLOSED, OR REPRESENTS THAT ITS USE WOULD NOT INFRINGE
 *  PRIVATELY OWNED RIGHTS.
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#include <stdbool.h>
#include "Hash.h"

static const char *usage =  "Usage: DomSearch locusId AlignedFile Threshold NLimit\n"
"   Given an alignment and a locusId, DomSearch finds the NLimit closest homologs"
"   of the locusId that have a bit score of Threshold or greater"
"   (Use nLimit 0 to have no limit, and a locusId of '-all' do this for each locus)\n"
"   AlignedFile should be tab-delimited with 3 or more fields,\n"
"   domainId (ignored), locusId, alignment (upper-case amino acids with gaps)\n"
"   DomSearch uses the BLOSUM62 matrix with affine gap penalties (11 existence 1 extend)\n"
"   and reports scores (both bit scores and raw scores) based on the best local region\n"
"   Input sequences are padded with gaps at the end if the lengths vary\n";

/* Maximum line length */
#define BUFSIZE 1000*1000
static char buf[BUFSIZE];

// Each input line has a domainId (ignored), a locusId, and an aligned subsequence
// The indices below are for each entry
typedef struct {
  char *gene;
  unsigned char *domseq;        /* encoded via legalChars, as indices into matrix */
  int begin;			/* index of first non-gap in domseq */
  int end;			/* index of last non-gap */
} domEntry_t;

int CompareScores(const void *, const void *);
void AnalyzeAnchor(domEntry_t *domEntries, int nDom,
		   char *geneId, int showAnchorFlag);

static char *legalChars = "ARNDCQEGHILKMFPSTWYVBZX*-"; /* replace other chars with X */
#define unknownChar 22
#define gapChar 24

/* From BLAST's BLOSUM62 -- everything is in "half-bit" (raw score) units */
int matrix[24][24] = {
  {4, -1, -2, -2, 0, -1, -1, 0, -2, -1, -1, -1, -1, -2, -1, 1, 0, -3, -2, 0, -2, -1, 0, -4},
  {-1, 5, 0, -2, -3, 1, 0, -2, 0, -3, -2, 2, -1, -3, -2, -1, -1, -3, -2, -3, -1, 0, -1, -4},
  {-2, 0, 6, 1, -3, 0, 0, 0, 1, -3, -3, 0, -2, -3, -2, 1, 0, -4, -2, -3, 3, 0, -1, -4},
  {-2, -2, 1, 6, -3, 0, 2, -1, -1, -3, -4, -1, -3, -3, -1, 0, -1, -4, -3, -3, 4, 1, -1, -4},
  {0, -3, -3, -3, 9, -3, -4, -3, -3, -1, -1, -3, -1, -2, -3, -1, -1, -2, -2, -1, -3, -3, -2, -4},
  {-1, 1, 0, 0, -3, 5, 2, -2, 0, -3, -2, 1, 0, -3, -1, 0, -1, -2, -1, -2, 0, 3, -1, -4},
  {-1, 0, 0, 2, -4, 2, 5, -2, 0, -3, -3, 1, -2, -3, -1, 0, -1, -3, -2, -2, 1, 4, -1, -4},
  {0, -2, 0, -1, -3, -2, -2, 6, -2, -4, -4, -2, -3, -3, -2, 0, -2, -2, -3, -3, -1, -2, -1, -4},
  {-2, 0, 1, -1, -3, 0, 0, -2, 8, -3, -3, -1, -2, -1, -2, -1, -2, -2, 2, -3, 0, 0, -1, -4},
  {-1, -3, -3, -3, -1, -3, -3, -4, -3, 4, 2, -3, 1, 0, -3, -2, -1, -3, -1, 3, -3, -3, -1, -4},
  {-1, -2, -3, -4, -1, -2, -3, -4, -3, 2, 4, -2, 2, 0, -3, -2, -1, -2, -1, 1, -4, -3, -1, -4},
  {-1, 2, 0, -1, -3, 1, 1, -2, -1, -3, -2, 5, -1, -3, -1, 0, -1, -3, -2, -2, 0, 1, -1, -4},
  {-1, -1, -2, -3, -1, 0, -2, -3, -2, 1, 2, -1, 5, 0, -2, -1, -1, -1, -1, 1, -3, -1, -1, -4},
  {-2, -3, -3, -3, -2, -3, -3, -3, -1, 0, 0, -3, 0, 6, -4, -2, -2, 1, 3, -1, -3, -3, -1, -4},
  {-1, -2, -2, -1, -3, -1, -1, -2, -2, -3, -3, -1, -2, -4, 7, -1, -1, -4, -3, -2, -2, -1, -2, -4},
  {1, -1, 1, 0, -1, 0, 0, 0, -1, -2, -2, 0, -1, -2, -1, 4, 1, -3, -2, -2, 0, 0, 0, -4},
  {0, -1, 0, -1, -1, -1, -1, -2, -2, -1, -1, -1, -1, -2, -1, 1, 5, -2, -2, 0, -1, -1, 0, -4},
  {-3, -3, -4, -4, -2, -2, -3, -2, -2, -3, -2, -3, -1, 1, -4, -3, -2, 11, 2, -3, -4, -3, -2, -4},
  {-2, -2, -2, -3, -2, -1, -2, -3, 2, -1, -1, -2, -1, 3, -3, -2, -2, 2, 7, -1, -3, -2, -1, -4},
  {0, -3, -3, -3, -1, -2, -2, -3, -3, 3, 1, -2, 1, -1, -2, -2, 0, -3, -1, 4, -3, -2, -1, -4},
  {-2, -1, 3, 4, -3, 0, 1, -1, 0, -3, -4, 0, -3, -3, -2, 0, -1, -4, -3, -3, 4, 1, -1, -4},
  {-1, 0, 0, 1, -3, 3, 4, -2, 0, -3, -3, 1, -1, -3, -1, 0, -1, -3, -2, -2, 1, 4, -1, -4},
  {0, -1, -1, -1, -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, 0, 0, -2, -1, -1, -1, -1, -1, -4},
  {-4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, -4, 1}
};

/* Because we have trimmed alignments these won't behave 
   quite right, but I think it is better than nothing.
   Also note that we ignore gaps at ends but we force
   non-gap ends to match (as they match the domain, they should
   align to each other). This may not be optimal.
*/
static int gapExtend = -1; /* raw score units */
static int gapOpen = -12;  /* full cost of 1-long gap */

/* raw to bit scores (Karlin/Altschul parameters) */
static double lambda = 0.267;
/*static double K = 0.0410;*/
static double logK = -3.194183;
static double log2 = 0.6931472;

double thresholdBits = 0;
int thresholdRaw = 0;
int nLimit; /* maximum number of hits to report */

int main(int argc, char *argv[]) {
  if (argc != 5) {
    fprintf(stderr, usage);
    exit(1);
  }

  char *anchorId = argv[1];
  if (strcmp(argv[1],"-all") == 0) {
    anchorId = NULL;
  }

  char *domSeqFile = argv[2];
  FILE *fpDom = fopen(domSeqFile,"rb");
  if (fpDom==NULL) {
    fprintf(stderr,"Cannot read from %s\n",domSeqFile);
    exit(1);
  }

  thresholdBits = atof(argv[3]);
  /* bits = raw * lambda/log(2) - log(K)/log(2)
   * raw = ( bits + log(K)/log(2) ) / ( lambda/log(2) )
   */
  thresholdRaw = (int) ( (thresholdBits + logK/log2)/(lambda/log2) );
  nLimit = atoi(argv[4]);

  int nDom = 0; /* position in the loci/seqs/score list */
  int i; /* general counter */

  char encoding[256];
  for (i = 0; i < 256; i++) { encoding[i] = unknownChar; }
  for (i = 0; legalChars[i] != '\0'; i++) {
    encoding[(unsigned int)legalChars[i]] = i;
  }

  int maxEntries = 50000;	/* increase as needed */
  domEntry_t *domEntries = (domEntry_t*)mymalloc(sizeof(domEntry_t)*maxEntries);

  while(fgets(buf, BUFSIZE, fpDom)) {
    char *domain = strtok(buf,"\t\r\n");
    char *locusStr = strtok(NULL,"\t\r\n");
    char *domseq = strtok(NULL,"\t\r\n");
    if (domseq == NULL || domseq[0] == '\0') {
      fprintf(stderr, "Error parsing input for %s %s -- no alignment column\n", domain, locusStr);
      exit(1);
    }

    /* encode sequence */
    int domLength = strlen(domseq);
    unsigned char *domEncoded = (unsigned char *)mymalloc(domLength);

    for (i = 0; i < domLength; i++) {
      if ((domseq[i] >= '0' && domseq[i] <= '9') || domseq[i] == ' ') {
	fprintf(stderr, "Illegal character %c in alignment %s\n",
		domseq[i], domseq);
	exit(1);
      }
      domEncoded[i] = encoding[(unsigned int) domseq[i]];
    }

    /* store begin & end */
    int begin = 0;
    while (domEncoded[begin] == gapChar && begin < domLength-1) {
      begin++;
    }

    int end = domLength - 1;
    while (domEncoded[end] == gapChar && end > 0) {
      end--;
    }

    /* store entry in domEntries */
    if (nDom>= maxEntries) {
      maxEntries *= 2;
      domEntries = (domEntry_t*)realloc(domEntries, sizeof(domEntry_t)*maxEntries);
      assert(domEntries != NULL);
    }
    domEntry_t *domEntry = &domEntries[nDom++];
    domEntry->gene = strdup(locusStr);
    assert(domEntry->gene);
    domEntry->domseq = domEncoded;
    domEntry->begin = begin;
    domEntry->end = end;
  }
  if (ferror(fpDom)) {
    fprintf(stderr,"Error reading from %s\n",domSeqFile);
    exit(1);
  }
  fclose(fpDom);
  fpDom = NULL;

  if (anchorId != NULL) {
    AnalyzeAnchor(domEntries, nDom, anchorId, /*showAnchorFlag*/0);
  } else {
    int iDom;
    for (iDom = 0; iDom < nDom; iDom++) {
      int used = 0;
      int iDom2;
      for (iDom2 = 0; iDom2 < iDom; iDom2++) {
	if (strcmp(domEntries[iDom2].gene,domEntries[iDom].gene) == 0) {
	  used = 1;
	  break;
	}
      }
      if (!used) {
	AnalyzeAnchor(domEntries, nDom, domEntries[iDom].gene, /*showAnchorFlag*/1);
      }
    }
  }

  /* Free space */
  for (i=0;i<nDom;i++) {
    free(domEntries[i].gene);
    free(domEntries[i].domseq);
  }
  free(domEntries);

  return(0);
}

int *scores = NULL; /* public so that CompareScores can see it */
void AnalyzeAnchor(domEntry_t *domEntries, int nDom, char *anchorId, int showAnchorFlag) {
  static const int NOSCORE = -100000;
  int iDom1;
  scores = (int *)mymalloc(sizeof(int)*nDom);

  for (iDom1 = 0; iDom1 < nDom; iDom1++) { scores[iDom1] = NOSCORE; }
  for (iDom1 = 0; iDom1 < nDom; iDom1++) {
    if (strcmp(domEntries[iDom1].gene, anchorId) == 0) {
      unsigned char *domseq1 = domEntries[iDom1].domseq;
      int begin1 = domEntries[iDom1].begin;
      int end1 = domEntries[iDom1].end;

      int iDom2;
      for (iDom2 = 0; iDom2 < nDom; iDom2++) {
	unsigned char *domseq2 = domEntries[iDom2].domseq;
	int begin2 = domEntries[iDom2].begin;
	int end2 = domEntries[iDom2].end;

	int aBegin = begin1 < begin2 ? begin2 : begin1;
	int aEnd = end2 < end1 ? end2 : end1;
	int extending = 0; /* 1 if last position was gapped */
	int maxScoreSoFar = 0;

	int alignScore = 0;
	int i;
	for (i = aBegin; i <= aEnd; i++) {
	  if (domseq1[i] != gapChar && domseq2[i] != gapChar) {
	    alignScore += matrix[domseq1[i]][domseq2[i]];
	    extending = 0;
	  } else if (domseq1[i] != gapChar  || domseq2[i] != gapChar) {
	    if (extending) {
	      alignScore += gapExtend;
	    } else {
	      extending = 1;
	      alignScore += gapOpen;
	    }
	  }
	  if(alignScore < 0) { alignScore = 0; }
	  if(alignScore > maxScoreSoFar) { maxScoreSoFar = alignScore; }
	}
	alignScore = maxScoreSoFar;

	if (alignScore >= thresholdRaw && alignScore > scores[iDom2]) {
	  scores[iDom2] = alignScore;
	}
      } /* end loop over iDom2 */
    } /* end if domain in anchor locus */
  } /* end loop over iDom1 */

  /* Make a new list of scores and sort */
  int nWithScore = 0;
  int i;
  for (i=0; i < nDom; i++) {
    if (scores[i] > NOSCORE) {
      nWithScore++;
    }
  }
  if (nWithScore > 0) {
    int *NewLoci = (int *)mymalloc(sizeof(int) * nWithScore);
    nWithScore = 0;
    for (i=0; i < nDom; i++) {
      if (scores[i] > NOSCORE) {
	NewLoci[nWithScore++] = i;
      }
    }
    qsort(NewLoci, nWithScore, sizeof(int), CompareScores);

    char **hitnames = (char **)mymalloc(nWithScore*sizeof(char*));
    for (i = 0; i < nWithScore; i++) {
      hitnames[i] = domEntries[NewLoci[i]].gene;
    }
    hashstrings_t *hithash = MakeHashtable(hitnames, nWithScore);
    bool *hitYet = (bool*)mymalloc(hithash->nBuckets * sizeof(int));
    for (i =0; i < hithash->nBuckets; i++) {
      hitYet[i] = false;
    }

    int nWritten = 0;
    for (i=0; i < nWithScore && (nLimit <= 0 || nWritten < nLimit); i++) {
      double bits = (scores[NewLoci[i]] * lambda - logK)/log2;
      if (bits >= thresholdBits) {
	char *gene = domEntries[NewLoci[i]].gene;

	/* Is this duplicate with a previous entry and hence ignored b/c it is a worse
	   hit to the same gene? */
	hashiterator_t hi = FindMatch(hithash,gene);
	assert(GetHashString(hithash,hi) != NULL);
	if (!hitYet[hi]) {
	  hitYet[hi] = true;
	  nWritten++;
	  if (showAnchorFlag) {
	    printf("%s\t%s\t%.1f\t%d\n", anchorId, domEntries[NewLoci[i]].gene,
		   bits, scores[NewLoci[i]]);
	  } else {
	    printf("%s\t%.1f\t%d\n", domEntries[NewLoci[i]].gene, bits, scores[NewLoci[i]]);
	  }
	}
      }
    }
    DeleteHashtable(hithash);
    free(hitnames);
    free(hitYet);
    free(NewLoci);
  }
  free(scores);
  scores = NULL;
}

/* reverse sort, so return 1 on lower scores */
int CompareScores(const void *a, const void *b) {
  int i1 = *(int *)a;
  int i2 = *(int *)b;
  if (scores[i1] < scores[i2]) { return 1; }
  else if (scores[i1] > scores[i2]) { return -1; }
  else return(0);
}
