/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.pqc.crypto.snova;

import java.security.SecureRandom;
import org.bouncycastle.crypto.CipherParameters;
import org.bouncycastle.crypto.CryptoServicesRegistrar;
import org.bouncycastle.crypto.digests.SHAKEDigest;
import org.bouncycastle.crypto.params.ParametersWithRandom;
import org.bouncycastle.pqc.crypto.MessageSigner;
import org.bouncycastle.pqc.crypto.snova.GF16Utils;
import org.bouncycastle.pqc.crypto.snova.MapGroup1;
import org.bouncycastle.pqc.crypto.snova.SnovaEngine;
import org.bouncycastle.pqc.crypto.snova.SnovaKeyElements;
import org.bouncycastle.pqc.crypto.snova.SnovaParameters;
import org.bouncycastle.pqc.crypto.snova.SnovaPrivateKeyParameters;
import org.bouncycastle.pqc.crypto.snova.SnovaPublicKeyParameters;
import org.bouncycastle.util.Arrays;
import org.bouncycastle.util.GF16;

public class SnovaSigner
implements MessageSigner {
    private SnovaParameters params;
    private SnovaEngine engine;
    private SecureRandom random;
    private final SHAKEDigest shake = new SHAKEDigest(256);
    private SnovaPublicKeyParameters pubKey;
    private SnovaPrivateKeyParameters privKey;

    public void init(boolean forSigning, CipherParameters param) {
        if (forSigning) {
            this.pubKey = null;
            if (param instanceof ParametersWithRandom) {
                ParametersWithRandom withRandom = (ParametersWithRandom)param;
                this.privKey = (SnovaPrivateKeyParameters)withRandom.getParameters();
                this.random = withRandom.getRandom();
            } else {
                this.privKey = (SnovaPrivateKeyParameters)param;
                this.random = CryptoServicesRegistrar.getSecureRandom();
            }
            this.params = this.privKey.getParameters();
        } else {
            this.pubKey = (SnovaPublicKeyParameters)param;
            this.params = this.pubKey.getParameters();
            this.privKey = null;
            this.random = null;
        }
        this.engine = new SnovaEngine(this.params);
    }

    public byte[] generateSignature(byte[] message) {
        byte[] ptPrivateKeySeed;
        byte[] publicKeySeed;
        byte[] hash = this.getMessageHash(message);
        byte[] salt = new byte[this.params.getSaltLength()];
        this.random.nextBytes(salt);
        byte[] signature = new byte[(this.params.getN() * this.params.getLsq() + 1 >>> 1) + this.params.getSaltLength()];
        SnovaKeyElements keyElements = new SnovaKeyElements(this.params);
        if (this.params.isSkIsSeed()) {
            byte[] seedPair = this.privKey.getPrivateKey();
            publicKeySeed = Arrays.copyOfRange(seedPair, 0, 16);
            ptPrivateKeySeed = Arrays.copyOfRange(seedPair, 16, seedPair.length);
            this.engine.genMap1T12Map2(keyElements, publicKeySeed, ptPrivateKeySeed);
        } else {
            byte[] privateKey = this.privKey.getPrivateKey();
            byte[] tmp = new byte[privateKey.length - 16 - 32 << 1];
            GF16Utils.decodeMergeInHalf(privateKey, tmp, tmp.length);
            int inOff = 0;
            inOff = SnovaKeyElements.copy3d(tmp, inOff, keyElements.map1.aAlpha);
            inOff = SnovaKeyElements.copy3d(tmp, inOff, keyElements.map1.bAlpha);
            inOff = SnovaKeyElements.copy3d(tmp, inOff, keyElements.map1.qAlpha1);
            inOff = SnovaKeyElements.copy3d(tmp, inOff, keyElements.map1.qAlpha2);
            inOff = SnovaKeyElements.copy3d(tmp, inOff, keyElements.T12);
            inOff = SnovaKeyElements.copy4d(tmp, inOff, keyElements.map2.f11);
            inOff = SnovaKeyElements.copy4d(tmp, inOff, keyElements.map2.f12);
            SnovaKeyElements.copy4d(tmp, inOff, keyElements.map2.f21);
            publicKeySeed = Arrays.copyOfRange(privateKey, privateKey.length - 16 - 32, privateKey.length - 32);
            ptPrivateKeySeed = Arrays.copyOfRange(privateKey, privateKey.length - 32, privateKey.length);
        }
        this.signDigestCore(signature, hash, salt, keyElements.map1.aAlpha, keyElements.map1.bAlpha, keyElements.map1.qAlpha1, keyElements.map1.qAlpha2, keyElements.T12, keyElements.map2.f11, keyElements.map2.f12, keyElements.map2.f21, publicKeySeed, ptPrivateKeySeed);
        return Arrays.concatenate(signature, message);
    }

    public boolean verifySignature(byte[] message, byte[] signature) {
        byte[] hash = this.getMessageHash(message);
        MapGroup1 map1 = new MapGroup1(this.params);
        byte[] pk = this.pubKey.getEncoded();
        byte[] publicKeySeed = Arrays.copyOf(pk, 16);
        byte[] p22_source = Arrays.copyOfRange(pk, 16, pk.length);
        this.engine.genABQP(map1, publicKeySeed);
        byte[][][][] p22 = new byte[this.params.getM()][this.params.getO()][this.params.getO()][this.params.getLsq()];
        if ((this.params.getLsq() & 1) == 0) {
            MapGroup1.decodeP(p22_source, 0, p22, p22_source.length << 1);
        } else {
            byte[] p22_gf16s = new byte[p22_source.length << 1];
            GF16.decode(p22_source, p22_gf16s, p22_gf16s.length);
            MapGroup1.fillP(p22_gf16s, 0, p22, p22_gf16s.length);
        }
        return this.verifySignatureCore(hash, signature, publicKeySeed, map1, p22);
    }

    void createSignedHash(byte[] ptPublicKeySeed, int seedLengthPublic, byte[] digest, int bytesDigest, byte[] arraySalt, int saltOff, int bytesSalt, byte[] signedHashOut, int bytesHash) {
        this.shake.update(ptPublicKeySeed, 0, seedLengthPublic);
        this.shake.update(digest, 0, bytesDigest);
        this.shake.update(arraySalt, saltOff, bytesSalt);
        this.shake.doFinal(signedHashOut, 0, bytesHash);
    }

    void signDigestCore(byte[] ptSignature, byte[] digest, byte[] arraySalt, byte[][][] Aalpha, byte[][][] Balpha, byte[][][] Qalpha1, byte[][][] Qalpha2, byte[][][] T12, byte[][][][] F11, byte[][][][] F12, byte[][][][] F21, byte[] ptPublicKeySeed, byte[] ptPrivateKeySeed) {
        int flagRedo;
        int m = this.params.getM();
        int l = this.params.getL();
        int lsq = this.params.getLsq();
        int alpha = this.params.getAlpha();
        int v = this.params.getV();
        int o = this.params.getO();
        int n = this.params.getN();
        int mxlsq = m * lsq;
        int oxlsq = o * lsq;
        int vxlsq = v * lsq;
        int bytesHash = oxlsq + 1 >>> 1;
        int bytesSalt = 16;
        byte[][] Gauss = new byte[mxlsq][mxlsq + 1];
        byte[][] Temp = new byte[lsq][lsq];
        byte[] solution = new byte[mxlsq];
        byte[][][] Left = new byte[alpha][v][lsq];
        byte[][][] Right = new byte[alpha][v][lsq];
        byte[] leftXTmp = new byte[lsq];
        byte[] rightXtmp = new byte[lsq];
        byte[] FvvGF16Matrix = new byte[lsq];
        byte[] hashInGF16 = new byte[mxlsq];
        byte[] vinegarGf16 = new byte[n * lsq];
        byte[] signedHash = new byte[bytesHash];
        byte[] vinegarBytes = new byte[vxlsq + 1 >>> 1];
        byte[] gf16mTemp0 = new byte[l];
        byte numSign = 0;
        this.createSignedHash(ptPublicKeySeed, ptPublicKeySeed.length, digest, digest.length, arraySalt, 0, arraySalt.length, signedHash, bytesHash);
        GF16.decode(signedHash, 0, hashInGF16, 0, hashInGF16.length);
        do {
            int i;
            for (i = 0; i < Gauss.length; ++i) {
                Arrays.fill(Gauss[i], (byte)0);
            }
            numSign = (byte)(numSign + 1);
            for (i = 0; i < mxlsq; ++i) {
                Gauss[i][mxlsq] = hashInGF16[i];
            }
            this.shake.update(ptPrivateKeySeed, 0, ptPrivateKeySeed.length);
            this.shake.update(digest, 0, digest.length);
            this.shake.update(arraySalt, 0, arraySalt.length);
            this.shake.update(numSign);
            this.shake.doFinal(vinegarBytes, 0, vinegarBytes.length);
            GF16.decode(vinegarBytes, vinegarGf16, vinegarBytes.length << 1);
            i = 0;
            int ixlsq = 0;
            while (i < m) {
                Arrays.fill(FvvGF16Matrix, (byte)0);
                int a = 0;
                int miPrime = i;
                while (a < alpha) {
                    if (miPrime >= o) {
                        miPrime -= o;
                    }
                    int j = 0;
                    int jxlsq = 0;
                    while (j < v) {
                        GF16Utils.gf16mTranMulMul(vinegarGf16, jxlsq, Aalpha[i][a], Balpha[i][a], Qalpha1[i][a], Qalpha2[i][a], gf16mTemp0, Left[a][j], Right[a][j], l);
                        ++j;
                        jxlsq += lsq;
                    }
                    for (j = 0; j < v; ++j) {
                        for (int k = 0; k < v; ++k) {
                            GF16Utils.gf16mMulMulTo(Left[a][j], F11[miPrime][j][k], Right[a][k], gf16mTemp0, FvvGF16Matrix, l);
                        }
                    }
                    ++a;
                    ++miPrime;
                }
                int off = 0;
                for (int j = 0; j < l; ++j) {
                    for (int k = 0; k < l; ++k) {
                        byte[] byArray = Gauss[ixlsq + off];
                        int n2 = mxlsq;
                        byArray[n2] = (byte)(byArray[n2] ^ FvvGF16Matrix[off++]);
                    }
                }
                int index = 0;
                int idxlsq = 0;
                while (index < o) {
                    int a2 = 0;
                    int mi_prime = i;
                    while (a2 < alpha) {
                        int ti;
                        if (mi_prime >= o) {
                            mi_prime -= o;
                        }
                        for (ti = 0; ti < lsq; ++ti) {
                            Arrays.fill(Temp[ti], (byte)0);
                        }
                        for (int j = 0; j < v; ++j) {
                            GF16Utils.gf16mMulMul(Left[a2][j], F12[mi_prime][j][index], Qalpha2[i][a2], gf16mTemp0, leftXTmp, l);
                            GF16Utils.gf16mMulMul(Qalpha1[i][a2], F21[mi_prime][index][j], Right[a2][j], gf16mTemp0, rightXtmp, l);
                            int ti2 = 0;
                            int colB_colRight = 0;
                            int rlraxl = 0;
                            while (ti2 < lsq) {
                                if (colB_colRight == l) {
                                    colB_colRight = 0;
                                    rlraxl += l;
                                }
                                byte valLeft = leftXTmp[rlraxl];
                                byte valRight = rightXtmp[colB_colRight];
                                int tj = 0;
                                int rowB_colA = 0;
                                int colLeft_rowRight = 0;
                                int clrrxl = 0;
                                int rbcaxl = 0;
                                while (tj < lsq) {
                                    if (rowB_colA == l) {
                                        rowB_colA = 0;
                                        rbcaxl = 0;
                                        valLeft = leftXTmp[rlraxl + ++colLeft_rowRight];
                                        valRight = rightXtmp[(clrrxl += l) + colB_colRight];
                                    }
                                    byte valB = Balpha[i][a2][rbcaxl + colB_colRight];
                                    byte valA = Aalpha[i][a2][rlraxl + rowB_colA];
                                    byte[] byArray = Temp[ti2];
                                    int n3 = tj++;
                                    byArray[n3] = (byte)(byArray[n3] ^ (GF16.mul(valLeft, valB) ^ GF16.mul(valA, valRight)));
                                    ++rowB_colA;
                                    rbcaxl += l;
                                }
                                ++ti2;
                                ++colB_colRight;
                            }
                        }
                        for (ti = 0; ti < lsq; ++ti) {
                            for (int tj = 0; tj < lsq; ++tj) {
                                byte[] byArray = Gauss[ixlsq + ti];
                                int n4 = idxlsq + tj;
                                byArray[n4] = (byte)(byArray[n4] ^ Temp[ti][tj]);
                            }
                        }
                        ++a2;
                        ++mi_prime;
                    }
                    ++index;
                    idxlsq += lsq;
                }
                ++i;
                ixlsq += lsq;
            }
        } while ((flagRedo = this.performGaussianElimination(Gauss, solution, mxlsq)) != 0);
        int idx = 0;
        int idxlsq = 0;
        while (idx < v) {
            int i = 0;
            int ixlsq = 0;
            while (i < o) {
                GF16Utils.gf16mMulTo(T12[idx][i], solution, ixlsq, vinegarGf16, idxlsq, l);
                ++i;
                ixlsq += lsq;
            }
            ++idx;
            idxlsq += lsq;
        }
        System.arraycopy(solution, 0, vinegarGf16, vxlsq, oxlsq);
        GF16.encode(vinegarGf16, ptSignature, vinegarGf16.length);
        System.arraycopy(arraySalt, 0, ptSignature, ptSignature.length - 16, 16);
    }

    boolean verifySignatureCore(byte[] digest, byte[] signature, byte[] publicKeySeed, MapGroup1 map1, byte[][][][] p22) {
        int lsq = this.params.getLsq();
        int o = this.params.getO();
        int oxlsq = o * lsq;
        int bytesHash = oxlsq + 1 >>> 1;
        int bytesSalt = this.params.getSaltLength();
        int m = this.params.getM();
        int n = this.params.getN();
        int nxlsq = n * lsq;
        int bytesSignature = nxlsq + 1 >>> 1;
        byte[] signedHash = new byte[bytesHash];
        this.createSignedHash(publicKeySeed, publicKeySeed.length, digest, digest.length, signature, bytesSignature, bytesSalt, signedHash, bytesHash);
        if ((oxlsq & 1) != 0) {
            int n2 = bytesHash - 1;
            signedHash[n2] = (byte)(signedHash[n2] & 0xF);
        }
        byte[] decodedSig = new byte[nxlsq];
        GF16.decode(signature, 0, decodedSig, 0, decodedSig.length);
        byte[] computedHashBytes = new byte[m * lsq];
        this.evaluation(computedHashBytes, map1, p22, decodedSig);
        byte[] encodedHash = new byte[bytesHash];
        GF16.encode(computedHashBytes, encodedHash, computedHashBytes.length);
        return Arrays.areEqual(signedHash, encodedHash);
    }

    private void evaluation(byte[] hashMatrix, MapGroup1 map1, byte[][][][] p22, byte[] signature) {
        int m = this.params.getM();
        int alpha = this.params.getAlpha();
        int n = this.params.getN();
        int l = this.params.getL();
        int lsq = this.params.getLsq();
        int o = this.params.getO();
        byte[][][] Left = new byte[alpha][n][lsq];
        byte[][][] Right = new byte[alpha][n][lsq];
        byte[] temp = new byte[lsq];
        int mi = 0;
        int mixlsq = 0;
        while (mi < m) {
            int si = 0;
            int sixlsq = 0;
            while (si < n) {
                for (int a = 0; a < alpha; ++a) {
                    GF16Utils.gf16mTranMulMul(signature, sixlsq, map1.aAlpha[mi][a], map1.bAlpha[mi][a], map1.qAlpha1[mi][a], map1.qAlpha2[mi][a], temp, Left[a][si], Right[a][si], l);
                }
                ++si;
                sixlsq += lsq;
            }
            int a = 0;
            int miPrime = mi;
            while (a < alpha) {
                if (miPrime >= o) {
                    miPrime -= o;
                }
                for (int ni = 0; ni < n; ++ni) {
                    byte[] p = this.getPMatrix(map1, p22, miPrime, ni, 0);
                    GF16Utils.gf16mMul(p, Right[a][0], temp, l);
                    for (int nj = 1; nj < n; ++nj) {
                        p = this.getPMatrix(map1, p22, miPrime, ni, nj);
                        GF16Utils.gf16mMulTo(p, Right[a][nj], temp, l);
                    }
                    GF16Utils.gf16mMulTo(Left[a][ni], temp, hashMatrix, mixlsq, l);
                }
                ++a;
                ++miPrime;
            }
            ++mi;
            mixlsq += lsq;
        }
    }

    private byte[] getPMatrix(MapGroup1 map1, byte[][][][] p22, int mi, int ni, int nj) {
        int v = this.params.getV();
        if (ni < v) {
            if (nj < v) {
                return map1.p11[mi][ni][nj];
            }
            return map1.p12[mi][ni][nj - v];
        }
        if (nj < v) {
            return map1.p21[mi][ni - v][nj];
        }
        return p22[mi][ni - v][nj - v];
    }

    private int performGaussianElimination(byte[][] Gauss, byte[] solution, int size) {
        int i;
        int cols = size + 1;
        for (i = 0; i < size; ++i) {
            int j;
            int pivot;
            for (pivot = i; pivot < size && Gauss[pivot][i] == 0; ++pivot) {
            }
            if (pivot >= size) {
                return 1;
            }
            if (pivot != i) {
                byte[] tempRow = Gauss[i];
                Gauss[i] = Gauss[pivot];
                Gauss[pivot] = tempRow;
            }
            byte invPivot = GF16.inv(Gauss[i][i]);
            for (j = i; j < cols; ++j) {
                Gauss[i][j] = GF16.mul(Gauss[i][j], invPivot);
            }
            for (j = i + 1; j < size; ++j) {
                byte factor = Gauss[j][i];
                if (factor == 0) continue;
                for (int k = i; k < cols; ++k) {
                    byte[] byArray = Gauss[j];
                    int n = k;
                    byArray[n] = (byte)(byArray[n] ^ GF16.mul(Gauss[i][k], factor));
                }
            }
        }
        for (i = size - 1; i >= 0; --i) {
            byte tmp = Gauss[i][size];
            for (int j = i + 1; j < size; ++j) {
                tmp = (byte)(tmp ^ GF16.mul(Gauss[i][j], solution[j]));
            }
            solution[i] = tmp;
        }
        return 0;
    }

    private byte[] getMessageHash(byte[] message) {
        byte[] hash = new byte[this.shake.getDigestSize()];
        this.shake.update(message, 0, message.length);
        this.shake.doFinal(hash, 0);
        return hash;
    }
}

