/* $Id:  XmlEncDecrypt.cs $ 
 *   Last updated:
 *   $Date: 2020-11-26 12:29:00 $
 *   $Version: 2.0.0 $
 */

#define TRACE

using System;
using System.Collections.Generic;
using System.Text;
using System.Diagnostics;
using Xmlsq;
using Pki = CryptoSysPKI;

/******************************* LICENSE ***********************************
 * Copyright (C) 2020 David Ireland, DI Management Services Pty Limited.
 * All rights reserved. <www.di-mgt.com.au> <www.cryptosys.net>
 * The code in this module is licensed under the terms of the MIT license.  
 * For a copy, see <http://opensource.org/licenses/MIT>
****************************************************************************
*/

/* This shows how to extract and decrypt the cipher value from an XMLENC document.
 * It uses an heuristic approach to extract the required values from the XMLENC example documents we use here.
 * It expects a single <EncryptedData> element with a single child <KeyInfo> element.
 * This code is intended to demonstrate principles you might use in your own programs.
 * NOTE: these examples do not use any prefixes: all elements are in the default namespace. 
 * Other XML documents may use the "xenc:" and "ds:" prefixes. You will need to modify this code to cope with that.
 * If prefixes are used they need to be "hardcoded" in the xmlsq search queries (or use the XPath local-name() function).
 * */

/* PSEUDOCODE:
 * Confirm EncryptedData element exists; else stop.
 * Get DataEncryption algorithm, expecting "*-cbc" or "*-gcm".
 * Get CipherValue.
 * Get the key - this may be a shared symmetric key referenced by a KeyName or an EncryptedKey.
 * Use key and the DataEncryption algorithm to decrypt the cipher value.
 * Return the decrypted text - we expect UTF-8 to be returned as a Unicode string.
 * */

/* [2020-11] Added support for rsa-oaep-mgf1p and rsa-oaep */

namespace TestXmlsqPki
{
    class TestXmlsqPki
    {
        /* Lookup tables mapping W3C XMLENC algorithm identifiers to CryptoSys PKI options */

        // W3C XMLENC Block Cipher CBC algorithms (all use Mode.CBC, Padding.W3CPadding and Cipher.Opts.PrefixIV)
        private static readonly Dictionary<string, Pki.CipherAlgorithm> W3C_CipherAlgs = new Dictionary<string, Pki.CipherAlgorithm>()
        {
            { "http://www.w3.org/2001/04/xmlenc#tripledes-cbc", Pki.CipherAlgorithm.Tdea },
            { "http://www.w3.org/2001/04/xmlenc#aes128-cbc", Pki.CipherAlgorithm.Aes128 },
            { "http://www.w3.org/2001/04/xmlenc#aes192-cbc", Pki.CipherAlgorithm.Aes192 },
            { "http://www.w3.org/2001/04/xmlenc#aes256-cbc", Pki.CipherAlgorithm.Aes256 },
        };

        // W3C XMLENC Block Cipher GCM algorithms (all use Cipher.Opts.PrefixIV)
        private static readonly Dictionary<string, Pki.AeadAlgorithm> W3C_GcmAlgs = new Dictionary<string, Pki.AeadAlgorithm>()
        {
            { "http://www.w3.org/2009/xmlenc11#aes128-gcm", Pki.AeadAlgorithm.Aes_128_Gcm },
            { "http://www.w3.org/2009/xmlenc11#aes192-gcm", Pki.AeadAlgorithm.Aes_192_Gcm },
            { "http://www.w3.org/2009/xmlenc11#aes256-gcm", Pki.AeadAlgorithm.Aes_256_Gcm },
        };

        // W3C XMLENC Symmetric Key Wrap algorithms
        private static readonly Dictionary<string, Pki.CipherAlgorithm> W3C_KeyWrapAlgs = new Dictionary<string, Pki.CipherAlgorithm>()
        {
            { "http://www.w3.org/2001/04/xmlenc#kw-tripledes", Pki.CipherAlgorithm.Tdea },
            { "http://www.w3.org/2001/04/xmlenc#kw-aes128", Pki.CipherAlgorithm.Aes128 },
            { "http://www.w3.org/2001/04/xmlenc#kw-aes192", Pki.CipherAlgorithm.Aes192 },
            { "http://www.w3.org/2001/04/xmlenc#kw-aes256", Pki.CipherAlgorithm.Aes256 },
        };

