Implementing AES in Go

Implementing AES in Go

You are not expected to understand this.
—Comment in Unix kernel source, quoted in “Lions’ Commentary on Unix”

AES is a powerful encryption algorithm that scrambles your text upside down, backwards, sideways, and inside out. It uses a key to encrypt the plaintext using two different, but complementary, kinds of changes: confusion (substituting one byte for another), and diffusion (shuffling bytes around within a block).

We talked about how AES works, at a functional level, in a previous post. And while the specific scheme used by the Rijndael cipher at the heart of AES is complicated in principle, in practice it can be implemented quickly and easily using pre-generated lookup tables.

So easily, in fact, that we’ll be able to read and understand the core AES encryption function in Go’s standard crypto/aes package. Let’s see how the scrambling magic happens.

Getting ready to rumble

Here’s the beginning of the encryptBlockGo function:

// Encrypt one block from src into dst, using the expanded key xk.
func encryptBlockGo(xk []uint32, dst, src []byte) {
    _ = src[15] // early bounds check
    s0 := binary.BigEndian.Uint32(src[0:4])
    s1 := binary.BigEndian.Uint32(src[4:8])
    s2 := binary.BigEndian.Uint32(src[8:12])
    s3 := binary.BigEndian.Uint32(src[12:16])

(crypto/aes/block.go)

This is the function that encrypts an individual block, or fixed-size chunk of the input. It takes a block’s worth of plaintext bytes (src), and enciphers it into the destination block (dst), using the set of expanded round keys xk.

Since AES blocks should be 16 bytes in size, the first thing we do is a quick sanity check: we reference src[15]. If src is too short, this will panic, which was bound to happen sooner or later anyway—we’re just getting it out of the way without wasting any further computation.

Then, we extract each of the four 4-byte columns of our grid, s0 through s3, from the source block. These will be modified by the various rounds as we go along. For convenience, we’re dealing with each 4-byte chunk as a single Go uint32 value.

Ding ding, round one

We’re now ready to start the first round, which as you may recall from Part 1, consists only of AddRoundKey:

    // First round just XORs input with key.
    s0 ^= xk[0]
    s1 ^= xk[1]
    s2 ^= xk[2]
    s3 ^= xk[3]

(crypto/aes/block.go)

Recall that xk contains the set of round keys that was expanded from the original key. Each round key is 128 bits, or 16 bytes, so it takes up four uint32 elements in the xk slice. In other words, the first four elements of xk are the first round key.

The XOR operation is represented in Go by ^, so this code updates each chunk of data by XORing its bits with the corresponding bits of the round key.

The diffusion loop

Next, since all the remaining rounds but the last will be the same, we can do them in a loop. But the number of rounds varies depending on the key size being used, so we first do a calculation based on the size of xk:

    // Middle rounds shuffle using tables.
    // Number of rounds is set by length of expanded key.
    nr := len(xk)/4 - 2 // - 2: one above, one more below

(crypto/aes/block.go)

In other words, since each group of 4 elements of xk represents a single round key, we can find out how many rounds there are in total by dividing the length of xk by 4. Since the first and last rounds are special, we won’t include them in this loop, hence the need to subtract 2.

Here’s the loop, then:

    k := 4
    var t0, t1, t2, t3 uint32
    for r := 0; r < nr; r++ {
        t0 = xk[k+0] ^ te0[uint8(s0>>24)] ^ te1[uint8(s1>>16)]
            ^ te2[uint8(s2>>8)] ^ te3[uint8(s3)]
        t1 = xk[k+1] ^ te0[uint8(s1>>24)] ^ te1[uint8(s2>>16)]
            ^ te2[uint8(s3>>8)] ^ te3[uint8(s0)]
        t2 = xk[k+2] ^ te0[uint8(s2>>24)] ^ te1[uint8(s3>>16)]
            ^ te2[uint8(s0>>8)] ^ te3[uint8(s1)]
        t3 = xk[k+3] ^ te0[uint8(s3>>24)] ^ te1[uint8(s0>>16)]
            ^ te2[uint8(s1>>8)] ^ te3[uint8(s2)]
        k += 4
        s0, s1, s2, s3 = t0, t1, t2, t3
    }

(crypto/aes/block.go)

Yes, it’s a little dense, but it’s written for efficiency, not readability. It combines SubBytes, ShiftRows, MixColumns, and AddRoundKey in a single operation, on each of our four chunks s0 through s3.

We implement SubBytes by replacing each data byte by its corresponding byte in the pre-generated lookup tables te0 through te3. By careful adjustment of indexes and some bit-shifting, we also implement ShiftRows and MixColumns. Finally, we XOR each of the resulting chunks with its corresponding part of the round key.

Fortunately, you don’t need to understand every detail here to get the general idea: it’s a bunch of lookups and bit-shifts. Indeed, we could have written out each of the round steps separately, to make it clearer what’s happening at each stage, but the standard library authors understandably aren’t concerned with that.

In my book Explore Go: Cryptography, though, we’ll take a deep dive into the history and technology of cryptography, working our way up from simple shift ciphers to modern cryptosystems like AES. By building some encryption and decryption projects in Go, you’ll get a solid, intuitive understanding of what makes ciphers secure (or not), and how to use them as a programmer. Do check out the book if you like this kind of thing; I think you’ll enjoy it.

The last round

Okay, back to the function. Here’s the last round:

    // Last round uses s-box directly and XORs to produce output.
    s0 = uint32(sbox0[t0>>24])<<24 | uint32(sbox0[t1>>16&0xff])<<16
        | uint32(sbox0[t2>>8&0xff])<<8 | uint32(sbox0[t3&0xff])
    s1 = uint32(sbox0[t1>>24])<<24 | uint32(sbox0[t2>>16&0xff])<<16
        | uint32(sbox0[t3>>8&0xff])<<8 | uint32(sbox0[t0&0xff])
    s2 = uint32(sbox0[t2>>24])<<24 | uint32(sbox0[t3>>16&0xff])<<16
        | uint32(sbox0[t0>>8&0xff])<<8 | uint32(sbox0[t1&0xff])
    s3 = uint32(sbox0[t3>>24])<<24 | uint32(sbox0[t0>>16&0xff])<<16
        | uint32(sbox0[t1>>8&0xff])<<8 | uint32(sbox0[t2&0xff])

    s0 ^= xk[k+0]
    s1 ^= xk[k+1]
    s2 ^= xk[k+2]
    s3 ^= xk[k+3]

(crypto/aes/block.go)

In the middle rounds, the S-box substitutions were already factored in to the pre-generated lookup tables which also included the effect of MixColumns. Here, since we’re skipping the MixColumns step, we can refer directly to the sbox0 table, which is just a constant.

Having thoroughly munged our data chunks s0 through s3, we finally write them back to the output block dst:

    _ = dst[15] // early bounds check
    binary.BigEndian.PutUint32(dst[0:4], s0)
    binary.BigEndian.PutUint32(dst[4:8], s1)
    binary.BigEndian.PutUint32(dst[8:12], s2)
    binary.BigEndian.PutUint32(dst[12:16], s3)
}

(crypto/aes/block.go)

Again, we check that dst is long enough before writing the enciphered block to it. So, that’s it: the core of AES in just 45 lines of Go (if you don’t count the constant lookup tables). Not bad!

In the next post, we’ll talk about putting this code to use, by Encrypting with AES in some simple Go programs. See you over there.

Rust error handling is perfect actually

Rust error handling is perfect actually

If you need the money, don't take the job

If you need the money, don't take the job