diff --git a/.gitignore b/.gitignore index b315aaa..a197244 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,7 @@ bin/ .idea release_bin vendor +go.work #vscode .vscode diff --git a/example/memory_kv/server.go b/example/memory_kv/server.go index de7dc9d..0887257 100644 --- a/example/memory_kv/server.go +++ b/example/memory_kv/server.go @@ -26,6 +26,10 @@ func main() { var reusePort bool var pprofDebug bool var pprofAddr string + var tlsEnable bool + var tlsCertFile string + var tlsKeyFile string + var tlsAddr string // Parse command-line arguments flag.StringVar(&network, "network", "tcp", "server network (default \"tcp\")") @@ -34,6 +38,10 @@ func main() { flag.BoolVar(&reusePort, "reusePort", false, "enable port reuse") flag.BoolVar(&pprofDebug, "pprofDebug", false, "enable pprof debugging") flag.StringVar(&pprofAddr, "pprofAddr", ":8888", "pprof address") + flag.BoolVar(&tlsEnable, "tls", false, "enable TLS support") + flag.StringVar(&tlsCertFile, "tlsCert", "", "TLS certificate file path") + flag.StringVar(&tlsKeyFile, "tlsKey", "", "TLS key file path") + flag.StringVar(&tlsAddr, "tlsAddr", "", "TLS listener address (default: derived from addr)") flag.Parse() // Start pprof server if debugging is enabled @@ -48,8 +56,12 @@ func main() { // Define RedHub options option := redhub.Options{ - Multicore: multicore, - ReusePort: reusePort, + Multicore: multicore, + ReusePort: reusePort, + TLSListenEnable: tlsEnable, + TLSCertFile: tlsCertFile, + TLSKeyFile: tlsKeyFile, + TLSAddr: tlsAddr, } // Create a new RedHub instance with custom handlers @@ -127,6 +139,13 @@ func main() { // Log the server start log.Printf("started redhub server at %s", addr) + if tlsEnable { + listenAddr := tlsAddr + if listenAddr == "" { + listenAddr = addr + } + log.Printf("TLS listener enabled at %s", listenAddr) + } // Start the RedHub server err := redhub.ListenAndServe(protoAddr, option, rh) diff --git a/redhub.go b/redhub.go index deb6515..ad9ae10 100644 --- a/redhub.go +++ b/redhub.go @@ -57,7 +57,11 @@ package redhub import ( "bytes" "context" + "crypto/tls" "errors" + "net" + "strconv" + "strings" "sync" "time" @@ -192,6 +196,27 @@ type Options struct { // This can reduce the number of system calls but requires careful handling. // Default: false EdgeTriggeredIO bool + + // TLSListenEnable enables TLS support. When true, a TLS listener is started + // alongside the TCP listener. TLS connections are proxied to the TCP server. + // Default: false + TLSListenEnable bool + + // TLSCertFile specifies the path to the TLS certificate file. + // Required when TLSListenEnable is true. + // Default: "" + TLSCertFile string + + // TLSKeyFile specifies the path to the TLS private key file. + // Required when TLSListenEnable is true. + // Default: "" + TLSKeyFile string + + // TLSAddr specifies the address for the TLS listener. + // If empty, it's derived from the main TCP address by changing the port + // (e.g., tcp://127.0.0.1:6379 -> tcp://127.0.0.1:6380). + // Default: "" + TLSAddr string } // RedHub represents the main server structure that manages connections and command processing. @@ -207,8 +232,10 @@ type RedHub struct { connSync *sync.RWMutex mu sync.Mutex addr string + tcpAddr string running bool engine gnet.Engine + tlsListener net.Listener } // connBuffer holds the buffer and commands for each connection. @@ -369,6 +396,126 @@ func (rs *RedHub) OnTick() (delay time.Duration, action gnet.Action) { return 0, gnet.None } +// deriveTLSAddr derives a TLS address from the TCP address by incrementing the port. +func deriveTLSAddr(tcpAddr string) string { + if !strings.HasPrefix(tcpAddr, "tcp://") { + return "" + } + + hostPort := strings.TrimPrefix(tcpAddr, "tcp://") + host, portStr, err := net.SplitHostPort(hostPort) + if err != nil { + return "" + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return "" + } + + return "tcp://" + net.JoinHostPort(host, strconv.Itoa(port+1)) +} + +// startTLSListener starts the TLS listener that proxies connections to the TCP server. +func (rs *RedHub) startTLSListener(options Options) error { + cert, err := tls.LoadX509KeyPair(options.TLSCertFile, options.TLSKeyFile) + if err != nil { + return err + } + + tlsAddr := options.TLSAddr + if tlsAddr == "" { + tlsAddr = deriveTLSAddr(rs.tcpAddr) + if tlsAddr == "" { + return errors.New("failed to derive TLS address from TCP address") + } + } + + listenAddr := tlsAddr + if strings.HasPrefix(tlsAddr, "tcp://") { + listenAddr = strings.TrimPrefix(tlsAddr, "tcp://") + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + rs.tlsListener, err = tls.Listen("tcp", listenAddr, tlsConfig) + if err != nil { + return err + } + + tcpForwardAddr := rs.tcpAddr + if strings.HasPrefix(tcpForwardAddr, "tcp://") { + tcpForwardAddr = strings.TrimPrefix(tcpForwardAddr, "tcp://") + } + + go rs.acceptTLSConnections(tcpForwardAddr) + + return nil +} + +// acceptTLSConnections accepts TLS connections and forwards them to the TCP server. +func (rs *RedHub) acceptTLSConnections(tcpAddr string) { + for { + tlsConn, err := rs.tlsListener.Accept() + if err != nil { + if !rs.running { + return + } + continue + } + + go rs.handleTLSConn(tlsConn, tcpAddr) + } +} + +// handleTLSConn handles a single TLS connection by forwarding data to the TCP server. +func (rs *RedHub) handleTLSConn(tlsConn net.Conn, tcpAddr string) { + defer tlsConn.Close() + + tcpConn, err := net.Dial("tcp", tcpAddr) + if err != nil { + return + } + defer tcpConn.Close() + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + buf := make([]byte, 4096) + for { + n, err := tlsConn.Read(buf) + if err != nil { + return + } + _, err = tcpConn.Write(buf[:n]) + if err != nil { + return + } + } + }() + + go func() { + defer wg.Done() + buf := make([]byte, 4096) + for { + n, err := tcpConn.Read(buf) + if err != nil { + return + } + _, err = tlsConn.Write(buf[:n]) + if err != nil { + return + } + } + }() + + wg.Wait() +} + // ListenAndServe starts the RedHub server on the specified address with the given options. // // This is the main entry point for starting a RedHub server. The address should be @@ -394,6 +541,12 @@ func (rs *RedHub) OnTick() (delay time.Duration, action gnet.Action) { // log.Fatal(err) // } func ListenAndServe(addr string, options Options, rh *RedHub) error { + if options.TLSListenEnable { + if options.TLSCertFile == "" || options.TLSKeyFile == "" { + return errors.New("TLSListenEnable requires TLSCertFile and TLSKeyFile") + } + } + var opts []gnet.Option if options.Multicore { @@ -438,15 +591,29 @@ func ListenAndServe(addr string, options Options, rh *RedHub) error { rh.mu.Lock() rh.addr = addr + rh.tcpAddr = addr rh.running = true rh.mu.Unlock() + if options.TLSListenEnable { + if err := rh.startTLSListener(options); err != nil { + rh.mu.Lock() + rh.running = false + rh.mu.Unlock() + return err + } + } + err := gnet.Run(rh, addr, opts...) rh.mu.Lock() rh.running = false rh.mu.Unlock() + if rh.tlsListener != nil { + rh.tlsListener.Close() + } + return err } @@ -465,5 +632,10 @@ func (rs *RedHub) Close() error { } rs.running = false + + if rs.tlsListener != nil { + _ = rs.tlsListener.Close() + } + return rs.engine.Stop(context.Background()) } diff --git a/redhub_test.go b/redhub_test.go index 4b22636..3457b8b 100644 --- a/redhub_test.go +++ b/redhub_test.go @@ -1,6 +1,7 @@ package redhub import ( + "crypto/tls" "net" "testing" "time" @@ -8,6 +9,7 @@ import ( "github.com/IceFireDB/redhub/pkg/resp" "github.com/panjf2000/gnet/v2" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type mockConn struct { @@ -358,7 +360,7 @@ func TestClose_Integration(t *testing.T) { assert.NoError(t, err) // Wait a moment for server to stop - time.Sleep(200 * time.Millisecond) + time.Sleep(time.Second) // Verify connection fails after close conn, err = net.DialTimeout("tcp", "127.0.0.1:16379", 200*time.Millisecond) @@ -375,3 +377,242 @@ func TestClose_Integration(t *testing.T) { t.Error("Server did not stop within timeout") } } + +func TestTLSListenEnable_NoCertFile(t *testing.T) { + rh := NewRedHub( + func(c *Conn) (out []byte, action Action) { + return nil, None + }, + func(c *Conn, err error) (action Action) { + return None + }, + func(cmd resp.Command, out []byte) ([]byte, Action) { + return out, None + }, + ) + + err := ListenAndServe("tcp://127.0.0.1:16380", Options{ + TLSListenEnable: true, + TLSCertFile: "", + TLSKeyFile: "testdata/key.pem", + }, rh) + assert.Error(t, err) + assert.Contains(t, err.Error(), "TLSCertFile and TLSKeyFile") +} + +func TestTLSListenEnable_NoKeyFile(t *testing.T) { + rh := NewRedHub( + func(c *Conn) (out []byte, action Action) { + return nil, None + }, + func(c *Conn, err error) (action Action) { + return None + }, + func(cmd resp.Command, out []byte) ([]byte, Action) { + return out, None + }, + ) + + err := ListenAndServe("tcp://127.0.0.1:16381", Options{ + TLSListenEnable: true, + TLSCertFile: "testdata/cert.pem", + TLSKeyFile: "", + }, rh) + assert.Error(t, err) + assert.Contains(t, err.Error(), "TLSCertFile and TLSKeyFile") +} + +func TestTLSListenEnable_InvalidCertPath(t *testing.T) { + rh := NewRedHub( + func(c *Conn) (out []byte, action Action) { + return nil, None + }, + func(c *Conn, err error) (action Action) { + return None + }, + func(cmd resp.Command, out []byte) ([]byte, Action) { + return out, None + }, + ) + + err := ListenAndServe("tcp://127.0.0.1:16382", Options{ + TLSListenEnable: true, + TLSCertFile: "nonexistent.pem", + TLSKeyFile: "nonexistent.pem", + }, rh) + assert.Error(t, err) +} + +func TestTLSListenEnable_Integration(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + rh := NewRedHub( + func(c *Conn) (out []byte, action Action) { + return nil, None + }, + func(c *Conn, err error) (action Action) { + return None + }, + func(cmd resp.Command, out []byte) ([]byte, Action) { + cmdName := string(cmd.Args[0]) + if cmdName == "PING" { + return resp.AppendString(out, "PONG"), None + } + return resp.AppendString(out, "OK"), None + }, + ) + + serverErr := make(chan error, 1) + go func() { + serverErr <- ListenAndServe("tcp://127.0.0.1:16383", Options{ + Multicore: false, + TLSListenEnable: true, + TLSCertFile: "testdata/cert.pem", + TLSKeyFile: "testdata/key.pem", + }, rh) + }() + + time.Sleep(time.Second) + + conn, err := tls.Dial("tcp", "127.0.0.1:16384", &tls.Config{ + InsecureSkipVerify: true, + }) + if err != nil { + t.Skipf("TLS connection failed (no certs?): %v", err) + } + require.NoError(t, err) + defer conn.Close() + + _, err = conn.Write([]byte("*1\r\n$4\r\nPING\r\n")) + require.NoError(t, err) + + buf := make([]byte, 1024) + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Contains(t, string(buf[:n]), "PONG") + + err = rh.Close() + require.NoError(t, err) + + select { + case err := <-serverErr: + assert.NoError(t, err) + case <-time.After(3 * time.Second): + t.Error("Server did not stop within timeout") + } +} + +func TestTLSListenEnable_CloseClosesTLSListener(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + rh := NewRedHub( + func(c *Conn) (out []byte, action Action) { + return nil, None + }, + func(c *Conn, err error) (action Action) { + return None + }, + func(cmd resp.Command, out []byte) ([]byte, Action) { + return resp.AppendString(out, "OK"), None + }, + ) + + serverErr := make(chan error, 1) + go func() { + serverErr <- ListenAndServe("tcp://127.0.0.1:16385", Options{ + Multicore: false, + TLSListenEnable: true, + TLSCertFile: "testdata/cert.pem", + TLSKeyFile: "testdata/key.pem", + }, rh) + }() + + time.Sleep(time.Second) + + conn, err := tls.Dial("tcp", "127.0.0.1:16386", &tls.Config{ + InsecureSkipVerify: true, + }) + if err != nil { + t.Skipf("TLS connection failed (no certs?): %v", err) + } + require.NoError(t, err) + conn.Close() + + err = rh.Close() + require.NoError(t, err) + + select { + case err := <-serverErr: + assert.NoError(t, err) + case <-time.After(2 * time.Second): + t.Error("Server did not stop within timeout") + } +} + +func TestTLSListenEnable_WithCustomTLSAddr(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + rh := NewRedHub( + func(c *Conn) (out []byte, action Action) { + return nil, None + }, + func(c *Conn, err error) (action Action) { + return None + }, + func(cmd resp.Command, out []byte) ([]byte, Action) { + cmdName := string(cmd.Args[0]) + if cmdName == "PING" { + return resp.AppendString(out, "PONG"), None + } + return resp.AppendString(out, "OK"), None + }, + ) + + serverErr := make(chan error, 1) + go func() { + serverErr <- ListenAndServe("tcp://127.0.0.1:16387", Options{ + Multicore: false, + TLSListenEnable: true, + TLSCertFile: "testdata/cert.pem", + TLSKeyFile: "testdata/key.pem", + TLSAddr: "127.0.0.1:16388", + }, rh) + }() + + time.Sleep(time.Second) + + conn, err := tls.Dial("tcp", "127.0.0.1:16388", &tls.Config{ + InsecureSkipVerify: true, + }) + if err != nil { + t.Skipf("TLS connection failed (no certs?): %v", err) + } + require.NoError(t, err) + defer conn.Close() + + _, err = conn.Write([]byte("*1\r\n$4\r\nPING\r\n")) + require.NoError(t, err) + + buf := make([]byte, 1024) + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := conn.Read(buf) + require.NoError(t, err) + assert.Contains(t, string(buf[:n]), "PONG") + + err = rh.Close() + require.NoError(t, err) + + select { + case err := <-serverErr: + assert.NoError(t, err) + case <-time.After(3 * time.Second): + t.Error("Server did not stop within timeout") + } +} diff --git a/testdata/cert.pem b/testdata/cert.pem new file mode 100644 index 0000000..085b5e9 --- /dev/null +++ b/testdata/cert.pem @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDTzCCAjegAwIBAgIUL6Jy5K+tmKgAeXnpoNqEJr8GilwwDQYJKoZIhvcNAQEL +BQAwNzESMBAGA1UEAwwJbG9jYWxob3N0MRQwEgYDVQQKDAtSZWRIdWIgVGVzdDEL +MAkGA1UEBhMCVVMwHhcNMjYwMTExMTMwOTU1WhcNMjcwMTExMTMwOTU1WjA3MRIw +EAYDVQQDDAlsb2NhbGhvc3QxFDASBgNVBAoMC1JlZEh1YiBUZXN0MQswCQYDVQQG +EwJVUzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOVaQzg4w/nNp8WM +m3G2b1VWo2ly2N0mZI+FbKnwhn6k0XRgmWdr/Y+c5TSVIILSJwlGBTB6Xa4a/at1 +u55gZVTSYtGX0RYE4ka2C0blNGCh0uM0klvJqzSneYbtWjIk1nEan9a4yzzTeYq4 +s+TSABj5o0CiOXGOEXLdd+qEyCjRKzFGmGVIMjLOS212/VSpUv1RaMmFMe70Bm35 +1mPa2vcm4yh17sXrSEJajv4NEhi6kBeIjtsvbVWn6+PhSZyz4G6tEm5day80vB+J +GceYLpMJaaqSj8ddfFj4qlR633m3E+fG5XxjLXWmkASYfo5RfVWLn+RpSS9DtcZ+ +FvhIRfMCAwEAAaNTMFEwHQYDVR0OBBYEFBXpwXFhMv0bbv2MsRmU4YZSdz7uMB8G +A1UdIwQYMBaAFBXpwXFhMv0bbv2MsRmU4YZSdz7uMA8GA1UdEwEB/wQFMAMBAf8w +DQYJKoZIhvcNAQELBQADggEBAG/GvVwRQjWcGZSwf/i5C3z5Tn7YbPJAfJBUtYrB +QYO5Na3rBt1vAY7RQIbOxRMCcUldeGoBpvTiQXlqr2YixFwTWi88kc721deRQG10 +E7ChZ6UPrCsfKCXSNFsdysZUOyjxItNgrwfsb9qP43SWpvh6c5UKlQoOwmvOHTQt +DfegQVQ9Bqvq2OX5hRXknTfjJmopCr3pY2PbdcUXEWwyoR4rDKkG3r0l1eOpMsVL +7KmgVpR8sy4Q58YJmdp/Lg9mQWf80H9vg6+qNcIZAG/Wz3PJ+S28OZvCa8Tp+rWc +oTScFGY9TwJ/K1ljWWmT6meR1qa99+MvTFQSfc6HIFlShLg= +-----END CERTIFICATE----- diff --git a/testdata/generate_certs.sh b/testdata/generate_certs.sh new file mode 100755 index 0000000..c93c3fc --- /dev/null +++ b/testdata/generate_certs.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +# Script to generate test TLS certificates for redhub tests + +set -e + +cd "$(dirname "$0")" + +# Generate private key +openssl genrsa -out key.pem 2048 + +# Generate certificate +openssl req -new -x509 -key key.pem -out cert.pem -days 365 \ + -subj "/C=US/ST=California/L=San Francisco/O=RedHub/CN=localhost" + +echo "Test certificates generated successfully!" diff --git a/testdata/key.pem b/testdata/key.pem new file mode 100644 index 0000000..9948312 --- /dev/null +++ b/testdata/key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDlWkM4OMP5zafF +jJtxtm9VVqNpctjdJmSPhWyp8IZ+pNF0YJlna/2PnOU0lSCC0icJRgUwel2uGv2r +dbueYGVU0mLRl9EWBOJGtgtG5TRgodLjNJJbyas0p3mG7VoyJNZxGp/WuMs803mK +uLPk0gAY+aNAojlxjhFy3XfqhMgo0SsxRphlSDIyzkttdv1UqVL9UWjJhTHu9AZt ++dZj2tr3JuMode7F60hCWo7+DRIYupAXiI7bL21Vp+vj4Umcs+BurRJuXWsvNLwf +iRnHmC6TCWmqko/HXXxY+KpUet95txPnxuV8Yy11ppAEmH6OUX1Vi5/kaUkvQ7XG +fhb4SEXzAgMBAAECggEAY+VRu+41pdtmhMv+dKPykCgBWw+T15c+W6jQsKA75HNj +a54bkwldUq0SxDlkBLcGG6rs3bWekhPdg03vX0c7O5u7QPEwN7f+2q+177Yrfx9c +3GtsiCApuvBrJVLCY27aHD9teTfaBe9SVBKpADRbqIUrDx7ZiFVJ0k8WSQZ2rBfD +8c1WHDmglu5SFoZSfPFNILI4Eyx5g8vGAqOzKIYm5NVE1nx0WVi/yqsUL1lCpOP1 +o+DON4xXKCK3wop5ajB1ShuLCmnznIl7cj1B273RIrqqIwqn1+IjyyZ9madTQgBV +z7exy0+85F7TQ1PPJxtwv7EsKIn5JdnLMK9m0QlLAQKBgQD0pM73W2zIZbG+QQVW +T34gCkgtvRYnezs4a/kcl8w8c+5lREN50Q67n/MlI6zcxn3GOrMS2beWqc/PS+YB +jnBqvUMhTKgcweYoPpL9L40CF6rOfBotLNLZoO7ZlCJFLqa0A0r6G48f1hVuVagA +dZ+WtHCSwpyzwCxSsiN7TG9RgQKBgQDv/75F6mPCi+UXYqmfmd1vLjJhFqutowAK +qB2/vbbJ9s/DUTH0w/TewNo8Y49RnWd/8DzA41avd3GNG4tbh2ZAn18CG+ou5idl +M4aym0sHdLU5Ij+Scdf0RiNwqerEzxlPBc5DAeTnHxbzZ9RD2R9QvbKCW/UWXR6R +/QWn6LkpcwKBgDmK9l96MqkkOl2Mv6uggQMaSAXyHt7kfnZz9yFBlzl071MEbnad +tMBvC+rlbEh1q6nPrsU1Tphykr2olY4yKcEBiWOwuy4gcXlv1nUVFS6z0GpHCIUt +sN6dmvC0hicNpQpcZ+tSRiTv3xSXsy+AeywgfwYWHnOtNP+yhOQAg4KBAoGAGtmd +y+yhJI7KHoenOnfYUiv07u++XTqzMn4EdgMfhBDcxZk74YpaxuEEiWUKD7NwdNvH +sDy+4fqW9ZZzTNYlFm2+D1pYJM8S8TuGgkzlY/wmmjG+sv+RjX6bUGtyHHqe9jxM +CysXFNRhmPGwybZszneqlPL8xHe+h86q51IeBQkCgYEA0FH+AHHZZEEIPKAz0GBF +HOyWdOWZes2FqKllqAHioemy85v2opxJHz6aS1l2/KxHb4RBH91ZIdhc99+mzUR2 +KSMGkUOzAlDFQvYrEvtC1rNzwXwnepNGUJ3VhhaJkJ0KwSinhEKP6rkZ4a3oBNb6 ++sS2gjA8dyWuWUy7teTCQdg= +-----END PRIVATE KEY-----