/*  $Id: SSHKeys.cs $ 
 *   Last updated:
 *   $Date: 2024-01-12 09:19:00 $
 *   $Version: 1.2.0 $
 */

/* Generate a pair of ECC keys using CryptoSys PKI and output the public key and private key as files in OpenSSH format.
 * Supports `ecdsa-sha2-nistp256` and `ssh-ed25519` curves with unencrypted private keys.
 */

/******************************* LICENSE ***********************************
 * Copyright (C) 2023-24 David Ireland, DI Management Services Pty Limited.
 * All rights reserved. <www.di-mgt.com.au> <www.cryptosys.net>
 * The code in this module is licensed under the terms of the MIT license.
 * @license MIT
 * For a copy, see <http://opensource.org/licenses/MIT>
****************************************************************************
*/

/* References:
 * RFC4716 "The Secure Shell (SSH) Public Key File Format" https://datatracker.ietf.org/doc/html/rfc4716
 * AJ ONeal "The OpenSSH Private Key Format" https://coolaj86.com/articles/the-openssh-private-key-format/ 
 * AJ ONeal "The OpenSSH Public Key Format" https://coolaj86.com/articles/the-ssh-public-key-format/ 
 * openssh-portable/sshkey.c https://github.com/openssh/openssh-portable/blob/master/sshkey.c
 */

using System;
using System.Linq;
using System.Text;
using System.Diagnostics;
using System.IO;

/* Requires DI Management CryptoSys PKI Library available from <https://www.cryptosys.net/pki/>
 * to be installed on your system. Then EITHER add a reference to
 * `diCrSysPKINet.dll` (installed by default in `C:\Program Files (x86)\CryptoSysPKI\DotNet`)
 * OR add the C# source code file `CryptoSysPKI.cs` directly to your project.
 */

using CryptoSysPKI;

