コンテンツにスキップ

Cipher

暗号化

AES 暗号化の実装を行います。

SubBytes

SubBytes 処理は、状態行列 state の各ブロックに対して、非線形変換である S-box を適用する処理です。

1
typedef uint8_t state_t[4][4];
2
3
#define get_sbox_value(num) (sbox[(num)])
4
5
void sub_bytes(state_t* state) {
6
uint8_t i, j;
7
for (i = 0; i < 4; ++i) {
8
for (j = 0; j < 4; ++j) {
9
(*state)[j][i] = get_sbox_value((*state)[j][i]);
10
}
11
}
12
}

ShiftRows

ShiftRows 処理は、状態行列 state の各行に対して、シフト操作を行う処理です。

(abcdefghijklmnop)(abcdfgheklijpmno)\begin{pmatrix} a & b & c & d \\ e & f & g & h \\ i & j & k & l \\ m & n & o & p \end{pmatrix} \longrightarrow \begin{pmatrix} a & b & c & d \\ f & g & h & e \\ k & l & i & j \\ p & m & n & o \end{pmatrix}

1 行目はシフトなし、2 行目は左に 1 シフト、3 行目は左に 2 シフト、4 行目は左に 3 シフトされます。

1
void shift_rows(state_t* state) {
2
uint8_t temp;
3
4
// 2行目
5
temp = (*state)[0][1];
6
(*state)[0][1] = (*state)[1][1];
7
(*state)[1][1] = (*state)[2][1];
8
(*state)[2][1] = (*state)[3][1];
9
(*state)[3][1] = temp;
10
11
// 3行目
12
temp = (*state)[0][2];
13
(*state)[0][2] = (*state)[2][2];
14
(*state)[2][2] = temp;
15
temp = (*state)[1][2];
16
(*state)[1][2] = (*state)[3][2];
17
(*state)[3][2] = temp;
18
19
// 4行目
20
temp = (*state)[0][3];
21
(*state)[0][3] = (*state)[3][3];
22
(*state)[3][3] = (*state)[2][3];
23
(*state)[2][3] = (*state)[1][3];
24
(*state)[1][3] = temp;
25
}

MixColumns

MixColumns 処理は、状態行列 state の各列に対して、行列演算を行う処理です。

(s0,is1,is2,is3,i)(xx+1111xx+1111xx+1x+111x)(s0,is1,is2,is3,i)\begin{pmatrix} s_{0,i} \\ s_{1,i} \\ s_{2,i} \\ s_{3,i} \end{pmatrix} \leftarrow \begin{pmatrix} x & x+1 & 1 & 1 \\ 1 & x & x+1 & 1 \\ 1 & 1 & x & x+1 \\ x+1 & 1 & 1 & x \\ \end{pmatrix} \begin{pmatrix} s_{0,i} \\ s_{1,i} \\ s_{2,i} \\ s_{3,i} \end{pmatrix}

xsj,ix \cdot s_{j,i} の計算方法は、以下の通りです。

  1. sj,is_{j,i} を 左に 1 シフト
  2. 最上位ビットが 1 である場合、0x1B で XOR
xsj,ix \cdot s_{j,i} を計算するために sj,is_{j,i} を左に 1 シフトする理由

AES では、各バイトは有限体 GF(28)GF(2^8) の要素として扱われます。この有限体は、以下の形で多項式として表現することができます。

b=b7x7+b6x6+b5x5++b1x+b0b = b_7x^7 + b_6x^6 + b_5x^5 + \dots + b_1x + b_0

有限体 GF(28)GF(2^8)xsj,ix \cdot s_{j,i} を計算するということは、バイト sj,is_{j,i} に相当する多項式を 1 次多項式 xx で乗算することを意味します。以下のように多項式全体を 1 次上にシフトすることと同じです。

b7x7+b6x6++b1x+b0左シフトb7x8+b6x7++b1x2+b0xb_7x^7 + b_6x^6 + \dots + b_1x + b_0 \quad \xrightarrow{\text{左シフト}} \quad b_7x^8 + b_6x^7 + \dots + b_1x^2 + b_0x

この左シフト操作は、多項式の係数に xx を掛けることに相当します。