        // W3C XMLENC RSA Key Transport algorithms
        private static readonly Dictionary<string, Pki.Rsa.EME> W3C_RsaAlgs = new Dictionary<string, Pki.Rsa.EME>()
        {
            { "http://www.w3.org/2001/04/xmlenc#rsa-1_5", Pki.Rsa.EME.PKCSv1_5 },
            { "http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p", Pki.Rsa.EME.OAEP },
            { "http://www.w3.org/2009/xmlenc11#rsa-oaep", Pki.Rsa.EME.OAEP },
        };

        // W3C XMLENC Message Digest algorithms for the RSAES-OAEP-ENCRYPT algorithm
        // (expected as a parameter for rsa-oaep or rsa-oaep-mgfp1)
        private static readonly Dictionary<string, Pki.Rsa.HashAlg> W3C_HashAlgs = new Dictionary<string, Pki.Rsa.HashAlg>()
        {
            { "http://www.w3.org/2000/09/xmldsig#sha1", Pki.Rsa.HashAlg.Sha1 },
            { "http://www.w3.org/2001/04/xmlenc#sha256", Pki.Rsa.HashAlg.Sha256 },
            { "http://www.w3.org/2001/04/xmlenc#sha384", Pki.Rsa.HashAlg.Sha384 },
            { "http://www.w3.org/2001/04/xmlenc#sha512", Pki.Rsa.HashAlg.Sha512 },
        };

        // W3C XMLENC mask generation function URI values for the RSAES-OAEP-ENCRYPT algorithm
        // (expected as a parameter for rsa-oaep)
        private static readonly Dictionary<string, Pki.Rsa.HashAlg> W3C_MGF1Algs = new Dictionary<string, Pki.Rsa.HashAlg>()
        {
            { "http://www.w3.org/2009/xmlenc11#mgf1sha1", Pki.Rsa.HashAlg.Sha1 },
            { "http://www.w3.org/2009/xmlenc11#mgf1sha224", Pki.Rsa.HashAlg.Sha224 },
            { "http://www.w3.org/2009/xmlenc11#mgf1sha256", Pki.Rsa.HashAlg.Sha256 },
            { "http://www.w3.org/2009/xmlenc11#mgf1sha384", Pki.Rsa.HashAlg.Sha384 },
            { "http://www.w3.org/2009/xmlenc11#mgf1sha512", Pki.Rsa.HashAlg.Sha512 },
        };

        // Shared secret keys identified by key name (keys in ASCII string form)
        // CAUTION: not a secure way to store secret keys!!
        private static readonly Dictionary<string, string> MySecretKeys = new Dictionary<string, string>()
        {
            { "bob", "abcdefghijklmnopqrstuvwx"},
            { "job", "abcdefghijklmnop"},
            { "jeb", "abcdefghijklmnopqrstuvwx"},
            { "jed", "abcdefghijklmnopqrstuvwxyz012345"},
        };

        // Index of RSA key files (all unencrypted)
        private static readonly Dictionary<string, string> MyRsaKeyFiles = new Dictionary<string, string>()
        {
            { "default", "rsa-w3c.p8"},
            { "bob@smime.example", "bob-smime.p8"},
        };

        static byte[] getKeyUsingKeyName(string keyname)
        {
            // Lookup key using KeyName
            byte[] key = null;

            // CAUTION: you may need a more secure way to do this :-)
            if (MySecretKeys.ContainsKey(keyname)) {
                string keystr = MySecretKeys[keyname];
                // keystr is an ASCII string; we need a byte array.
                key = System.Text.Encoding.Default.GetBytes(keystr);
            } else {
                Console.WriteLine("**ERROR: Cannot find matching key for " + keyname);
            }
            return key;
        }

        static string getRsaKeyFileName(string keyname)
        {
            string keyfilename = MyRsaKeyFiles["default"];
            if (keyname.Length == 0) return keyfilename;
            if (MyRsaKeyFiles.ContainsKey(keyname)) {
                keyfilename = MyRsaKeyFiles[keyname];
            }
            return keyfilename;
        }

