romulus-js/src/romulus-m.ts
Jack Hadrill 5a042757dd
All checks were successful
continuous-integration/drone/push Build is passing
continuous-integration/drone/pr Build is passing
Improve code quality and add decrypt status return value
2022-02-05 22:55:31 +00:00

384 lines
12 KiB
TypeScript

import { COUNTER_LENGTH } from './constants'
import { tweakeyEncode, skinnyEncrypt } from './skinny-128-384-plus'
/**
* Parse message into blocks.
* @param message The message to parse.
* @param blockLength The block length.
* @returns An array of blocks.
*/
function parse (message: number[], blockLength: number): number[][] {
// Keep track of position in message currently parsed into blocks.
let cursor = 0
// Slice message into blocks.
let ret: number[][] = []
while (message.length - cursor >= blockLength) {
ret.push(message.slice(cursor, cursor + blockLength))
cursor = cursor + blockLength
}
// Append any remaining blocks regardless of block length. These will be padded later.
if (message.length - cursor > 0) {
ret.push(message.slice(cursor))
}
// If no message, return a single block.
if (message.length === 0) {
ret = [[]]
}
// Insert empty array at position 0.
ret.splice(0, 0, [])
return ret
}
/**
* Pads the byte length of message to padLength. The final byte (when padded) contains the original message length.
* @param message The message to pad.
* @param padLength The length to pad the message to.
* @returns A padded block.
*/
function pad (message: number[], padLength: number): number[] {
// If there is no message, return a fully padded block.
if (message.length === 0) {
return Array(16)
}
// Return a copy of the message if no padding is required.
if (message.length === padLength) {
return Array.from(message)
}
// Pad a copy of the message to padLength.
const ret = Array.from(message)
const requiredPadding = padLength - message.length - 1
ret.push(...Array(requiredPadding))
// Set the final byte of the padded blocked to the length of the original message.
ret[padLength - 1] = message.length
return ret
}
/**
* Generate the key stream from the internal state by multiplying the state S and the constant matrix G.
* @param state The state from which the key stream will be generated.
* @returns The key stream.
*/
function g (state: number[]): number[] {
return state.map(x => {
return (x >> 1) ^ (x & 0x80) ^ ((x & 0x01) << 7)
})
}
/**
* The state update function. Pads an M block.
* @param state The internal state, S.
* @param mBlock An M block.
* @returns [S', C] where S' = M ⊕ S and C = M ⊕ G(S)
*/
function rho (state: number[], mBlock: number[]): [number[], number[]] {
// G(S)
const gOfS = g(state)
// C = M ⊕ G(S)
const cBlock = Array.from(Array(16).keys()).map(i => mBlock[i] ^ gOfS[i])
// S' = M ⊕ S
const nextState = Array.from(Array(16).keys()).map(i => state[i] ^ mBlock[i])
return [nextState, cBlock]
}
/**
* The state update function. Pads a C block.
* @param state The internal state, S.
* @param cBlock A C block.
* @returns [S', M] where M = C ⊕ G(S) and S' = C ⊕ M.
*/
function inverseRoh (state: number[], cBlock: number[]): [number[], number[]] {
// G(S)
const gOfS = g(state)
// M = C ⊕ G(S)
const mBlock = Array.from(Array(16).keys()).map(i => cBlock[i] ^ gOfS[i])
// S' = S ⊕ M
const nextState = Array.from(Array(16).keys()).map(i => state[i] ^ mBlock[i])
return [nextState, mBlock]
}
/**
* Increments the 56 bit LFSR-based counter.
* @param counter The old counter.
* @returns An incremented counter.
*/
function increaseCounter (counter: number[]): number[] {
const fb0 = counter[6] >> 7
counter[6] = (counter[6] << 1) | (counter[5] >> 7)
counter[5] = (counter[5] << 1) | (counter[4] >> 7)
counter[4] = (counter[4] << 1) | (counter[3] >> 7)
counter[3] = (counter[3] << 1) | (counter[2] >> 7)
counter[2] = (counter[2] << 1) | (counter[1] >> 7)
counter[1] = (counter[1] << 1) | (counter[0] >> 7)
if (fb0 === 1) {
counter[0] = (counter[0] << 1) ^ 0x95
} else {
counter[0] = (counter[0] << 1)
}
return counter
}
/**
* Returns a reset counter.
* @returns A reset counter.
*/
function resetCounter (): number[] {
const counter = Array(COUNTER_LENGTH)
counter[0] = 1
return counter
}
/**
* Calculate the domain separation.
* @param combinedData The parsed and concatenated message and associated data,
* @param parsedMessageLength The length of the parsed message.
* @param parsedAssociatedDataLength The length of the parsed associated data.
*/
function calculateDomainSeparation (combinedData: number[][], parsedMessageLength: number, parsedAssociatedDataLength: number): number {
let domainSeparation = 16
if (combinedData[parsedAssociatedDataLength].length < 16) {
domainSeparation = domainSeparation ^ 2
}
if (combinedData[parsedAssociatedDataLength + parsedMessageLength].length < 16) {
domainSeparation = domainSeparation ^ 1
}
if (parsedAssociatedDataLength % 2 === 0) {
domainSeparation = domainSeparation ^ 8
}
if (parsedMessageLength % 2 === 0) {
domainSeparation = domainSeparation ^ 4
}
return domainSeparation
}
/**
* Encrypt a message using the Romulus-M cryptography specification.
* See https://romulusae.github.io/romulus/docs/Romulusv1.3.pdf for more information.
* @param message The message to encrypt.
* @param associatedData The associated data to encrypt.
* @param nonce A 128 bit nonce.
* @param key A 128 bit encryption key.
* @returns The encrypted ciphertext.
*/
export function cryptoAeadEncrypt (message: number[], associatedData: number[], nonce: number[], key: number[]): number[] {
// Buffer for ciphertext.
const ciphertext = []
// Reset state and counter.
let state = Array(16)
let counter = resetCounter()
// Carve message and associated data into blocks.
const messageBlocks = parse(message, 16)
const messageBlockCount = messageBlocks.length - 1
const associatedDataBlocks = parse(associatedData, 16)
const associatedDataBlockCount = associatedDataBlocks.length - 1
// Concatenate the message and associated data blocks, excluding each array's first element.
const combinedDataBlocks = associatedDataBlocks.slice(1).concat(messageBlocks.slice(1))
// Insert empty array at position 0.
combinedDataBlocks.splice(0, 0, [])
// Calculate domain separation for final encryption stage.
const domainSeparation = calculateDomainSeparation(combinedDataBlocks, messageBlockCount, associatedDataBlockCount)
// Pad combined data.
combinedDataBlocks[associatedDataBlockCount] = pad(combinedDataBlocks[associatedDataBlockCount], 16)
combinedDataBlocks[associatedDataBlockCount + messageBlockCount] = pad(combinedDataBlocks[associatedDataBlockCount + messageBlockCount], 16)
// Process the associated data.
let x = 8
for (let i = 1; i < Math.floor((associatedDataBlockCount + messageBlockCount) / 2) + 1; i++) {
[state] = rho(state, combinedDataBlocks[2 * i - 1])
counter = increaseCounter(counter)
if (i === Math.floor(associatedDataBlockCount / 2) + 1) {
x = x ^ 4
}
state = skinnyEncrypt(state, tweakeyEncode(counter, x, combinedDataBlocks[2 * i], key))
counter = increaseCounter(counter)
}
if (associatedDataBlockCount % 2 === messageBlockCount % 2) {
[state] = rho(state, Array(16))
} else {
[state] = rho(state, combinedDataBlocks[associatedDataBlockCount + messageBlockCount])
counter = increaseCounter(counter)
}
// Generate authentication tag.
const [,authenticationTag] = rho(skinnyEncrypt(state, tweakeyEncode(counter, domainSeparation, nonce, key)), Array(16))
if (message.length === 0) {
return authenticationTag
}
state = Array.from(authenticationTag)
counter = resetCounter()
// Encrypt the message.
const originalFinalMessageBlockLength = messageBlocks[messageBlockCount].length
messageBlocks[messageBlockCount] = pad(messageBlocks[messageBlockCount], 16)
for (let i = 1; i < messageBlockCount + 1; i++) {
state = skinnyEncrypt(state, tweakeyEncode(counter, 4, nonce, key))
let cBlock
[state, cBlock] = rho(state, messageBlocks[i])
counter = increaseCounter(counter)
if (i < messageBlockCount) {
ciphertext.push(...cBlock)
} else {
ciphertext.push(...cBlock.slice(0, originalFinalMessageBlockLength))
}
}
// Store the authentication tag in the final 16 bytes of the ciphertext.
ciphertext.push(...authenticationTag)
return ciphertext
}
/**
* Return interface for decrypting a message.
*/
export interface DecryptResult {
success: boolean
plaintext: number[]
}
/**
* Decrypt a message using the Romulus-M cryptography specification.
* See https://romulusae.github.io/romulus/docs/Romulusv1.3.pdf for more information.
* @param ciphertext The ciphertext to decrypt.
* @param associatedData The associated data.
* @param nonce The nonce.
* @param key The key.
* @returns The decrypted plaintext.
*/
export function cryptoAeadDecrypt (ciphertext: number[], associatedData: number[], nonce: number[], key: number[]): DecryptResult {
// Buffer for decrypted message.
const cleartext = []
// The authentication tag is represented by the final 16 bytes of the ciphertext.
const authenticationTag = ciphertext.slice(-16)
ciphertext.length -= 16
// Reset state and counter.
let state = Array(16)
let counter = resetCounter()
if (ciphertext.length !== 0) {
// Combine the ciphertext.
state = Array.from(authenticationTag)
const ciphertextBlocks = parse(ciphertext, 16)
const ciphertextBlockCount = ciphertextBlocks.length - 1
const finalCiphertextBlockLength = ciphertextBlocks[ciphertextBlockCount].length
ciphertextBlocks[ciphertextBlockCount] = pad(ciphertextBlocks[ciphertextBlockCount], 16)
for (let i = 1; i < ciphertextBlockCount + 1; i++) {
state = skinnyEncrypt(state, tweakeyEncode(counter, 4, nonce, key))
let mBlock
[state, mBlock] = inverseRoh(state, ciphertextBlocks[i])
counter = increaseCounter(counter)
if (i < ciphertextBlockCount) {
cleartext.push(...mBlock)
} else {
cleartext.push(...mBlock.slice(0, finalCiphertextBlockLength))
}
}
} else {
state = []
}
// Reset state and counter.
state = Array(16)
counter = resetCounter()
// Carve the message and associated data into blocks.
const messageBlocks = parse(cleartext, 16)
const messageBlockLength = messageBlocks.length - 1
const associatedDataBlocks = parse(associatedData, 16)
const associatedDataBlockCount = associatedDataBlocks.length - 1
// Concatenate the message and associated data blocks, excluding each array's first element.
const combinedData = associatedDataBlocks.slice(1).concat(messageBlocks.slice(1))
// Insert empty array at position 0.
combinedData.splice(0, 0, [])
// Calculate domain separation for final decryption stage.
const domainSeparation = calculateDomainSeparation(combinedData, messageBlockLength, associatedDataBlockCount)
// Pad combined data.
combinedData[associatedDataBlockCount] = pad(combinedData[associatedDataBlockCount], 16)
combinedData[associatedDataBlockCount + messageBlockLength] = pad(combinedData[associatedDataBlockCount + messageBlockLength], 16)
// Verifiy associated data.
let x = 8
for (let i = 1; i < Math.floor((associatedDataBlockCount + messageBlockLength) / 2) + 1; i++) {
[state] = rho(state, combinedData[2 * i - 1])
counter = increaseCounter(counter)
if (i === Math.floor(associatedDataBlockCount / 2) + 1) {
x = x ^ 4
}
state = skinnyEncrypt(state, tweakeyEncode(counter, x, combinedData[2 * i], key))
counter = increaseCounter(counter)
}
if (associatedDataBlockCount % 2 === messageBlockLength % 2) {
[state] = rho(state, Array(16))
} else {
[state] = rho(state, combinedData[associatedDataBlockCount + messageBlockLength])
counter = increaseCounter(counter)
}
// Calculate authentication tag.
const [,computedTag] = rho(skinnyEncrypt(state, tweakeyEncode(counter, domainSeparation, nonce, key)), Array(16))
// Validate authentication tag.
let compare = 0
for (let i = 0; i < 16; i++) {
compare |= (authenticationTag[i] ^ computedTag[i])
}
if (compare !== 0) {
// Authentication failed.
return {
success: false,
plaintext: []
}
} else {
// Decrypted successfully.
return {
success: true,
plaintext: cleartext
}
}
}