左シフト後に最上位ビットが 1 である場合、固定値0x1Bを XOR することで剰余を取ります。以下のxtimextime 関数で実装されています。

1
uint8_t xtime(uint8_t x) { return ((x << 1) ^ (((x >> 7) & 1) * 0x1b)); }
  • x << 1 は左シフト操作です。
  • ((x >> 7) & 1) は、最上位ビットをチェックします。
  • ((x >> 7) & 1) * 0x1b は、最上位ビットが 1 であった場合に x8x^8 の項を除去するために 0x1B を XOR します。
固定値 0x1B が使用される理由

0x1Bが使用される理由は以下の通りです。

既約多項式 m(x)=x8+x4+x3+x+1m(x) = x^8 + x^4 + x^3 + x + 1 は、AES で使用される有限体 GF(28)GF(2^8) を定義するためのものです。この多項式を 2 進数で表すと、各ビットが多項式の係数に対応します。

m(x)=x8+x4+x3+x+1m(x) = x^8 + x^4 + x^3 + x + 1m(x)=1000110112m(x) = 100011011_2

左から右に向かって、x8x^8x7x^7x6x^6x5x^5x4x^4x3x^3x2x^2x1x^1x0x^0 の係数を表しています。

AES では、乗算結果が 8 ビットを超える場合、すなわち多項式の次数が 8 以上になる場合に、この多項式 m(x)m(x) での剰余を取ります。

100011011をそのまま 16 進数に変換した 0x11B を使用しても動作しますが、ここではuint8_tを使用しており 9 ビット目以降は無視されるため、 00011011 を 16 進数で表した 0x1B で XOR を行っています。

1
uint8_t xtime(uint8_t x) { return ((x << 1) ^ (((x >> 7) & 1) * 0x1b)); }
2
3
void mix_columns(state_t* state) {
4
uint8_t i;
5
uint8_t tmp, tm, t;
6
for (i = 0; i < 4; ++i) {
7
t = (*state)[i][0];
8
tmp = (*state)[i][0] ^ (*state)[i][1] ^ (*state)[i][2] ^ (*state)[i][3];
9
tm = (*state)[i][0] ^ (*state)[i][1];
10
tm = xtime(tm);
11
(*state)[i][0] ^= tm ^ tmp;
12
tm = (*state)[i][1] ^ (*state)[i][2];
13
tm = xtime(tm);
14
(*state)[i][1] ^= tm ^ tmp;
15
tm = (*state)[i][2] ^ (*state)[i][3];
16
tm = xtime(tm);
17
(*state)[i][2] ^= tm ^ tmp;
18
tm = (*state)[i][3] ^ t;
19
tm = xtime(tm);
20
(*state)[i][3] ^= tm ^ tmp;
21
}
22
}

以下の式の一行目の行列計算を例にして、mix_columns 関数の動作を説明します。

(s0,is1,is2,is3,i)(xx+1111xx+1111xx+1x+111x)(s0,is1,is2,is3,i)\begin{pmatrix} s'_{0,i} \\ s'_{1,i} \\ s'_{2,i} \\ s'_{3,i} \end{pmatrix} \leftarrow \begin{pmatrix} x & x+1 & 1 & 1 \\ 1 & x & x+1 & 1 \\ 1 & 1 & x & x+1 \\ x+1 & 1 & 1 & x \\ \end{pmatrix} \begin{pmatrix} s_{0,i} \\ s_{1,i} \\ s_{2,i} \\ s_{3,i} \end{pmatrix} s0,i=(xs0,i)((x+1)s1,i)(1s2,i)(1s3,i)s'_{0,i} = (x \cdot s_{0,i}) \oplus ((x + 1) \cdot s_{1,i}) \oplus (1 \cdot s_{2,i}) \oplus (1 \cdot s_{3,i})
  1. (x+1)s1,i(x + 1) \cdot s_{1,i} の展開

    (x+1)s1,i=(xs1,i)s1,i(x + 1) \cdot s_{1,i} = (x \cdot s_{1,i}) \oplus s_{1,i}

    この部分を元の式に代入します。

  2. 全体の式の展開

    s0,i=(xs0,i)(xs1,i)s1,is2,is3,i=(x(s0,is1,i))s1,is2,is3,i=(x(s0,is1,i))(s0,is0,i)s1,is2,is3,i=s0,i(x(s0,is1,i))(s0,is1,is2,is3,i)=s0,itmtmp\begin{align*} s'_{0,i} &= (x \cdot s_{0,i}) \oplus (x \cdot s_{1,i}) \oplus s_{1,i} \oplus s_{2,i} \oplus s_{3,i} \\ &= (x \cdot (s_{0,i} \oplus s_{1,i})) \oplus s_{1,i} \oplus s_{2,i} \oplus s_{3,i} \\ &= (x \cdot (s_{0,i} \oplus s_{1,i})) \oplus (s_{0,i} \oplus s_{0,i}) \oplus s_{1,i} \oplus s_{2,i} \oplus s_{3,i} \\ &= s_{0,i} \oplus \left( x \cdot (s_{0,i} \oplus s_{1,i}) \right) \oplus \left( s_{0,i} \oplus s_{1,i} \oplus s_{2,i} \oplus s_{3,i} \right) \\ &= s_{0,i} \oplus tm \oplus tmp \\ \end{align*}

