
# might be needed for chunkScore
from nltk.token import Token, FrozenToken, CharSpanLocation, SubtokenContextPointer 
from sets import Set
###

from nltk.tokenreader.tagged import ChunkedTaggedTokenReader
from nltk.tree import Tree
from nltk.chktype import chktype
#from nltk.parser.chunk import *
import string
import sys

#RULES_FILE = sys.argv[1]
#NUM_FROM = int(sys.argv[2])
#NUM_TO = int(sys.argv[3])

def removePunct(thetree):
   def isPunct(node):
      if isinstance(node,Tree): return False
      else: return node['TEXT'].startswith('yy')
    
   if isinstance(thetree,Tree):
      tree = Tree(thetree.node,[])
      for sub in thetree:
         if not isPunct(sub):
            tree.append(removePunct(sub))
      #else: print "removed",sub
   else:
      return thetree
   return tree

class ChunkScore: # improved efficiancy of nltk's ChunkScore
                  # also, scoring in this version of scorer ignores any punctuations (any tag starting with yy)
    """
    A utility class for scoring chunk parsers.  C{ChunkScore} can
    evaluate a chunk parser's output, based on a number of statistics
    (precision, recall, f-measure, misssed chunks, incorrect chunks).
    It can also combine the scores from the parsing of multiple texts;
    this makes it signifigantly easier to evaluate a chunk parser that
    operates one sentence at a time.

    Texts are evaluated with the C{score} method.  The results of
    evaluation can be accessed via a number of accessor methods, such
    as C{precision} and C{f_measure}.  A typical use of the
    C{ChunkScore} class is::

        >>> chunkscore = ChunkScore()
        >>> for correct in correct_sentences:
        ...     guess = chunkparser.parse(correct.leaves())
        ...     chunkscore.score(correct, guess)
        >>> print 'F Measure:', chunkscore.f_measure()
        F Measure: 0.823

    @ivar kwargs: Keyword arguments:

        - max_tp_examples: The maximum number actual examples of true
          positives to record.  This affects the C{correct} member
          function: C{correct} will not return more than this number
          of true positive examples.  This does *not* affect any of
          the numerical metrics (precision, recall, or f-measure)

        - max_fp_examples: The maximum number actual examples of false
          positives to record.  This affects the C{incorrect} member
          function and the C{guessed} member function: C{incorrect}
          will not return more than this number of examples, and
          C{guessed} will not return more than this number of true
          positive examples.  This does *not* affect any of the
          numerical metrics (precision, recall, or f-measure)
        
        - max_fn_examples: The maximum number actual examples of false
          negatives to record.  This affects the C{missed} member
          function and the C{correct} member function: C{missed}
          will not return more than this number of examples, and
          C{correct} will not return more than this number of true
          negative examples.  This does *not* affect any of the
          numerical metrics (precision, recall, or f-measure)
        
    @type _tp: C{list} of C{Token}
    @ivar _tp: List of true positives
    @type _fp: C{list} of C{Token}
    @ivar _fp: List of false positives
    @type _fn: C{list} of C{Token}
    @ivar _fn: List of false negatives
    
    @type _tp_num: C{int}
    @ivar _tp_num: Number of true positives
    @type _fp_num: C{int}
    @ivar _fp_num: Number of false positives
    @type _fn_num: C{int}
    @ivar _fn_num: Number of false negatives.
    """
    def __init__(self, **kwargs):
        self._correct = Set()
        self._guessed = Set()
        self._tp = Set()
        self._fp = Set()
        self._fn = Set()
        self._max_tp = kwargs.get('max_tp_examples', 100)
        self._max_fp = kwargs.get('max_fp_examples', 100)
        self._max_fn = kwargs.get('max_fn_examples', 100)
        self._tp_num = 0
        self._fp_num = 0
        self._fn_num = 0

    def _childtuple(self, t):
        return tuple(t.freeze(FrozenToken))



    def score(self, correct, guessed):
        """
        Given a correctly chunked text, score another chunked text.
        Merge the results with all previous scorings.  Note that when
        the score() function is used repeatedly, each token I{must}
        have a unique location.  For sentence-at-a-time chunking, it
        is recommended that you use locations like C{@12w@3s} (the
        word at index 12 of the sentence at index 3).
        
        @type correct: chunk structure
        @param correct: The known-correct ("gold standard") chunked
            sentence.
        @type guessed: chunk structure
        @param guessed: The chunked sentence to be scored.
        """
        assert chktype(1, correct, Tree)
        assert chktype(2, guessed, Tree)
        import sys
        
        guessed = removePunct(guessed)
        correct = removePunct(correct)

        self._correct |= Set([self._childtuple(t) for t in correct
                               if isinstance(t, Tree)])

        self._guessed |= Set([self._childtuple(t) for t in guessed
                               if isinstance(t, Tree)])


    def precision(self):
        """
        @return: the overall precision for all texts that have been
            scored by this C{ChunkScore}.
        @rtype: C{float}
        """
        self._tp = self._guessed & self._correct
        self._fn = self._correct - self._guessed
        self._fp = self._guessed - self._correct
        self._tp_num = len(self._tp)
        self._fp_num = len(self._fp)
        self._fn_num = len(self._fn)
        
        div = self._tp_num + self._fp_num
        if div == 0: return 0
        else: return float(self._tp_num) / div
    
    def recall(self):
        """
        @return: the overall recall for all texts that have been
            scored by this C{ChunkScore}.
        @rtype: C{float}
        """
        self._tp = self._guessed & self._correct
        self._fn = self._correct - self._guessed
        self._fp = self._guessed - self._correct
        self._tp_num = len(self._tp)
        self._fp_num = len(self._fp)
        self._fn_num = len(self._fn)
        div = self._tp_num + self._fn_num
        if div == 0: return 0
        else: return float(self._tp_num) / div
    
    def f_measure(self, alpha=0.5):
        """
        @return: the overall F measure for all texts that have been
            scored by this C{ChunkScore}.
        @rtype: C{float}
        
        @param alpha: the relative weighting of precision and recall.
            Larger alpha biases the score towards the precision value,
            while smaller alpha biases the score towards the recall
            value.  C{alpha} should have a value in the range [0,1].
        @type alpha: C{float}
        """
        self._tp = self._guessed & self._correct
        self._fn = self._correct - self._guessed
        self._fp = self._guessed - self._correct
        self._tp_num = len(self._tp)
        self._fp_num = len(self._fp)
        self._fn_num = len(self._fn)
        p = self.precision()
        r = self.recall()
        if p == 0 or r == 0:    # what if alpha is 0 or 1?
            return 0
        return 1/(alpha/p + (1-alpha)/r)
    
    def missed(self):
        """
        @rtype: C{Set} of C{Token}
        @return: the set of chunks which were included in the
            correct chunk structures, but not in the guessed chunk
            structures.  Each chunk is encoded as a single token,
            spanning the chunk.  This encoding makes it easier to
            examine the missed chunks.
        """
        self._tp = self._guessed & self._correct
        self._fn = self._correct - self._guessed
        self._fp = self._guessed - self._correct
        self._tp_num = len(self._tp)
        self._fp_num = len(self._fp)
        self._fn_num = len(self._fn)
        return list(self._fn)
    
    def incorrect(self):
        """
        @rtype: C{Set} of C{Token}
        @return: the set of chunks which were included in the
            guessed chunk structures, but not in the correct chunk
            structures.  Each chunk is encoded as a single token,
            spanning the chunk.  This encoding makes it easier to
            examine the incorrect chunks.
        """
        self._tp = self._guessed & self._correct
        self._fn = self._correct - self._guessed
        self._fp = self._guessed - self._correct
        self._tp_num = len(self._tp)
        self._fp_num = len(self._fp)
        self._fn_num = len(self._fn)
        return list(self._fp)
    
    def correct(self):
        """
        @rtype: C{Set} of C{Token}
        @return: the set of chunks which were included in the correct
            chunk structures.  Each chunk is encoded as a single token,
            spanning the chunk.  This encoding makes it easier to
            examine the correct chunks.
        """
        return list(self._correct)

    def guessed(self):
        """
        @rtype: C{Set} of C{Token}
        @return: the set of chunks which were included in the guessed
            chunk structures.  Each chunk is encoded as a single token,
            spanning the chunk.  This encoding makes it easier to
            examine the guessed chunks.
        """
        return list(self._guessed)

    def __len__(self):
        self._tp = self._guessed & self._correct
        self._fn = self._correct - self._guessed
        self._fp = self._guessed - self._correct
        self._tp_num = len(self._tp)
        self._fp_num = len(self._fp)
        self._fn_num = len(self._fn)
        return self._tp_num + self._fn_num
    
    def __repr__(self):
        """
        @rtype: C{String}
        @return: a concise representation of this C{ChunkScoring}.
        """
        return '<ChunkScoring of '+`len(self)`+' chunks>'

    def __str__(self):
        """
        @rtype: C{String}
        @return: a verbose representation of this C{ChunkScoring}.
            This representation includes the precision, recall, and
            f-measure scores.  For other information about the score,
            use the accessor methods (e.g., C{missed()} and
            C{incorrect()}). 
        """
        self._tp = self._guessed & self._correct
        self._fn = self._correct - self._guessed
        self._fp = self._guessed - self._correct
        self._tp_num = len(self._tp)
        self._fp_num = len(self._fp)
        self._fn_num = len(self._fn)
        return ("ChunkParser score:\n" +
                ("    Precision: %5.1f%%\n" % (self.precision()*100)) +
                ("    Recall:    %5.1f%%\n" % (self.recall()*100))+
                ("    F-Measure: %5.1f%%\n" % (self.f_measure()*100)))
        
    def _chunk_toks(self, text):
        """
        @return: The list of tokens contained in C{text}.
        """
        return [tok for tok in text if isinstance(tok, AbstractTree)]

