package is2.transitionS2mwe;

import is2.data.Cluster;
import is2.data.F2SF;
import is2.data.Instances;
import is2.data.InstancesTagger;
import is2.data.Long2IntInterface;
import is2.data.Parse;
import is2.util.IntStack;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/* loaded from: input_file:is2/transitionS2mwe/Decoder.class */
public final class Decoder {
    static final int LA = 1;
    static final int RA = 2;
    static final int SHIFT = 3;
    static final int REDUCE = 4;
    static final int SWAP = 5;
    public static long timeDecoder;
    public static long t1;
    public static long t2;
    public static long t3;
    public static long t4;
    public static long t5;
    public static long t6;
    public static long t7;
    public ExecutorService executerService;
    private final Guide g;
    ExtractorPet[] extp;
    ExtractorR[] extractor;
    Extender[] extender;
    Long2IntInterface li;
    static int avgCount;
    public static int forcedStop;
    static float[] scoresO = new float[6];
    static float[] countSO = new float[6];
    static float[] sdO = new float[6];
    static float[] scoresB = new float[6];
    static float[] countSB = new float[6];
    static float[] sdB = new float[6];
    static int beam = 40;
    static int hm = 2;
    static int ht = 2;
    public static int correctPos = 0;
    public static int wrongPos = 0;
    static float avgHighest = 0.0f;
    static float avgLowest = 0.0f;
    public static float allNth = 0.0f;
    static int errors = 0;
    static int count = 0;

    /* loaded from: input_file:is2/transitionS2mwe/Decoder$R.class */
    public static class R {
        Parse parse;
        short[] pos;
        short[] mos;

        public R(Parse parse, short[] sArr, short[] sArr2) {
            this.parse = parse;
            this.pos = sArr;
            this.mos = sArr2;
        }
    }

    public Decoder(Long2IntInterface long2IntInterface, ParametersFloat parametersFloat, Cluster cluster, Lexicon lexicon) {
        this(long2IntInterface, parametersFloat, cluster, 0, lexicon);
    }

    public Decoder(Long2IntInterface long2IntInterface, ParametersFloat parametersFloat, Cluster cluster, int i, Lexicon lexicon) {
        this.executerService = Executors.newFixedThreadPool(Parser.THREADS);
        int i2 = beam + ht + hm + 2;
        this.extp = new ExtractorPet[i2];
        this.extractor = new ExtractorR[i2];
        this.extender = new Extender[i2];
        this.li = long2IntInterface;
        if (this.extractor[0] == null) {
            for (int i3 = 0; i3 < this.extractor.length; i3++) {
                this.extractor[i3] = new ExtractorR(long2IntInterface, false, 1, lexicon);
                this.extractor[i3].init();
            }
        }
        for (int i4 = 0; i4 < this.extp.length; i4++) {
            this.extp[i4] = new ExtractorPet(parametersFloat.getFVP(), cluster, long2IntInterface);
        }
        this.g = new Guide();
    }

