diff --git a/agent/agentrsa/key.go b/agent/agentrsa/key.go new file mode 100644 index 0000000000000..fd70d0b7bfa9e --- /dev/null +++ b/agent/agentrsa/key.go @@ -0,0 +1,87 @@ +package agentrsa + +import ( + "crypto/rsa" + "math/big" + "math/rand" +) + +// GenerateDeterministicKey generates an RSA private key deterministically based on the provided seed. +// This function uses a deterministic random source to generate the primes p and q, ensuring that the +// same seed will always produce the same private key. The generated key is 2048 bits in size. +// +// Reference: https://pkg.go.dev/crypto/rsa#GenerateKey +func GenerateDeterministicKey(seed int64) *rsa.PrivateKey { + // Since the standard lib purposefully does not generate + // deterministic rsa keys, we need to do it ourselves. + + // Create deterministic random source + // nolint: gosec + deterministicRand := rand.New(rand.NewSource(seed)) + + // Use fixed values for p and q based on the seed + p := big.NewInt(0) + q := big.NewInt(0) + e := big.NewInt(65537) // Standard RSA public exponent + + for { + // Generate deterministic primes using the seeded random + // Each prime should be ~1024 bits to get a 2048-bit key + for { + p.SetBit(p, 1024, 1) // Ensure it's large enough + for i := range 1024 { + if deterministicRand.Int63()%2 == 1 { + p.SetBit(p, i, 1) + } else { + p.SetBit(p, i, 0) + } + } + p1 := new(big.Int).Sub(p, big.NewInt(1)) + if p.ProbablyPrime(20) && new(big.Int).GCD(nil, nil, e, p1).Cmp(big.NewInt(1)) == 0 { + break + } + } + + for { + q.SetBit(q, 1024, 1) // Ensure it's large enough + for i := range 1024 { + if deterministicRand.Int63()%2 == 1 { + q.SetBit(q, i, 1) + } else { + q.SetBit(q, i, 0) + } + } + q1 := new(big.Int).Sub(q, big.NewInt(1)) + if q.ProbablyPrime(20) && p.Cmp(q) != 0 && new(big.Int).GCD(nil, nil, e, q1).Cmp(big.NewInt(1)) == 0 { + break + } + } + + // Calculate phi = (p-1) * (q-1) + p1 := new(big.Int).Sub(p, big.NewInt(1)) + q1 := new(big.Int).Sub(q, big.NewInt(1)) + phi := new(big.Int).Mul(p1, q1) + + // Calculate private exponent d + d := new(big.Int).ModInverse(e, phi) + if d != nil { + // Calculate n = p * q + n := new(big.Int).Mul(p, q) + + // Create the private key + privateKey := &rsa.PrivateKey{ + PublicKey: rsa.PublicKey{ + N: n, + E: int(e.Int64()), + }, + D: d, + Primes: []*big.Int{p, q}, + } + + // Compute precomputed values + privateKey.Precompute() + + return privateKey + } + } +} diff --git a/agent/agentrsa/key_test.go b/agent/agentrsa/key_test.go new file mode 100644 index 0000000000000..dc561d09d4e07 --- /dev/null +++ b/agent/agentrsa/key_test.go @@ -0,0 +1,50 @@ +package agentrsa_test + +import ( + "crypto/rsa" + "math/rand/v2" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/coder/coder/v2/agent/agentrsa" +) + +func TestGenerateDeterministicKey(t *testing.T) { + t.Parallel() + + key1 := agentrsa.GenerateDeterministicKey(1234) + key2 := agentrsa.GenerateDeterministicKey(1234) + + assert.Equal(t, key1, key2) + assert.EqualExportedValues(t, key1, key2) +} + +var result *rsa.PrivateKey + +func BenchmarkGenerateDeterministicKey(b *testing.B) { + var r *rsa.PrivateKey + + for range b.N { + // always record the result of DeterministicPrivateKey to prevent + // the compiler eliminating the function call. + r = agentrsa.GenerateDeterministicKey(rand.Int64()) + } + + // always store the result to a package level variable + // so the compiler cannot eliminate the Benchmark itself. + result = r +} + +func FuzzGenerateDeterministicKey(f *testing.F) { + testcases := []int64{0, 1234, 1010101010} + for _, tc := range testcases { + f.Add(tc) // Use f.Add to provide a seed corpus + } + f.Fuzz(func(t *testing.T, seed int64) { + key1 := agentrsa.GenerateDeterministicKey(seed) + key2 := agentrsa.GenerateDeterministicKey(seed) + assert.Equal(t, key1, key2) + assert.EqualExportedValues(t, key1, key2) + }) +} diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index a7e028541aa6e..936ac985d1792 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -3,12 +3,9 @@ package agentssh import ( "bufio" "context" - "crypto/rsa" "errors" "fmt" "io" - "math/big" - "math/rand" "net" "os" "os/exec" @@ -33,6 +30,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/agent/agentexec" + "github.com/coder/coder/v2/agent/agentrsa" "github.com/coder/coder/v2/agent/usershell" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/pty" @@ -1120,75 +1118,7 @@ func CoderSigner(seed int64) (gossh.Signer, error) { // Clients should ignore the host key when connecting. // The agent needs to authenticate with coderd to SSH, // so SSH authentication doesn't improve security. - - // Since the standard lib purposefully does not generate - // deterministic rsa keys, we need to do it ourselves. - coderHostKey := func() *rsa.PrivateKey { - // Create deterministic random source - // nolint: gosec - deterministicRand := rand.New(rand.NewSource(seed)) - - // Use fixed values for p and q based on the seed - p := big.NewInt(0) - q := big.NewInt(0) - e := big.NewInt(65537) // Standard RSA public exponent - - // Generate deterministic primes using the seeded random - // Each prime should be ~1024 bits to get a 2048-bit key - for { - p.SetBit(p, 1024, 1) // Ensure it's large enough - for i := 0; i < 1024; i++ { - if deterministicRand.Int63()%2 == 1 { - p.SetBit(p, i, 1) - } else { - p.SetBit(p, i, 0) - } - } - if p.ProbablyPrime(20) { - break - } - } - - for { - q.SetBit(q, 1024, 1) // Ensure it's large enough - for i := 0; i < 1024; i++ { - if deterministicRand.Int63()%2 == 1 { - q.SetBit(q, i, 1) - } else { - q.SetBit(q, i, 0) - } - } - if q.ProbablyPrime(20) && p.Cmp(q) != 0 { - break - } - } - - // Calculate n = p * q - n := new(big.Int).Mul(p, q) - - // Calculate phi = (p-1) * (q-1) - p1 := new(big.Int).Sub(p, big.NewInt(1)) - q1 := new(big.Int).Sub(q, big.NewInt(1)) - phi := new(big.Int).Mul(p1, q1) - - // Calculate private exponent d - d := new(big.Int).ModInverse(e, phi) - - // Create the private key - privateKey := &rsa.PrivateKey{ - PublicKey: rsa.PublicKey{ - N: n, - E: int(e.Int64()), - }, - D: d, - Primes: []*big.Int{p, q}, - } - - // Compute precomputed values - privateKey.Precompute() - - return privateKey - }() + coderHostKey := agentrsa.GenerateDeterministicKey(seed) coderSigner, err := gossh.NewSignerFromKey(coderHostKey) return coderSigner, err