        static byte[] getKey(string xmlfile)
        {
            byte[] key = null;
            string query;
            string s;
            int n;
            string keyname;
            string enckey_alg;
            string enckey_keyname;
            string enckey_ciphervalue;
            string rsa_keyfile;
            byte[] kek;
            Pki.CipherAlgorithm cipherAlg;
            Pki.Rsa.HashAlg digestAlg;
            Pki.Rsa.HashAlg mgfDigAlg;
            Pki.Rsa.AdvOptions mgfOpt;

            /* In our simple example here, we either have 
             * (1) a KeyInfo/KeyName element which we can use to lookup a symmetric block cipher key
             * (2) an EncryptedKey using one of
             *     (a) a key wrap algorithm
             *     (b) rsa-1_5
             *     (c) rsa-oaep-mgf1p
             *     (d) rsa-oaep
             * 
             * */

            // (1) Do we have a simple key name?
            query = "//EncryptedData/KeyInfo/KeyName";
            s = Xmlsq.Query.GetText(xmlfile, query);
            Trace.WriteLine(query + " => '" + s + "'");
            if (s.Length > 0) {
                keyname = s;
                key = getKeyUsingKeyName(keyname);
                return key;
            }

            // (2) Do we have an EncryptedKey?
            query = "//EncryptedData/KeyInfo/EncryptedKey";
            n = Xmlsq.Query.Count(xmlfile, query);
            Trace.WriteLine("COUNT: '" + query + "' = " + n);
            if (n < 1)
                return key;

            query = "//EncryptedData/KeyInfo/EncryptedKey/EncryptionMethod/@Algorithm";
            s = Xmlsq.Query.GetText(xmlfile, query);
            Trace.WriteLine(query + " => '" + s + "'");
            enckey_alg = s;

            query = "//EncryptedData/KeyInfo/EncryptedKey/CipherData/CipherValue";
            s = Xmlsq.Query.GetText(xmlfile, query, Query.Opts.Trim);
            enckey_ciphervalue = s;
            Trace.WriteLine("ciphervalue='" + enckey_ciphervalue + "'");

            // Do we have a KeyName
            query = "//EncryptedData/KeyInfo/EncryptedKey/KeyInfo/KeyName";
            s = Xmlsq.Query.GetText(xmlfile, query);
            Trace.WriteLine(query + " => '" + s + "'");
            enckey_keyname = s;

            if (enckey_alg.EndsWith("rsa-1_5")) {
                Trace.WriteLine("Using rsa-1_5");
                // RSA Version 1.5
                rsa_keyfile = getRsaKeyFileName(enckey_keyname);    // no password
                key = Pki.Rsa.Decrypt(Pki.Cnv.FromBase64(enckey_ciphervalue), rsa_keyfile, "");
                Trace.WriteLine("key=0x" + Pki.Cnv.ToHex(key));
            } 
            else if (enckey_alg.EndsWith("#rsa-oaep-mgf1p")) {
                Trace.WriteLine("Using rsa-oaep-mgf1p");
                // RSA-OAEP using default MGF1withSHA1
                rsa_keyfile = getRsaKeyFileName(enckey_keyname);    // no password
                // The message digest function SHOULD be specified using the Algorithm attribute of the ds:DigestMethod child
                // element of the xenc:EncryptionMethod element. If it is not specified, the default value of SHA1 is to be used. 
                query = "//EncryptedData/KeyInfo/EncryptedKey/EncryptionMethod/DigestMethod/@Algorithm";
                s = Xmlsq.Query.GetText(xmlfile, query);
                Trace.WriteLine(query + " => '" + s + "'");
                if (s.Length > 0 && W3C_HashAlgs.ContainsKey(s)) {
                        digestAlg = W3C_HashAlgs[s];
                    }
                else {
                    digestAlg = Pki.Rsa.HashAlg.Sha1;
                }
                mgfOpt = Pki.Rsa.AdvOptions.Mgf1_Sha1;  // Default
                key = Pki.Rsa.Decrypt(Pki.Cnv.FromBase64(enckey_ciphervalue), rsa_keyfile, "", Pki.Rsa.EME.OAEP, digestAlg, mgfOpt);
                Trace.WriteLine("key=0x" + Pki.Cnv.ToHex(key));

            } else if (enckey_alg.EndsWith("#rsa-oaep")) {
                Trace.WriteLine("Using rsa-oaep");
                // RSA-OAEP 
                rsa_keyfile = getRsaKeyFileName(enckey_keyname);    // no password
                // The message digest function SHOULD be specified using the Algorithm attribute of the ds:DigestMethod child
                // element of the xenc:EncryptionMethod element. If it is not specified, the default value of SHA1 is to be used. 
                query = "//EncryptedData/KeyInfo/EncryptedKey/EncryptionMethod/DigestMethod/@Algorithm";
                s = Xmlsq.Query.GetText(xmlfile, query);
                Trace.WriteLine(query + " => '" + s + "'");
                if (s.Length > 0 && W3C_HashAlgs.ContainsKey(s)) {
                    digestAlg = W3C_HashAlgs[s];
                } else {
                    digestAlg = Pki.Rsa.HashAlg.Sha1;
                }
                // We only support MGF1Alg = same as DigestMethod (PKI default) or MGF1Alg = SHA-1 (strict default)
                query = "//EncryptedData/KeyInfo/EncryptedKey/EncryptionMethod/MGF/@Algorithm";
                s = Xmlsq.Query.GetText(xmlfile, query);
                Trace.WriteLine(query + " => '" + s + "'");
                if (s.Length > 0 && W3C_MGF1Algs.ContainsKey(s)) {
                    mgfDigAlg = W3C_MGF1Algs[s];
                    if (mgfDigAlg == digestAlg) {
                        mgfOpt = Pki.Rsa.AdvOptions.Default;
                    } else if (mgfDigAlg == Pki.Rsa.HashAlg.Sha1) {
                        mgfOpt = Pki.Rsa.AdvOptions.Mgf1_Sha1;
                    } else { 
                        Console.WriteLine("**ERROR: Unsupported digest algorithm for MGF1 " + s);
                        return null;
                    }
                } else {
                    mgfOpt = Pki.Rsa.AdvOptions.Default;
                }
                key = Pki.Rsa.Decrypt(Pki.Cnv.FromBase64(enckey_ciphervalue), rsa_keyfile, "", Pki.Rsa.EME.OAEP, digestAlg, mgfOpt);
                Trace.WriteLine("key=0x" + Pki.Cnv.ToHex(key));

            } else if (enckey_alg.Contains("#kw-")) {
                Trace.WriteLine("Using Key Wrap");
                // Lookup block cipher algorithm
                if (W3C_KeyWrapAlgs.ContainsKey(enckey_alg)) {
                    cipherAlg = W3C_KeyWrapAlgs[enckey_alg];
                } else {
                    Console.WriteLine("**ERROR: Cannot find matching cipher algorithm for " + enckey_alg);
                    return null;
                }
                // Uses Key Wrap, expecting KeyName
                kek = getKeyUsingKeyName(enckey_keyname);
                key = Pki.Cipher.KeyUnwrap(Pki.Cnv.FromBase64(enckey_ciphervalue), kek, cipherAlg);

            }


            return key;

        }

