/*  $Id: test_MLDSA_json.cs $ 
 *  $Date: 2025-06-01 07:17:00 $
 *  $Version: 1.0.1 $
 */

/* Use CryptoSys PQC to validate against the NIST ACVP test vectors for ML-DSA-sigGen-FIPS204 RELEASE/v1.1.0.39.
 * https://github.com/usnistgov/ACVP-Server/blob/master/gen-val/json-files/ML-DSA-sigGen-FIPS204/internalProjection.json
 * https://github.com/usnistgov/ACVP-Server/tree/master/gen-val/json-files/ML-DSA-sigVer-FIPS204/internalProjection.json
 * and for SLH-DSA-sigGen-FIPS205 RELEASE/v1.1.0.38
 * https://github.com/usnistgov/ACVP-Server/blob/master/gen-val/json-files/SLH-DSA-sigGen-FIPS205/internalProjection.json
 * https://github.com/usnistgov/ACVP-Server/tree/master/gen-val/json-files/SLH-DSA-sigVer-FIPS205/internalProjection.json
 * (Note these are all big files and will not display raw in github)
 * 
 * Requires `CryptoSys PQC` to be installed on your system: available from <https://cryptosys.net/pqc/>.
 * Add a reference to `diCrSysPQCNet.dll` installed in `C:\Program Files (x86)\CrSysPQC\DotNet`,
 * or include the C# source code module `CryptoSysPQC.cs` directly in your project.
 * 
 * This is a Console Application written for target .NET Framework 4.7.2 and above using Newtonsoft.Json.
 * Please report any bugs to <https://cryptosys.net/contact>.
 */
/******************************* LICENSE ***********************************
 * Copyright (C) 2025 David Ireland, DI Management Services Pty Limited.
 * t/a CryptoSys <www.di-mgt.com.au> <www.cryptosys.net> All rights reserved. 
 * The code in this module is licensed under the terms of the MIT license.  
 * For a copy, see <http://opensource.org/licenses/MIT>
****************************************************************************
*/

/*
 * Usage: test_DSA_json {[mldsa]|slhdsa|both|verify}
 *   mldsa  - do only the ML-DSA-sigGen tests (default)
 *   slhdsa - do only the SLH-DSA-sigGen tests
 *   all    - do all sets of tests
 *   verify - do the verify (DSA-sigVer) tests instead of DSA-sigGen
 * Default is just to run the ML-DSA-sigGen tests.
 * Warning: the SLH-DSA-sigGen tests take several minutes. Be prepared to wait.
 * 
 * Requires the following files to exist in the current working directory.
 *   `ML-DSA-sigGen-FIPS204-1_1_0_39.json`, 
 *   `ML-DSA-sigGen-preHashes-1_1_0_39.json`,
 *   `SLH-DSA-sigGen-FIPS205-1_1_0_38.json`, 
 *   `SLH-DSA-sigGen-preHashes-1_1_0_38.json`
 *   `ML-DSA-sigVer-FIPS204-1_1_0_39.json`, 
 *   `ML-DSA-sigVer-preHashes-1_1_0_39.json`,
 *   `SLH-DSA-sigVer-FIPS205-1_1_0_38.json`, 
 *   `SLH-DSA-sigVer-preHashes-1_1_0_38.json`,
 * 
 * These can be downloaded from
 * <https://www.cryptosys.net/pqc/ACVP-pqc-json.zip>
 */

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.IO;
using System.Diagnostics;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using CryptoSysPQC;

namespace CryptoSys.Test_JSON
{
    class test_DSA_json
    {
        // Mapping of json parameterSet string to CryptoSysPQC Dsa.Alg
        public static Dictionary<string, Dsa.Alg> DsaAlgs =
            new Dictionary<string, Dsa.Alg>(){
                {"ML-DSA-44", Dsa.Alg.ML_DSA_44},
                {"ML-DSA-65", Dsa.Alg.ML_DSA_65},
                {"ML-DSA-87", Dsa.Alg.ML_DSA_87},
                {"SLH-DSA-SHA2-128s", Dsa.Alg.SLH_DSA_SHA2_128S},
                {"SLH-DSA-SHA2-128f", Dsa.Alg.SLH_DSA_SHA2_128F},
                {"SLH-DSA-SHA2-192s", Dsa.Alg.SLH_DSA_SHA2_192S},
                {"SLH-DSA-SHA2-192f", Dsa.Alg.SLH_DSA_SHA2_192F},
                {"SLH-DSA-SHA2-256s", Dsa.Alg.SLH_DSA_SHA2_256S},
                {"SLH-DSA-SHA2-256f", Dsa.Alg.SLH_DSA_SHA2_256F},
                {"SLH-DSA-SHAKE-128s", Dsa.Alg.SLH_DSA_SHAKE_128S},
                {"SLH-DSA-SHAKE-128f", Dsa.Alg.SLH_DSA_SHAKE_128F},
                {"SLH-DSA-SHAKE-192s", Dsa.Alg.SLH_DSA_SHAKE_192S},
                {"SLH-DSA-SHAKE-192f", Dsa.Alg.SLH_DSA_SHAKE_192F},
                {"SLH-DSA-SHAKE-256s", Dsa.Alg.SLH_DSA_SHAKE_256S},
                {"SLH-DSA-SHAKE-256f", Dsa.Alg.SLH_DSA_SHAKE_256F},            
            };