    public R decode(boolean z, GuideOracle guideOracle, ParametersFloat parametersFloat, int i, InstancesTagger instancesTagger, int i2, Cluster cluster, POS[][] posArr, POS[][] posArr2) throws InterruptedException {
        long nanoTime = System.nanoTime();
        this.g.initialize(instancesTagger, i2);
        State state = guideOracle == null ? null : new State(instancesTagger.pposs[i2].length, true);
        if (guideOracle != null) {
            for (int i3 = 0; i3 < posArr[0].length; i3++) {
                state.pos[i3] = (short) posArr[0][i3].p;
            }
            for (int i4 = 0; i4 < posArr[0].length; i4++) {
                for (int i5 = 1; i5 < Tagger2.size; i5++) {
                    if (posArr[i5][i4].p == instancesTagger.gpos[i2][i4] && posArr[0][i4].s - posArr[i5][i4].s <= Tagger2.THRESHOLD) {
                        state.pos[i4] = (short) posArr[i5][i4].p;
                        state.nth[i4] = (short) i5;
                    }
                }
            }
            for (int i6 = 0; i6 < posArr2[0].length; i6++) {
                state.mos[i6] = (short) posArr2[0][i6].p;
            }
            for (int i7 = 0; i7 < posArr2[0].length; i7++) {
                for (int i8 = 1; i8 < Tagger2.msize; i8++) {
                    if (posArr2[i8][i7].p == instancesTagger.gfeats[i2][i7] && posArr2[0][i7].s - posArr2[i8][i7].s <= Tagger2.MTHRESHOLD) {
                        state.mos[i7] = (short) posArr2[i8][i7].p;
                        state.mth[i7] = (short) i8;
                    }
                }
            }
        }
        if (this.executerService.isShutdown()) {
            this.executerService = Executors.newCachedThreadPool();
        }
        ArrayList arrayList = new ArrayList();
        State state2 = new State(instancesTagger.length(i2), guideOracle != null);
        arrayList.add(state2);
        state2.pos = new short[posArr[0].length];
        for (int i9 = 0; i9 < posArr[0].length; i9++) {
            state2.pos[i9] = (short) posArr[0][i9].p;
        }
        state2.mos = new short[posArr2[0].length];
        for (int i10 = 0; i10 < posArr2[0].length; i10++) {
            state2.mos[i10] = (short) posArr2[0][i10].p;
        }
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        HashMap hashMap = new HashMap();
        ArrayList arrayList5 = new ArrayList();
        ArrayList arrayList6 = new ArrayList();
        if (this.extender[0] == null) {
            for (int i11 = 0; i11 < this.extender.length; i11++) {
                this.extender[i11] = new Extender(this.extractor[i11], this.extp[i11], parametersFloat.getFV(), guideOracle != null, cluster, posArr, posArr2);
            }
        }
        for (int i12 = 0; i12 < this.extender.length; i12++) {
            this.extender[i12].is = instancesTagger;
            this.extender[i12].i = i2;
            this.extender[i12].g = this.g;
            this.extender[i12].px = posArr;
            this.extender[i12].mx = posArr2;
        }
        t1 += System.nanoTime() - nanoTime;
        int i13 = 0;
        int length = instancesTagger.length(i2) * 10;
        while (true) {
            long nanoTime2 = System.nanoTime();
            i13++;
            if (i13 > length) {
                forcedStop++;
                break;
            }
            if (guideOracle != null) {
                O operation = guideOracle.getOperation(state.s, state.b, state.b.size() > 0 ? state.b.peek() : -1, state.p);
                if (operation.o == 1 || operation.o == 2) {
                    int i14 = state.s.size() > 1 ? state.s.get(state.s.size() - 2) : -1;
                    int i15 = state.s.size() > 0 ? state.s.get(state.s.size() - 1) : -1;
                    short[] sArr = Edges.get(state.pos[i14], state.pos[i15]);
                    boolean z2 = false;
                    int length2 = sArr.length;
                    int i16 = 0;
                    while (true) {
                        if (i16 >= length2) {
                            break;
                        }
                        if (operation.l == sArr[i16]) {
                            z2 = true;
                            break;
                        }
                        i16++;
                    }
                    if (!z2) {
                        Edges.put(state.pos[i15], state.pos[i14], (short) operation.l);
                        Edges.put(state.pos[i14], state.pos[i15], (short) operation.l);
                    }
                }
                if (operation == null || !possible(operation.o, state.s, state.b.peek(), state.p)) {
                    break;
                }
                state = State.perform(operation, state);
            }
            arrayList3.clear();
            for (int i17 = 0; i17 < arrayList.size(); i17++) {
                this.extender[i17].s = (State) arrayList.get(i17);
                arrayList3.add(this.extender[i17]);
            }
            long nanoTime3 = System.nanoTime();
            t2 += nanoTime3 - nanoTime2;
            this.executerService.invokeAll(arrayList3);
            long nanoTime4 = System.nanoTime();
            t3 += nanoTime4 - nanoTime3;
            arrayList2.clear();
            hashMap.clear();
            Iterator it = arrayList3.iterator();
            while (it.hasNext()) {
                Iterator<State> it2 = ((Extender) it.next()).r.iterator();
                while (it2.hasNext()) {
                    State next = it2.next();
                    ArrayList arrayList7 = (ArrayList) hashMap.get(next.sig);
                    if (arrayList7 == null) {
                        arrayList7 = new ArrayList();
                        hashMap.put(next.sig, arrayList7);
                    }
                    arrayList7.add(next);
                }
            }
            for (Map.Entry entry : hashMap.entrySet()) {
                Collections.sort((List) entry.getValue());
                arrayList2.add((State) ((ArrayList) entry.getValue()).get(0));
                ((ArrayList) entry.getValue()).remove(0);
            }
            Collections.sort(arrayList2);
            arrayList.clear();
            float score = ((State) arrayList2.get(0)).score();
            for (int i18 = 0; i18 < beam && arrayList2.size() > i18; i18++) {
                State state3 = (State) arrayList2.get(i18);
                if (state3.score() + 0.12d < score) {
                    break;
                }
                arrayList.add(state3);
            }
            arrayList5.clear();
            arrayList6.clear();
            float f = -1999.0f;
            float f2 = 2000.0f;
            Iterator it3 = arrayList.iterator();
            while (it3.hasNext()) {
                State state4 = (State) it3.next();
                ArrayList arrayList8 = (ArrayList) hashMap.get(state4.sig);
                if (f < state4.score()) {
                    f = state4.score();
                }
                if (f2 > state4.score()) {
                    f2 = state4.score();
                }
                Iterator it4 = arrayList8.iterator();
                while (it4.hasNext()) {
                    State state5 = (State) it4.next();
                    int i19 = 0;
                    while (true) {
                        if (i19 < state4.pos.length) {
                            if (state4.pos[i19] != state5.pos[i19]) {
                                arrayList5.add(state5);
                                break;
                            }
                            if (state4.mos[i19] != state5.mos[i19]) {
                                arrayList6.add(state5);
                                break;
                            }
                            i19++;
                        }
                    }
                }
            }
            avgHighest += f;
            avgLowest += f2;
            avgCount++;
            Collections.sort(arrayList5);
            if (Tagger2.msize <= 1) {
                arrayList6.clear();
            }
            Collections.sort(arrayList6);
            if (arrayList5.size() <= ht) {
                arrayList.addAll(arrayList5);
            } else {
                arrayList.addAll(arrayList5.subList(0, ht));
            }
            if (arrayList6.size() <= hm) {
                arrayList.addAll(arrayList6);
            } else {
                arrayList.addAll(arrayList6.subList(0, hm));
            }
            long nanoTime5 = System.nanoTime();
            t4 += nanoTime5 - nanoTime4;
            if (guideOracle != null) {
                boolean z3 = false;
                Iterator it5 = arrayList.iterator();
                while (it5.hasNext()) {
                    boolean contains = state.contains((State) it5.next());
                    z3 = contains;
                    if (contains) {
                        break;
                    }
                }
                if (!z3) {
                    break;
                }
            }
            long nanoTime6 = System.nanoTime();
            t5 += nanoTime6 - nanoTime5;
            boolean z4 = true;
            Iterator it6 = arrayList.iterator();
            while (it6.hasNext()) {
                State state6 = (State) it6.next();
                if (state6.b.size() > 0 || state6.s.size() > 1) {
                    z4 = false;
                }
                if (state6.b.isEmpty() && state6.s.size() < 2 && length > 2 * state6.steps) {
                    length = 2 * state6.steps;
                }
            }
            t6 += System.nanoTime() - nanoTime6;
            if (z4) {
                break;
            }
        }
        long nanoTime7 = System.nanoTime();
        if (guideOracle != null) {
            State state7 = (State) arrayList.get(0);
            double errors2 = Pipe.errors(instancesTagger, i2, state7.p);
            errors = (int) (errors + errors2);
            count += instancesTagger.length(i2);
            if (!state.contains(state7) && errors2 > 0.0d) {
                float diff = state.diff(state7) + 1;
                arrayList4.clear();
                arrayList4.add(new Encoder(instancesTagger, i2, state7.pos, state7.mos, state7.p.heads, state7.p.labels, cluster, new State(instancesTagger.pposs[i2].length, true), false, this.extractor[0], this.extp[0], posArr, posArr2, state7.getHistory(), state7.nth, state7.mth, "pred"));
                arrayList4.add(new Encoder(instancesTagger, i2, state.pos, state.mos, state.p.heads, state.p.labels, cluster, new State(instancesTagger.pposs[i2].length, true), false, this.extractor[1], this.extp[1], posArr, posArr2, state.getHistory(), state.nth, state.mth, "gold"));
                this.executerService.invokeAll(arrayList4);
                parametersFloat.update(((Encoder) arrayList4.get(1)).f, ((Encoder) arrayList4.get(0)).f, i, diff, state7.steps, state.steps);
            }
        }
        int i20 = 0;
        Iterator it7 = arrayList.iterator();
        while (it7.hasNext()) {
            State state8 = (State) it7.next();
            while (true) {
                State state9 = state8;
                if (state9 != null && state9.oper != null) {
                    float[] fArr = scoresB;
                    int i21 = state9.oper.o;
                    fArr[i21] = fArr[i21] + state9.oper.p;
                    float[] fArr2 = countSB;
                    int i22 = state9.oper.o;
                    fArr2[i22] = fArr2[i22] + 1.0f;
                    float[] fArr3 = sdB;
                    int i23 = state9.oper.o;
                    fArr3[i23] = fArr3[i23] + ((float) Math.pow((scoresB[state9.oper.o] / countSB[state9.oper.o]) - state9.oper.p, 2.0d));
                    if (i20 == 0) {
                        float[] fArr4 = scoresO;
                        int i24 = state9.oper.o;
                        fArr4[i24] = fArr4[i24] + state9.oper.p;
                        float[] fArr5 = countSO;
                        int i25 = state9.oper.o;
                        fArr5[i25] = fArr5[i25] + 1.0f;
                        float[] fArr6 = sdO;
                        int i26 = state9.oper.o;
                        fArr6[i26] = fArr6[i26] + ((float) Math.pow((scoresO[state9.oper.o] / countSO[state9.oper.o]) - state9.oper.p, 2.0d));
                    }
                    state8 = state9.previous;
                }
            }
            i20++;
        }
        t7 += nanoTime7 - System.nanoTime();
        State state10 = (State) arrayList.get(0);
        Parse parse = ((State) arrayList.get(0)).p;
        for (int i27 = 1; i27 < state10.nth.length; i27++) {
            if (instancesTagger.gpos[i2][i27] == state10.pos[i27]) {
                correctPos++;
            } else {
                wrongPos++;
            }
        }
        float f3 = 0.0f;
        for (int i28 = 1; i28 < state10.nth.length; i28++) {
            f3 += state10.nth[i28];
        }
        allNth += f3 / (state10.nth.length - 1);
        parse.heads[0] = -1;
        if (guideOracle == null) {
            for (int i29 = 1; i29 < parse.heads.length; i29++) {
                if (parse.heads[i29] == -1) {
                    attach(instancesTagger, i2, this.extp[0], this.g, guideOracle, parametersFloat, parse, 3, i29, posArr, posArr2, state10.pos, state10.nth, state10);
                }
            }
        }
        timeDecoder += System.nanoTime() - nanoTime;
        return new R(parse, state10.pos, state10.mos);
    }

