Skip to content

Commit dbcf49f

Browse files
authored
Merge pull request gopherdata#31 from gopherds/refactor
Refactor
2 parents ef837ae + 3975935 commit dbcf49f

File tree

2 files changed

+87
-25
lines changed

2 files changed

+87
-25
lines changed

gophernotes.go

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"os"
1010

1111
zmq "github.com/alecthomas/gozmq"
12+
"github.com/pkg/errors"
1213
)
1314

1415
var logger *log.Logger
@@ -37,17 +38,16 @@ type SocketGroup struct {
3738
}
3839

3940
// PrepareSockets sets up the ZMQ sockets through which the kernel will communicate.
40-
func PrepareSockets(connInfo ConnectionInfo) (sg SocketGroup) {
41+
func PrepareSockets(connInfo ConnectionInfo) (SocketGroup, error) {
4142

42-
// TODO handle errors.
43-
context, _ := zmq.NewContext()
44-
sg.ShellSocket, _ = context.NewSocket(zmq.ROUTER)
45-
sg.ControlSocket, _ = context.NewSocket(zmq.ROUTER)
46-
sg.StdinSocket, _ = context.NewSocket(zmq.ROUTER)
47-
sg.IOPubSocket, _ = context.NewSocket(zmq.PUB)
43+
// Initialize the Socket Group.
44+
context, sg, err := createSockets()
45+
if err != nil {
46+
return sg, errors.Wrap(err, "Could not initialize context and Socket Group")
47+
}
4848

49+
// Bind the sockets.
4950
address := fmt.Sprintf("%v://%v:%%v", connInfo.Transport, connInfo.IP)
50-
5151
sg.ShellSocket.Bind(fmt.Sprintf(address, connInfo.ShellPort))
5252
sg.ControlSocket.Bind(fmt.Sprintf(address, connInfo.ControlPort))
5353
sg.StdinSocket.Bind(fmt.Sprintf(address, connInfo.StdinPort))
@@ -57,11 +57,46 @@ func PrepareSockets(connInfo ConnectionInfo) (sg SocketGroup) {
5757
sg.Key = []byte(connInfo.Key)
5858

5959
// Start the heartbeat device
60-
HBSocket, _ := context.NewSocket(zmq.REP)
60+
HBSocket, err := context.NewSocket(zmq.REP)
61+
if err != nil {
62+
return sg, errors.Wrap(err, "Could not get the Heartbeat device socket")
63+
}
6164
HBSocket.Bind(fmt.Sprintf(address, connInfo.HBPort))
6265
go zmq.Device(zmq.FORWARDER, HBSocket, HBSocket)
6366

64-
return
67+
return sg, nil
68+
}
69+
70+
// createSockets initializes the sockets for the socket group based on values from zmq.
71+
func createSockets() (*zmq.Context, SocketGroup, error) {
72+
73+
context, err := zmq.NewContext()
74+
if err != nil {
75+
return context, SocketGroup{}, errors.Wrap(err, "Could not create zmq Context")
76+
}
77+
78+
var sg SocketGroup
79+
sg.ShellSocket, err = context.NewSocket(zmq.ROUTER)
80+
if err != nil {
81+
return context, sg, errors.Wrap(err, "Could not get Shell Socket")
82+
}
83+
84+
sg.ControlSocket, err = context.NewSocket(zmq.ROUTER)
85+
if err != nil {
86+
return context, sg, errors.Wrap(err, "Could not get Control Socket")
87+
}
88+
89+
sg.StdinSocket, err = context.NewSocket(zmq.ROUTER)
90+
if err != nil {
91+
return context, sg, errors.Wrap(err, "Could not get Stdin Socket")
92+
}
93+
94+
sg.IOPubSocket, err = context.NewSocket(zmq.PUB)
95+
if err != nil {
96+
return context, sg, errors.Wrap(err, "Could not get IOPub Socket")
97+
}
98+
99+
return context, sg, nil
65100
}
66101

67102
// HandleShellMsg responds to a message on the shell ROUTER socket.
@@ -126,14 +161,16 @@ func RunKernel(connectionFile string, logwriter io.Writer) {
126161
if err != nil {
127162
log.Fatalln(err)
128163
}
129-
err = json.Unmarshal(bs, &connInfo)
130-
if err != nil {
164+
if err = json.Unmarshal(bs, &connInfo); err != nil {
131165
log.Fatalln(err)
132166
}
133167
logger.Printf("%+v\n", connInfo)
134168

135169
// Set up the ZMQ sockets through which the kernel will communicate
136-
sockets := PrepareSockets(connInfo)
170+
sockets, err := PrepareSockets(connInfo)
171+
if err != nil {
172+
log.Fatalln(err)
173+
}
137174

138175
pi := zmq.PollItems{
139176
zmq.PollItem{Socket: sockets.ShellSocket, Events: zmq.POLLIN},
@@ -144,8 +181,7 @@ func RunKernel(connectionFile string, logwriter io.Writer) {
144181
var msgparts [][]byte
145182
// Message receiving loop:
146183
for {
147-
_, err = zmq.Poll(pi, -1)
148-
if err != nil {
184+
if _, err = zmq.Poll(pi, -1); err != nil {
149185
log.Fatalln(err)
150186
}
151187
switch {

messages.go

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ import (
55
"crypto/sha256"
66
"encoding/hex"
77
"encoding/json"
8+
"log"
89

910
zmq "github.com/alecthomas/gozmq"
1011
uuid "github.com/nu7hatch/gouuid"
12+
"github.com/pkg/errors"
1113
)
1214

1315
// MsgHeader encodes header info for ZMQ messages
@@ -44,7 +46,6 @@ func WireMsgToComposedMsg(msgparts [][]byte, signkey []byte) (msg ComposedMsg,
4446
i++
4547
}
4648
identities = msgparts[:i]
47-
// msgparts[i] is the delimiter
4849

4950
// Validate signature
5051
if len(signkey) != 0 {
@@ -67,18 +68,34 @@ func WireMsgToComposedMsg(msgparts [][]byte, signkey []byte) (msg ComposedMsg,
6768

6869
// ToWireMsg translates a ComposedMsg into a multipart ZMQ message ready to send, and
6970
// signs it. This does not add the return identities or the delimiter.
70-
func (msg ComposedMsg) ToWireMsg(signkey []byte) (msgparts [][]byte) {
71-
msgparts = make([][]byte, 5)
72-
header, _ := json.Marshal(msg.Header)
71+
func (msg ComposedMsg) ToWireMsg(signkey []byte) ([][]byte, error) {
72+
73+
msgparts := make([][]byte, 5)
74+
header, err := json.Marshal(msg.Header)
75+
if err != nil {
76+
return msgparts, errors.Wrap(err, "Could not marshal message header")
77+
}
7378
msgparts[1] = header
74-
parentHeader, _ := json.Marshal(msg.ParentHeader)
79+
80+
parentHeader, err := json.Marshal(msg.ParentHeader)
81+
if err != nil {
82+
return msgparts, errors.Wrap(err, "Could not marshal parent header")
83+
}
7584
msgparts[2] = parentHeader
85+
7686
if msg.Metadata == nil {
7787
msg.Metadata = make(map[string]interface{})
7888
}
79-
metadata, _ := json.Marshal(msg.Metadata)
89+
metadata, err := json.Marshal(msg.Metadata)
90+
if err != nil {
91+
return msgparts, errors.Wrap(err, "Could not marshal metadata")
92+
}
8093
msgparts[3] = metadata
81-
content, _ := json.Marshal(msg.Content)
94+
95+
content, err := json.Marshal(msg.Content)
96+
if err != nil {
97+
return msgparts, errors.Wrap(err, "Could not marshal content")
98+
}
8299
msgparts[4] = content
83100

84101
// Sign the message
@@ -90,7 +107,7 @@ func (msg ComposedMsg) ToWireMsg(signkey []byte) (msgparts [][]byte) {
90107
msgparts[0] = make([]byte, hex.EncodedLen(mac.Size()))
91108
hex.Encode(msgparts[0], mac.Sum(nil))
92109
}
93-
return
110+
return msgparts, nil
94111
}
95112

96113
// MsgReceipt represents a received message, its return identities, and the sockets for
@@ -103,9 +120,15 @@ type MsgReceipt struct {
103120

104121
// SendResponse sends a message back to return identites of the received message.
105122
func (receipt *MsgReceipt) SendResponse(socket *zmq.Socket, msg ComposedMsg) {
123+
106124
socket.SendMultipart(receipt.Identities, zmq.SNDMORE)
107125
socket.Send([]byte("<IDS|MSG>"), zmq.SNDMORE)
108-
socket.SendMultipart(msg.ToWireMsg(receipt.Sockets.Key), 0)
126+
127+
msgParts, err := msg.ToWireMsg(receipt.Sockets.Key)
128+
if err != nil {
129+
log.Fatalln(err)
130+
}
131+
socket.SendMultipart(msgParts, 0)
109132
logger.Println("<--", msg.Header.MsgType)
110133
logger.Printf("%+v\n", msg.Content)
111134
}
@@ -117,7 +140,10 @@ func NewMsg(msgType string, parent ComposedMsg) (msg ComposedMsg) {
117140
msg.Header.Session = parent.Header.Session
118141
msg.Header.Username = parent.Header.Username
119142
msg.Header.MsgType = msgType
120-
u, _ := uuid.NewV4()
143+
u, err := uuid.NewV4()
144+
if err != nil {
145+
log.Fatalln(errors.Wrap(err, "Could not generate UUID"))
146+
}
121147
msg.Header.MsgID = u.String()
122148
return
123149
}

0 commit comments

Comments
 (0)