        /// <summary>
        /// Extract and decrypt cipher text from XMLENC document using block cipher.
        /// </summary>
        /// <param name="xmlfile">File path to XML file or string containing XML.</param>
        /// <returns>Decrypted text or empty string if error occurred.</returns>
        static string getXmlEncDecryptedText(string xmlfile)
        {
            string s;
            int n;
            string query;
            string encalgstr;
            string ciphervalue;
            Pki.CipherAlgorithm cipherAlg;
            Pki.Mode mode;
            Pki.Padding pad;
            Pki.Cipher.Opts opts;
            Pki.AeadAlgorithm aeadAlg;
            byte[] key, ct;
            byte[] pt = null;
            string data;

            // Confirm EncryptedData existence
            Trace.WriteLine("Confirm EncryptedData exists using Xmlsq.Query.Count (expecting > 0)");
            query = "//EncryptedData";
            n = Xmlsq.Query.Count(xmlfile, query);
            Trace.WriteLine("COUNT: '" + query + "' = " + n);
            if (n < 1) {
                Console.WriteLine("**ERROR: EncryptedData element not found");
                return "";
            }

            // Get data encryption algorithm: expecting "xmlenc#*-cbc" or "xmlenc#*-gcm"
            Trace.WriteLine("Get data encryption algorithm...");
            query = "//EncryptedData/EncryptionMethod/@Algorithm";
            Trace.WriteLine("Query: " + query);
            encalgstr = Xmlsq.Query.GetText(xmlfile, query);
            Trace.WriteLine("encalg=" + encalgstr);
            if (encalgstr.Length == 0) {
                Console.WriteLine("**ERROR: Cannot find EncryptionMethod");
                return "";
            }

            // Show EncryptedData attribute details, if present
            query = "//EncryptedData/@Type";
            s = Xmlsq.Query.GetText(xmlfile, query);
            Trace.WriteLine(query + " => '" + s + "'");
            query = "//EncryptedData/@MimeType";
            s = Xmlsq.Query.GetText(xmlfile, query);
            Trace.WriteLine(query + " => '" + s + "'");

            // Get cipher value
            Trace.WriteLine("Get cipher value...");
            query = "//EncryptedData/CipherData/CipherValue";
            Trace.WriteLine("Query: " + query);
            ciphervalue = Xmlsq.Query.GetText(xmlfile, query, Query.Opts.Trim); // NB Trim whitespace here
            Trace.WriteLine("ciphervalue='" + ciphervalue + "'");
            if (ciphervalue.Length == 0) {
                Console.WriteLine("**ERROR: Cannot find CipherValue");
                return "";
            }

            // Get the key for symmetric encryption 
            key = getKey(xmlfile);
            if (null == key) {
                Console.WriteLine("**ERROR: Failed to get a valid key");
                return "";
            }

            // Decrypt cipher value using key and data encryption algorithm...
            // Expecting algorithm string to end with either "-cbc" or "-gcm"
            if (encalgstr.EndsWith("-cbc")) {
                // Lookup block cipher algorithm
                if (W3C_CipherAlgs.ContainsKey(encalgstr)) {
                    cipherAlg = W3C_CipherAlgs[encalgstr];
                } else {
                    Console.WriteLine("**ERROR: Cannot find matching cipher algorithm for " + encalgstr);
                    return "";
                }
                // These are always the same for "-cbc" algorithms
                mode = Pki.Mode.CBC;
                pad = Pki.Padding.W3CPadding;
                opts = Pki.Cipher.Opts.PrefixIV;

                // Do block cipher decryption
                ct = Pki.Cnv.FromBase64(ciphervalue);
                pt = Pki.Cipher.Decrypt(ct, key, null, cipherAlg, mode, pad, opts);
            } 
            else if (encalgstr.EndsWith("-gcm")) {
                // Lookup block cipher GCM algorithm
                if (W3C_GcmAlgs.ContainsKey(encalgstr)) {
                    aeadAlg = W3C_GcmAlgs[encalgstr];
                } else {
                    Console.WriteLine("**ERROR: Cannot find matching cipher algorithm for " + encalgstr);
                    return "";
                }
                opts = Pki.Cipher.Opts.PrefixIV;

                // Do AEAD decryption
                ct = Pki.Cnv.FromBase64(ciphervalue);
                pt = Pki.Cipher.DecryptAEAD(ct, key, null, null, aeadAlg, opts);
            } 
            else {
                Console.WriteLine("**ERROR: expecting -cbc or -gcm in Algorithm");
                return "";
            }
            // Convert bytes to string: we assume UTF-8 encoded text
            data = System.Text.Encoding.UTF8.GetString(pt);
            return data;
        }