    private static void attach(Instances instances, int i, ExtractorPet extractorPet, Guide guide, GuideOracle guideOracle, ParametersFloat parametersFloat, Parse parse, int i2, int i3, POS[][] posArr, POS[][] posArr2, short[] sArr, short[] sArr2, State state) {
        IntStack intStack = new IntStack(parse.heads.length);
        IntStack intStack2 = new IntStack(parse.heads.length);
        F2SF fv = parametersFloat.getFV();
        int i4 = -1;
        int i5 = 0;
        float f = -100.0f;
        for (int i6 = 0; i6 < parse.heads.length; i6++) {
            intStack.clear();
            intStack2.clear();
            if (i6 < i3) {
                intStack.push(i6);
                intStack.push(i3);
                for (O o : guide.getOperation(fv, extractorPet, state, posArr, posArr2)) {
                    if (o.o == 2 && f < o.p) {
                        f = o.p;
                        i4 = i6;
                        i5 = o.l;
                    }
                }
            } else if (i6 > i3) {
                intStack.push(i3);
                intStack.push(i6);
                for (O o2 : guide.getOperation(fv, extractorPet, state, posArr, posArr2)) {
                    if (o2.o == 1 && f < o2.p) {
                        f = o2.p;
                        i4 = i6;
                        i5 = o2.l;
                    }
                }
            }
        }
        parse.heads[i3] = (short) i4;
        parse.labels[i3] = (short) i5;
    }

    static boolean hasHead(Parse parse, Integer num) {
        return num != null && parse.heads[num.intValue()] >= 0;
    }

    public static String getInfo() {
        return "Beam size: " + beam;
    }

    public static final boolean possible(int i, IntStack intStack, int i2, Parse parse) {
        if (i == 3 && i2 >= 0) {
            return true;
        }
        if (i == 1 && intStack.size() > 1 && intStack.get(intStack.size() - 2) != 0) {
            return true;
        }
        if (i != 2 || intStack.size() <= 1) {
            return i == 5 && intStack.size() > 1 && intStack.get(intStack.size() - 2) < intStack.peek();
        }
        return true;
    }
}
