package se.lth.cs.nlp.opinions; import java.util.ArrayList; import java.util.Arrays; import se.lth.cs.nlp.nlputils.ml_long.*; public class SeqLabelingDefinition extends ProblemDefinition { private static final long serialVersionUID = 0L; //private SeqLabelingFE fe; private boolean secondOrder; private FeatureTemplateSet fe; private int[] range; // used in span cost private ArrayList labelDict; private double[] iobResult = new double[6]; public SeqLabelingDefinition(FeatureTemplateSet fe, boolean secondOrder, int[] range, ArrayList labelDict, int costType) { this.fe = fe; this.range = range; this.secondOrder = secondOrder; this.labelDict = labelDict; if(fe.getMaximumRange() > 4) throw new IllegalArgumentException("Dynamic features too far back"); if(fe.getMaximumRange() == 4 && !secondOrder) throw new IllegalArgumentException("This feature set requires second-order search"); this.costType = costType; } private static final boolean SQRT = true; private static final int TOKEN_COST = 1; private static final int ZERO_ONE_COST = 2; private static final int SPAN_COST_HARD = 3; private static final int SPAN_COST_SOFT = 4; private static final int SPAN_F1_COST_SOFT = 5; private int costType = TOKEN_COST; public double cost(Sentence x, Labeling y, Labeling yhat) { switch(costType) { case TOKEN_COST: { double out = 0; for(int i = 0; i < y.labels.length; i++) if(y.labels[i] != yhat.labels[i]) out++; return out; } case ZERO_ONE_COST: { for(int i = 0; i < y.labels.length; i++) if(y.labels[i] != yhat.labels[i]) return 1; return 0; } case SPAN_COST_HARD: { Arrays.fill(iobResult, 0.0); SeqLabeler.evalIOB(labelDict, y, yhat, iobResult); double nCorrect = iobResult[0]; double nGuessed = iobResult[1]; double nInGold = iobResult[2]; double cost = nGuessed + nInGold - 2*nCorrect; if(cost < 0) throw new RuntimeException("cost < 0"); return cost; } case SPAN_COST_SOFT: { Arrays.fill(iobResult, 0.0); SeqLabeler.evalIOB(labelDict, y, yhat, iobResult); double nGuessed = iobResult[1]; double nInGold = iobResult[2]; double propGuessedCorrect = iobResult[3]; double propFoundCorrect = iobResult[4]; double cost = nGuessed + nInGold - propGuessedCorrect - propFoundCorrect; if(cost < 0) throw new RuntimeException("cost < 0"); return cost; } case SPAN_F1_COST_SOFT: { Arrays.fill(iobResult, 0.0); SeqLabeler.evalIOB(labelDict, y, yhat, iobResult); double nGuessed = iobResult[1]; double nInGold = iobResult[2]; double propGuessedCorrect = iobResult[3]; double propFoundCorrect = iobResult[4]; double cost; if(nGuessed == 0 && propGuessedCorrect == 0 && nInGold == 0 && propFoundCorrect == 0) cost = 0; else { double p = propGuessedCorrect / nGuessed; double r = propFoundCorrect / nInGold; double f; if(p == 0 || r == 0) f = 0; else f = 2*p*r/(p+r); cost = 1 - f; } if(cost < 0) throw new RuntimeException("cost < 0"); return cost; } default: { throw new RuntimeException("Illegal cost type"); } } } public Labeling argmax(Sentence x, Model model, SparseVector sv, double[] result) { //System.out.println("entering argmax"); //System.out.println("x = " + x); Labeling yhat; if(secondOrder) yhat = Viterbi.search2(x, null, false, SQRT, fe, range, model, result); else yhat = Viterbi.search1(x, null, false, SQRT, fe, range, model, result); //System.out.println("yhat = " + yhat); sv.clear(); encode(x, yhat, sv); //System.out.println(sv.index); //System.out.println(sv); //System.out.println(fe.svToString(sv)); //System.exit(0); double result2 = model.score(sv); double diff = result2 - result[0]; if(Math.abs(diff) > 1e-5) { System.out.println(x); //System.out.println(fe.svToString(sv)); throw new RuntimeException("diff = " + diff); } //System.out.println("sv = " + sv); //System.out.println("sv.length = " + sv.index); //System.out.println(fe.svToString(sv)); return yhat; } public Labeling maxLoss(Sentence x, Labeling y, Model model, SparseVector sv, double[] result) { Labeling yhat; if(secondOrder) yhat = Viterbi.search2(x, y, true, SQRT, fe, range, model, result); else yhat = Viterbi.search1(x, y, true, SQRT, fe, range, model, result); sv.clear(); encode(x, yhat, sv); result[0] = model.score(sv); result[1] = cost(x, y, yhat); /* SparseVector svtmp = new SparseVector(); encode(x, y, svtmp); double yscore = model.score(svtmp); double fakeloss = result[0] + result[1]/Math.sqrt(y.labels.length) - yscore; if(fakeloss < -1e-5) { throw new RuntimeException("fakeloss = " + fakeloss); }*/ return yhat; } public int findNBest(Sentence x, Model model, int n, Labeling[] ybars, SparseVector[] reps, double[] scores) { int nout = -1; if(secondOrder) throw new UnsupportedOperationException("Unimplemented!"); else nout = Viterbi.search1KBest(x, null, n, ybars, scores, false, SQRT, fe, range, model); for(int i = 0; i < nout; i++) { reps[i] = null; //new SparseVector(); // encode(x, ybars[i], reps[i]); // varför behövs? //scores[i] = model.score(reps[i]); } return nout; } public void encode(Sentence x, Labeling y, SparseVector sv) { sv.clear(); if(x.features.length != y.labels.length) throw new IllegalArgumentException("illegal length"); int[] ofs = new int[3]; Arrays.fill(ofs, FeatureTemplate.OUTSIDE_LABEL); for(int i = 0; i < y.labels.length; i++) { ofs[2] = ofs[1]; ofs[1] = ofs[0]; ofs[0] = y.labels[i]; fe.extractFeatures(sv, x.features, ofs, i, 0); fe.extractFeatures(sv, x.features, ofs, i, 1); fe.extractFeatures(sv, x.features, ofs, i, 2); if(secondOrder) fe.extractFeatures(sv, x.features, ofs, i, 3); } ofs[2] = ofs[1]; ofs[1] = ofs[0]; ofs[0] = FeatureTemplate.OUTSIDE_LABEL; fe.extractFeatures(sv, x.features, ofs, x.features.length, 0); fe.extractFeatures(sv, x.features, ofs, x.features.length, 1); fe.extractFeatures(sv, x.features, ofs, x.features.length, 2); if(secondOrder) { fe.extractFeatures(sv, x.features, ofs, x.features.length, 3); } // det beh�vs v�l inte ett trigram f�r slutet? sv.sortIndices(); } }