C#实现MD4算法

更新于 2023-12-26

在用C#实现SMTP协议中的NTLM算法时需要用到MD4,没找到官方实现,参考第三方实现并进行简单的修改。

using System;

namespace Zhger.Sparkle.Security
{
    /// <summary>
    /// https://github.com/antanaskat/md4_hashing/tree/master
    /// </summary>
    public class MD4
    {

        private static void Word2Bytes(byte[] buffer, int offset, uint word)
        {
            buffer[offset++] = (byte)word;
            buffer[offset++] = (byte)(word >> 8);
            buffer[offset++] = (byte)(word >> 16);
            buffer[offset++] = (byte)(word >> 24);
        }
        // apsirasom funkcijas
        private static uint F(uint x, uint y, uint z) => (x & y) | (~x & z);
        private static uint G(uint x, uint y, uint z) => (x & y) | (x & z) | (y & z);
        private static uint H(uint x, uint y, uint z) => x ^ y ^ z;

        // apsirasom left rotate funkcija
        private static uint leftRotate(uint x, uint y) => x << (int)y | x >> 32 - (int)y;

        // round1 - 1,11,12,13 kad galetume keist visu 4 registru reiksmes
        private static uint ROUND1(uint a, uint b, uint c, uint d, uint x, ushort m)
            => leftRotate((a + F(b, c, d) + x), m);
        private static uint ROUND11(uint d, uint a, uint b, uint c, uint x, ushort m)
            => leftRotate((d + F(a, b, c) + x), m);
        private static uint ROUND12(uint c, uint d, uint a, uint b, uint x, ushort m)
            => leftRotate((c + F(d, a, b) + x), m);
        private static uint ROUND13(uint b, uint c, uint d, uint a, uint x, ushort m)
            => leftRotate((b + F(c, d, a) + x), m);

        // Round2
        private static uint ROUND2(uint a, uint b, uint c, uint d, uint x, ushort m)
            => leftRotate((a + G(b, c, d) + x + (uint)0x5a827999), m);
        private static uint ROUND21(uint d, uint a, uint b, uint c, uint x, ushort m)
            => leftRotate((d + G(a, b, c) + x + (uint)0x5a827999), m);
        private static uint ROUND22(uint c, uint d, uint a, uint b, uint x, ushort m)
            => leftRotate((c + G(d, a, b) + x + (uint)0x5a827999), m);
        private static uint ROUND23(uint b, uint c, uint d, uint a, uint x, ushort m)
            => leftRotate((b + G(c, d, a) + x + (uint)0x5a827999), m);

        //Round3
        private static uint ROUND3(uint a, uint b, uint c, uint d, uint x, ushort m)
            => leftRotate((a + H(b, c, d) + x + (uint)0x6ed9eba1), m);
        private static uint ROUND31(uint d, uint a, uint b, uint c, uint x, ushort m)
            => leftRotate((d + H(a, b, c) + x + (uint)0x6ed9eba1), m);
        private static uint ROUND32(uint c, uint d, uint a, uint b, uint x, ushort m)
            => leftRotate((c + H(d, a, b) + x + (uint)0x6ed9eba1), m);
        private static uint ROUND33(uint b, uint c, uint d, uint a, uint x, ushort m)
            => leftRotate((b + H(c, d, a) + x + (uint)0x6ed9eba1), m);

        /// <summary>
        /// compute md4 hash for whole buffer
        /// </summary>
        /// <param name="buffer"></param>
        /// <returns></returns>
        public static byte[] Compute(byte[] buffer) =>Compute(buffer, 0, buffer.Length);

