Skip to content

fix(agent/agentssh): ensure RSA key generation always produces valid keys #16694

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions agent/agentrsa/key.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package agentrsa

import (
"crypto/rsa"
"math/big"
"math/rand"
)

// GenerateDeterministicKey generates an RSA private key deterministically based on the provided seed.
// This function uses a deterministic random source to generate the primes p and q, ensuring that the
// same seed will always produce the same private key. The generated key is 2048 bits in size.
//
// Reference: https://pkg.go.dev/crypto/rsa#GenerateKey
func GenerateDeterministicKey(seed int64) *rsa.PrivateKey {
// Since the standard lib purposefully does not generate
// deterministic rsa keys, we need to do it ourselves.

// Create deterministic random source
// nolint: gosec
deterministicRand := rand.New(rand.NewSource(seed))

// Use fixed values for p and q based on the seed
p := big.NewInt(0)
q := big.NewInt(0)
e := big.NewInt(65537) // Standard RSA public exponent

for {
// Generate deterministic primes using the seeded random
// Each prime should be ~1024 bits to get a 2048-bit key
for {
p.SetBit(p, 1024, 1) // Ensure it's large enough
for i := range 1024 {
if deterministicRand.Int63()%2 == 1 {
p.SetBit(p, i, 1)
} else {
p.SetBit(p, i, 0)
}
}
p1 := new(big.Int).Sub(p, big.NewInt(1))
if p.ProbablyPrime(20) && new(big.Int).GCD(nil, nil, e, p1).Cmp(big.NewInt(1)) == 0 {
break
}
}

for {
q.SetBit(q, 1024, 1) // Ensure it's large enough
for i := range 1024 {
if deterministicRand.Int63()%2 == 1 {
q.SetBit(q, i, 1)
} else {
q.SetBit(q, i, 0)
}
}
q1 := new(big.Int).Sub(q, big.NewInt(1))
if q.ProbablyPrime(20) && p.Cmp(q) != 0 && new(big.Int).GCD(nil, nil, e, q1).Cmp(big.NewInt(1)) == 0 {
break
}
}

// Calculate phi = (p-1) * (q-1)
p1 := new(big.Int).Sub(p, big.NewInt(1))
q1 := new(big.Int).Sub(q, big.NewInt(1))
phi := new(big.Int).Mul(p1, q1)

// Calculate private exponent d
d := new(big.Int).ModInverse(e, phi)
if d != nil {
// Calculate n = p * q
n := new(big.Int).Mul(p, q)

// Create the private key
privateKey := &rsa.PrivateKey{
PublicKey: rsa.PublicKey{
N: n,
E: int(e.Int64()),
},
D: d,
Primes: []*big.Int{p, q},
}

// Compute precomputed values
privateKey.Precompute()

return privateKey
}
}
}
50 changes: 50 additions & 0 deletions agent/agentrsa/key_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package agentrsa_test

import (
"crypto/rsa"
"math/rand/v2"
"testing"

"github.com/stretchr/testify/assert"

"github.com/coder/coder/v2/agent/agentrsa"
)

func TestGenerateDeterministicKey(t *testing.T) {
t.Parallel()

key1 := agentrsa.GenerateDeterministicKey(1234)
key2 := agentrsa.GenerateDeterministicKey(1234)

assert.Equal(t, key1, key2)
assert.EqualExportedValues(t, key1, key2)
}

var result *rsa.PrivateKey

func BenchmarkGenerateDeterministicKey(b *testing.B) {
var r *rsa.PrivateKey

for range b.N {
// always record the result of DeterministicPrivateKey to prevent
// the compiler eliminating the function call.
r = agentrsa.GenerateDeterministicKey(rand.Int64())
}

// always store the result to a package level variable
// so the compiler cannot eliminate the Benchmark itself.
result = r
}

func FuzzGenerateDeterministicKey(f *testing.F) {
testcases := []int64{0, 1234, 1010101010}
for _, tc := range testcases {
f.Add(tc) // Use f.Add to provide a seed corpus
}
f.Fuzz(func(t *testing.T, seed int64) {
key1 := agentrsa.GenerateDeterministicKey(seed)
key2 := agentrsa.GenerateDeterministicKey(seed)
assert.Equal(t, key1, key2)
assert.EqualExportedValues(t, key1, key2)
})
}
74 changes: 2 additions & 72 deletions agent/agentssh/agentssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@ package agentssh
import (
"bufio"
"context"
"crypto/rsa"
"errors"
"fmt"
"io"
"math/big"
"math/rand"
"net"
"os"
"os/exec"
Expand All @@ -33,6 +30,7 @@ import (
"cdr.dev/slog"

"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentrsa"
"github.com/coder/coder/v2/agent/usershell"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty"
Expand Down Expand Up @@ -1120,75 +1118,7 @@ func CoderSigner(seed int64) (gossh.Signer, error) {
// Clients should ignore the host key when connecting.
// The agent needs to authenticate with coderd to SSH,
// so SSH authentication doesn't improve security.

// Since the standard lib purposefully does not generate
// deterministic rsa keys, we need to do it ourselves.
coderHostKey := func() *rsa.PrivateKey {
// Create deterministic random source
// nolint: gosec
deterministicRand := rand.New(rand.NewSource(seed))

// Use fixed values for p and q based on the seed
p := big.NewInt(0)
q := big.NewInt(0)
e := big.NewInt(65537) // Standard RSA public exponent

// Generate deterministic primes using the seeded random
// Each prime should be ~1024 bits to get a 2048-bit key
for {
p.SetBit(p, 1024, 1) // Ensure it's large enough
for i := 0; i < 1024; i++ {
if deterministicRand.Int63()%2 == 1 {
p.SetBit(p, i, 1)
} else {
p.SetBit(p, i, 0)
}
}
if p.ProbablyPrime(20) {
break
}
}

for {
q.SetBit(q, 1024, 1) // Ensure it's large enough
for i := 0; i < 1024; i++ {
if deterministicRand.Int63()%2 == 1 {
q.SetBit(q, i, 1)
} else {
q.SetBit(q, i, 0)
}
}
if q.ProbablyPrime(20) && p.Cmp(q) != 0 {
break
}
}

// Calculate n = p * q
n := new(big.Int).Mul(p, q)

// Calculate phi = (p-1) * (q-1)
p1 := new(big.Int).Sub(p, big.NewInt(1))
q1 := new(big.Int).Sub(q, big.NewInt(1))
phi := new(big.Int).Mul(p1, q1)

// Calculate private exponent d
d := new(big.Int).ModInverse(e, phi)

// Create the private key
privateKey := &rsa.PrivateKey{
PublicKey: rsa.PublicKey{
N: n,
E: int(e.Int64()),
},
D: d,
Primes: []*big.Int{p, q},
}

// Compute precomputed values
privateKey.Precompute()

return privateKey
}()
coderHostKey := agentrsa.GenerateDeterministicKey(seed)

coderSigner, err := gossh.NewSignerFromKey(coderHostKey)
return coderSigner, err
Expand Down
Loading