namespace DIManagement.SSHKeys
{
    class SSHKeys
    {
        /// <summary>
        /// Generate a pair of ECC keys in SSH format.
        /// </summary>
        /// <param name="curveName">Ecc.CurveName.Ed25519 or Ecc.CurveName.P_256 only</param>
        /// <param name="newprikeyfile">Name of new private key file to be created.</param>
        /// <param name="newpubkeyfile">Name of new public key file to be created.</param>
        /// <param name="userHostName">Optional user@hostname parameter (comment)</param>
        /// <param name="useKnownTest">Set true to use known test case values; else generate at random.</param>
        public static void GenSSHKeys(Ecc.CurveName curveName, string newprikeyfile, string newpubkeyfile, string userHostName = "", bool useKnownTest = false)
        {
            // Generate a new key pair in temp files
            string prikeyfile = "temp_prikey.p8e";
            string pubkeyfile = "temp_pubkey.p1";
            string pwd = "password";
            string prikeystr, pubkeystr, txtstr, pubkeyhex, prikeyhex;
            string keyformatstr;
            byte[] bpk, bsk, sk_outer, sk_inner, check, padding, skblock;
            int npad, i;
            const string magichex = "6f70656e7373682d6b65792d763100";  // "openssh-key-v1"0x00
            const string begin_label = "-----BEGIN OPENSSH PRIVATE KEY-----\n";
            const string end_label = "-----END OPENSSH PRIVATE KEY-----\n";

            Console.WriteLine("");
            Console.WriteLine("Creating new ECC key pair using curve {0}{1}...", curveName.ToString(), (useKnownTest ? " [KnownTest]" : " [random]"));

            // Make key pair in temp files with default encryption params and throw-away password
            generate_keys(pubkeyfile, prikeyfile, curveName, pwd, useKnownTest);

            // Create a one-line SSH public key...
            // Extract the public key in hex form
            pubkeystr = Ecc.ReadPublicKey(pubkeyfile).ToString();
            Debug.Assert(pubkeystr.Length > 0, "Ecc.ReadPublicKey failed");

            pubkeyhex = Ecc.QueryKey(pubkeystr, "publicKey");
            Trace.WriteLine("pk=" + pubkeyhex);
            // Create output string based on curve
            bpk = new byte[0];
            switch (curveName) {
                case Ecc.CurveName.Ed25519:
                    // Expecting 51 bytes
                    keyformatstr = "ssh-ed25519";
                    bpk = bpk.Concat(ssh_rfc4251_string(keyformatstr)).ToArray();
                    bpk = bpk.Concat(ssh_rfc4251_bytes(Cnv.FromHex(pubkeyhex))).ToArray();
                    break;
                case Ecc.CurveName.P_256:
                    // Expecting 104 bytes
                    string id = "nistp256";
                    keyformatstr = "ecdsa-sha2-" + id;
                    bpk = bpk.Concat(ssh_rfc4251_string(keyformatstr)).ToArray();
                    bpk = bpk.Concat(ssh_rfc4251_string(id)).ToArray();
                    bpk = bpk.Concat(ssh_rfc4251_bytes(Cnv.FromHex(pubkeyhex))).ToArray();
                    break;
                default:
                    Console.WriteLine("Invalid CurveName");
                    throw new Exception("Invalid CurveName");
            }
            // DEBUGGING...
            Trace.WriteLine(String.Format("bpk={0}, {1} bytes", Cnv.ToHex(bpk), bpk.Length));

            // Write one-line public key
            // <type-name> <base64-encoded-ssh-public-key>[ <comment>]
            txtstr = string.Format("{0} {1}", keyformatstr, Cnv.ToBase64(bpk));
            if (!String.IsNullOrEmpty(userHostName)) {
                txtstr += " " + userHostName;
            }
            Console.WriteLine(txtstr);
            // Save to file
            File.WriteAllText(newpubkeyfile, txtstr);
            Console.WriteLine("Created new file '{0}'", newpubkeyfile);
            // Read in and show text file we just created
            Console.WriteLine("FILE: {0}", newpubkeyfile);
            txtstr = File.ReadAllText(newpubkeyfile);
            Console.WriteLine(txtstr);

            // Now create OpenSSH-style private key (unencrypted)

            // Read in private key from temp file as internal string
            prikeystr = Ecc.ReadPrivateKey(prikeyfile, pwd).ToString();
            Debug.Assert(prikeystr.Length > 0, "Ecc.ReadPrivateKey failed");
            // Extract in hex form
            prikeyhex = Ecc.QueryKey(prikeystr, "privateKey");
            Trace.WriteLine("skhex=" + prikeyhex);

            /* STRUCTURE
             * <sk-outer> :=
             *   <cipher-none> <kdf-none> int64(1)
             *   int32(pklen) <bpk>
             *   int32(skilen) <sk-inner>
             * <sk-inner> :=
             *   <check> <check>
             *   <bpk>
             *   int32(sklen) <bsk>
             *   [int32(commentlen) <comment>] --OPTIONAL
             *   <padding>  --0-7 bytes to make <skouter> an exact multiple of 8
             * */

            // Generate a random 32-bit check value
            check = generate_check(useKnownTest, curveName);    // Use ``generate_check()`` for a random value
            Trace.WriteLine("check" + Cnv.ToHex(check));
            // <check> <check>
            sk_inner = check;
            sk_inner = sk_inner.Concat(check).ToArray();
            // <bpk>
            sk_inner = sk_inner.Concat(bpk).ToArray();
            /*DEBUG*/
            Trace.WriteLine("sk_inner=" + Cnv.ToHex(sk_inner));

            // Construct actual private key value bsk
            bsk = new byte[0];
            switch (curveName) {
                case Ecc.CurveName.Ed25519:
                    // bsk = prikey (32 bytes) + pubkey (32 bytes), len = 64 bytes
                    bsk = bsk.Concat(Cnv.FromHex(prikeyhex)).ToArray();
                    bsk = bsk.Concat(Cnv.FromHex(pubkeyhex)).ToArray();
                    break;
                case Ecc.CurveName.P_256:
                    // prepend 0x00 to prikey to make a BITSTRING with 0 unused bits
                    bsk = bsk.Concat(Cnv.FromHex("00")).ToArray();
                    bsk = bsk.Concat(Cnv.FromHex(prikeyhex)).ToArray();
                    break;
                default:
                    Console.WriteLine("Invalid CurveName");
                    throw new Exception("Invalid CurveName");
            }
            Trace.WriteLine("bsk=" + Cnv.ToHex(bsk));
            // Append bsk to sk_inner
            sk_inner = sk_inner.Concat(ssh_rfc4251_bytes(bsk)).ToArray();
            // Append optional comment (a.k.a. user@hostname)
            if (!String.IsNullOrEmpty(userHostName)) {
                sk_inner = sk_inner.Concat(ssh_rfc4251_string(userHostName)).ToArray();
            }
            Trace.WriteLine("sk_inner=" + Cnv.ToHex(sk_inner));
            // Compute padding required to make length of sk_inner an exact multiple of 8 bytes
            npad = 8 - sk_inner.Length % 8;
            if (npad >= 8) npad = 0;
            // Padding is 0x01020304050607 truncated to length npad, which may be zero
            padding = new byte[npad];
            for (i = 0; i < npad; i++) {
                padding[i] = (byte)(i + 1);
            }
            sk_inner = sk_inner.Concat(padding).ToArray();
            Trace.WriteLine(String.Format("npad={0} len(sk_outer)={1} blocks={2:F3}", npad, sk_inner.Length, sk_inner.Length / 8.0));
            Trace.WriteLine("sk_inner=" + Cnv.ToHex(sk_inner));

            // Construct sk_outer in binary...
            // Common first part (excluding magic)
            // 0x0004"none" 0x0004"none" 0x0000000000000001
            sk_outer = Cnv.FromHex("000000046e6f6e65000000046e6f6e650000000000000001");
            sk_outer = sk_outer.Concat(ssh_rfc4251_bytes(bpk)).ToArray();
            sk_outer = sk_outer.Concat(ssh_rfc4251_bytes(sk_inner)).ToArray();

            // Prepend magic string and encode in base64
            skblock = Cnv.FromHex(magichex);
            skblock = skblock.Concat(sk_outer).ToArray();
            txtstr = Cnv.ToBase64(skblock);
            //Trace.WriteLine(txtstr);

            // Word-wrap at 70 characters and encapsulate in PEM boundaries.
            // Note Wrap() always ends with a newline.
            txtstr = begin_label + Wrap(txtstr, 70) + end_label;
            //Trace.WriteLine(txtstr);

            // Save to file
            File.WriteAllText(newprikeyfile, txtstr);
            Console.WriteLine("Created new file '{0}'", newprikeyfile);
            // Read in and show text file we just created
            Console.WriteLine("FILE: {0}", newprikeyfile);
            txtstr = File.ReadAllText(newprikeyfile);
            Console.WriteLine(txtstr);
        }