これが、以下のコードに対応します。

8
tmp = (*state)[i][0] ^ (*state)[i][1] ^ (*state)[i][2] ^ (*state)[i][3];
9
tm = (*state)[i][0] ^ (*state)[i][1];
10
tm = xtime(tm);
11
(*state)[i][0] ^= tm ^ tmp;

AddRoundKey

AddRoundKey 処理は、状態行列 state とラウンド鍵 round_key の XOR 演算を行う処理です。

1
void add_round_key(uint8_t round, state_t* state, const uint8_t* round_key) {
2
uint8_t i, j;
3
for (i = 0; i < 4; ++i) {
4
for (j = 0; j < 4; ++j) {
5
(*state)[i][j] ^= round_key[(round * Nb * 4) + (i * Nb) + j];
6
}
7
}
8
}

コード

cipher.c
1
#include "cipher.h"
2
3
#define Nb 4
4
#define Nr 10
5
6
extern uint8_t sbox[256];
7
8
#define get_sbox_value(num) (sbox[(num)])
9
10
void add_round_key(uint8_t round, state_t* state, const uint8_t* round_key) {
11
uint8_t i, j;
12
for (i = 0; i < 4; ++i) {
13
for (j = 0; j < 4; ++j) {
14
(*state)[i][j] ^= round_key[(round * Nb * 4) + (i * Nb) + j];
15
}
16
}
17
}
18
19
void sub_bytes(state_t* state) {
20
uint8_t i, j;
21
for (i = 0; i < 4; ++i) {
22
for (j = 0; j < 4; ++j) {
23
(*state)[j][i] = get_sbox_value((*state)[j][i]);
24
}
25
}
26
}
27
28
void shift_rows(state_t* state) {
29
uint8_t temp;
30
31
temp = (*state)[0][1];
32
(*state)[0][1] = (*state)[1][1];
33
(*state)[1][1] = (*state)[2][1];
34
(*state)[2][1] = (*state)[3][1];
35
(*state)[3][1] = temp;
36
37
temp = (*state)[0][2];
38
(*state)[0][2] = (*state)[2][2];
39
(*state)[2][2] = temp;
40
41
temp = (*state)[1][2];
42
(*state)[1][2] = (*state)[3][2];
43
(*state)[3][2] = temp;
44
45
temp = (*state)[0][3];
46
(*state)[0][3] = (*state)[3][3];
47
(*state)[3][3] = (*state)[2][3];
48
(*state)[2][3] = (*state)[1][3];
49
(*state)[1][3] = temp;
50
}
51
52
uint8_t xtime(uint8_t x) { return ((x << 1) ^ (((x >> 7) & 1) * 0x1b)); }
53
54
void mix_columns(state_t* state) {
55
uint8_t i;
56
uint8_t tmp, tm, t;
57
for (i = 0; i < 4; ++i) {
58
t = (*state)[i][0];
59
tmp = (*state)[i][0] ^ (*state)[i][1] ^ (*state)[i][2] ^ (*state)[i][3];
60
tm = (*state)[i][0] ^ (*state)[i][1];
61
tm = xtime(tm);
62
(*state)[i][0] ^= tm ^ tmp;
63
tm = (*state)[i][1] ^ (*state)[i][2];
64
tm = xtime(tm);
65
(*state)[i][1] ^= tm ^ tmp;
66
tm = (*state)[i][2] ^ (*state)[i][3];
67
tm = xtime(tm);
68
(*state)[i][2] ^= tm ^ tmp;
69
tm = (*state)[i][3] ^ t;
70
tm = xtime(tm);
71
(*state)[i][3] ^= tm ^ tmp;
72
}
73
}
74
75
void cipher(state_t* state, const uint8_t* round_key) {
76
uint8_t round = 0;
77
78
add_round_key(0, state, round_key);
79
80
for (round = 1;; ++round) {
81
sub_bytes(state);
82
shift_rows(state);
83
if (round == Nr) {
84
break;
85
}
86
mix_columns(state);
87
add_round_key(round, state, round_key);
88
}
89
90
add_round_key(Nr, state, round_key);
91
}
cipher.h
1
#ifndef _CIPHER_H_
2
#define _CIPHER_H_
3
4
#include <stdint.h>
5
6
typedef uint8_t state_t[4][4];
7
void cipher(state_t* state, const uint8_t* round_key);
8
9
#endif // _CIPHER_H_
test_cipher.h
1
#include <stdio.h>
2
#include <string.h>
3
4
#include "cipher.h"
5
#include "key_expansion.h"
6
#include "rcon.h"
7
#include "sbox.h"
8
9
#define Nb 4
10
#define Nk 4
11
#define Nr 10
12
13
extern uint8_t sbox[256];
14
extern int rcon[Nr + 1];
15
16
static int test_cipher(void);
17
18
int main(void) {
19
int exit;
20
21
int num_rounds = Nr + 1;
22
initialize_aes_sbox(sbox);
23
calculate_rcon(num_rounds, rcon);
24
25
exit = test_cipher();
26
27
return exit;
28
}
29
30
static int test_cipher(void) {
31
uint8_t round_key[176];
32
uint8_t key[16] = {0x2B, 0x7E, 0x15, 0x16, 0x28, 0xAE, 0xD2, 0xA6,
33
0xAB, 0xF7, 0x15, 0x88, 0x09, 0xCF, 0x4F, 0x3C};
34
35
key_expansion(round_key, key);
36
37
uint8_t encrypted[] = {0x3A, 0xD7, 0x7B, 0xB4, 0x0D, 0x7A, 0x36, 0x60,
38
0xA8, 0x9E, 0xCA, 0xF3, 0x24, 0x66, 0xEF, 0x97};
39
uint8_t in[] = {0x6B, 0xC1, 0xBE, 0xE2, 0x2E, 0x40, 0x9F, 0x96,
40
0xE9, 0x3D, 0x7E, 0x11, 0x73, 0x93, 0x17, 0x2A};
41
42
cipher((state_t*)in, round_key);
43
44
printf("AES Cipher test: ");
45
46
if (0 == memcmp((char*)in, (char*)encrypted, sizeof encrypted)) {
47
printf("SUCCESS!\n");
48
return (0);
49
} else {
50
printf("FAILURE!\n");
51
return (1);
52
}
53
}
その他のソースコード
key_expansion.c
1
#include "key_expansion.h"
2
3
#define Nb 4
4
#define Nk 4
5
#define Nr 10
6
7
extern uint8_t sbox[256];
8
extern int rcon[Nr + 1];
9
10
#define get_sbox_value(num) (sbox[(num)])
11
12
void key_expansion(uint8_t* round_key, const uint8_t* key) {
13
unsigned i, j, k;
14
uint8_t tempa[4];
15
16
for (i = 0; i < Nk; ++i) {
17
round_key[(i * 4) + 0] = key[(i * 4) + 0];
18
round_key[(i * 4) + 1] = key[(i * 4) + 1];
19
round_key[(i * 4) + 2] = key[(i * 4) + 2];
20
round_key[(i * 4) + 3] = key[(i * 4) + 3];
21
}
22
23
for (i = Nk; i < Nb * (Nr + 1); ++i) {
24
{
25
k = (i - 1) * 4;
26
tempa[0] = round_key[k + 0];
27
tempa[1] = round_key[k + 1];
28
tempa[2] = round_key[k + 2];
29
tempa[3] = round_key[k + 3];
30
}
31
32
if (i % Nk == 0) {
33
// Function RotWord()
34
{
35
const uint8_t u8tmp = tempa[0];
36
tempa[0] = tempa[1];
37
tempa[1] = tempa[2];
38
tempa[2] = tempa[3];
39
tempa[3] = u8tmp;
40
}
41
42
// Function Subword()
43
{
44
tempa[0] = get_sbox_value(tempa[0]);
45
tempa[1] = get_sbox_value(tempa[1]);
46
tempa[2] = get_sbox_value(tempa[2]);
47
tempa[3] = get_sbox_value(tempa[3]);
48
}
49
50
tempa[0] = tempa[0] ^ rcon[i / Nk];
51
}
52
53
j = i * 4;
54
k = (i - Nk) * 4;
55
round_key[j + 0] = round_key[k + 0] ^ tempa[0];
56
round_key[j + 1] = round_key[k + 1] ^ tempa[1];
57
round_key[j + 2] = round_key[k + 2] ^ tempa[2];
58
round_key[j + 3] = round_key[k + 3] ^ tempa[3];
59
}
60
}
key_expansion.h
1
#ifndef _KEY_EXPANSION_H_
2
#define _KEY_EXPANSION_H_
3
4
#include <stdint.h>
5
6
void key_expansion(uint8_t* round_key, const uint8_t* key);
7
8
#endif // _KEY_EXPANSION_H_
sbox.c
1
#include "sbox.h"
2
3
uint8_t sbox[256];
4
uint8_t inv_sbox[256];
5
6
#define ROTL8(x, shift) ((uint8_t)((x) << (shift)) | ((x) >> (8 - (shift))))
7
8
void initialize_aes_sbox(uint8_t sbox[256]) {
9
uint8_t p = 1, q = 1;
10
11
do {
12
p = p ^ (p << 1) ^ (p & 0x80 ? 0x1B : 0);
13
14
q ^= q << 1;
15
q ^= q << 2;
16
q ^= q << 4;
17
q ^= q & 0x80 ? 0x09 : 0;
18
19
uint8_t xformed =
20
q ^ ROTL8(q, 1) ^ ROTL8(q, 2) ^ ROTL8(q, 3) ^ ROTL8(q, 4);
21
22
sbox[p] = xformed ^ 0x63;
23
} while (p != 1);
24
25
sbox[0] = 0x63;
26
}
27
28
void initialize_inverse_aes_sbox(uint8_t inv_sbox[256],
29
const uint8_t sbox[256]) {
30
for (int i = 0; i < 256; i++) {
31
inv_sbox[sbox[i]] = i;
32
}
33
}
sbox.h
1
#ifndef _SBOX_H_
2
#define _SBOX_H_
3
4
#include <stdint.h>
5
6
void initialize_aes_sbox(uint8_t sbox[256]);
7
void initialize_inverse_aes_sbox(uint8_t inv_sbox[256],
8
const uint8_t sbox[256]);
9
10
#endif // _SBOX_H_
rcon.c
1
#define Nr 10
2
3
int rcon[Nr + 1];
4
5
void calculate_rcon(int num_rounds, int* rcon) {
6
rcon[0] = 0x8d;
7
for (int i = 1; i < num_rounds; ++i) {
8
rcon[i] = rcon[i - 1] << 1;
9
if (rcon[i] & 0x100) {
10
rcon[i] ^= 0x11B;
11
}
12
}
13
}
rcon.h
1
#ifndef _RCON_H_
2
#define _RCON_H_
3
4
void calculate_rcon(int num_rounds, int* rcon);
5
6
#endif // _RCON_H_

実行

Terminal window
gcc cipher.c test_cipher.c key_expansion.c sbox.c rcon.c && ./a.out
実行結果
AES Cipher test: SUCCESS!