diff --git a/crypto/etls/conn.go b/crypto/etls/conn.go index 150d85a3e..4b25bbb5d 100644 --- a/crypto/etls/conn.go +++ b/crypto/etls/conn.go @@ -64,11 +64,6 @@ func Dial(network, address string, cipher *Cipher) (c *CryptoConn, err error) { return } -// RawRead is the raw net.Conn.Read -func (c *CryptoConn) RawRead(b []byte) (n int, err error) { - return c.Conn.Read(b) -} - // Read iv and Encrypted data func (c *CryptoConn) Read(b []byte) (n int, err error) { if c.decStream == nil { @@ -104,11 +99,6 @@ func (c *CryptoConn) Read(b []byte) (n int, err error) { return } -// RawWrite is the raw net.Conn.Write -func (c *CryptoConn) RawWrite(b []byte) (n int, err error) { - return c.Conn.Read(b) -} - // Write iv and Encrypted data func (c *CryptoConn) Write(b []byte) (n int, err error) { var iv []byte diff --git a/crypto/etls/conn_test.go b/crypto/etls/conn_test.go index 4778c12b8..ffa9b1010 100644 --- a/crypto/etls/conn_test.go +++ b/crypto/etls/conn_test.go @@ -20,12 +20,25 @@ import ( "net" "net/rpc" "strings" + "sync" "testing" + "github.com/CovenantSQL/CovenantSQL/crypto/hash" "github.com/CovenantSQL/CovenantSQL/utils/log" . "github.com/smartystreets/goconvey/convey" ) +const service = "127.0.0.1:28000" +const serviceComplex = "127.0.0.1:28001" +const contentLength = 9999 +const pass = "123" + +var simpleCipherHandler CipherHandler = func(conn net.Conn) (cryptoConn *CryptoConn, err error) { + cipher := NewCipher([]byte(pass)) + cryptoConn = NewConn(conn, cipher, nil) + return +} + type Foo bool type Result struct { @@ -34,19 +47,34 @@ type Result struct { func (f *Foo) Bar(args *string, res *Result) error { res.Data = len(*args) - log.Printf("Received %q, its length is %d", *args, res.Data) - //return fmt.Error("whoops, error happened") + log.Debugf("Received %q, its length is %d", *args, res.Data) return nil } -const service = "127.0.0.1:28000" -const contentLength = 9999 -const pass = "123" +type FooComplex bool -var simpleCipherHandler CipherHandler = func(conn net.Conn) (cryptoConn *CryptoConn, err error) { - cipher := NewCipher([]byte(pass)) - cryptoConn = NewConn(conn, cipher, nil) - return +type QueryComplex struct { + DataS struct { + Strs []string + } +} + +type ResultComplex struct { + Count int + Hash struct { + Str string + } +} + +func qHash(q *QueryComplex) string { + return string(hash.THashB([]byte(strings.Join(q.DataS.Strs, "")))) +} + +func (f *FooComplex) Bar(args *QueryComplex, res *ResultComplex) error { + res.Count = len(args.DataS.Strs) + res.Hash.Str = qHash(args) + log.Debugf("Received %v", *args) + return nil } func server() *CryptoListener { @@ -59,15 +87,15 @@ func server() *CryptoListener { if err != nil { log.Errorf("server: listen: %s", err) } - log.Print("server: listening") + log.Debug("server: listening") go func() { for { conn, err := listener.Accept() if err != nil { - log.Printf("server: accept: %s", err) + log.Debugf("server: accept: %s", err) break } - log.Printf("server: accepted from %s", conn.RemoteAddr()) + log.Debugf("server: accepted from %s", conn.RemoteAddr()) go handleClient(conn) } }() @@ -89,26 +117,72 @@ func client(pass string) (ret int, err error) { //conn.SetReadDeadline(time.Time{}) //conn.SetWriteDeadline(time.Time{}) - log.Println("client: connected to: ", conn.RemoteAddr()) - log.Println("client: LocalAddr: ", conn.LocalAddr()) + log.Debugln("client: connected to: ", conn.RemoteAddr()) + log.Debugln("client: LocalAddr: ", conn.LocalAddr()) rpcClient := rpc.NewClient(conn) res := new(Result) args := strings.Repeat("a", contentLength) if err := rpcClient.Call("Foo.Bar", args, &res); err != nil { - log.Error("failed to call RPC", err) + log.Errorf("failed to call RPC %v", err) return 0, err } - log.Printf("Returned result is %d", res.Data) + log.Debugf("Returned result is %d", res.Data) return res.Data, err } +func serverComplex() *CryptoListener { + if err := rpc.Register(new(FooComplex)); err != nil { + log.Error("failed to register RPC method") + } + + listener, err := NewCryptoListener("tcp", serviceComplex, simpleCipherHandler) + if err != nil { + log.Errorf("server: listen: %s", err) + } + log.Debug("server: listening") + go func() { + for { + conn, err := listener.Accept() + if err != nil { + log.Debugf("server: accept: %s", err) + break + } + log.Debugf("server: accepted from %s", conn.RemoteAddr()) + go handleClient(conn) + } + }() + return listener +} + +func clientComplex(pass string, args *QueryComplex) (ret *ResultComplex, err error) { + cipher := NewCipher([]byte(pass)) + + conn, err := Dial("tcp", serviceComplex, cipher) + if err != nil { + log.Errorf("client: dial: %s", err) + return nil, err + } + defer conn.Close() + + log.Debugln("client: connected to: ", conn.RemoteAddr()) + log.Debugln("client: LocalAddr: ", conn.LocalAddr()) + rpcClient := rpc.NewClient(conn) + res := new(ResultComplex) + + if err := rpcClient.Call("FooComplex.Bar", args, &res); err != nil { + log.Errorf("failed to call RPC %v", err) + return nil, err + } + return res, err +} + func handleClient(conn net.Conn) { defer conn.Close() rpc.ServeConn(conn) - log.Println("server: conn: closed") + log.Debugln("server: conn: closed") } -func TestConn(t *testing.T) { +func TestSimpleRPC(t *testing.T) { l := server() Convey("get addr", t, func() { addr := l.Addr().String() @@ -119,38 +193,96 @@ func TestConn(t *testing.T) { So(ret, ShouldEqual, contentLength) So(err, ShouldBeNil) }) + Convey("server client OK 100", t, func() { + for i := 0; i < 100; i++ { + ret, err := client(pass) + So(ret, ShouldEqual, contentLength) + So(err, ShouldBeNil) + } + }) Convey("pass not match", t, func() { ret, err := client("1234") So(ret, ShouldEqual, 0) So(err, ShouldNotBeNil) }) + Convey("pass not match 100", t, func() { + for i := 0; i < 100; i++ { + ret, err := client("12345") + So(ret, ShouldEqual, 0) + So(err, ShouldNotBeNil) + } + }) + Convey("server close", t, func() { + err := l.Close() + So(err, ShouldBeNil) + }) +} + +func TestComplexRPC(t *testing.T) { + l := serverComplex() + Convey("get addr", t, func() { + addr := l.Addr().String() + So(addr, ShouldEqual, serviceComplex) + }) + Convey("server client OK", t, func() { + argsComplex := &QueryComplex{ + DataS: struct{ Strs []string }{Strs: []string{"aaa", "bbbbbbb"}}, + } + ret, err := clientComplex(pass, argsComplex) + So(ret.Count, ShouldEqual, len(argsComplex.DataS.Strs)) + So(ret.Hash.Str, ShouldResemble, qHash(argsComplex)) + So(err, ShouldBeNil) + }) + Convey("server client pass error", t, func() { + argsComplex := &QueryComplex{ + DataS: struct{ Strs []string }{Strs: []string{"aaa", "bbbbbbb"}}, + } + ret, err := clientComplex(pass+"1", argsComplex) + So(ret, ShouldBeNil) + So(err, ShouldNotBeNil) + }) + + Convey("server client pass error", t, func() { + argsComplex := &QueryComplex{ + DataS: struct{ Strs []string }{Strs: []string{"aaa", "bbbbbbb"}}, + } + ret, err := clientComplex(strings.Repeat(pass, 100), argsComplex) + So(ret, ShouldBeNil) + So(err, ShouldNotBeNil) + }) + Convey("server close", t, func() { err := l.Close() So(err, ShouldBeNil) }) } -func TestCryptoConn_RawRead(t *testing.T) { +func TestCryptoConn_RW(t *testing.T) { + cipher := NewCipher([]byte(pass)) var nilCipherHandler CipherHandler = func(conn net.Conn) (cryptoConn *CryptoConn, err error) { - cryptoConn = NewConn(conn, nil, nil) + cryptoConn = NewConn(conn, cipher, nil) return } - Convey("server client OK", t, func() { + Convey("server client OK", t, func(c C) { + msg := "xxxxxxxxxxxxxxxx" l, _ := NewCryptoListener("tcp", "127.0.0.1:0", nilCipherHandler) + wg := sync.WaitGroup{} + wg.Add(1) go func() { - rBuf := make([]byte, 1) - c, err := l.Accept() - cc, _ := c.(*CryptoConn) - _, err = cc.RawRead(rBuf) - log.Errorf("RawRead: %s", err) - So(rBuf[0], ShouldEqual, 'x') - So(err, ShouldBeNil) - }() - conn, _ := Dial("tcp", l.Addr().String(), nil) - go func() { - n, err := conn.RawWrite([]byte("xxxxxxxxxxxxxxxx")) - log.Errorf("RawWrite: %d %s", n, err) + rBuf := make([]byte, len(msg)) + conn, err := l.Accept() + cc, _ := conn.(*CryptoConn) + n, err := cc.Read(rBuf) + c.So(n, ShouldEqual, len(msg)) + c.So(string(rBuf), ShouldResemble, msg) + c.So(err, ShouldBeNil) + wg.Done() }() + conn, _ := Dial("tcp", l.Addr().String(), cipher) + n, err := conn.Write([]byte(msg)) + So(n, ShouldEqual, len(msg)) + So(err, ShouldBeNil) + wg.Wait() }) }