class CorpusReader:
   def __init__(self, filename):
      file = open(filename)
      self._name = "MyCorpusReader"
      self.raw_data = file.readlines()
      file.close()
      self.reader = ChunkedTaggedTokenReader(chunk_node='NP', SUBTOKENS='WORDS')
      
   def items(self, group=None):
      return range(1,len(self.raw_data))
   
   def read(self, item, *reader_args, **reader_kwargs):
      source = '%s/%s' % (self._name, item)
      text = self.raw_data[item]
      return self.reader.read_token(text, add_locs=True, source=source,
                              *reader_args, **reader_kwargs)

class CorpusData:
   def __init__(self, corpusReader, start = 0, last=15):
      print "Loading data"
      if last == -1: last = len(corpusReader.items())
      if start > len(corpusReader.items()): start = len(corpusReader.items())
      #self.data = [corpusReader.read(item) for item in corpusReader.items()]
      self.data = [corpusReader.read(item) for item in range(start,last)]
      print "Done"
      self.__iter__ = self.data.__iter__
      self.__str__ = "Corpus Data (Sequence with "+str(len(self.data))+" items)"

   def allNPs(self):
      result = []
      for item in self.data:
         for node in item['TREE']:
            if isinstance(node, Tree):
               result.append(node.leaves())

      return result

   def allNPs_as_strings(self, INCLUDE_TEXT=False):
      import string
      return [string.join([item['TAG'] for item in np]) for np in self.allNPs()]
         
