Skip to content

encoding/base64: add constant-time behavior, enabled by default #73909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 43 additions & 22 deletions src/encoding/base64/base64.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
* Encodings
*/

type MappingFunc func (in uint) byte
// An Encoding is a radix 64 encoding/decoding scheme, defined by a
// 64-character alphabet. The most common encoding is the "base64"
// encoding defined in RFC 4648 and used in MIME (RFC 2045) and PEM
Expand All @@ -26,6 +27,8 @@ type Encoding struct {
decodeMap [256]uint8 // mapping of symbol byte value to symbol index
padChar rune
strict bool
encodeMapFunc MappingFunc // optional mapping function to replace table look-ups when encoding
decodeMapFunc MappingFunc // optional mapping function to replace table look-ups when decoding
}

const (
Expand Down Expand Up @@ -83,6 +86,10 @@ func NewEncoding(encoder string) *Encoding {
}
e.decodeMap[encoder[i]] = uint8(i)
}
var encodeFunc = e.encodeMapDefault
e.encodeMapFunc = encodeFunc
var decodeFunc = e.decodeMapDefault
e.decodeMapFunc = decodeFunc
return e
}

Expand All @@ -104,6 +111,17 @@ func (enc Encoding) WithPadding(padding rune) *Encoding {
return &enc
}

// WithDecodeMappingFunc sets the value fo encodeMapFunc
func (enc Encoding) WithDecodeMappingFunc(f MappingFunc) *Encoding {
enc.decodeMapFunc = f
return &enc
}
// WithEncodeMappingFunc sets the value fo encodeMapFunc
func (enc Encoding) WithEncodeMappingFunc(f MappingFunc) *Encoding {
enc.encodeMapFunc = f
return &enc
}

// Strict creates a new encoding identical to enc except with
// strict decoding enabled. In this mode, the decoder requires that
// trailing padding bits are zero, as described in RFC 4648 section 3.5.
Expand All @@ -116,11 +134,15 @@ func (enc Encoding) Strict() *Encoding {
}

// StdEncoding is the standard base64 encoding, as defined in RFC 4648.
var StdEncoding = NewEncoding("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/")
var StdEncoding = NewEncoding("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/").
WithEncodeMappingFunc(StandardBase64Encode).
WithDecodeMappingFunc(StandardBase64Decode)

// URLEncoding is the alternate base64 encoding defined in RFC 4648.
// It is typically used in URLs and file names.
var URLEncoding = NewEncoding("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_")
var URLEncoding = NewEncoding("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_").
WithEncodeMappingFunc(UrlSafeBase64Encode).
WithDecodeMappingFunc(UrlSafeBase64Decode)

