diff --git a/src/romulus-m.ts b/src/romulus-m.ts index 9695303..17e1422 100644 --- a/src/romulus-m.ts +++ b/src/romulus-m.ts @@ -91,6 +91,24 @@ function rho (state: number[], mBlock: number[]): [number[], number[]] { return [nextState, c] } +/** + * The state update function. Pads a C block. + * @param state The internal state, S. + * @param cBlock A C block. + * @returns [S', M] where M = C ⊕ G(S) and S' = C ⊕ M. + */ +function inverseRoh (state: number[], cBlock: number[]): [number[], number[]] { + // G(S) + const gOfS = g(state) + + // M = C ⊕ G(S) + const mBlock = [...Array(16).keys()].map(i => cBlock[i] ^ gOfS[i]) + + // S' = S ⊕ M + const nextState = [...Array(16).keys()].map(i => state[i] ^ mBlock[i]) + return [nextState, mBlock] +} + /** * Increments the 56 bit LFSR-based counter. * @param counter The old counter. @@ -243,3 +261,90 @@ export function cryptoAeadEncrypt (message: number[], associatedData: number[], return ciphertext } + +export function cryptoAeadDecrypt (ciphertext: number[], associatedData: number[], nonce: number[], key: number[]): number[] { + // Buffer for decrypted message. + const message = [] + + // The authentication tag is represented by the last 16 bytes of the ciphertext. + const authenticationTag = ciphertext.slice(-16) + ciphertext.length -= 16 + + let state = zeroedBuffer(16) + if (ciphertext.length !== 0) { + state = [...authenticationTag] + const parsedCiphertext = parse(ciphertext, 16) + const parsedCiphertextLength = parsedCiphertext.length - 1 + const finalCiphertextBlockLength = parsedCiphertext[parsedCiphertextLength].length + parsedCiphertext[parsedCiphertextLength] = pad(parsedCiphertext[parsedCiphertextLength], 16) + + var counter = resetCounter() + + for (let i = 1; i < parsedCiphertextLength + 1; i++) { + state = skinnyEncrypt(state, tweakeyEncode(counter, 4, nonce, key)) + let mBlock + [state, mBlock] = inverseRoh(state, parsedCiphertext[i]) + counter = increaseCounter(counter) + if (i < parsedCiphertextLength) { + message.push(...mBlock) + } else { + message.push(...mBlock.slice(0, finalCiphertextBlockLength)) + } + } + } else { + state = [] + } + + state = zeroedBuffer(16) + counter = resetCounter() + + const parsedAssociatedData = parse(associatedData, 16) + const parsedAssociatedDataLength = parsedAssociatedData.length - 1 + + const parsedMessage = parse(message, 16) + const parsedMessageLength = parsedMessage.length - 1 + + // Concatenate the parsed message and the associated data, excluding each array's first element. + const combinedData = parsedAssociatedData.slice(1).concat(parsedMessage.slice(1)) + + // Insert empty array at position 0. + combinedData.splice(0, 0, []) + + // Calculate domain separation for final decryption stage. + const domainSeparation = calculateDomainSeparation(combinedData, parsedMessageLength, parsedAssociatedDataLength) + + // Pad combined data. + combinedData[parsedAssociatedDataLength] = pad(combinedData[parsedAssociatedDataLength], 16) + combinedData[parsedAssociatedDataLength + parsedMessageLength] = pad(combinedData[parsedAssociatedDataLength + parsedMessageLength], 16) + + let x = 8 + for (let i = 1; i < Math.floor((parsedAssociatedDataLength + parsedMessageLength) / 2) + 1; i++) { + [state] = rho(state, combinedData[2 * i - 1]) + counter = increaseCounter(counter) + if (i === Math.floor(parsedAssociatedDataLength / 2) + 1) { + x = x ^ 4 + } + state = skinnyEncrypt(state, tweakeyEncode(counter, x, combinedData[2 * i], key)) + counter = increaseCounter(counter) + } + + if (parsedAssociatedDataLength % 2 === parsedMessageLength % 2) { + [state] = rho(state, zeroedBuffer(16)) + } else { + [state] = rho(state, combinedData[parsedAssociatedDataLength + parsedMessageLength]) + counter = increaseCounter(counter) + } + let computedTag + [state, computedTag] = rho(skinnyEncrypt(state, tweakeyEncode(counter, domainSeparation, nonce, key)), zeroedBuffer(16)) + + let compare = 0 + for (let i = 0; i < 16; i++) { + compare |= (authenticationTag[i] ^ computedTag[i]) + } + + if (compare !== 0) { + return [] + } else { + return message + } +} diff --git a/tests/romulus-m.test.ts b/tests/romulus-m.test.ts index dc32dd9..7fe9538 100644 --- a/tests/romulus-m.test.ts +++ b/tests/romulus-m.test.ts @@ -1,4 +1,4 @@ -import { cryptoAeadEncrypt } from '../src/romulus-m' +import { cryptoAeadDecrypt, cryptoAeadEncrypt } from '../src/romulus-m' function stringToArray (string: string): number[] { const encoder = new TextEncoder() @@ -42,3 +42,41 @@ test('Encrypt a message with associated data.', () => { ] expect(result).toMatchObject(expectedResult) }) + +test('Decrypt a message with no associated data.', () => { + // Given + const ciphertext = [ + 85, 125, 23, 244, 73, 241, 140, 72, 166, 113, 114, 78, 239, 211, 84, 113, 222, + 153, 207, 183, 69, 142, 174, 15, 38, 46, 112, 162, 229, 27, 136, 184, 163, 78, + 132, 42, 107, 160, 74, 115, 28, 251, 209, 37, 48, 57, 184, 204, 199, 247, 93, 5, 208 + ] + const associatedData = stringToArray('') + const nonce = stringToArray('\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f') + const key = stringToArray('\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f') + + // When + const result = cryptoAeadDecrypt(ciphertext, associatedData, nonce, key) + + // Then + const expectedResult = stringToArray('Hello, World! This is a test message.') + expect(result).toMatchObject(expectedResult) +}) + +test('Decrypt a message with associated data.', () => { + // Given + const ciphertext = [ + 225, 53, 3, 212, 22, 112, 246, 194, 61, 171, 230, 187, 157, 102, 32, 76, 62, 65, + 25, 202, 255, 201, 206, 49, 60, 58, 82, 216, 72, 116, 106, 129, 162, 142, 69, 40, + 167, 88, 94, 195, 174, 217, 242, 149, 224, 125, 196, 237, 172, 165, 116, 119, 128 + ] + const associatedData = stringToArray('Some associated data.') + const nonce = stringToArray('\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f') + const key = stringToArray('\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f') + + // When + const result = cryptoAeadDecrypt(ciphertext, associatedData, nonce, key) + + // Then + const expectedResult = stringToArray('Hello, World! This is a test message.') + expect(result).toMatchObject(expectedResult) +})