class NPsFormatter:
   def getNPs(self, item):
      result = []
      for node in item['TREE']:
         if isinstance(node, Tree):
            result.append(node.leaves())
      return result
   
   def allNPs(self, corpusData):
      result = []
      for item in corpusData:
         result.append(self.getNPs(item))
      return result
   
   def as_tags_string(self, np):
      import string
      return ""
      #return string.join(["<"+item['TAG']+">" for item in np])

   def as_tags_strings(self, listOfNPs):
      return [self.as_tags_string(np) for np in listOfNPs]

def count_each(list_of_strings):
   dict = {}
   for str in list_of_strings:
      if dict.has_key(str):
         dict[str] += 1
      else :
         dict[str] = 1
   return dict


class MistakesTracker:
   def __init__(self):
      self.mistakes = []
   def track(self, correct, guess):
      if (removePunct(correct) != removePunct(guess)):
         self.mistakes += [Token(CORRECT=correct, GUESS=guess)]
   def asList(self):
      return self.mistakes
         
class ChunkRulesTester:
   def __init__(self, corpus):
      self.corpus = corpus
      self.matcher = None
      self.mistakes = MistakesTracker()
      self.history = []
      self.score = None

   def setLearningData(self, trees):
      self.trees = trees
      return trees

   def setMatcher(self, matcher):
      self.matcher = matcher
      return matcher

   def getMatcher(self): return self.matcher
   def getScore(self): return self.score

   def evaluate(self, matcher=None,fixFunc=lambda(x):x):
      self.history = []
      if matcher == None: matcher = self.matcher
      chunkscore = ChunkScore()
      self.mistakes = MistakesTracker()
      chunkparser = MyChunker(matcher)
      for tree in self.corpus:
         sys.stdout.write("*")
         to_chunk = tree['TREE'].leaves()
