package se.lth.cs.nlp.opinions; import se.lth.cs.nlp.nlputils.ml_long.*; import se.lth.cs.nlp.nlputils.annotations.*; import se.lth.cs.nlp.nlputils.core.*; import java.io.*; import java.util.*; import java.util.zip.GZIPInputStream; import gnu.trove.*; public class SeqLabeler { private static final int NBITS_FEATURE_DATA = 58; private static void initWordEncoder(SymbolEncoder enc, String trainingFile, int col) throws IOException { for(String line: Util.readLines(trainingFile)) { line = line.trim(); if(line.isEmpty()) continue; String[] ts = line.split("\t"); enc.encode(ts[col]); } } public static void train(String[] argv) throws IOException { System.out.println(Arrays.toString(argv)); String trainingFile = argv[1]; String templateFile = argv[2]; String modelFile = argv[3]; boolean secondOrder = Boolean.parseBoolean(argv[4]); int costType = Integer.parseInt(argv[5]); String methodName = argv[6]; String methodArgs = argv[7]; FeatureTemplateSet fe = new FeatureTemplateSet(Util.readFile(templateFile), NBITS_FEATURE_DATA); SymbolEncoder wordEncoder = new SymbolEncoder(); SymbolEncoder labelEncoder = new SymbolEncoder(); int outsideEnc = labelEncoder.encode(""); if(outsideEnc != 1) throw new RuntimeException("outsideEnc != 1"); if(wordEncoder.encode("") != 1) throw new RuntimeException("outsideEnc != 1"); /* EPE hack. */ initWordEncoder(wordEncoder, trainingFile, 3); initWordEncoder(wordEncoder, trainingFile, 1); ArrayList> trainingSet = readLabeledFile(trainingFile, wordEncoder, labelEncoder); int[] range = getLabelRange(trainingSet); System.out.println("Training set size: " + trainingSet.size()); wordEncoder.freeze(); labelEncoder.freeze(); SeqLabelingDefinition def = new SeqLabelingDefinition(fe, secondOrder, range, labelEncoder.inverse(), costType); LearningAlgorithm method = new AlgorithmFactory().create(methodName, methodArgs); Classifier cl = method.train(def, trainingSet); trainingSet = null; TaggingModel tm = new TaggingModel(cl, labelEncoder, wordEncoder, def); ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(modelFile)); oos.writeObject(tm); oos.close(); } private static Pair readLabeledSentence(BufferedReader br, SymbolEncoder wordEncoder, SymbolEncoder labelEncoder) throws IOException { ArrayList sen = new ArrayList(); String line = br.readLine(); while(true) { if(line == null) return null; line = line.trim(); if(!line.equals("")) { String[] ss = line.split("\t"); sen.add(ss); } else { Sentence senc = new Sentence(sen, wordEncoder); Labeling lenc = new Labeling(sen, labelEncoder); return new Pair(senc, lenc); } line = br.readLine(); } } private static ArrayList> readLabeledFile(String file, SymbolEncoder wordEncoder, SymbolEncoder labelEncoder) throws IOException { BufferedReader br; InputStream is; if(file.endsWith("gz")) is = new GZIPInputStream(new FileInputStream(file)); else is = new FileInputStream(file); br = new BufferedReader(new InputStreamReader(is)); ArrayList> fullSet = new ArrayList(); Pair sen = readLabeledSentence(br, wordEncoder, labelEncoder); while(sen != null) { fullSet.add(sen); sen = readLabeledSentence(br, wordEncoder, labelEncoder); } br.close(); return fullSet; } private static int[] getLabelRange(ArrayList> s) { TIntHashSet labelSet = new TIntHashSet(); for(Pair p: s) for(int l: p.right.labels) labelSet.add(l); int[] range = labelSet.toArray(); Arrays.sort(range); return range; } public static double[] cv(String[] argv) throws IOException { //System.out.println(Arrays.toString(argv)); String trainingFile = argv[1]; String templateFile = argv[2]; int nFolds = Integer.parseInt(argv[3]); boolean secondOrder = Boolean.parseBoolean(argv[4]); int costType = Integer.parseInt(argv[5]); String methodName = argv[6]; String methodArgs = argv[7]; FeatureTemplateSet fe = new FeatureTemplateSet(Util.readFile(templateFile), NBITS_FEATURE_DATA); SymbolEncoder wordEncoder = new SymbolEncoder(); SymbolEncoder labelEncoder = new SymbolEncoder(); int outsideEnc = labelEncoder.encode(""); if(outsideEnc != 1) throw new RuntimeException("outsideEnc != 1"); if(wordEncoder.encode("") != 1) throw new RuntimeException("outsideEnc != 1"); ArrayList> fullSet = readLabeledFile(trainingFile, wordEncoder, labelEncoder); Collections.shuffle(fullSet, new Random(0)); wordEncoder.freeze(); labelEncoder.freeze(); ArrayList labelDecoder = labelEncoder.inverse(); double[] stats = new double[6]; AlgorithmFactory.setVerbosity(0); for(int i = 0; i < nFolds; i++) { ArrayList> trs = new ArrayList(); ArrayList> tes = new ArrayList(); CollectionUtils.cvSplit(fullSet, trs, tes, nFolds, i); int[] range = getLabelRange(trs); SeqLabelingDefinition def = new SeqLabelingDefinition(fe, secondOrder, range, labelEncoder.inverse(), costType); LearningAlgorithm method = new AlgorithmFactory().create(methodName, methodArgs); Classifier cl = method.train(def, trs); for(Pair p: tes) { Sentence senc = p.left; Labeling lenc = p.right; Labeling yhat = cl.classify(senc); evalIOB(labelDecoder, lenc, yhat, stats); } } return stats; } public static TaggingModel loadModel(String fileName) throws IOException { InputStream is = Util.openFileStream(fileName); ObjectInputStream ois = new ObjectInputStream(is); try { TaggingModel tm = (TaggingModel) ois.readObject(); ois.close(); return tm; } catch(Exception e) { e.printStackTrace(); throw new IOException(e.getMessage()); } } public static void printSentence(TaggingModel model, ArrayList tokens, Labeling labeling) { for(int i = 0; i < labeling.labels.length; i++) { for(String t: tokens.get(i)) System.out.print(t + "\t"); String lbl = model.labelDecoder.get(labeling.labels[i]); System.out.println(lbl); } System.out.println(); } /* public static ArrayList toSpans(TaggingModel model, Labeling labeling) { ArrayList out = new ArrayList(); Span current = null; for(int i = 0; i < labeling.labels.length; i++) { String lbl = model.labelDecoder.get(labeling.labels[i]); if(current != null) current.tokenEnd = i; if(lbl.startsWith("B-") || lbl.startsWith("I-") && (current == null || !lbl.substring(2).equals(current.label))) { current = new Span(); //i, -1, lbl.substring(2)); current.tokenStart = i; current.tokenEnd = -1; current.label = lbl.substring(2); out.add(current); } else if(lbl.equals("O")) current = null; } if(current != null) current.end = labeling.labels.length; return out; }*/ public static void printSentenceErrors(TaggingModel model, ArrayList tokens, Sentence x, Labeling y, Labeling yhat) { for(int i = 0; i < y.labels.length; i++) { if(y.labels[i] != yhat.labels[i]) System.out.print("!"); System.out.print("\t"); if(x.features[i][0] == 0) System.out.print("U"); System.out.print("\t"); System.out.print(tokens.get(i)[0]); System.out.print("\t"); System.out.print(model.labelDecoder.get(y.labels[i])); System.out.print("\t"); System.out.print(model.labelDecoder.get(yhat.labels[i])); System.out.println(); } System.out.println(); } static AnnotationLayer toAnnotationLayer(ArrayList dict, Labeling y) { AnnotationLayer l = new AnnotationLayer(); Span current = null; for(int i = 0; i < y.labels.length; i++) { String lbl = dict.get(y.labels[i]); //if(lbl.startsWith("B-") || // lbl.startsWith("I-") && (current == null || !lbl.equals("I-" + current.label))) { //} if(lbl.equals("O") && current != null) { current.tokenEnd = i; current = null; } if(lbl.startsWith("B-") || lbl.startsWith("I-") && (current == null || !lbl.equals("I-" + current.label))) { if(current != null) current.tokenEnd = i; current = new Span(); current.tokenStart = i; current.tokenEnd = Integer.MIN_VALUE; current.label = lbl.substring(2); l.add(current); } } if(current != null) current.tokenEnd = y.labels.length; for(Span s: l) if(s.tokenStart < 0 || s.tokenEnd < 0) { System.out.println("y = " + y); System.out.println("l = " + l); throw new RuntimeException("!!!"); } for(Span s: l) { // hack to make span evaluation work s.start = s.tokenStart; s.end = s.tokenEnd; } return l; } static void evalIOB(ArrayList dict, Labeling y, Labeling yhat, double[] scores) { // ncorrect, nguesses, ningold, propGuessCorrect, propFoundCorrect AnnotationLayer l1 = toAnnotationLayer(dict, y); AnnotationLayer l2 = toAnnotationLayer(dict, yhat); double[] comp = l1.compareApprox(l2); /*CompareResult senRes = l1.compare(l2); res.add(senRes); if(senRes.nCorrect != comp[3]) { System.out.println(y); System.out.println(l1); System.out.println(yhat); System.out.println(l2); System.out.println(Arrays.toString(comp)); System.out.println(senRes); System.exit(0); }*/ scores[0] += comp[3]; scores[1] += comp[4]; scores[2] += comp[5]; scores[3] += comp[6]; scores[4] += comp[7]; scores[5] += comp[8]; } public static void tagStdin(TaggingModel model) throws IOException { BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); long t0 = System.currentTimeMillis(); int nWords = 0; String line = br.readLine(); ArrayList sen = new ArrayList(); while(line != null) { line = line.trim(); if(!line.equals("")) { String[] ss = line.split("\t"); sen.add(ss); } else { Sentence senc = new Sentence(sen, model.inEncoder); //Labeling lenc = new Labeling(sen, model.outEncoder); Labeling yhat = model.cl.classify(senc); nWords += yhat.labels.length; printSentence(model, sen, yhat); sen.clear(); } line = br.readLine(); } System.out.flush(); System.out.close(); long t1 = System.currentTimeMillis(); System.err.println("Tagging time: " + (t1 - t0) + " ms."); System.err.println(1000.0 * nWords / (double)(t1-t0) + " words/sec."); } public static void debug(String[] argv) { String modelFile = argv[1]; String testFile = argv[2]; try { TaggingModel model = loadModel(modelFile); BufferedReader br = new BufferedReader(new FileReader(testFile)); long t0 = System.currentTimeMillis(); int nWords = 0; String line = br.readLine(); ArrayList sen = new ArrayList(); while(line != null) { line = line.trim(); if(!line.equals("")) { String[] ss = line.split("\t"); sen.add(ss); } else { Sentence senc = new Sentence(sen, model.inEncoder); Labeling lenc = new Labeling(sen, model.outEncoder); Labeling yhat = model.cl.classify(senc); nWords += yhat.labels.length; printSentenceErrors(model, sen, senc, lenc, yhat); sen.clear(); } line = br.readLine(); } System.out.flush(); System.out.close(); long t1 = System.currentTimeMillis(); System.err.println("Tagging time: " + (t1- t0) + " ms."); System.err.println(1000.0 * nWords / (double)(t1-t0) + " words/sec."); } catch(Exception e) { e.printStackTrace(); } } private static void tag(String[] argv) { try { TaggingModel tm = loadModel(argv[1]); tagStdin(tm); } catch(Exception e) { e.printStackTrace(); } } public static void tagKBest(String[] argv) { try { TaggingModel model = loadModel(argv[1]); int k = Integer.parseInt(argv[2]); BufferedReader br; PrintWriter out; if(argv.length <= 3) { br = new BufferedReader(new InputStreamReader(System.in)); out = new PrintWriter(new OutputStreamWriter(System.out, "UTF-8")); } else { br = Util.openFileReader(argv[3], "UTF-8"); out = Util.openFileWriter(argv[4], "UTF-8"); } Labeling[] results = new Labeling[k]; double[] scores = new double[k]; SparseVector[] reps = new SparseVector[k]; long t0 = System.currentTimeMillis(); int nWords = 0; double[] stats = new double[6]; int[] positionHist = new int[k]; String line = br.readLine(); ArrayList sen = new ArrayList(); while(line != null) { line = line.trim(); if(!line.equals("")) { String[] ss = line.split("\t"); sen.add(ss); } else { for(int i = 0; i < sen.size(); i++) { if(i > 0) out.print(" "); out.print(sen.get(i)[0]); } out.println(); Sentence senc = new Sentence(sen, model.inEncoder); Labeling lenc = new Labeling(sen, model.outEncoder); AnnotationLayer goldLayer = toAnnotationLayer(model.labelDecoder, lenc); for(int j = 0; j < goldLayer.spans.size(); j++) { if(j > 0) out.print(" "); Span s = goldLayer.spans.get(j); out.print(s.label + "," + s.tokenStart + "," + s.tokenEnd); } out.println(); int nfound = model.cl.nbest(k, senc, results, scores, reps); nWords += results[0].labels.length; //double bestScore = Double.NEGATIVE_INFINITY; //double[] best = null; double minErr = Double.POSITIVE_INFINITY; int minErrPos = -1; double[] best = null; for(int i = 0; i < nfound; i++) { out.println(i); out.println(scores[i]); double[] senStats = new double[6]; evalIOB(model.labelDecoder, lenc, results[i], senStats); double propGuessWrong = senStats[1] - senStats[3]; double propNotFound = senStats[2] - senStats[4]; //double score = senStats[3] + senStats[4]; double err = propGuessWrong + propNotFound; /*if(score > bestScore) { bestScore = score; best = senStats; }*/ if(err < minErr) { minErr = err; best = senStats; minErrPos = i; } for(int j = 0; j < 6; j++) { if(j > 0) out.print(" "); out.print(senStats[j]); } out.println(); //out.println(propGuessWrong + " " + propNotFound); AnnotationLayer layer = toAnnotationLayer(model.labelDecoder, results[i]); for(int j = 0; j < layer.spans.size(); j++) { if(j > 0) out.print(" "); Span s = layer.spans.get(j); out.print(s.label + "," + s.tokenStart + "," + s.tokenEnd); } out.println(); //out.println(toAnnotationLayer(model.labelDecoder, results[i]).spans); //printSentence(model, sen, results[i]); } out.println("---"); for(int i = 0; i < best.length; i++) stats[i] += best[i]; positionHist[minErrPos]++; sen.clear(); } line = br.readLine(); } out.flush(); out.close(); long t1 = System.currentTimeMillis(); System.err.println(Viterbi.countPops); System.err.println(Viterbi.countRedundantPops); System.err.println("Tagging time: " + (t1 - t0) + " ms."); System.err.println(1000.0 * nWords / (double)(t1-t0) + " words/sec."); //System.err.println(Arrays.toString(positionHist)); double nCorrectSpans = stats[0]; double nGuesses = stats[1]; double nInGold = stats[2]; double propGuessCorrect = stats[3]; double propFoundCorrect = stats[4]; double nOverlap = stats[5]; double pHard = nCorrectSpans / nGuesses; double rHard = nCorrectSpans / nInGold; double fHard = 2*pHard*rHard / (pHard + rHard); double pSoft = propGuessCorrect / nGuesses; double rSoft = propFoundCorrect / nInGold; double fSoft = 2*pSoft*rSoft / (pSoft + rSoft); double pOver = nOverlap / nGuesses; double rOver = nOverlap / nInGold; double fOver = 2*pOver*rOver / (pOver + rOver); System.err.println("Hard: p = " + pHard + ", r = " + rHard + ", f1 = " + fHard); System.err.println("Soft: p = " + pSoft + ", r = " + rSoft + ", f1 = " + fSoft); System.err.println("Overlap: p = " + pOver + ", r = " + rOver + ", f1 = " + fOver); } catch(Exception e) { e.printStackTrace(); } } public static AnnotationLayer tagSentence(TaggingModel model, String[][] words) { ArrayList sen = new ArrayList(); for(String[] s: words) sen.add(s); Sentence senc = new Sentence(sen, model.inEncoder); Labeling yhat = model.cl.classify(senc); return toAnnotationLayer(model.labelDecoder, yhat); } public static ArrayList> tagSentenceKBest(TaggingModel model, String[][] words, int k) { // if(true) // for(String[] s: words) // System.out.println(Arrays.toString(s)); ArrayList sen = new ArrayList(); for(String[] s: words) sen.add(s); Sentence senc = new Sentence(sen, model.inEncoder); if(k > 1) { Labeling[] results = new Labeling[k]; double[] scores = new double[k]; SparseVector[] reps = new SparseVector[k]; int nfound = model.cl.nbest(k, senc, results, scores, reps); ArrayList> out = new ArrayList(); for(int i = 0; i < nfound; i++) { AnnotationLayer l = toAnnotationLayer(model.labelDecoder, results[i]); out.add(new DoubleObjPair(scores[i], l)); } // if(true) { // for(DoubleObjPair l: out) // System.out.println(l); // } return out; } else { Labeling y = model.cl.classify(senc); ArrayList> out = new ArrayList(); AnnotationLayer l = toAnnotationLayer(model.labelDecoder, y); out.add(new DoubleObjPair(0.0, l)); return out; } } public static void eval(String[] argv) throws IOException { String modelName = argv[1]; String evalFileName = argv[2]; TaggingModel model = loadModel(modelName); BufferedReader br = Util.openFileReader(evalFileName); long t0 = System.currentTimeMillis(); int nWords = 0, nCorrect = 0; int nSeen = 0, nUnseen = 0, nSeenCorrect = 0, nUnseenCorrect = 0; int nAmb = 0, nUnamb = 0, nAmbCorrect = 0, nUnambCorrect = 0; double[] stats = new double[6]; Pair p = readLabeledSentence(br, model.inEncoder, model.outEncoder); while(p != null) { Sentence senc = p.left; Labeling lenc = p.right; Labeling yhat = model.cl.classify(senc); evalIOB(model.labelDecoder, lenc, yhat, stats); for(int i = 0; i < yhat.labels.length; i++) { if(yhat.labels[i] == lenc.labels[i]) nCorrect++; //else // errWords.add(sen.get(i)[0].toLowerCase()); if(senc.features[i][0] != 0) { nSeen++; if(yhat.labels[i] == lenc.labels[i]) nSeenCorrect++; } else { nUnseen++; if(yhat.labels[i] == lenc.labels[i]) nUnseenCorrect++; } } nWords += lenc.labels.length; p = readLabeledSentence(br, model.inEncoder, model.outEncoder); } long t1 = System.currentTimeMillis(); System.out.println("Tagging time: " + (t1- t0) + " ms."); System.out.println(1000.0 * nWords / (double)(t1-t0) + " words/sec."); System.out.println("Accuracy: " + nCorrect + "/" + nWords + " = " + (double) nCorrect / nWords); System.out.println(); System.out.println("Accuracy (seen): " + nSeenCorrect + "/" + nSeen + " = " + (double) nSeenCorrect / nSeen); System.out.println("Accuracy (unseen): " + nUnseenCorrect + "/" + nUnseen + " = " + (double) nUnseenCorrect / nUnseen); System.out.println(); System.out.println("Accuracy (ambig): " + nAmbCorrect + "/" + nAmb + " = " + (double) nAmbCorrect / nAmb); System.out.println("Accuracy (unambig): " + nUnambCorrect + "/" + nUnamb + " = " + (double) nUnambCorrect / nUnamb); System.out.println(); double nCorrectSpans = stats[0]; double nGuesses = stats[1]; double nInGold = stats[2]; double propGuessCorrect = stats[3]; double propFoundCorrect = stats[4]; //System.out.println(Arrays.toString(stats)); double pHard = nCorrectSpans / nGuesses; double rHard = nCorrectSpans / nInGold; double fHard = 2*pHard*rHard / (pHard + rHard); double pSoft = propGuessCorrect / nGuesses; double rSoft = propFoundCorrect / nInGold; double fSoft = 2*pSoft*rSoft / (pSoft + rSoft); System.out.format("In gold: %d, guessed: %d, correct: %d.\n", (int) nInGold, (int) nGuesses, (int) nCorrectSpans); System.out.format("Hard: p = %f, r = %f, f1 = %f\n", pHard, rHard, fHard); System.out.format("Soft: p = %f, r = %f, f1 = %f\n", pSoft, rSoft, fSoft); /* System.out.println(); int count = 0; for(IntObjPair p: errWords.asSortedList()) { System.out.println(p.right + "\t" + p.left); count++; if(count == 20) break; }*/ } public static void main(String[] argv) { try { if(argv[0].equals("-train")) { train(argv); /*String prefix = "/home/richard/work/workspace/mpqa_structlearn/epe/lth_test/seq"; String seqTemplate = "template_100402b.txt"; String modelFile = prefix + ".model"; String secondOrder = "true"; String costType = "1"; // for hamming loss String methodName = "OnlineAlgorithm"; String methodArgs = "nRounds=1 printDots=100 useNewLinear=true modelSize=1400000 PAUpdate C=0.1 sorted=true sqrt=true maximizeLoss=true"; String[] argv2 = new String[] { null, prefix, seqTemplate, modelFile, secondOrder, costType, methodName, methodArgs }; train(argv2);*/ } else if(argv[0].equals("-tag")) tag(argv); else if(argv[0].equals("-kbest")) tagKBest(argv); else if(argv[0].equals("-debug")) debug(argv); else if(argv[0].equals("-eval")) eval(argv); else { System.err.println("Error: illegal mode " + argv[0]); } } catch(Exception e) { e.printStackTrace(); System.exit(1); } } }