        // Mapping of json hashAlg strings to CryptoSysPQC PreHashAlg codes
        public static Dictionary<string, Dsa.PreHashAlg> PreHashAlgs =
            new Dictionary<string, Dsa.PreHashAlg>(){
                {"SHA2-224", Dsa.PreHashAlg.Sha224},
                {"SHA2-256", Dsa.PreHashAlg.Sha256},
                {"SHA2-384", Dsa.PreHashAlg.Sha384},
                {"SHA2-512", Dsa.PreHashAlg.Sha512},
                {"SHA2-512/224", Dsa.PreHashAlg.Sha512_224},
                {"SHA2-512/256", Dsa.PreHashAlg.Sha512_256},
                {"SHA3-224", Dsa.PreHashAlg.Sha3_224},
                {"SHA3-256", Dsa.PreHashAlg.Sha3_256},
                {"SHA3-384", Dsa.PreHashAlg.Sha3_384},
                {"SHA3-512", Dsa.PreHashAlg.Sha3_512},
                {"SHAKE-128", Dsa.PreHashAlg.Shake128_256},
                {"SHAKE-256", Dsa.PreHashAlg.Shake256_512},
            };

        /// <summary>
        /// Check test vectors in sigGen file
        /// </summary>
        /// <param name="fname">Input JSON file</param>
        /// <param name="phfile">Input file of preHash values</param>
        static void ProcessTestJsonSign(string fname, string phfile)
        {
            var watch = System.Diagnostics.Stopwatch.StartNew();
            Console.WriteLine("FILE: {0}", Path.GetFullPath(fname));
            Console.WriteLine("PH FILE: {0}", Path.GetFullPath(phfile));

            // Read in the set of test vectors (internalProjection.json)
            string jsonString = File.ReadAllText(fname);
            JObject jobj = JObject.Parse(jsonString);

            // Read in set of pre-hash values (`ph`)
            /* Expected JSON format:
            "preHashes": {
                "16": {
                    "tcId": 16,
                    "messagelength": 859,
                    "hashAlg": "SHA2-512/256",
                    "ph": "ee6dc6912b33dcf6eb46a262a0452caa96e688aef2f221576e20df597e4fa8c2"
                },
             */
            jsonString = File.ReadAllText(phfile);
            JObject phobj = JObject.Parse(jsonString);

            int ngroups = 0;
            int ntests = 0;
            int ntestsdone = 0;
            byte[] sk, msg, context, sig, sigok, ph;
            string rnd;
            bool deterministic, isprehash, externalmu, isinternal;
            Dsa.Alg alg;
            Dsa.SigOpts opts;
            Dsa.PreHashAlg preHashAlg;
            string tcid;

            // Iterate through each test group...
            foreach (JObject tg in jobj["testGroups"]) {
                ngroups++;
                Console.WriteLine(tg.First);  // Expecting `"tgId": N`
                Console.WriteLine("parameterSet=" + tg["parameterSet"]);
                // Lookup the DSA algorithm
                alg = DsaAlgs[tg["parameterSet"].ToString()];
                Console.WriteLine("deterministic=" + tg["deterministic"]);
                deterministic = (bool)tg["deterministic"];
                Console.WriteLine("preHash=" + tg["preHash"]);
                // Catch preHash vs pure
                isprehash = (tg["preHash"].ToString() == "preHash");
                // And externalMu and internal, if applicable
                if (tg.ContainsKey("externalMu")) {
                    externalmu = (bool)tg["externalMu"];
                } else {
                    externalmu = false;
                }
                isinternal = (!externalmu && (tg["signatureInterface"].ToString() == "internal"));

                JArray tests = (JArray)tg["tests"];
                Console.WriteLine("ntests= {0}", tests.Count);
                ntests += tests.Count;  // Count tests expected to be done

                // Iterate through each test for this group...
                foreach (JObject test in tests) {
                    ntestsdone++;
                    opts = Dsa.SigOpts.Default;
                    Console.WriteLine("  " + test.First);  // Expecting `"tcId": N`
                    tcid = test["tcId"].ToString();  // Keep as a string, we'll use it to lookup in a dictionary later

                    // Read in hex strings as byte arrays
                    sk = test["sk"].ToString().FromHex();
                    Console.WriteLine("  |sk|={0} bytes", sk.Length);
                    if (externalmu) {
                        msg = test["mu"].ToString().FromHex();
                        opts |= Dsa.SigOpts.ExternalMu;
                        context = null;
                    }
                    else if (isinternal) {
                        opts |= Dsa.SigOpts.Internal;
                        msg = test["message"].ToString().FromHex();
                        context = null;
                    } 
                    else {
                        msg = test["message"].ToString().FromHex();
                        context = test["context"].ToString().FromHex();
                    }
                    sigok = test["signature"].ToString().FromHex();

                    // Set deterministic vs hedged
                    if (!deterministic) {
                        // NOTE: We want the plain hex string for paramstr
                        // ML-DSA uses "rnd" but SLH-DSA uses "additionalRandomness"
                        if (test.ContainsKey("additionalRandomness"))
                            rnd = test["additionalRandomness"].ToString();
                        else
                            rnd = test["rnd"].ToString(); 
                    } else {
                        rnd = "";
                        opts |= Dsa.SigOpts.Deterministic;
                    }

                    // Go create the signature
                    if (isprehash) {
                        // Get extra info for pre-hash mode
                        preHashAlg = PreHashAlgs[test["hashAlg"].ToString()];
                        // For now, we lookup the digest value PH_M=H(M) in a separate json file
                        string phhex = phobj["preHashes"][tcid]["ph"].ToString();
                        ph = phhex.FromHex();
                        Console.WriteLine("  hashAlg={0}", preHashAlg);
                        Console.WriteLine(HeadTail("  ph=", ph.ToHex(), 48, "  ..."));
                        // Generate a "pre-hash" signature
                        sig = Dsa.SignPreHash(alg, ph, preHashAlg, sk, opts, context: context, paramstr: rnd);
                    } else {
                        // Generate a "pure" signature
                        sig = Dsa.Sign(alg, msg, sk, opts, context: context, paramstr: rnd);
                    }
                    // We have a signature, show it
                    Console.WriteLine("  |sig|={0} bytes", sig.Length);
                    Console.WriteLine(HeadTail("sig=", sig.ToHex(), 96));
                    Console.WriteLine(HeadTail("OK =", sigok.ToHex(), 96));

                    // Check signature equals what is expected
                    if (!ByteArraysEqual(sig, sigok)) {
                        Console.WriteLine("**ERROR: signature does not match expected!");
                        throw new ApplicationException("Signature does not match");
                    }
                }
            }
            Console.WriteLine("\nALL DONE. {0} groups, {1} tests found, {2} tests done", ngroups, ntests, ntestsdone);
            // Display elapsed time this took
            watch.Stop();
            TimeSpan ts = watch.Elapsed;
            string elapsedTime = String.Format("{0:00}:{1:00}:{2:00}.{3:00}",
                        ts.Hours, ts.Minutes, ts.Seconds,
                        ts.Milliseconds);
            Console.WriteLine("Elapsed time: {0} seconds. {1}", (double)watch.ElapsedMilliseconds / 1000, elapsedTime);
        }