#         sys.stdout.write("1")
         chunked = chunkparser.chunk(to_chunk)
         chunked = fixFunc(chunked)
#         sys.stdout.write("2")
         chunkscore.score(tree['TREE'], chunked)
#         sys.stdout.write("3")
         self.mistakes.track(tree['TREE'], chunked)
#         sys.stdout.write("4")
         self.history.append(Token(CORRECT = tree['TREE'], GUESS=chunked))
#         sys.stdout.write("5  ")
      self.score = chunkscore
      return self.score

class Model:
   def __init__(self, corpusFileName, dataFirst = 0, dataLast = -1):
      self.reader = CorpusReader(corpusFileName)
      self.corpusData = CorpusData(self.reader, dataFirst, dataLast)
      self.tester = ChunkRulesTester(self.corpusData)
      self.rules = []
      self.formatter = NPsFormatter()


class DecTreeLeaf:
   def __init__(self, tag, depth=0):
      self.tag = tag
      self.next = {}
      self.depth = depth
      self.rule = False
      
   def addNext(self, tag):
      self.next[tag] = DecTreeLeaf(tag, self.depth+1)
   def hasNext(self, tag):
      return self.next.has_key(tag)
   def getNext(self, tag):
      return self.next[tag]

   def isRule(self):
      return self.rule
   def setRule(self, bool):
      self.rule = bool
   
   def getDepth(self): return self.depth
           
         
class Matcher:
   def __init__(self):
      self.decisionTree = DecTreeLeaf('ROOOOT',0)
      self.decisionTree.setRule(True)
      self.rules = []

   def getRulesList(self):return self.rules

   def addStrRule(self, rule):
      self.rules.append(rule)
      rule = rule.replace("><","|")
      rule = rule.replace("<","")
      rule = rule.replace(">","")
      self.addRule(rule.split("|"))
      
   # RULE SHOULD BE A LIST OF TAG NAMES!!
   def addRule(self, rule):
#      print "Adding rule",rule
      current = self.decisionTree
      for tag in rule:
         if current.hasNext(tag):
            current = current.getNext(tag)
         else: 
            current.addNext(tag)
            current = current.getNext(tag)
      current.setRule(True)

   def match(self, tokens, frm = 0):
      current = self.decisionTree
      lastMatch = current.getDepth()
      for tok in tokens[frm:]:
         if current.hasNext(tok['TAG']):