        /// <summary>
        /// Wrap a text string at exactly wraplen characters with a newline.
        /// </summary>
        /// <param name="s">String to be wrapped</param>
        /// <param name="wraplen">Line length to wrap at</param>
        /// <returns>String containing wrapped text with final newline</returns>
        static string Wrap(string s, int wraplen)
        {
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < s.Length; i += wraplen) {
                if (s.Length - i > wraplen)
                    sb.Append(s.Substring(i, wraplen));
                else
                    sb.Append(s.Substring(i));
                sb.Append("\n");
            }
            return sb.ToString();
        }

        /// <summary>
        /// Create a new RFC4251 "string" element
        /// </summary>
        /// <param name="s">String value</param>
        /// <returns>Byte array</returns>
        static byte[] ssh_rfc4251_string(string s)
        {
            byte[] bout, bs;
            uint n = (uint)s.Length;
            bout = Cnv.NumToBytes(n);
            // Get input string as ISO-10646 UTF-8 encoded byte array.
            bs = System.Text.Encoding.UTF8.GetBytes(s);
            // Concatenate
            bout = bout.Concat(bs).ToArray();
            return bout;
        }

        /// <summary>
        /// Create a new RFC4251 "byte[n]" element
        /// </summary>
        /// <param name="bin">byte value</param>
        /// <returns>Byte array</returns>
        static byte[] ssh_rfc4251_bytes(byte[] bin)
        {
            byte[] bout;
            uint n = (uint)bin.Length;
            bout = Cnv.NumToBytes(n);
            // Concatenate
            bout = bout.Concat(bin).ToArray();
            return bout;
        }