        /// <summary>
        /// Check test vectors in sigVer file
        /// </summary>
        /// <param name="fname">Input JSON file</param>
        /// <param name="phfile">Input file of preHash values</param>
        static void ProcessTestJsonVerify(string fname, string phfile)
        {
            var watch = System.Diagnostics.Stopwatch.StartNew();
            Console.WriteLine("FILE: {0}", Path.GetFullPath(fname));
            Console.WriteLine("PH FILE: {0}", Path.GetFullPath(phfile));

            // Read in the set of test vectors (internalProjection.json)
            string jsonString = File.ReadAllText(fname);
            JObject jobj = JObject.Parse(jsonString);

            // Read in set of pre-hash values (`ph`)
            /* Expected JSON format:
            "preHashes": {
                "16": {
                    "tcId": 16,
                    "messagelength": 859,
                    "hashAlg": "SHA2-512/256",
                    "ph": "ee6dc6912b33dcf6eb46a262a0452caa96e688aef2f221576e20df597e4fa8c2"
                },
             */
            jsonString = File.ReadAllText(phfile);
            JObject phobj = JObject.Parse(jsonString);

            int ngroups = 0;
            int ntests = 0;
            int ntestsdone = 0;
            byte[] pk, msg, context, sig, ph;
            bool testPassed, isok;
            bool isprehash, externalmu, isinternal;
            Dsa.Alg alg;
            Dsa.SigOpts opts;
            Dsa.PreHashAlg prehashalg;
            string tcid;
            string reason;

            // Iterate through each test group...
            foreach (JObject tg in jobj["testGroups"]) {
                ngroups++;
                Console.WriteLine(tg.First);  // Expecting `"tgId": N`
                Console.WriteLine("parameterSet=" + tg["parameterSet"]);
                // Lookup the DSA algorithm
                alg = DsaAlgs[tg["parameterSet"].ToString()];
                Console.WriteLine("preHash=" + tg["preHash"]);
                // Catch preHash vs pure
                isprehash = (tg["preHash"].ToString() == "preHash");
                // And externalMu and internal, if applicable
                if (tg.ContainsKey("externalMu")) {
                    externalmu = (bool)tg["externalMu"];
                } else {
                    externalmu = false;
                }
                isinternal = (!externalmu && (tg["signatureInterface"].ToString() == "internal"));

                JArray tests = (JArray)tg["tests"];
                Console.WriteLine("ntests= {0}", tests.Count);
                ntests += tests.Count;  // Count tests expected to be done

                // Iterate through each test for this group...
                foreach (JObject test in tests) {
                    ntestsdone++;
                    opts = Dsa.SigOpts.Default;
                    Console.WriteLine("  " + test.First);  // Expecting `"tcId": N`
                    tcid = test["tcId"].ToString();  // Keep as a string, we'll use it to lookup in a dictionary later
                    testPassed = (bool)test["testPassed"];
                    reason = test["reason"].ToString();

                    // Read in hex strings as byte arrays
                    pk = test["pk"].ToString().FromHex();
                    Console.WriteLine("  |pk|={0} bytes", pk.Length);
                    if (externalmu) {
                        msg = test["mu"].ToString().FromHex();
                        opts |= Dsa.SigOpts.ExternalMu;
                        context = null;
                    }
                    else if (isinternal) {
                        opts |= Dsa.SigOpts.Internal;
                        msg = test["message"].ToString().FromHex();
                        context = null;
                    } 
                    else {
                        msg = test["message"].ToString().FromHex();
                        context = test["context"].ToString().FromHex();
                    }
                    sig = test["signature"].ToString().FromHex();

                    if (isprehash) {
                        // Get extra info for pre-hash mode
                        prehashalg = PreHashAlgs[test["hashAlg"].ToString()];
                        // For now, we lookup the digest value PH_M=H(M) in a separate json file
                        string phhex = phobj["preHashes"][tcid]["ph"].ToString();
                        ph = phhex.FromHex();
                        Console.WriteLine("  hashAlg={0}", prehashalg);
                        Console.WriteLine(HeadTail("  ph=", ph.ToHex(), 48, "  ..."));
                    } else { // Avoid error message error CS0165: Use of unassigned local variable
                        ph = null;
                        prehashalg = Dsa.PreHashAlg.Sha256; // Any valid value - it is ignored
                    }
                    // OK, go ahead and verify what we have
                    try {
                        if (isprehash) {
                            isok = Dsa.VerifyPreHash(alg, sig, ph, prehashalg, pk, context: context);
                        } else {
                            // Verify a "pure" signature
                            isok = Dsa.Verify(alg, sig, msg, pk, context, opts);
                        }
                    }
                    catch (Exception e) {
                        Console.WriteLine("  {0}", e.Message);
                        isok = false;
                    }

                    // Check we got expected result OK if `isok` XNOR `testPassed`
                    Console.WriteLine("  Dsa.Verify returns {0}, expected {1}", isok, testPassed);
                    if ((isok && testPassed) || (!isok && !testPassed)) {
                        // Success 
                        Console.WriteLine("  reason: {0}", reason);
                    } else {
                        Console.WriteLine("**ERROR: verification did not work as expected!");
                        throw new ApplicationException("Verification did not work");
                    }
                }
            }
            Console.WriteLine("\nALL DONE. {0} groups, {1} tests found, {2} tests done", ngroups, ntests, ntestsdone);
            // Display elapsed time this took
            watch.Stop();
            TimeSpan ts = watch.Elapsed;
            string elapsedTime = String.Format("{0:00}:{1:00}:{2:00}.{3:00}",
                        ts.Hours, ts.Minutes, ts.Seconds,
                        ts.Milliseconds);
            Console.WriteLine("Elapsed time: {0} seconds. {1}", (double)watch.ElapsedMilliseconds / 1000, elapsedTime);
        }