#            print "Going to next token for ",tok['TAG']
            current = current.getNext(tok['TAG'])
            if current.isRule(): lastMatch = current.getDepth()
         else:
#            print "Stuck. return:",lastMatch
            return lastMatch

      return lastMatch

class MyChunker:
   def __init__(self, matcher):
      self.matcher = matcher
      
   def chunk(self, leaves):
      chunked = Tree('S',[])
      end = len(leaves)
      loc = 0
      while (loc < end):
         lenOfMatch = self.matcher.match(leaves, loc)
         if (lenOfMatch > 0):
#            print "Match:",loc, lenOfMatch
            chunked.append(Tree("NP",leaves[loc:loc+lenOfMatch]))
            loc += lenOfMatch
         else: 
            chunked.append(leaves[loc])
            loc += 1
      
#      print "Chunked",chunked
      return chunked

#reads rules into a set of text strings
def readChunkRulesFromFile(filename):
   lines = open(filename,'r').readlines()
   def processLine(line):
      [type, rule] = line.split(': ',1)
      if type.startswith('#'): return
      if (type == 'chunk'):
         return rule.strip()
      return 

   rules = Set()
   for line in lines:
      r = processLine(line)
      if r:
         rules.add(r)
   return rules

def makeMatcher(setOfRules):
   listOfRules = [r for r in setOfRules]
   matcher = Matcher()
   for r in listOfRules:
      matcher.addStrRule(r)
   return matcher

def getResultsAsTextLines(list_of_items):
   def printNicely(item):
      res = "";
      # item is a tuple of FrozenTokens
      for token in item:
         res += token.values()[0] + "/" + token.values()[1] + " "# 0 is the word, 1 is the POS
      return res #.decode('iso-8859-8')
   #for item in list_of_items: print item
   #unique = count_each(model.formatter.as_tags_strings(list_of_items))
   #textarea.insert("0.0",string.join([str(item) for item in unique.keys()],"\n"))
   return string.join([printNicely(item) for item in list_of_items],"\n")
         
def getScore(score):
   res = ""
   res += "Precision:\t" + str(score.precision()) + "\n"
   res += "Recall:\t" + str(score.recall()) + "\n"
   res += "Missed:\t" + str(len(score.missed())) + "\n"
   res += "Correct:\t" + str(len(score.correct())) + "\n"
   res += "Incorrect:\t" + str(len(score.incorrect())) + "\n"

   return res

def annotate(missed, guessed):
   for guess in guessed.split("\n"):
      missed = missed.replace(guess, " { "+guess+" } ")
   return missed

def writeResultsToFiles(tester, ext = ''):
   score = tester.score
   mistakes = tester.mistakes.asList()
   
   correct = getResultsAsTextLines(score.correct())
   incorrect = getResultsAsTextLines( score.incorrect())
   missed = getResultsAsTextLines( score.missed())
   guessed = getResultsAsTextLines( score.guessed())

   s = open("data/score.txt" + ext,"w")
   s.write(getScore(score))
   s.close()
   open("data/correct.txt" + ext,"w").write(correct)
   open("data/incorrect.txt" + ext,"w").write(incorrect)
   open("data/missed.txt" + ext,"w").write(missed)
   open("data/guessed.txt" + ext ,"w").write(guessed)
#   open("data/missed-annotated.txt" + ext,"w").write(annotate(missed, guessed))
   open("data/mistakes.txt" + ext,"w").write(mistakesAsStringFull(mistakes))
   open("data/mistakes-brief.txt" + ext,"w").write(mistakesAsStringBrief(mistakes))
   
