Implementing AES in Go
You are not expected to understand this.
—Comment in Unix kernel source, quoted in “Lions’ Commentary on Unix”
AES is a powerful encryption algorithm that scrambles your text upside down, backwards, sideways, and inside out. It uses a key to encrypt the plaintext using two different, but complementary, kinds of changes: confusion (substituting one byte for another), and diffusion (shuffling bytes around within a block).
We talked about how AES works, at a functional level, in a previous post. And while the specific scheme used by the Rijndael cipher at the heart of AES is complicated in principle, in practice it can be implemented quickly and easily using pre-generated lookup tables.
So easily, in fact, that we’ll be able to read and understand the
core AES encryption function in Go’s standard crypto/aes
package. Let’s see how the scrambling magic happens.
Getting ready to rumble
Here’s the beginning of the encryptBlockGo
function:
// Encrypt one block from src into dst, using the expanded key xk.
func encryptBlockGo(xk []uint32, dst, src []byte) {
= src[15] // early bounds check
_ := binary.BigEndian.Uint32(src[0:4])
s0 := binary.BigEndian.Uint32(src[4:8])
s1 := binary.BigEndian.Uint32(src[8:12])
s2 := binary.BigEndian.Uint32(src[12:16]) s3
This is the function that encrypts an individual block, or
fixed-size chunk of the input. It takes a block’s worth of plaintext
bytes (src
), and enciphers it into the destination block
(dst
), using the set of expanded round keys
xk
.
Since AES blocks should be 16 bytes in size, the first thing we do is
a quick sanity check: we reference src[15]
. If
src
is too short, this will panic, which was bound to
happen sooner or later anyway—we’re just getting it out of the way
without wasting any further computation.
Then, we extract each of the four 4-byte columns of our grid,
s0
through s3
, from the source block. These
will be modified by the various rounds as we go along. For convenience,
we’re dealing with each 4-byte chunk as a single Go uint32
value.
Ding ding, round one
We’re now ready to start the first round, which as you may recall from Part 1, consists only of AddRoundKey:
// First round just XORs input with key.
^= xk[0]
s0 ^= xk[1]
s1 ^= xk[2]
s2 ^= xk[3] s3
Recall that xk
contains the set of round keys that was
expanded from the original key. Each round key is 128 bits, or 16 bytes,
so it takes up four uint32
elements in the xk
slice. In other words, the first four elements of xk
are
the first round key.
The XOR operation is represented in Go by ^
, so this
code updates each chunk of data by XORing its bits with the
corresponding bits of the round key.
The diffusion loop
Next, since all the remaining rounds but the last will be the same,
we can do them in a loop. But the number of rounds varies depending on
the key size being used, so we first do a calculation based on the size
of xk
:
// Middle rounds shuffle using tables.
// Number of rounds is set by length of expanded key.
:= len(xk)/4 - 2 // - 2: one above, one more below nr
In other words, since each group of 4 elements of xk
represents a single round key, we can find out how many rounds there are
in total by dividing the length of xk
by 4. Since the first
and last rounds are special, we won’t include them in this loop, hence
the need to subtract 2.
Here’s the loop, then:
:= 4
k var t0, t1, t2, t3 uint32
for r := 0; r < nr; r++ {
= xk[k+0] ^ te0[uint8(s0>>24)] ^ te1[uint8(s1>>16)]
t0 ^ te2[uint8(s2>>8)] ^ te3[uint8(s3)]
= xk[k+1] ^ te0[uint8(s1>>24)] ^ te1[uint8(s2>>16)]
t1 ^ te2[uint8(s3>>8)] ^ te3[uint8(s0)]
= xk[k+2] ^ te0[uint8(s2>>24)] ^ te1[uint8(s3>>16)]
t2 ^ te2[uint8(s0>>8)] ^ te3[uint8(s1)]
= xk[k+3] ^ te0[uint8(s3>>24)] ^ te1[uint8(s0>>16)]
t3 ^ te2[uint8(s1>>8)] ^ te3[uint8(s2)]
+= 4
k , s1, s2, s3 = t0, t1, t2, t3
s0}
Yes, it’s a little dense, but it’s written for efficiency, not
readability. It combines SubBytes, ShiftRows, MixColumns, and
AddRoundKey in a single operation, on each of our four chunks
s0
through s3
.
We implement SubBytes by replacing each data byte by its
corresponding byte in the pre-generated lookup tables te0
through te3
. By careful adjustment of indexes and some
bit-shifting, we also implement ShiftRows and MixColumns. Finally, we
XOR each of the resulting chunks with its corresponding part of the
round key.
Fortunately, you don’t need to understand every detail here to get the general idea: it’s a bunch of lookups and bit-shifts. Indeed, we could have written out each of the round steps separately, to make it clearer what’s happening at each stage, but the standard library authors understandably aren’t concerned with that.
In my book Explore Go: Cryptography, though, we’ll take a deep dive into the history and technology of cryptography, working our way up from simple shift ciphers to modern cryptosystems like AES. By building some encryption and decryption projects in Go, you’ll get a solid, intuitive understanding of what makes ciphers secure (or not), and how to use them as a programmer. Do check out the book if you like this kind of thing; I think you’ll enjoy it.
The last round
Okay, back to the function. Here’s the last round:
// Last round uses s-box directly and XORs to produce output.
= uint32(sbox0[t0>>24])<<24 | uint32(sbox0[t1>>16&0xff])<<16
s0 | uint32(sbox0[t2>>8&0xff])<<8 | uint32(sbox0[t3&0xff])
= uint32(sbox0[t1>>24])<<24 | uint32(sbox0[t2>>16&0xff])<<16
s1 | uint32(sbox0[t3>>8&0xff])<<8 | uint32(sbox0[t0&0xff])
= uint32(sbox0[t2>>24])<<24 | uint32(sbox0[t3>>16&0xff])<<16
s2 | uint32(sbox0[t0>>8&0xff])<<8 | uint32(sbox0[t1&0xff])
= uint32(sbox0[t3>>24])<<24 | uint32(sbox0[t0>>16&0xff])<<16
s3 | uint32(sbox0[t1>>8&0xff])<<8 | uint32(sbox0[t2&0xff])
^= xk[k+0]
s0 ^= xk[k+1]
s1 ^= xk[k+2]
s2 ^= xk[k+3] s3
In the middle rounds, the S-box substitutions were already factored
in to the pre-generated lookup tables which also included the effect of
MixColumns. Here, since we’re skipping the MixColumns step, we can refer
directly to the sbox0
table, which is just a constant.
Having thoroughly munged our data chunks s0
through
s3
, we finally write them back to the output block
dst
:
= dst[15] // early bounds check
_ .BigEndian.PutUint32(dst[0:4], s0)
binary.BigEndian.PutUint32(dst[4:8], s1)
binary.BigEndian.PutUint32(dst[8:12], s2)
binary.BigEndian.PutUint32(dst[12:16], s3)
binary}
Again, we check that dst
is long enough before writing
the enciphered block to it. So, that’s it: the core of AES in just 45
lines of Go (if you don’t count the constant lookup tables). Not
bad!
In the next post, we’ll talk about putting this code to use, by Encrypting with AES in some simple Go programs. See you over there.