        /// <summary>
        /// compute md4 hash for part of buffer
        /// </summary>
        /// <param name="buffer"></param>
        /// <param name="offset"></param>
        /// <param name="count"></param>
        /// <returns></returns>
        public static byte[] Compute(byte[] buffer, int offset, int count)
        {
            int extra = (count + 1) % 64;
            int padding = extra > 56 ? (64 + 56 - extra) : (56 - extra);
            byte[] bytes = new byte[count + 1 + padding];

            Buffer.BlockCopy(buffer, offset, bytes, 0, count);

            bytes[count] = 128;

            var uints = new uint[bytes.Length / 4 + 2];
            for (int i = 0; i + 3 < bytes.Length; i += 4) // shiftinam i kaire
            {
                uints[i >> 2] = bytes[i] | (uint)bytes[i + 1] << 8 | (uint)bytes[i + 2] << 16 | (uint)bytes[i + 3] << 24;
            }

            uints[uints.Length - 2] = (uint)(count * 8);
            uints[uints.Length - 1] = 0;

            // apsibreziam nurodytas MD4 reiksmes
            uint A = 0x67452301;
            uint B = 0xefcdab89;
            uint C = 0x98badcfe;
            uint D = 0x10325476;

            uint[] X = new uint[16];
            // skirstom po 16 32bit zodziu
            for (int i = 0; i < uints.Length; i += 16)
            {
                for (int j = 0; j < 16; j++) X[j] = uints[i + j];

                // issisaugom turimas registru reiksmes, kad galetume panaudoti pabaigoje
                uint AA = A;
                uint BB = B;
                uint CC = C;
                uint DD = D;
                // duodam parametrus i aprasytus round funkcijas
                A = ROUND1(A, B, C, D, X[0], 3);
                D = ROUND11(D, A, B, C, X[1], 7);
                C = ROUND12(C, D, A, B, X[2], 11);
                B = ROUND13(B, C, D, A, X[3], 19);
                A = ROUND1(A, B, C, D, X[4], 3);
                D = ROUND11(D, A, B, C, X[5], 7);
                C = ROUND12(C, D, A, B, X[6], 11);
                B = ROUND13(B, C, D, A, X[7], 19);
                A = ROUND1(A, B, C, D, X[8], 3);
                D = ROUND11(D, A, B, C, X[9], 7);
                C = ROUND12(C, D, A, B, X[10], 11);
                B = ROUND13(B, C, D, A, X[11], 19);
                A = ROUND1(A, B, C, D, X[12], 3);
                D = ROUND11(D, A, B, C, X[13], 7);
                C = ROUND12(C, D, A, B, X[14], 11);
                B = ROUND13(B, C, D, A, X[15], 19);

                A = ROUND2(A, B, C, D, X[0], 3);
                D = ROUND21(D, A, B, C, X[4], 5);
                C = ROUND22(C, D, A, B, X[8], 9);
                B = ROUND23(B, C, D, A, X[12], 13);
                A = ROUND2(A, B, C, D, X[1], 3);
                D = ROUND21(D, A, B, C, X[5], 5);
                C = ROUND22(C, D, A, B, X[9], 9);
                B = ROUND23(B, C, D, A, X[13], 13);
                A = ROUND2(A, B, C, D, X[2], 3);
                D = ROUND21(D, A, B, C, X[6], 5);
                C = ROUND22(C, D, A, B, X[10], 9);
                B = ROUND23(B, C, D, A, X[14], 13);
                A = ROUND2(A, B, C, D, X[3], 3);
                D = ROUND21(D, A, B, C, X[7], 5);
                C = ROUND22(C, D, A, B, X[11], 9);
                B = ROUND23(B, C, D, A, X[15], 13);

                A = ROUND3(A, B, C, D, X[0], 3);
                D = ROUND31(D, A, B, C, X[8], 9);
                C = ROUND32(C, D, A, B, X[4], 11);
                B = ROUND33(B, C, D, A, X[12], 15);
                A = ROUND3(A, B, C, D, X[2], 3);
                D = ROUND31(D, A, B, C, X[10], 9);
                C = ROUND32(C, D, A, B, X[6], 11);
                B = ROUND33(B, C, D, A, X[14], 15);
                A = ROUND3(A, B, C, D, X[1], 3);
                D = ROUND31(D, A, B, C, X[9], 9);
                C = ROUND32(C, D, A, B, X[5], 11);
                B = ROUND33(B, C, D, A, X[13], 15);
                A = ROUND3(A, B, C, D, X[3], 3);
                D = ROUND31(D, A, B, C, X[11], 9);
                C = ROUND32(C, D, A, B, X[7], 11);
                B = ROUND33(B, C, D, A, X[15], 15);

                A += AA;
                B += BB;
                C += CC;
                D += DD;
            }


            byte[] hash = new byte[16];
            Word2Bytes(hash, 0, A);
            Word2Bytes(hash, 4, B);
            Word2Bytes(hash, 8, C);
            Word2Bytes(hash, 12, D);
            return hash;
        }
    }
}