        /// <summary>
        /// Generate ECDH private and public key pairs
        /// </summary>
        /// <param name="pubkeyfile">Filename of public key to be created</param>
        /// <param name="prikeyfile">Filename of private key to be created</param>
        /// <param name="curveName">ECC curve name (<c>Ecc.CurveName.Ed25519</c> or <c>Ecc.CurveName.P_256</c> only)</param>
        /// <param name="pwd">Password for private key file (throw-away)</param>
        /// <param name="useKnownTest">If true, use fixed test values for ECC keys; else generate new random keys</param>
        static void generate_keys(string pubkeyfile, string prikeyfile, Ecc.CurveName curveName, string pwd, bool useKnownTest = false) 
        {
            int r;
            string prikeystr, pubkeystr;
            if (useKnownTest) {
                // Compute known keys for one of two specific test cases and save as key files
                switch (curveName) {
                    case Ecc.CurveName.Ed25519:
                        prikeystr = Ecc.ReadKeyByCurve("ae63a9e08d44ccbfc5d04ad45e936b968021f824b5717d41a6ccdb021317e551", Ecc.CurveName.Ed25519);
                        pubkeystr = Ecc.PublicKeyFromPrivate(prikeystr);
                        break;
                    case Ecc.CurveName.P_256:
                        prikeystr = Ecc.ReadKeyByCurve("a99ff78ae96a630ff9a367c426036ce37cf4dbcb3d4c8162b0b5c6c8c0644aa0", Ecc.CurveName.P_256);
                        pubkeystr = Ecc.PublicKeyFromPrivate(prikeystr);
                        break;
                    default:
                        Console.WriteLine("Invalid CurveName");
                        throw new Exception("Invalid CurveName");
                }
                Ecc.SaveEncKey(prikeyfile, prikeystr, pwd, Ecc.PbeScheme.Default, "", Ecc.Format.Default);
                Ecc.SaveKey(pubkeyfile, pubkeystr);
                
            } else {
                // Generate new random keys and save as key files
                r = Ecc.MakeKeys(pubkeyfile, prikeyfile, curveName, pwd);
                Debug.Assert(r == 0, "Ecc.MakeKeys failed");
            }

        }
        /// <summary>
        /// Generate a random 32-bit check value
        /// </summary>
        /// <param name="useKnownTest">If true, then return known fixed test value</param>
        /// <param name="curveName">Only relevant if <c>useKnownTest</c> is true</param>
        /// <returns>Four-byte check value</returns>
        static byte[] generate_check(bool useKnownTest = false, Ecc.CurveName curveName = Ecc.CurveName.Ed25519)
        {
            byte[] check;
            if (useKnownTest) {
                // Return known fixed values for specific test case
                switch (curveName) {
                    case Ecc.CurveName.Ed25519:
                        check = Cnv.FromHex("0f 42 b3 2f");
                        break;
                    case Ecc.CurveName.P_256:
                        check = Cnv.FromHex("87 b6 2c a2");
                        break;
                    default:
                        throw new Exception("Invalid CurveName");
                }
            } else {
                // Generate a random value
                check = Rng.Bytes(4);
            }
            return check;
        }

    }

    class Program
    {
        static void Main(string[] args)
        {
            // Setup to display Trace info in console
            Trace.Listeners.Add(new TextWriterTraceListener(Console.Out));
            // Then turn off (comment out next line to show debugging)
            Trace.Listeners.Clear();

            Trace.WriteLine("PKI Version=" + General.Version());
            // Generate known fixed test keys...
            SSHKeys.GenSSHKeys(Ecc.CurveName.Ed25519, "knownprikey_ed25519.pem", "knownpubkey_ed25519.pub", "user@example.com", useKnownTest: true);
            SSHKeys.GenSSHKeys(Ecc.CurveName.P_256, "knownprikey_ecdsa256.pem", "knownpubkey_ecdsa256.pub", "user@example.com", useKnownTest: true);

            // Generate fresh random keys...
            SSHKeys.GenSSHKeys(Ecc.CurveName.Ed25519, "newprikey_ed25519.pem", "newpubkey_ed25519.pub", "user@example.com");
            SSHKeys.GenSSHKeys(Ecc.CurveName.P_256, "newprikey_ed25519.pem", "newpubkey_ed25519.pub", "user@example.com");
        }
    }
}