def mistakesAsStringBrief(mistakesList):
   """ all the bad nps, and the correct sents """
   def makestr(tree):
      str = ""
      for x in tree:
         if isinstance(x, Tree):
            str += ' [' + makestr(x) + '] '
         else:
            #print "&",x
            str += x['TEXT'] + "/" + x['TAG'] + " "
      return str

   res = ""
   for mistake in mistakesList:
      correctTree = removePunct(mistake['CORRECT'])
      res += "+ " + makestr(correctTree) + "\n"
      guessedTree = removePunct(mistake['GUESS'])
      correctNPs = [tree for tree in correctTree if isinstance(tree,Tree)]
      badGuesses = [tree for tree in guessedTree if isinstance(tree,Tree) and tree not in correctNPs]
      res += "- "
      for wrong in badGuesses:
         res += " { " + makestr(wrong) + " } "
      res += "\n"
   return res
   
def mistakesAsStringFull(mistakesList):
   """ show both sents one after other """
   def makestr(tree):
      str = ""
      for x in tree:
         if isinstance(x, Tree):
            str += ' [' + makestr(x) + '] '
         else:
            #print "&",x
            str += x['TEXT'] + "/" + x['TAG'] + " "
      return str

   res = ""
   for mistake in mistakesList:
      #print "*",mistake
      correct_str = "+ " + makestr(mistake['CORRECT'])
      guessed_str = "- " + makestr(mistake['GUESS'])
      res += correct_str + "\n" + guessed_str + "\n"
   return res

def scoreRules(history, rulesSet):
   """ each rule is given a score a-la Cardice/Pierce article """
   ruleScores = {}
   def equals(np1,np2):
      l1 = [el for el in np1]
      l2 = [el for el in np2]
#      if l1==l2: print "EQUALS: ",l1,l2
      return l1 == l2

   def overlap(np1,np2):
      """ returns true if nps overlap """
      #l1 = Set([tuple(x.freeze()) for x in np1.leaves()]) # tuple(x.freeze(FrozenToken)) is for allowing Token into the set
      #l2 = Set([tuple(x.freeze()) for x in np2.leaves()])
      for tok in np1.leaves():
         for tok2 in np2.leaves():
            if tok == tok2:
#               print "OVERLAP-A:",np2
#               print "OVERLAP-B:",np1
#               print "OVERLAP-O:",tok,tok2
               return True
      return False

   def getRule(np):
      rule=""
      for tok in np:
         rule += "<"+tok['TAG']+">"
      return rule
   
   def addScore(np, num):
      rule = getRule(np)
#      if num != 0:
#         print "adding score",num,"to rule",rule
      if not ruleScores.has_key(rule):
         ruleScores[rule] = 0
      ruleScores[rule] += num
      #if num < 0: print "added score",num,"to",rule
         
   def nps(chunked):
      #returns a list all nps from the given chunk 
      nps = []
      for el in chunked:
         if isinstance(el,Tree):
            if el.node == 'NP': nps.append(el)
      return nps
   
   for event in history:
      correct = nps(event['CORRECT'])
      guessed = nps(event['GUESS'])
      for real in correct:
         first = True
         gcopy = guessed
         for guess in guessed:
            if equals(guess,real): 
               addScore(guess,1)
               gcopy.remove(guess)
            else:
               if overlap(real, guess):
                  if first: 
                     addScore(guess,-1)
                     gcopy.remove(guess)
                     first = False
                  else: 
                     gcopy.remove(guess)
         guessed=gcopy
      for guess in gcopy:
         addScore(guess,-1)
         
      for r in rulesSet:
         if not ruleScores.has_key(r):
            ruleScores[r] = 0
         
   return ruleScores
            
def makeIncrementalPruningFunc(HOW_MUCH):
   def pruneFunc(ruleScores):
      res = Set()
      rules = ruleScores.keys()
      rules.sort(lambda a,b:ruleScores[a] - ruleScores[b])
      for r in rules[:HOW_MUCH]:
         res.add(r)
         print "$",ruleScores[r],":",r
      return res
   
   return pruneFunc

def makeThresholdPruningFunc(MIN):
   def pruneFunc(ruleScores):
      res = Set()
      for r in ruleScores.keys():
         if ruleScores[r] < MIN:
            res.add(r)
