package mpqa_seq_reranker; import se.lth.cs.nlp.depsrl.format.PAStructure; import se.lth.cs.nlp.nlputils.ml_long.*; import se.lth.cs.nlp.nlputils.annotations.Span; import se.lth.cs.nlp.nlputils.core.*; import se.lth.cs.nlp.nlputils.depgraph.*; import java.io.Serializable; import java.util.*; public class RerankingFE implements Serializable { private static final long serialVersionUID = 0; //SparseVector tmp = new SparseVector(); private boolean useBigrams; private boolean useSyntax; private boolean useSpans; private boolean useSemantics; //private boolean scoreSources; public RerankingFE(boolean useBigrams, boolean useSyntax, boolean useSpans, boolean useSemantics) { this.useBigrams = useBigrams; this.useSyntax = useSyntax; this.useSpans = useSpans; this.useSemantics = useSemantics; } public void extractFeatures(Candidate candidate, SynSemParse parse, double baseScore, SparseVector sv) { if((useSyntax || useSemantics) && parse == null) throw new IllegalArgumentException("parse == null"); sv.clear(); if(useBigrams) extractBigrams(candidate, sv); //extractUnigrams(candidate, sv); if(useSyntax) extractPaths(candidate, parse, sv); if(useSemantics) extractSemPaths(candidate, parse, sv); /*if(candidate.spans.size() > 1) { double scale = 1.0 / (candidate.spans.size() - 1); for(int i = 0; i < sv.index; i++) sv.values[i] *= scale; }*/ if(useSpans) extractSpanFeatures(candidate, parse, sv); sv.put(1, baseScore); sv.sortIndices(); } public void extractSourceFeatures(Candidate c, SynSemParse parse, SparseVector sv) { sv.put(hash("srcBase"), c.baseSrcScore); for(int i = 0; i < c.spans.size(); i++) { Span s1 = c.spans.get(i); for(int j = i+1; j < c.spans.size(); j++) { if(i == j) continue; Span s2 = c.spans.get(j); DepNode[] ns = findClosestNodes(s1, s2, parse); if(false) { String depPath = depPath(ns[0], ns[1], ns[2]); String w1 = c.isWriter[i]? "t": "f"; String w2 = c.isWriter[j]? "t": "f"; String i1 = c.isImplicit[i]? "t": "f"; String i2 = c.isImplicit[j]? "t": "f"; String n1 = c.srcIndex[i] > 0? "t": "f"; String n2 = c.srcIndex[j] > 0? "t": "f"; String f2 = "src:" + s1.label + ":" + depPath + ":" + s2.label; f2 += w1 + w2 + i1 + i2 + n1 + n2; putFeature(f2, sv); String path1 = ""; if(c.srcIndex[i] > 0) { Span s_src = new Span(); s_src.tokenStart = c.srcIndex[i] - 1; s_src.tokenEnd = c.srcIndex[i]; DepNode[] ns_src = findClosestNodes(s1, s_src, parse); path1 = depPath(ns_src[0], ns_src[1], ns_src[2]); } String f3 = "srcX:" + s1.label + ":" + depPath + ":" + path1 + ":" + s2.label; putFeature(f3, sv); } if(true) { if(ns[2] == ns[0]) { String depPath = depPath(ns[0], ns[1], ns[2]); String f = s1.label + /*":" + depPath +*/ ":" + s2.label; f = "shs:" + sharedSource(c, i, j) + ":" + f; putFeature(f, sv); String w1 = c.isWriter[i]? "t": "f"; String w2 = c.isWriter[j]? "t": "f"; String i1 = c.isImplicit[i]? "t": "f"; String i2 = c.isImplicit[j]? "t": "f"; String n1 = c.srcIndex[i] > 0? "t": "f"; String n2 = c.srcIndex[j] > 0? "t": "f"; String f2 = "src:" + s1.label + ":" + depPath + ":" + s2.label; f2 += w1 + w2 + i1 + i2 + n1 + n2; putFeature(f2, sv); /*String f3 = "src2:" + s1.label + ":" + depPath + ":" + s2.label; f3 += getSourceString(c, i, parse.dg) + ":" + getSourceString(c, j, parse.dg); putFeature(f3, sv);*/ } else if(ns[2] == ns[1]) { String depPath = depPath(ns[1], ns[0], ns[2]); String f = s2.label + /*":" + depPath +*/ ":" + s1.label; f = "shs:" + sharedSource(c, i, j) + ":" + f; putFeature(f, sv); String w2 = c.isWriter[i]? "t": "f"; String w1 = c.isWriter[j]? "t": "f"; String i2 = c.isImplicit[i]? "t": "f"; String i1 = c.isImplicit[j]? "t": "f"; String n2 = c.srcIndex[i] > 0? "t": "f"; String n1 = c.srcIndex[j] > 0? "t": "f"; String f2 = "src:" + s2.label + ":" + depPath + ":" + s1.label; f2 += w1 + w2 + i1 + i2 + n1 + n2; putFeature(f2, sv); /*String f3 = "src2:" + s2.label + ":" + depPath + ":" + s1.label; f3 += getSourceString(c, j, parse.dg) + ":" + getSourceString(c, i, parse.dg); putFeature(f3, sv);*/ } else continue; } } } sv.sortIndices(); } private String getSourceString(Candidate c, int i, DepGraph dg) { if(c.isWriter[i]) return "W"; if(c.isImplicit[i]) return "I"; if(c.srcIndex[i] > 0) return dg.nodes[c.srcIndex[i]].word.toLowerCase(); return "E"; } private boolean sharedSource(Candidate c, int i, int j) { if(c.isWriter[i]) return c.isWriter[j]; if(c.isImplicit[i]) return c.isImplicit[j]; if(c.srcIndex[i] > 0) return c.srcIndex[i] == c.srcIndex[j]; return false; } private static void extractBigrams(Candidate c, SparseVector sv) { if(false) { String prev = ""; for(Span s: c.spans) { String bigram = "BG:" + prev + "|" + s.label; sv.put(hash(bigram), 1.0); prev = s.label; } String bigram = "BG:" + prev + "|"; sv.put(hash(bigram), 1.0); } else { for(int i = 0; i < c.spans.size(); i++) { for(int j = 0; j < c.spans.size(); j++) { if(i == j) continue; Span s1 = c.spans.get(i); Span s2 = c.spans.get(j); String bigram; if(s1.label.compareTo(s2.label) < 0) bigram = s1.label + "_" + s2.label; else bigram = s2.label + "_" + s2.label; sv.put(hash("BG2:" + bigram), 1.0); } } } } private static void extractUnigrams(Candidate c, SparseVector sv) { if(false) { for(Span s: c.spans) { String unigram = "UG:" + s.label; sv.put(hash(unigram), 1.0); } } else { for(Span s: c.spans) { String unigram = "UG2"; sv.put(hash(unigram), 1.0); } } } private static IntObjPair findCommonAncestor(DepNode n1, DepNode n2) { HashMap m = new HashMap(); int steps = 0; while(true) { m.put(n1, steps); if(n1.parents.length == 0) break; else n1 = n1.parents[0]; steps++; } steps = 0; while(true) { Integer otherSteps = m.get(n2); if(otherSteps != null) return new IntObjPair(otherSteps + steps, n2); if(n2.parents.length == 0) break; else n2 = n2.parents[0]; steps++; } throw new RuntimeException("coulnd't find ancestor"); } private static DepNode[] findClosestNodes(Span s1, Span s2, SynSemParse parse) { DepGraph dg = parse.dg; int minLength = Integer.MAX_VALUE; DepNode bestN1 = null, bestN2 = null, bestAnc = null; for(int i = s1.tokenStart; i < s1.tokenEnd; i++) { DepNode n1 = dg.nodes[i+1]; for(int j = s2.tokenStart; j < s2.tokenEnd; j++) { DepNode n2 = dg.nodes[j+1]; IntObjPair p = findCommonAncestor(n1, n2); if(p.left < minLength) { minLength = p.left; bestAnc = p.right; bestN1 = n1; bestN2 = n2; } } } return new DepNode[] { bestN1, bestN2, bestAnc }; } private static String depPath(DepNode n1, DepNode n2, DepNode ancestor) { StringBuilder sb = new StringBuilder(); while(n1 != ancestor) { sb.append(n1.relations[0] + "_"); n1 = n1.parents[0]; } sb.append("|"); while(n2 != ancestor) { sb.append(n2.relations[0] + "_"); n2 = n2.parents[0]; } return sb.toString(); } private static void extractPaths(Candidate c, SynSemParse parse, SparseVector sv) { for(int i = 0; i < c.spans.size(); i++) { Span s1 = c.spans.get(i); for(int j = i + 1; j < c.spans.size(); j++) { Span s2 = c.spans.get(j); DepNode[] ns = findClosestNodes(s1, s2, parse); String depPath = depPath(ns[0], ns[1], ns[2]); String depPathFeature = "p:" + s1.label + "-" + depPath + "-" + s2.label; if(true) putFeature(depPathFeature, sv); String depPathWordFeature = "pw:" + s1.label + "|" + ns[0].word.toLowerCase() + "|" + depPath + "|" + ns[1].word.toLowerCase() + "|" + s2.label; if(true) putFeature(depPathWordFeature, sv); String dominanceFeature = null; if(ns[2] == ns[0]) { dominanceFeature = "dom_w:" + s1.label + "/" + ns[0].word.toLowerCase() + "->" + s2.label + "/" + ns[1].word.toLowerCase(); } else if(ns[2] == ns[1]) { dominanceFeature = "dom_w:" + s2.label + "/" + ns[1].word.toLowerCase() + "->" + s1.label + "/" + ns[0].word.toLowerCase(); } if(dominanceFeature != null) { if(true) putFeature(dominanceFeature, sv); } } } } private static void extractSpanFeatures(Candidate c, SynSemParse parse, SparseVector sv) { for(int i = 0; i < c.spans.size(); i++) { Span s = c.spans.get(i); ArrayList ns = new ArrayList(); for(int j = s.tokenStart; j < s.tokenEnd; j++) ns.add(parse.dg.nodes[j + 1]); String startWord = ns.get(0).word.toLowerCase(); String endWord = ns.get(ns.size() - 1).word.toLowerCase(); putFeature("st:" + s.label + ":" + startWord, sv); putFeature("en:" + s.label + ":" + endWord, sv); putFeature("st_en:" + s.label + ":" + startWord + ":" + endWord, sv); } } private static void extractSemPaths(Candidate c, SynSemParse parse, SparseVector sv) { HashMap m = new HashMap(); for(Span s: c.spans) for(int i = s.tokenStart; i < s.tokenEnd; i++) m.put(parse.dg.nodes[i+1], s); for(PAStructure pa: parse.pas) { int predPosition = pa.pred.position - 1; for(Span s: c.spans) if(s.tokenStart <= predPosition && s.tokenEnd > predPosition) { // PREDICATE if(true) putFeature("prd:" + s.label + ":" + pa.lemma, sv); for(int i = 0; i < pa.argLabels.size(); i++) { String al = pa.argLabels.get(i); // PREDICATE + ARGLABEL if(true) putFeature("prd_al:" + s.label + ":" + pa.lemma + ":" + al, sv); int argPosition = pa.args.get(i).position - 1; for(Span s2: c.spans) // todo kolla också att det är ett annat arg? if(s2.tokenStart <= argPosition && s2.tokenEnd > argPosition) { if(false) putFeature("prd_al_l:" + s.label + ":" + pa.lemma + ":" + al + ":" + s2.label, sv); // CONNECTING ARGLABEL else if(true) putFeature("prd_al_l:" + s.label + /*":" + pa.lemma +*/ ":" + al + ":" + s2.label, sv); break; } } if(true) for(Span s2: c.spans) { if(s == s2) continue; // check whether s2 is dominated by some argument DepNode[] ns = findClosestNodes(s, s2, parse); //return new DepNode[] { bestN1, bestN2, bestAnc }; DepNode n1= ns[0]; DepNode n2 = ns[1]; DepNode anc = ns[2]; //boolean found = pa.args.contains(n2); int ix = pa.args.indexOf(n2); while(ix == -1) { if(n2 != anc) { n2 = n2.parents[0]; Span s_n2 = m.get(n2); if(s_n2 != null && s_n2 != s && s_n2 != s2) break; ix = pa.args.indexOf(n2); } else break; } while(n1 != anc) { n1 = n1.parents[0]; Span s_n1 = m.get(n1); if(s_n1 != null && s_n1 != s && s_n1 != s2) { ix = -1; break; } } if(ix == -1) continue; String al = pa.argLabels.get(ix); // CONNECTING ARGLABEL VARIANT 2 putFeature("prd_al2_l:" + s.label + /*":" + pa.lemma +*/ ":" + al + ":" + s2.label, sv); } break; } } } private static SymbolEncoder enc = new SymbolEncoder(); static { if(enc.encode("(base model score)") != 1) throw new RuntimeException("!!!"); } private static void putFeature(String f, SparseVector sv) { if(Reranker.PRINT_TOP_FEATURES) { int encoded = enc.encode(f); sv.put(encoded, 1.0); } else sv.put(hash(f), 1.0); } static void printProminentFeatures(LinkedList l) { int index = 0; ArrayList encInv = enc.inverse(); for(DoubleIntPair p: l) { index++; System.out.println(index + ":\t" + p.left + "\t" + encInv.get(p.right)); } } private static long hash(String s) { long hash = 0; int l = s.length(); for(int i = 0; i < l; i++) { hash += s.charAt(i); hash += (hash << 10L); hash ^= (hash >> 6L); } hash += (hash << 3L); hash ^= (hash >> 11L); hash += (hash << 15L); return hash; } public String toString() { return "FE:" + useBigrams + "/" + useSyntax + "/" + useSpans + "/" + useSemantics; } }