        static void test_XmlEncDecrypt(string xmlfile)
        {
            string data;

            Console.WriteLine("\nFILE: {0}", xmlfile);
            data = getXmlEncDecryptedText(xmlfile);
            Debug.Assert(data.Length > 0, "Failed to find any data");
            Console.WriteLine("plaintext=[{0}]", data);
        }


        static void Main(string[] args)
        {
            // Setup to display Trace info in console
            // Comment out the next line to turn off Trace
            Trace.Listeners.Add(new TextWriterTraceListener(Console.Out));
         

            // Show versions of core DLLs...
            Trace.WriteLine(string.Format("Xmlsq Version={0:D5}", Xmlsq.Gen.Version()));
            Trace.WriteLine(string.Format("CryptoSys PKI Version={0:D5}", Pki.General.Version()));

            /* These example files are from (or adapted from)
             * W3C "XML Encryption Implementation and Interoperability Report"
             * https://www.w3.org/Encryption/2002/02-xenc-interop.html
             * merlin-xmlenc-five.tar.gz
             * https://lists.w3.org/Archives/Public/xml-encryption/2002Mar/0008.html
             * */


            // Demonstrate decrypting data using block ciphers
            test_XmlEncDecrypt("encrypt-data-aes128-cbc.xml");
            test_XmlEncDecrypt("encrypt-content-tripledes-cbc.xml");
            test_XmlEncDecrypt("encrypt-content-aes256-cbc-prop.xml");
            test_XmlEncDecrypt("encrypt-element-aes128-gcm.xml");
            test_XmlEncDecrypt("encrypt-data-aes256-gcm.xml");

            // Using key wrap
            test_XmlEncDecrypt("encrypt-content-aes128-cbc-kw-aes192.xml");
            test_XmlEncDecrypt("encrypt-data-aes192-cbc-kw-aes256.xml");

            // Using RSA key transport
            test_XmlEncDecrypt("encrypt-element-aes128-cbc-rsa-1_5.xml");
            test_XmlEncDecrypt("encrypt-data-tripledes-cbc-rsa-oaep-mgf1p.xml");
            test_XmlEncDecrypt("encrypt-content-aes128-gcm-rsa-oaep-sha256.xml");
            test_XmlEncDecrypt("root-bob-oaep-gcm.xml");
        }
    }
}