#            print "$",ruleScores[r],":",r
      return res

   return pruneFunc
      
def extractRules(model):
   def tagsString(lst):
      str = ""
      for token in lst:
         str += '<' + token['TAG'] + '>'
      return str
   
   allNps = model.corpusData.allNPs()
   npsCounter = {}
   for np in allNps:
      str = tagsString(np)
      if not npsCounter.has_key(str):
         npsCounter[str] = 0
      npsCounter[str] += 1

   rules = ["chunk: " + r for r in npsCounter.keys() if r]
   rulesStr = string.join(rules, "\n")
   open("data/extractedRules.txt","w").write(rulesStr)
   
   return rules
      
   
def doGenRuleSet(model, RULES_FILE, pruneFunction, whenToStop=1):
   """
   prundeFunction: a function that takes ruleScores and returns a list of rules to remove
   whenToStop: 1 when precision drops
               2 when recall drops
               3 when rules are stable (no rule got pruned)
   """
      
   if whenToStop not in [1,2,3]: whenToStop = 1

   oldPrec = -1
   oldRecall = -1
   prec = 0
   recall = 0
   cont = True
   rulesSet = readChunkRulesFromFile(RULES_FILE)
   while (cont):
      oldPrec = prec
      oldRecall = recall
      oldRules = rulesSet.copy()
      model.tester.setMatcher(makeMatcher(rulesSet))
      print "Evaluating:"
      score = model.tester.evaluate()
      prec = score.precision()
      recall = score.recall()
      print "Prec:",prec
      
      # pruning of the rules
      
      print "scoring rules"
      # score each rule
      ruleScores = scoreRules(model.tester.history, rulesSet)

      print "pruning bad rules"
      rulesCnt = len(rulesSet)
      # remove bad rules
      rulesSet -= pruneFunction(ruleScores)
      print "New size of ruleset:",len(rulesSet)

      if   whenToStop == 1: cont = (prec > oldPrec)
      elif whenToStop == 2: cont = (recall > oldRecall)
      elif whenToStop == 3: cont = (rulesCnt - len(rulesSet)) > 0

      if (cont): 
         theRuleSet = rulesSet
         rules = string.join(["chunk: " + r for r in rulesSet],"\n")
         open("data/ruleset.txt.prune","w").write(rules)
         writeResultsToFiles(model.tester, ".prune")
   
   return theRuleSet

def doEval(model, RULES_FILE, afterFix = False):
   def afterFixFunc(chunked):
      "after-chunking fixes - get a chunked text and return a fixed version"
      def removeEich(thetree):
         if isinstance(thetree,Tree):
            tree = Tree(thetree.node,[])
            if len(thetree) == 1 and thetree[0]['TEXT'] == '\xea\xe9\xe0': 
               print "REMOVED THE EICH!"
               return thetree[0]
            else:
               for sub in thetree:
                  tree.append(removeEich(sub))
         else:
            return thetree
         return tree
   
      # remove every "Eich" which is the only thing in it's chunk
      return removeEich(chunked)
   
   rulesSet = readChunkRulesFromFile(RULES_FILE)
   
   model.tester.setMatcher(makeMatcher(rulesSet))
   print "Evaluating:"
   if (afterFix != False):
      score = model.tester.evaluate(None,afterFixFunc)
   else:
      score = model.tester.evaluate()
   writeResultsToFiles(model.tester, ".results")

   return score

#model = Model("corpus/oneperline.basenps.nps.noasterix.txt", NUM_FROM, NUM_TO)
#extractRules(model)
#doGenRuleSet(model, makeThresholdPruningFunc(1), 1)
#doGenRuleSet(model, makeIncrementalPruningFunc(10), 1)
#doEval(model, RULES_FILE)

if __name__ == '__main__':
   print "Don't run me. I'm not that kind of program."

