Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions crypto/etls/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ func Dial(network, address string, cipher *Cipher) (c *CryptoConn, err error) {
return
}

// RawRead is the raw net.Conn.Read
func (c *CryptoConn) RawRead(b []byte) (n int, err error) {
return c.Conn.Read(b)
}

// Read iv and Encrypted data
func (c *CryptoConn) Read(b []byte) (n int, err error) {
if c.decStream == nil {
Expand Down Expand Up @@ -104,11 +99,6 @@ func (c *CryptoConn) Read(b []byte) (n int, err error) {
return
}

// RawWrite is the raw net.Conn.Write
func (c *CryptoConn) RawWrite(b []byte) (n int, err error) {
return c.Conn.Read(b)
}

// Write iv and Encrypted data
func (c *CryptoConn) Write(b []byte) (n int, err error) {
var iv []byte
Expand Down
198 changes: 165 additions & 33 deletions crypto/etls/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,25 @@ import (
"net"
"net/rpc"
"strings"
"sync"
"testing"

"github.com/CovenantSQL/CovenantSQL/crypto/hash"
"github.com/CovenantSQL/CovenantSQL/utils/log"
. "github.com/smartystreets/goconvey/convey"
)

const service = "127.0.0.1:28000"
const serviceComplex = "127.0.0.1:28001"
const contentLength = 9999
const pass = "123"

var simpleCipherHandler CipherHandler = func(conn net.Conn) (cryptoConn *CryptoConn, err error) {
cipher := NewCipher([]byte(pass))
cryptoConn = NewConn(conn, cipher, nil)
return
}

type Foo bool

type Result struct {
Expand All @@ -34,19 +47,34 @@ type Result struct {

func (f *Foo) Bar(args *string, res *Result) error {
res.Data = len(*args)
log.Printf("Received %q, its length is %d", *args, res.Data)
//return fmt.Error("whoops, error happened")
log.Debugf("Received %q, its length is %d", *args, res.Data)
return nil
}

const service = "127.0.0.1:28000"
const contentLength = 9999
const pass = "123"
type FooComplex bool

var simpleCipherHandler CipherHandler = func(conn net.Conn) (cryptoConn *CryptoConn, err error) {
cipher := NewCipher([]byte(pass))
cryptoConn = NewConn(conn, cipher, nil)
return
type QueryComplex struct {
DataS struct {
Strs []string
}
}

type ResultComplex struct {
Count int
Hash struct {
Str string
}
}

func qHash(q *QueryComplex) string {
return string(hash.THashB([]byte(strings.Join(q.DataS.Strs, ""))))
}

func (f *FooComplex) Bar(args *QueryComplex, res *ResultComplex) error {
res.Count = len(args.DataS.Strs)
res.Hash.Str = qHash(args)
log.Debugf("Received %v", *args)
return nil
}

func server() *CryptoListener {
Expand All @@ -59,15 +87,15 @@ func server() *CryptoListener {
if err != nil {
log.Errorf("server: listen: %s", err)
}
log.Print("server: listening")
log.Debug("server: listening")
go func() {
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("server: accept: %s", err)
log.Debugf("server: accept: %s", err)
break
}
log.Printf("server: accepted from %s", conn.RemoteAddr())
log.Debugf("server: accepted from %s", conn.RemoteAddr())
go handleClient(conn)
}
}()
Expand All @@ -89,26 +117,72 @@ func client(pass string) (ret int, err error) {
//conn.SetReadDeadline(time.Time{})
//conn.SetWriteDeadline(time.Time{})

log.Println("client: connected to: ", conn.RemoteAddr())
log.Println("client: LocalAddr: ", conn.LocalAddr())
log.Debugln("client: connected to: ", conn.RemoteAddr())
log.Debugln("client: LocalAddr: ", conn.LocalAddr())
rpcClient := rpc.NewClient(conn)
res := new(Result)
args := strings.Repeat("a", contentLength)
if err := rpcClient.Call("Foo.Bar", args, &res); err != nil {
log.Error("failed to call RPC", err)
log.Errorf("failed to call RPC %v", err)
return 0, err
}
log.Printf("Returned result is %d", res.Data)
log.Debugf("Returned result is %d", res.Data)
return res.Data, err
}

func serverComplex() *CryptoListener {
if err := rpc.Register(new(FooComplex)); err != nil {
log.Error("failed to register RPC method")
}

listener, err := NewCryptoListener("tcp", serviceComplex, simpleCipherHandler)
if err != nil {
log.Errorf("server: listen: %s", err)
}
log.Debug("server: listening")
go func() {
for {
conn, err := listener.Accept()
if err != nil {
log.Debugf("server: accept: %s", err)
break
}
log.Debugf("server: accepted from %s", conn.RemoteAddr())
go handleClient(conn)
}
}()
return listener
}

func clientComplex(pass string, args *QueryComplex) (ret *ResultComplex, err error) {
cipher := NewCipher([]byte(pass))

conn, err := Dial("tcp", serviceComplex, cipher)
if err != nil {
log.Errorf("client: dial: %s", err)
return nil, err
}
defer conn.Close()

log.Debugln("client: connected to: ", conn.RemoteAddr())
log.Debugln("client: LocalAddr: ", conn.LocalAddr())
rpcClient := rpc.NewClient(conn)
res := new(ResultComplex)

if err := rpcClient.Call("FooComplex.Bar", args, &res); err != nil {
log.Errorf("failed to call RPC %v", err)
return nil, err
}
return res, err
}

func handleClient(conn net.Conn) {
defer conn.Close()
rpc.ServeConn(conn)
log.Println("server: conn: closed")
log.Debugln("server: conn: closed")
}

func TestConn(t *testing.T) {
func TestSimpleRPC(t *testing.T) {
l := server()
Convey("get addr", t, func() {
addr := l.Addr().String()
Expand All @@ -119,38 +193,96 @@ func TestConn(t *testing.T) {
So(ret, ShouldEqual, contentLength)
So(err, ShouldBeNil)
})
Convey("server client OK 100", t, func() {
for i := 0; i < 100; i++ {
ret, err := client(pass)
So(ret, ShouldEqual, contentLength)
So(err, ShouldBeNil)
}
})
Convey("pass not match", t, func() {
ret, err := client("1234")
So(ret, ShouldEqual, 0)
So(err, ShouldNotBeNil)
})
Convey("pass not match 100", t, func() {
for i := 0; i < 100; i++ {
ret, err := client("12345")
So(ret, ShouldEqual, 0)
So(err, ShouldNotBeNil)
}
})
Convey("server close", t, func() {
err := l.Close()
So(err, ShouldBeNil)
})
}

func TestComplexRPC(t *testing.T) {
l := serverComplex()
Convey("get addr", t, func() {
addr := l.Addr().String()
So(addr, ShouldEqual, serviceComplex)
})
Convey("server client OK", t, func() {
argsComplex := &QueryComplex{
DataS: struct{ Strs []string }{Strs: []string{"aaa", "bbbbbbb"}},
}
ret, err := clientComplex(pass, argsComplex)
So(ret.Count, ShouldEqual, len(argsComplex.DataS.Strs))
So(ret.Hash.Str, ShouldResemble, qHash(argsComplex))
So(err, ShouldBeNil)
})
Convey("server client pass error", t, func() {
argsComplex := &QueryComplex{
DataS: struct{ Strs []string }{Strs: []string{"aaa", "bbbbbbb"}},
}
ret, err := clientComplex(pass+"1", argsComplex)
So(ret, ShouldBeNil)
So(err, ShouldNotBeNil)
})

Convey("server client pass error", t, func() {
argsComplex := &QueryComplex{
DataS: struct{ Strs []string }{Strs: []string{"aaa", "bbbbbbb"}},
}
ret, err := clientComplex(strings.Repeat(pass, 100), argsComplex)
So(ret, ShouldBeNil)
So(err, ShouldNotBeNil)
})

Convey("server close", t, func() {
err := l.Close()
So(err, ShouldBeNil)
})
}

func TestCryptoConn_RawRead(t *testing.T) {
func TestCryptoConn_RW(t *testing.T) {
cipher := NewCipher([]byte(pass))
var nilCipherHandler CipherHandler = func(conn net.Conn) (cryptoConn *CryptoConn, err error) {
cryptoConn = NewConn(conn, nil, nil)
cryptoConn = NewConn(conn, cipher, nil)
return
}

Convey("server client OK", t, func() {
Convey("server client OK", t, func(c C) {
msg := "xxxxxxxxxxxxxxxx"
l, _ := NewCryptoListener("tcp", "127.0.0.1:0", nilCipherHandler)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
rBuf := make([]byte, 1)
c, err := l.Accept()
cc, _ := c.(*CryptoConn)
_, err = cc.RawRead(rBuf)
log.Errorf("RawRead: %s", err)
So(rBuf[0], ShouldEqual, 'x')
So(err, ShouldBeNil)
}()
conn, _ := Dial("tcp", l.Addr().String(), nil)
go func() {
n, err := conn.RawWrite([]byte("xxxxxxxxxxxxxxxx"))
log.Errorf("RawWrite: %d %s", n, err)
rBuf := make([]byte, len(msg))
conn, err := l.Accept()
cc, _ := conn.(*CryptoConn)
n, err := cc.Read(rBuf)
c.So(n, ShouldEqual, len(msg))
c.So(string(rBuf), ShouldResemble, msg)
c.So(err, ShouldBeNil)
wg.Done()
}()
conn, _ := Dial("tcp", l.Addr().String(), cipher)
n, err := conn.Write([]byte(msg))
So(n, ShouldEqual, len(msg))
So(err, ShouldBeNil)
wg.Wait()
})
}