        static void Main(string[] args)
        {
            Console.WriteLine("Using PQC version {0} [{1}]", CryptoSysPQC.General.Version(), CryptoSysPQC.General.DllInfo());
            // Accept arguments {[mldsa]|slhdsa|all|verify}
            bool do_mldsa = true;   // Default
            bool do_slhdsa = false;
            bool do_verify = false; // Default = test sigGen only
            // Parse aguments (case sensitive)
            for (int iarg = 0; iarg < args.Length; iarg++) {
                if (args[iarg] == "mldsa") { // Do only ML-DSA
                    do_mldsa = true;
                }
                if (args[iarg] == "slhdsa") { // Do only SLH-DSA
                    do_slhdsa = true;
                    do_mldsa = false;
                }
                if (args[iarg] == "all") {
                    do_mldsa = true;
                    do_slhdsa = true;
                }
                if (args[iarg] == "verify") {
                    do_verify = true;
                }
            }

            // 1. Check against ML-DSA sigGen test vectors
            if (!do_verify) {
                if (do_mldsa) {
                    ProcessTestJsonSign("ML-DSA-sigGen-FIPS204-1_1_0_39.json", "ML-DSA-sigGen-preHashes-1_1_0_39.json");
                }
                // 2. Check against SLH-DSA sigGen test vectors 
                // WARNING: this takes several minutes
                if (do_slhdsa) {
                    ProcessTestJsonSign("SLH-DSA-sigGen-FIPS205-1_1_0_38.json", "SLH-DSA-sigGen-preHashes-1_1_0_38.json");
                }
            } else { // Check against DSA sigVer test vectors
                if (do_mldsa) {
                    ProcessTestJsonVerify("ML-DSA-sigVer-FIPS204-1_1_0_39.json", "ML-DSA-sigVer-preHashes-1_1_0_39.json");
                }
                if (do_slhdsa) {
                    ProcessTestJsonVerify("SLH-DSA-sigVer-FIPS205-1_1_0_38.json", "SLH-DSA-sigVer-preHashes-1_1_0_38.json");
                }
            }
        }