// RawStdEncoding is the standard raw, unpadded base64 encoding,
// as defined in RFC 4648 section 3.2.
Expand Down Expand Up @@ -157,10 +179,10 @@ func (enc *Encoding) Encode(dst, src []byte) {
// Convert 3x 8bit source bytes into 4 bytes
val := uint(src[si+0])<<16 | uint(src[si+1])<<8 | uint(src[si+2])

dst[di+0] = enc.encode[val>>18&0x3F]
dst[di+1] = enc.encode[val>>12&0x3F]
dst[di+2] = enc.encode[val>>6&0x3F]
dst[di+3] = enc.encode[val&0x3F]
dst[di+0] = enc.encodeMapFunc(val>>18&0x3F)
dst[di+1] = enc.encodeMapFunc(val>>12&0x3F)
dst[di+2] = enc.encodeMapFunc(val>>6&0x3F)
dst[di+3] = enc.encodeMapFunc(val&0x3F)

si += 3
di += 4
Expand All @@ -176,8 +198,8 @@ func (enc *Encoding) Encode(dst, src []byte) {
val |= uint(src[si+1]) << 8
}

dst[di+0] = enc.encode[val>>18&0x3F]
dst[di+1] = enc.encode[val>>12&0x3F]
dst[di+0] = enc.encodeMapFunc(val>>18&0x3F)
dst[di+1] = enc.encodeMapFunc(val>>12&0x3F)

switch remain {
case 2:
Expand Down Expand Up @@ -330,8 +352,7 @@ func (enc *Encoding) decodeQuantum(dst, src []byte, si int) (nsi, n int, err err
}
in := src[si]
si++

out := enc.decodeMap[in]
out := enc.decodeMapFunc(uint(in))
if out != 0xff {
dbuf[j] = out
continue
Expand Down Expand Up @@ -529,14 +550,14 @@ func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
for strconv.IntSize >= 64 && len(src)-si >= 8 && len(dst)-n >= 8 {
src2 := src[si : si+8]
if dn, ok := assemble64(
enc.decodeMap[src2[0]],
enc.decodeMap[src2[1]],
enc.decodeMap[src2[2]],
enc.decodeMap[src2[3]],
enc.decodeMap[src2[4]],
enc.decodeMap[src2[5]],
enc.decodeMap[src2[6]],
enc.decodeMap[src2[7]],
enc.decodeMapFunc(uint(src2[0])),
enc.decodeMapFunc(uint(src2[1])),
enc.decodeMapFunc(uint(src2[2])),
enc.decodeMapFunc(uint(src2[3])),
enc.decodeMapFunc(uint(src2[4])),
enc.decodeMapFunc(uint(src2[5])),
enc.decodeMapFunc(uint(src2[6])),
enc.decodeMapFunc(uint(src2[7])),
); ok {
byteorder.BEPutUint64(dst[n:], dn)
n += 6
Expand All @@ -554,10 +575,10 @@ func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
for len(src)-si >= 4 && len(dst)-n >= 4 {
src2 := src[si : si+4]
if dn, ok := assemble32(
enc.decodeMap[src2[0]],
enc.decodeMap[src2[1]],
enc.decodeMap[src2[2]],
enc.decodeMap[src2[3]],
enc.decodeMapFunc(uint(src2[0])),
enc.decodeMapFunc(uint(src2[1])),
enc.decodeMapFunc(uint(src2[2])),
enc.decodeMapFunc(uint(src2[3])),
); ok {
byteorder.BEPutUint32(dst[n:], dn)
n += 3
Expand Down
94 changes: 94 additions & 0 deletions src/encoding/base64/mapping.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package base64 implements base64 encoding as specified by RFC 4648.
package base64

func (enc Encoding) encodeMapDefault(in uint) byte {
return enc.encode[in]
}

func (enc Encoding) decodeMapDefault(in uint) byte {
return enc.decodeMap[in]
}

func StandardBase64Decode(in uint) byte {
ch := int(in)
ret := -1

// if (ch > 0x40 && ch < 0x5b) ret += ch - 0x41 + 1; // -64
ret += (((0x40 - ch) & (ch - 0x5b)) >> 8) & (ch - 64)

// if (ch > 0x60 && ch < 0x7b) ret += ch - 0x61 + 26 + 1; // -70
ret += (((0x60 - ch) & (ch - 0x7b)) >> 8) & (ch - 70)

// if (ch > 0x2f && ch < 0x3a) ret += ch - 0x30 + 52 + 1; // 5
ret += (((0x2f - ch) & (ch - 0x3a)) >> 8) & (ch + 5)

// if (ch == 0x2b) ret += 62 + 1
ret += (((0x2a - ch) & (ch - 0x2c)) >> 8) & 63

// if (ch == 0x2f) ret += 63 + 1;
ret += (((0x2e - ch) & (ch - 0x30)) >> 8) & 64

return byte(ret)
}

func StandardBase64Encode(in uint) byte {
src := int(in)
diff := int(0x41)

// if (in > 25) diff += 0x61 - 0x41 - 26; // 6
diff += ((25 - src) >> 8) & 6;

// if (in > 51) diff += 0x30 - 0x61 - 26; // -75
diff -= ((51 - src) >> 8) & 75;

// if (in > 61) diff += 0x2b - 0x30 - 10; // -15
diff -= ((61 - src) >> 8) & 15;

// if (in > 62) diff += 0x2f - 0x2b - 1; // 3
diff += ((62 - src) >> 8) & 3
return byte(src + diff)
}

func UrlSafeBase64Decode(in uint) byte {
ch := int(in)
ret := -1

// if (ch > 0x40 && ch < 0x5b) ret += ch - 0x41 + 1; // -64
ret += (((0x40 - ch) & (ch - 0x5b)) >> 8) & (ch - 64)

// if (ch > 0x60 && ch < 0x7b) ret += ch - 0x61 + 26 + 1; // -70
ret += (((0x60 - ch) & (ch - 0x7b)) >> 8) & (ch - 70)

// if (ch > 0x2f && ch < 0x3a) ret += ch - 0x30 + 52 + 1; // 5
ret += (((0x2f - ch) & (ch - 0x3a)) >> 8) & (ch + 5)

// if (ch == 0x2c) ret += 62 + 1;
ret += (((0x2c - ch) & (ch - 0x2e)) >> 8) & 63

// if (ch == 0x5f) ret += 63 + 1;
ret += (((0x5e - ch) & (ch - 0x60)) >> 8) & 64

return byte(ret)
}


func UrlSafeBase64Encode(in uint) byte {
src := int(in)
diff := int(0x41)
// if (src > 25) diff += 0x61 - 0x41 - 26; // 6
diff += ((25 - src) >> 8) & 6

// if (src > 51) diff += 0x30 - 0x61 - 26; // -75
diff -= ((51 - src) >> 8) & 75

// if (src > 61) diff += 0x2d - 0x30 - 10; // -13
diff -= ((61 - src) >> 8) & 13

// if (src > 62) diff += 0x5f - 0x2b - 1; // 3
diff += ((62 - src) >> 8) & 49
return byte(src + diff)
}