import sys import os.path from collections import defaultdict import time import gzip import tsdb from digest import * # Author: Bart Cramer # Usage: python train_gm.py # Make sure to update the parameters in the main() procedure, at the bottom of the file. def create_lexent_lextype_dict (grammar_dir, filenames) : # Given a list of lexicon .tdl files, this routine returns a dictionary, mapping lexical entries to lexical types. m_dict = {} for filename in filenames : f_lexicon = open (grammar_dir + filename, 'r') for line in f_lexicon : uline = unicode(line, encoding="utf8") split = uline.split(":=") if len(split) == 2 : lex_entry = split[0].strip() lex_type = split[1] lex_type_clean = lex_type[:lex_type.find("&")].strip() m_dict[lex_entry] = lex_type_clean f_lexicon.close() return m_dict class PCFG : def __init__ (self) : self.conditional_counts = defaultdict(int) # Maps (X,A) or (X,A,B) tuples to counts. self.prior_counts = defaultdict(int) # Maps X to counts. self.n_samples = 0 def add_datum (self, rule, children) : if len(children) == 1 : self.conditional_counts[(rule,children[0])] += 1 elif len(children) == 2 : self.conditional_counts[(rule,children[0],children[1])] += 1 self.prior_counts[rule] += 1 def write (self, filename) : out = open(filename, 'w') out.write (";;; \n") out.write (";;; #[GM, created by train_gm.py ] \n") out.write (";;; Created: " + time.asctime() + '\n') out.write (";;; \n\n") out.write (":begin :pcfg " + str(self.n_samples) + '.\n') out.write ("*pcfg-use-preterminal-types-p* := yes.\n\n") out.write ("*pcfg-include-leafs-p* := yes.\n\n") out.write ("*pcfg-laplace-smoothing-p* := nil.\n\n") out.write (":begin :rules " + str(len(self.conditional_counts)) + ".\n\n") all_lhs = sorted(self.prior_counts.keys()) i_rule = 0 for lhs in all_lhs : lhs_rules = [ rule for rule in self.conditional_counts if rule[0] == lhs] prior_count = self.prior_counts[lhs] for rule in lhs_rules : # (2145) [1 (0) lt-noun-nomod-noapp-cp "gefahr"] -1.0002e-4 {1 1} count = self.conditional_counts[rule] w = count / float(prior_count) s = "(%i) [1 (0) %s] %e {%i %i}\n" % (i_rule, " ".join(rule), w, prior_count, count) out.write (s.encode('utf8')) i_rule += 1 out.write (":end :rules. \n\n\n") out.write (":end :pcfg. \n\n") out.close() def update_pcfg(pcfg, node) : if isinstance(node, tsdb.NonTerminal) : pcfg.add_datum(node.name, [ child.name for child in node.children ]) x = [ update_pcfg(pcfg, child) for child in node.children ] def train_pcfg(profile_dirs, grammar_dir, lexent_lextype_dict) : # Read the profiles, and create the counts for the learned CFG rules. # Use m_dict to turn lex entries into the lex types. pcfg = PCFG() for profile_dir in profile_dirs : sys.stderr.write (profile_dir + ' ') if not profile_dir[-1] != "/" : profile_dir = profile_dir + "/" if os.path.exists (profile_dir + "result") : f_result = open(profile_dir + "result") elif os.path.exists (profile_dir + "result.gz") : f_result = gzip.open(profile_dir + "result.gz") else : raise IOError ("File not found: " + profile_dir + "result(.gz)") if os.path.exists (profile_dir + "preference") : f_pref = open(profile_dir + "preference") elif os.path.exists (profile_dir + "preference.gz") : f_pref = gzip.open(profile_dir + "preference.gz") else : raise IOError ("File not found: " + profile_dir + "preference(.gz)") # This assumes that the result and preference files are in the same order. preferences = set([]) for line_pref in f_pref : s = line_pref.split("@") i_id = int(s[0]) reading_id = int(s[2]) pcfg.n_samples += 1 preferences.add (str(i_id) + "@" + str(reading_id)) for line_result in f_result : s = line_result.split("@") p = s[0] + "@" + s[1] sys.stderr.write (p + ' ') if p in preferences : derivation = unicode(line_result.split("@")[10], encoding="utf8") tree, pos = tsdb.parse_derivation (derivation, 0) tree.transform(lexent_lextype_dict) update_pcfg(pcfg, tree) sys.stderr.write ('\n') return pcfg def main() : # Adapt these to your needs. Keep track of the trailing slashes :) grammar_dir = "/home/bart/logon/lingo/erg/" lexicon_files = ["lexicon.tdl"] treebank_base = grammar_dir + "tsdb/gold/" gm_file = grammar_dir + "jh.gm" # The list of the profile directories to train on. Adapt this, too. # Virtual profiles are not supported. profile_dirs = [] for p in ["jh0", "jh1", "jh2"] : profile = treebank_base + p + '/' profile_dirs.append (profile) # Make a dictionary that translates lexical entries to lexical types. m_dict = create_lexent_lextype_dict(grammar_dir, lexicon_files) sys.stderr.write ("Size of replacement dictionary: " + str(len(m_dict)) + '\n') pcfg = train_pcfg(profile_dirs, grammar_dir, m_dict) pcfg.write(gm_file) main()