        // UTILITIES
        /// <summary>
        /// Show the first and last n characters of a string with an ellipsis inbetween.
        /// </summary>
        /// <param name="pre">Prefix</param>
        /// <param name="s">string</param>
        /// <param name="headlen">Length n of head and tail substrings</param>
        /// <param name="inter">String to separate head and tail (default=ellipsis)</param>
        /// <returns>Formatted string</returns>
        static string HeadTail(string pre, string s, int headlen, string inter = "...")
        {
            string ht = pre;
            if (s.Length > 2 * headlen) {
                ht += s.Substring(0, headlen);
                ht += "\n" + inter;
                ht += s.Substring(s.Length - headlen);
            } else {
                ht += s;
            }
            return ht;
        }
        /// <summary>
        /// Compare two byte arrays for equality
        /// </summary>
        /// <param name="data1">First byte array</param>
        /// <param name="data2">Second byte array</param>
        /// <returns>true if byte arrays are equal.</returns>
        static bool ByteArraysEqual(byte[] data1, byte[] data2)
        {   // Thanks to Jon Skeet http://www.pobox.com/~skeet
            // If both are null, they're equal
            if (data1 == null && data2 == null) {
                return true;
            }
            // If either but not both are null, they're not equal
            if (data1 == null || data2 == null) {
                return false;
            }
            if (data1.Length != data2.Length) {
                return false;
            }
            for (int i = 0; i < data1.Length; i++) {
                if (data1[i] != data2[i]) {
                    return false;
                }
            }
            return true;
        }
    }
}