diff --git a/Gopkg.lock b/Gopkg.lock index dca93369b..f938ce406 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -559,12 +559,12 @@ source = "github.com/CovenantSQL/usql" [[projects]] - digest = "1:07aaea55c778d9f0461a429eb19a134f17e9ac3232a5a28b16dfe57c637889bc" + branch = "master" + digest = "1:fd9ee7072a8121eb2a1592611023924f00492432ef923e561110e4c5e380f285" name = "github.com/xtaci/smux" packages = ["."] pruneopts = "UT" - revision = "545ecee9d2a96ef4cf3c420c6b4095ac313fe870" - version = "v1.09" + revision = "6cf098d439391c8f8f6a485f8928f47575b6002e" [[projects]] digest = "1:4619abe2e9ceabced45ff40a4826866c48f264bb58384efe799a8fb83c2256e0" diff --git a/Gopkg.toml b/Gopkg.toml index 89678703c..0b2e9a1b8 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -61,6 +61,10 @@ name = "github.com/CovenantSQL/xurls" branch = "master" +[[override]] + name = "github.com/xtaci/smux" + branch = "master" + [[override]] name = "github.com/siddontang/go-mysql" source = "github.com/CovenantSQL/go-mysql" diff --git a/cmd/cql-minerd/integration_test.go b/cmd/cql-minerd/integration_test.go index 8d296713e..eacb16bce 100644 --- a/cmd/cql-minerd/integration_test.go +++ b/cmd/cql-minerd/integration_test.go @@ -35,7 +35,10 @@ import ( "time" "github.com/CovenantSQL/CovenantSQL/client" + "github.com/CovenantSQL/CovenantSQL/conf" "github.com/CovenantSQL/CovenantSQL/crypto/asymmetric" + "github.com/CovenantSQL/CovenantSQL/proto" + "github.com/CovenantSQL/CovenantSQL/rpc" "github.com/CovenantSQL/CovenantSQL/utils" "github.com/CovenantSQL/CovenantSQL/utils/log" "github.com/CovenantSQL/go-sqlite3-encrypt" @@ -316,7 +319,7 @@ func stopNodes() { func TestFullProcess(t *testing.T) { log.SetLevel(log.DebugLevel) - Convey("test full process", t, func() { + Convey("test full process", t, func(c C) { startNodes() defer stopNodes() var err error @@ -384,6 +387,68 @@ func TestFullProcess(t *testing.T) { So(err, ShouldBeNil) So(resultBytes, ShouldResemble, []byte("ha\001ppy")) + Convey("test query cancel", FailureContinues, func(c C) { + /* test cancel write query */ + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + db.Exec("INSERT INTO test VALUES(sleep(10000000000))") + }() + time.Sleep(time.Second) + wg.Add(1) + go func() { + defer wg.Done() + var err error + _, err = db.Exec("UPDATE test SET test = 100;") + // should be canceled + c.So(err, ShouldNotBeNil) + }() + time.Sleep(time.Second) + for _, n := range conf.GConf.KnownNodes { + if n.Role == proto.Miner { + rpc.GetSessionPoolInstance().Remove(n.ID) + } + } + time.Sleep(time.Second) + + // ensure connection + db.Query("SELECT 1") + + // test before write operation complete + var result int + err = db.QueryRow("SELECT * FROM test WHERE test = 4 LIMIT 1").Scan(&result) + c.So(err, ShouldBeNil) + c.So(result, ShouldEqual, 4) + + wg.Wait() + + /* test cancel read query */ + go func() { + _, err = db.Query("SELECT * FROM test WHERE test = sleep(10000000000)") + // call write query using read query interface + //_, err = db.Query("INSERT INTO test VALUES(sleep(10000000000))") + c.So(err, ShouldNotBeNil) + }() + time.Sleep(time.Second) + for _, n := range conf.GConf.KnownNodes { + if n.Role == proto.Miner { + rpc.GetSessionPoolInstance().Remove(n.ID) + } + } + time.Sleep(time.Second) + // ensure connection + db.Query("SELECT 1") + + /* test long running write query */ + row = db.QueryRow("SELECT * FROM test WHERE test = 10000000000 LIMIT 1") + err = row.Scan(&result) + c.So(err, ShouldBeNil) + c.So(result, ShouldEqual, 10000000000) + + c.So(err, ShouldBeNil) + }) + err = db.Close() So(err, ShouldBeNil) diff --git a/kayak/runtime.go b/kayak/runtime.go index 8c763ceda..1858a803f 100644 --- a/kayak/runtime.go +++ b/kayak/runtime.go @@ -35,7 +35,7 @@ import ( const ( // commit channel window size - commitWindow = 10 + commitWindow = 0 // prepare window trackerWindow = 10 ) @@ -246,6 +246,7 @@ func (r *Runtime) Shutdown() (err error) { // Apply defines entry for Leader node. func (r *Runtime) Apply(ctx context.Context, req interface{}) (result interface{}, logIndex uint64, err error) { var commitFuture <-chan *commitResult + var cResult *commitResult var tmStart, tmLeaderPrepare, tmFollowerPrepare, tmCommitEnqueue, tmLeaderRollback, tmRollback, tmCommitDequeue, tmLeaderCommit, tmCommit time.Time @@ -350,37 +351,36 @@ func (r *Runtime) Apply(ctx context.Context, req interface{}) (result interface{ tmCommitEnqueue = time.Now() - select { - case cResult := <-commitFuture: - if cResult != nil { - logIndex = prepareLog.Index - result = cResult.result - err = cResult.err - - tmCommitDequeue = cResult.start - dbCost = cResult.dbCost - tmLeaderCommit = time.Now() - - // wait until context deadline or commit done - if cResult.rpc != nil { - cResult.rpc.get(ctx) - } - } else { - log.Fatal("IMPOSSIBLE BRANCH") - select { - case <-ctx.Done(): - err = errors.Wrap(ctx.Err(), "process commit timeout") - goto ROLLBACK - default: - } - } - case <-ctx.Done(): - // pipeline commit timeout + if commitFuture == nil { logIndex = prepareLog.Index err = errors.Wrap(ctx.Err(), "enqueue commit timeout") goto ROLLBACK } + cResult = <-commitFuture + if cResult != nil { + logIndex = prepareLog.Index + result = cResult.result + err = cResult.err + + tmCommitDequeue = cResult.start + dbCost = cResult.dbCost + tmLeaderCommit = time.Now() + + // wait until context deadline or commit done + if cResult.rpc != nil { + cResult.rpc.get(ctx) + } + } else { + log.Fatal("IMPOSSIBLE BRANCH") + select { + case <-ctx.Done(): + err = errors.Wrap(ctx.Err(), "process commit timeout") + goto ROLLBACK + default: + } + } + tmCommit = time.Now() return @@ -572,6 +572,7 @@ func (r *Runtime) leaderCommitResult(ctx context.Context, reqPayload interface{} select { case <-ctx.Done(): + res = nil case r.commitCh <- req: } diff --git a/proto/proto.go b/proto/proto.go index 2f5a812f8..1e04c2999 100644 --- a/proto/proto.go +++ b/proto/proto.go @@ -18,6 +18,7 @@ package proto import ( + "context" "time" ) @@ -30,19 +31,22 @@ type EnvelopeAPI interface { GetTTL() time.Duration GetExpire() time.Duration GetNodeID() *RawNodeID + GetContext() context.Context SetVersion(string) SetTTL(time.Duration) SetExpire(time.Duration) SetNodeID(*RawNodeID) + SetContext(context.Context) } // Envelope is the protocol header type Envelope struct { - Version string `json:"v"` - TTL time.Duration `json:"t"` - Expire time.Duration `json:"e"` - NodeID *RawNodeID `json:"id"` + Version string `json:"v"` + TTL time.Duration `json:"t"` + Expire time.Duration `json:"e"` + NodeID *RawNodeID `json:"id"` + _ctx context.Context `json:"-"` } // PingReq is Ping RPC request @@ -120,6 +124,14 @@ func (e *Envelope) GetNodeID() *RawNodeID { return e.NodeID } +// GetContext returns context from envelop which is set in server Accept +func (e *Envelope) GetContext() context.Context { + if e._ctx == nil { + return context.Background() + } + return e._ctx +} + // SetVersion implements EnvelopeAPI.SetVersion func (e *Envelope) SetVersion(ver string) { e.Version = ver @@ -140,5 +152,10 @@ func (e *Envelope) SetNodeID(nodeID *RawNodeID) { e.NodeID = nodeID } +// SetContext set a ctx in envelope +func (e *Envelope) SetContext(ctx context.Context) { + e._ctx = ctx +} + // DatabaseID is database name, will be generated from UUID type DatabaseID string diff --git a/proto/proto_test.go b/proto/proto_test.go index 82428d995..50db0d90e 100644 --- a/proto/proto_test.go +++ b/proto/proto_test.go @@ -17,6 +17,7 @@ package proto import ( + "context" "testing" "time" @@ -41,5 +42,11 @@ func TestEnvelope_GetSet(t *testing.T) { env.SetVersion("0.0.1") So(env.GetVersion(), ShouldEqual, "0.0.1") + + ctx := env.GetContext() + So(ctx, ShouldEqual, context.Background()) + cldCtx, _ := context.WithCancel(ctx) + env.SetContext(cldCtx) + So(env.GetContext(), ShouldEqual, cldCtx) }) } diff --git a/rpc/codec.go b/rpc/codec.go index 63b8f53cc..98f38ce33 100644 --- a/rpc/codec.go +++ b/rpc/codec.go @@ -17,6 +17,7 @@ package rpc import ( + "context" "net/rpc" "github.com/CovenantSQL/CovenantSQL/proto" @@ -26,13 +27,15 @@ import ( type NodeAwareServerCodec struct { rpc.ServerCodec NodeID *proto.RawNodeID + Ctx context.Context } // NewNodeAwareServerCodec returns new NodeAwareServerCodec with normal rpc.ServerCode and proto.RawNodeID -func NewNodeAwareServerCodec(codec rpc.ServerCodec, nodeID *proto.RawNodeID) *NodeAwareServerCodec { +func NewNodeAwareServerCodec(ctx context.Context, codec rpc.ServerCodec, nodeID *proto.RawNodeID) *NodeAwareServerCodec { return &NodeAwareServerCodec{ ServerCodec: codec, NodeID: nodeID, + Ctx: ctx, } } @@ -51,6 +54,8 @@ func (nc *NodeAwareServerCodec) ReadRequestBody(body interface{}) (err error) { if r, ok := body.(proto.EnvelopeAPI); ok { // inject node id to rpc envelope r.SetNodeID(nc.NodeID) + // inject context + r.SetContext(nc.Ctx) } return diff --git a/rpc/rpcutil.go b/rpc/rpcutil.go index d39b12c86..d439dee95 100644 --- a/rpc/rpcutil.go +++ b/rpc/rpcutil.go @@ -23,6 +23,7 @@ import ( "math/rand" "net" "net/rpc" + "strings" "sync" "github.com/CovenantSQL/CovenantSQL/crypto/kms" @@ -65,6 +66,7 @@ func (c *PersistentCaller) initClient(isAnonymous bool) (err error) { c.Lock() defer c.Unlock() if c.client == nil { + log.Debug("init new rpc client") var conn net.Conn conn, err = DialToNode(c.TargetID, c.pool, isAnonymous) if err != nil { @@ -93,11 +95,13 @@ func (c *PersistentCaller) Call(method string, args interface{}, reply interface if err == io.EOF || err == io.ErrUnexpectedEOF || err == io.ErrClosedPipe || - err == rpc.ErrShutdown { + err == rpc.ErrShutdown || + strings.Contains(strings.ToLower(err.Error()), "shut down") || + strings.Contains(strings.ToLower(err.Error()), "broken pipe") { // if got EOF, retry once - err = c.Reconnect(method) - if err != nil { - log.WithField("rpc", method).WithError(err).Error("reconnect failed") + reconnectErr := c.ResetClient(method) + if reconnectErr != nil { + log.WithField("rpc", method).WithError(reconnectErr).Error("reconnect failed") } } log.WithField("rpc", method).WithError(err).Error("call RPC failed") @@ -105,17 +109,12 @@ func (c *PersistentCaller) Call(method string, args interface{}, reply interface return } -// Reconnect tries to rebuild RPC client -func (c *PersistentCaller) Reconnect(method string) (err error) { +// ResetClient resets client. +func (c *PersistentCaller) ResetClient(method string) (err error) { c.Lock() c.Close() c.client = nil c.Unlock() - err = c.initClient(method == route.DHTPing.String()) - if err != nil { - log.WithField("rpc", method).WithError(err).Error("second init client for RPC failed") - return - } return } diff --git a/rpc/server.go b/rpc/server.go index 159ad76a0..3a56db669 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -17,6 +17,7 @@ package rpc import ( + "context" "io" "net" "net/rpc" @@ -149,7 +150,12 @@ sessionLoop: } break sessionLoop } - nodeAwareCodec := NewNodeAwareServerCodec(utils.GetMsgPackServerCodec(muxConn), remoteNodeID) + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + <-muxConn.GetDieCh() + cancelFunc() + }() + nodeAwareCodec := NewNodeAwareServerCodec(ctx, utils.GetMsgPackServerCodec(muxConn), remoteNodeID) go s.rpcServer.ServeCodec(nodeAwareCodec) } } diff --git a/sqlchain/chain.go b/sqlchain/chain.go index a95f81c45..326619662 100644 --- a/sqlchain/chain.go +++ b/sqlchain/chain.go @@ -18,6 +18,7 @@ package sqlchain import ( "bytes" + "context" "encoding/binary" "fmt" "os" @@ -112,8 +113,8 @@ type Chain struct { st *x.State cl *rpc.Caller rt *runtime + ctx context.Context // ctx is the root context of Chain - stopCh chan struct{} blocks chan *types.Block heights chan int32 responses chan *types.ResponseHeader @@ -132,8 +133,6 @@ type Chain struct { observerReplicators map[proto.NodeID]*observerReplicator // replCh defines the replication trigger channel for replication check. replCh chan struct{} - // replWg defines the waitGroups for running replications. - replWg sync.WaitGroup // Cached fileds, may need to renew some of this fields later. // @@ -143,6 +142,11 @@ type Chain struct { // NewChain creates a new sql-chain struct. func NewChain(c *Config) (chain *Chain, err error) { + return NewChainWithContext(context.Background(), c) +} + +// NewChainWithContext creates a new sql-chain struct with context. +func NewChainWithContext(ctx context.Context, c *Config) (chain *Chain, err error) { // TODO(leventeliu): this is a rough solution, you may also want to clean database file and // force rebuilding. var fi os.FileInfo @@ -202,8 +206,8 @@ func NewChain(c *Config) (chain *Chain, err error) { ai: newAckIndex(), st: state, cl: rpc.NewCaller(), - rt: newRunTime(c), - stopCh: make(chan struct{}), + rt: newRunTime(ctx, c), + ctx: ctx, blocks: make(chan *types.Block), heights: make(chan int32, 1), responses: make(chan *types.ResponseHeader), @@ -229,6 +233,12 @@ func NewChain(c *Config) (chain *Chain, err error) { // LoadChain loads the chain state from the specified database and rebuilds a memory index. func LoadChain(c *Config) (chain *Chain, err error) { + return LoadChainWithContext(context.Background(), c) +} + +// LoadChainWithContext loads the chain state from the specified database and rebuilds +// a memory index with context. +func LoadChainWithContext(ctx context.Context, c *Config) (chain *Chain, err error) { // Open LevelDB for block and state bdbFile := c.ChainFilePrefix + "-block-state.ldb" bdb, err := leveldb.OpenFile(bdbFile, &leveldbConf) @@ -272,8 +282,8 @@ func LoadChain(c *Config) (chain *Chain, err error) { ai: newAckIndex(), st: xstate, cl: rpc.NewCaller(), - rt: newRunTime(c), - stopCh: make(chan struct{}), + rt: newRunTime(ctx, c), + ctx: ctx, blocks: make(chan *types.Block), heights: make(chan int32, 1), responses: make(chan *types.ResponseHeader), @@ -572,7 +582,12 @@ func (c *Chain) produceBlockV2(now time.Time) (err error) { return } // Send to pending list - c.blocks <- block + select { + case c.blocks <- block: + case <-c.rt.ctx.Done(): + err = c.rt.ctx.Err() + return + } log.WithFields(log.Fields{ "peer": c.rt.getPeerInfoString(), "time": c.rt.getChainTimeString(), @@ -609,28 +624,28 @@ func (c *Chain) produceBlockV2(now time.Time) (err error) { go func(id proto.NodeID) { defer wg.Done() resp := &MuxAdviseNewBlockResp{} - if err := c.cl.CallNode( - id, route.SQLCAdviseNewBlock.String(), req, resp); err != nil { + if err := c.cl.CallNodeWithContext( + c.rt.ctx, id, route.SQLCAdviseNewBlock.String(), req, resp, + ); err != nil { log.WithFields(log.Fields{ "peer": c.rt.getPeerInfoString(), "time": c.rt.getChainTimeString(), "curr_turn": c.rt.getNextTurn(), "using_timestamp": now.Format(time.RFC3339Nano), "block_hash": block.BlockHash().String(), - }).WithError(err).Error( - "Failed to advise new block") + }).WithError(err).Error("Failed to advise new block") } }(s) } } wg.Wait() // fire replication to observers - c.startStopReplication() + c.startStopReplication(c.rt.ctx) return } func (c *Chain) syncHead() { - // Try to fetch if the the block of the current turn is not advised yet + // Try to fetch if the block of the current turn is not advised yet if h := c.rt.getNextTurn() - 1; c.rt.getHead().Height < h { var err error req := &MuxFetchBlockReq{ @@ -662,7 +677,12 @@ func (c *Chain) syncHead() { "Failed to fetch block from peer") } else { statBlock(resp.Block) - c.blocks <- resp.Block + select { + case c.blocks <- resp.Block: + case <-c.rt.ctx.Done(): + err = c.rt.ctx.Err() + return + } log.WithFields(log.Fields{ "peer": c.rt.getPeerInfoString(), "time": c.rt.getChainTimeString(), @@ -739,16 +759,10 @@ func (c *Chain) runCurrentTurn(now time.Time) { } // mainCycle runs main cycle of the sql-chain. -func (c *Chain) mainCycle() { - defer func() { - c.rt.wg.Done() - // Signal worker goroutines to stop - close(c.stopCh) - }() - +func (c *Chain) mainCycle(ctx context.Context) { for { select { - case <-c.rt.stopCh: + case <-ctx.Done(): return default: c.syncHead() @@ -795,24 +809,26 @@ func (c *Chain) sync() (err error) { return } -func (c *Chain) processBlocks() { - rsCh := make(chan struct{}) - rsWG := &sync.WaitGroup{} +func (c *Chain) processBlocks(ctx context.Context) { + var ( + cld, ccl = context.WithCancel(ctx) + wg = &sync.WaitGroup{} + ) + returnStash := func(stash []*types.Block) { - defer rsWG.Done() + defer wg.Done() for _, block := range stash { select { case c.blocks <- block: - case <-rsCh: + case <-cld.Done(): return } } } defer func() { - close(rsCh) - rsWG.Wait() - c.rt.wg.Done() + ccl() + wg.Wait() }() var stash []*types.Block @@ -825,7 +841,7 @@ func (c *Chain) processBlocks() { "stashs": len(stash), }).Debug("Read new height from channel") if stash != nil { - rsWG.Add(1) + wg.Add(1) go returnStash(stash) stash = nil } @@ -863,30 +879,8 @@ func (c *Chain) processBlocks() { } } // fire replication to observers - c.startStopReplication() - case <-c.stopCh: - return - } - } -} - -func (c *Chain) processResponses() { - // TODO(leventeliu): implement that - defer c.rt.wg.Done() - for { - select { - case <-c.stopCh: - return - } - } -} - -func (c *Chain) processAcks() { - // TODO(leventeliu): implement that - defer c.rt.wg.Done() - for { - select { - case <-c.stopCh: + c.startStopReplication(c.rt.ctx) + case <-ctx.Done(): return } } @@ -898,16 +892,9 @@ func (c *Chain) Start() (err error) { return } - c.rt.wg.Add(1) - go c.processBlocks() - c.rt.wg.Add(1) - go c.processResponses() - c.rt.wg.Add(1) - go c.processAcks() - c.rt.wg.Add(1) - go c.mainCycle() - c.rt.wg.Add(1) - go c.replicationCycle() + c.rt.goFunc(c.processBlocks) + c.rt.goFunc(c.mainCycle) + c.rt.goFunc(c.replicationCycle) c.rt.startService(c) return } @@ -923,7 +910,7 @@ func (c *Chain) Stop() (err error) { log.WithFields(log.Fields{ "peer": c.rt.getPeerInfoString(), "time": c.rt.getChainTimeString(), - }).Debug("Chain service stopped") + }).Debug("Chain service and workers stopped") // Close LevelDB file var ierr error if ierr = c.bdb.Close(); ierr != nil && err == nil { @@ -1040,7 +1027,7 @@ func (c *Chain) CheckAndPushNewBlock(block *types.Block) (err error) { // } // Replicate local state from the new block - if err = c.st.ReplayBlock(block); err != nil { + if err = c.st.ReplayBlockWithContext(c.rt.ctx, block); err != nil { return } @@ -1159,10 +1146,9 @@ func (c *Chain) getBilling(low, high int32) (req *pt.BillingRequest, err error) return } -func (c *Chain) collectBillingSignatures(billings *pt.BillingRequest) { - defer c.rt.wg.Done() +func (c *Chain) collectBillingSignatures(ctx context.Context, billings *pt.BillingRequest) { // Process sign billing responses, note that range iterating over channel will only break if - // the channle is closed + // the channel is closed req := &MuxSignBillingReq{ Envelope: proto.Envelope{ // TODO(leventeliu): Add fields. @@ -1212,7 +1198,9 @@ func (c *Chain) collectBillingSignatures(billings *pt.BillingRequest) { } var resp interface{} - if err = c.cl.CallNode(bpNodeID, route.MCCAdviseBillingRequest.String(), bpReq, resp); err != nil { + if err = c.cl.CallNodeWithContext( + ctx, bpNodeID, route.MCCAdviseBillingRequest.String(), bpReq, resp, + ); err != nil { return } }() @@ -1232,12 +1220,13 @@ func (c *Chain) collectBillingSignatures(billings *pt.BillingRequest) { defer rpcWG.Done() resp := &MuxSignBillingResp{} - if err := c.cl.CallNode(id, route.SQLCSignBilling.String(), req, resp); err != nil { + if err := c.cl.CallNodeWithContext( + ctx, id, route.SQLCSignBilling.String(), req, resp, + ); err != nil { log.WithFields(log.Fields{ "peer": c.rt.getPeerInfoString(), "time": c.rt.getChainTimeString(), - }).WithError(err).Error( - "Failed to send sign billing request") + }).WithError(err).Error("Failed to send sign billing request") } respC <- &resp.SignBillingResp @@ -1268,8 +1257,7 @@ func (c *Chain) LaunchBilling(low, high int32) (err error) { return } - c.rt.wg.Add(1) - go c.collectBillingSignatures(req) + c.rt.goFunc(func(ctx context.Context) { c.collectBillingSignatures(ctx, req) }) return } @@ -1307,7 +1295,7 @@ func (c *Chain) addSubscription(nodeID proto.NodeID, startHeight int32) (err err c.observerLock.Lock() defer c.observerLock.Unlock() c.observers[nodeID] = startHeight - c.startStopReplication() + c.startStopReplication(c.rt.ctx) return } @@ -1316,15 +1304,15 @@ func (c *Chain) cancelSubscription(nodeID proto.NodeID) (err error) { c.observerLock.Lock() defer c.observerLock.Unlock() delete(c.observers, nodeID) - c.startStopReplication() + c.startStopReplication(c.rt.ctx) return } -func (c *Chain) startStopReplication() { +func (c *Chain) startStopReplication(ctx context.Context) { if c.replCh != nil { select { case c.replCh <- struct{}{}: - case <-c.stopCh: + case <-ctx.Done(): default: } } @@ -1344,10 +1332,9 @@ func (c *Chain) populateObservers() { } } else { // start new replication routine - c.replWg.Add(1) replicator := newObserverReplicator(nodeID, startHeight, c) c.observerReplicators[nodeID] = replicator - go replicator.run() + c.rt.goFunc(replicator.run) } } @@ -1360,12 +1347,7 @@ func (c *Chain) populateObservers() { } } -func (c *Chain) replicationCycle() { - defer func() { - c.replWg.Wait() - c.rt.wg.Done() - }() - +func (c *Chain) replicationCycle(ctx context.Context) { for { select { case <-c.replCh: @@ -1375,7 +1357,7 @@ func (c *Chain) replicationCycle() { for _, replicator := range c.observerReplicators { replicator.tick() } - case <-c.stopCh: + case <-ctx.Done(): return } } @@ -1384,7 +1366,9 @@ func (c *Chain) replicationCycle() { // Query queries req from local chain state and returns the query results in resp. func (c *Chain) Query(req *types.Request) (resp *types.Response, err error) { var ref *x.QueryTracker - if ref, resp, err = c.st.Query(req); err != nil { + // TODO(leventeliu): we're using an external context passed by request. Make sure that + // cancelling will be propagated to this context before chain instance stops. + if ref, resp, err = c.st.QueryWithContext(req.GetContext(), req); err != nil { return } if err = resp.Sign(c.pk); err != nil { @@ -1403,7 +1387,7 @@ func (c *Chain) Replay(req *types.Request, resp *types.Response) (err error) { case types.ReadQuery: return case types.WriteQuery: - return c.st.Replay(req, resp) + return c.st.ReplayWithContext(req.GetContext(), req, resp) default: err = ErrInvalidRequest } diff --git a/sqlchain/observer.go b/sqlchain/observer.go index e1fa19960..5e293578c 100644 --- a/sqlchain/observer.go +++ b/sqlchain/observer.go @@ -17,6 +17,7 @@ package sqlchain import ( + "context" "sync" "github.com/CovenantSQL/CovenantSQL/proto" @@ -59,7 +60,6 @@ func newObserverReplicator(nodeID proto.NodeID, startHeight int32, c *Chain) *ob func (r *observerReplicator) setNewHeight(newHeight int32) { r.replLock.Lock() defer r.replLock.Unlock() - r.height = newHeight } @@ -241,15 +241,13 @@ func (r *observerReplicator) tick() { default: } } -func (r *observerReplicator) run() { - defer r.c.replWg.Done() - +func (r *observerReplicator) run(ctx context.Context) { for { select { case <-r.triggerCh: // replication r.replicate() - case <-r.c.stopCh: + case <-ctx.Done(): r.stop() return case <-r.stopCh: diff --git a/sqlchain/runtime.go b/sqlchain/runtime.go index b1935558d..14cdabf32 100644 --- a/sqlchain/runtime.go +++ b/sqlchain/runtime.go @@ -17,6 +17,7 @@ package sqlchain import ( + "context" "fmt" "sync" "time" @@ -29,8 +30,9 @@ import ( // runtime represents a chain runtime state. type runtime struct { - wg sync.WaitGroup - stopCh chan struct{} + wg *sync.WaitGroup + ctx context.Context + cancel context.CancelFunc // chainInitTime is the initial cycle time, when the Genesis blcok is produced. chainInitTime time.Time @@ -85,9 +87,13 @@ type runtime struct { } // newRunTime returns a new sql-chain runtime instance with the specified config. -func newRunTime(c *Config) (r *runtime) { +func newRunTime(ctx context.Context, c *Config) (r *runtime) { + var cld, ccl = context.WithCancel(ctx) r = &runtime{ - stopCh: make(chan struct{}), + wg: &sync.WaitGroup{}, + ctx: cld, + cancel: ccl, + databaseID: c.DatabaseID, period: c.Period, tick: c.Tick, @@ -204,11 +210,7 @@ func (r *runtime) getQueryGas(t types.QueryType) uint64 { // stop sends a signal to the Runtime stop channel by closing it. func (r *runtime) stop() { r.stopService() - select { - case <-r.stopCh: - default: - close(r.stopCh) - } + r.cancel() r.wg.Wait() } @@ -328,3 +330,11 @@ func (r *runtime) setHead(head *state) { defer r.stateMutex.Unlock() r.head = head } + +func (r *runtime) goFunc(f func(context.Context)) { + r.wg.Add(1) + go func() { + defer r.wg.Done() + f(r.ctx) + }() +} diff --git a/vendor/github.com/xtaci/smux/stream.go b/vendor/github.com/xtaci/smux/stream.go index 2ce00d2d9..2a2b82fcc 100644 --- a/vendor/github.com/xtaci/smux/stream.go +++ b/vendor/github.com/xtaci/smux/stream.go @@ -146,6 +146,12 @@ func (s *Stream) Close() error { } } +// GetDieCh returns a readonly chan which can be readable +// when the stream is to be closed. +func (s *Stream) GetDieCh() <-chan struct{} { + return s.die +} + // SetReadDeadline sets the read deadline as defined by // net.Conn.SetReadDeadline. // A zero time value disables the deadline. diff --git a/worker/db.go b/worker/db.go index ba2c7289f..96095d271 100644 --- a/worker/db.go +++ b/worker/db.go @@ -17,9 +17,9 @@ package worker import ( - "context" "os" "path/filepath" + //"runtime/trace" "sync" "time" @@ -296,7 +296,7 @@ func (db *Database) writeQuery(request *types.Request) (response *types.Response // call kayak runtime Process var result interface{} - if result, _, err = db.kayakRuntime.Apply(context.Background(), request); err != nil { + if result, _, err = db.kayakRuntime.Apply(request.GetContext(), request); err != nil { err = errors.Wrap(err, "apply failed") return } diff --git a/worker/db_storage.go b/worker/db_storage.go index b56f6bbee..1b49fce83 100644 --- a/worker/db_storage.go +++ b/worker/db_storage.go @@ -19,6 +19,7 @@ package worker import ( "bytes" "container/list" + "context" "github.com/CovenantSQL/CovenantSQL/types" "github.com/CovenantSQL/CovenantSQL/utils" @@ -99,6 +100,9 @@ func (db *Database) Commit(rawReq interface{}) (result interface{}, err error) { return } + // reset context, commit should never be canceled + req.SetContext(context.Background()) + // execute return db.chain.Query(req) } diff --git a/worker/db_test.go b/worker/db_test.go index 36561d5bc..e1689cb0e 100644 --- a/worker/db_test.go +++ b/worker/db_test.go @@ -18,6 +18,7 @@ package worker import ( "bytes" + "context" "fmt" "io/ioutil" "math/rand" @@ -631,6 +632,7 @@ func buildQueryEx(queryType types.QueryType, connID uint64, seqNo uint64, timeSh Queries: realQueries, }, } + query.SetContext(context.Background()) err = query.Sign(privateKey) diff --git a/xenomint/sqlite/sqlite.go b/xenomint/sqlite/sqlite.go index 0ea8bf846..0c2780250 100644 --- a/xenomint/sqlite/sqlite.go +++ b/xenomint/sqlite/sqlite.go @@ -18,22 +18,41 @@ package sqlite import ( "database/sql" + "time" "github.com/CovenantSQL/CovenantSQL/storage" + "github.com/CovenantSQL/CovenantSQL/utils/log" "github.com/CovenantSQL/go-sqlite3-encrypt" ) const ( - serializableDriver = "sqlite3" + serializableDriver = "sqlite3-custom" dirtyReadDriver = "sqlite3-dirty-reader" ) func init() { + sleepFunc := func(t int64) int64 { + log.Info("sqlite func sleep start") + time.Sleep(time.Duration(t)) + log.Info("sqlite func sleep end") + return t + } sql.Register(dirtyReadDriver, &sqlite3.SQLiteDriver{ ConnectHook: func(c *sqlite3.SQLiteConn) (err error) { if _, err = c.Exec("PRAGMA read_uncommitted=1", nil); err != nil { return } + if err = c.RegisterFunc("sleep", sleepFunc, true); err != nil { + return + } + return + }, + }) + sql.Register(serializableDriver, &sqlite3.SQLiteDriver{ + ConnectHook: func(c *sqlite3.SQLiteConn) (err error) { + if err = c.RegisterFunc("sleep", sleepFunc, true); err != nil { + return + } return }, }) diff --git a/xenomint/state.go b/xenomint/state.go index 8599cf012..d6a986d61 100644 --- a/xenomint/state.go +++ b/xenomint/state.go @@ -17,6 +17,7 @@ package xenomint import ( + "context" "database/sql" "io" "strings" @@ -67,20 +68,21 @@ func NewState(nodeID proto.NodeID, strg xi.Storage) (s *State, err error) { } func (s *State) incSeq() { - s.current++ + atomic.AddUint64(&s.current, 1) } func (s *State) setNextTxID() { - s.origin = s.current - s.cmpoint = s.current + current := s.getID() + s.origin = current + s.cmpoint = current } func (s *State) setCommitPoint() { - s.cmpoint = s.current + s.cmpoint = s.getID() } func (s *State) rollbackID(id uint64) { - s.current = id + atomic.StoreUint64(&s.current, id) } // InitTx sets the initial id of the current transaction. This method is not safe for concurrency @@ -88,7 +90,7 @@ func (s *State) rollbackID(id uint64) { func (s *State) InitTx(id uint64) { s.origin = id s.cmpoint = id - s.current = id + s.rollbackID(id) s.setSavepoint() } @@ -103,14 +105,18 @@ func (s *State) Close(commit bool) (err error) { } if s.unc != nil { if commit { + s.Lock() + defer s.Unlock() if err = s.uncCommit(); err != nil { return } } else { - // Only rollback to last commmit point + // Only rollback to last commit point if err = s.rollback(); err != nil { return } + s.Lock() + defer s.Unlock() if err = s.uncCommit(); err != nil { return } @@ -199,10 +205,13 @@ func buildTypeNamesFromSQLColumnTypes(types []*sql.ColumnType) (names []string) type sqlQuerier interface { Query(query string, args ...interface{}) (*sql.Rows, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) } func readSingle( - qer sqlQuerier, q *types.Query) (names []string, types []string, data [][]interface{}, err error, + ctx context.Context, qer sqlQuerier, q *types.Query, +) ( + names []string, types []string, data [][]interface{}, err error, ) { var ( rows *sql.Rows @@ -214,7 +223,7 @@ func readSingle( if _, pattern, args, err = convertQueryAndBuildArgs(q.Pattern, q.Args); err != nil { return } - if rows, err = qer.Query(pattern, args...); err != nil { + if rows, err = qer.QueryContext(ctx, pattern, args...); err != nil { return } defer rows.Close() @@ -253,6 +262,12 @@ func buildRowsFromNativeData(data [][]interface{}) (rows []types.ResponseRow) { } func (s *State) read(req *types.Request) (ref *QueryTracker, resp *types.Response, err error) { + return s.readWithContext(context.Background(), req) +} + +func (s *State) readWithContext( + ctx context.Context, req *types.Request) (ref *QueryTracker, resp *types.Response, err error, +) { var ( ierr error cnames, ctypes []string @@ -260,7 +275,7 @@ func (s *State) read(req *types.Request) (ref *QueryTracker, resp *types.Respons ) // TODO(leventeliu): no need to run every read query here. for i, v := range req.Payload.Queries { - if cnames, ctypes, data, ierr = readSingle(s.strg.DirtyReader(), &v); ierr != nil { + if cnames, ctypes, data, ierr = readSingle(ctx, s.strg.DirtyReader(), &v); ierr != nil { err = errors.Wrapf(ierr, "query at #%d failed", i) // Add to failed pool list s.pool.setFailed(req) @@ -288,7 +303,9 @@ func (s *State) read(req *types.Request) (ref *QueryTracker, resp *types.Respons return } -func (s *State) readTx(req *types.Request) (ref *QueryTracker, resp *types.Response, err error) { +func (s *State) readTx( + ctx context.Context, req *types.Request) (ref *QueryTracker, resp *types.Response, err error, +) { var ( tx *sql.Tx id uint64 @@ -297,15 +314,18 @@ func (s *State) readTx(req *types.Request) (ref *QueryTracker, resp *types.Respo data [][]interface{} querier sqlQuerier ) - id = s.getID() if atomic.LoadUint32(&s.hasSchemaChange) == 1 { // lock transaction s.Lock() defer s.Unlock() + id = s.getID() s.setSavepoint() querier = s.unc defer s.rollbackTo(id) + + // TODO(): should detect query type, any timeout write query will cause underlying transaction to rollback } else { + id = s.getID() if tx, ierr = s.strg.DirtyReader().Begin(); ierr != nil { err = errors.Wrap(ierr, "open tx failed") return @@ -314,8 +334,18 @@ func (s *State) readTx(req *types.Request) (ref *QueryTracker, resp *types.Respo defer tx.Rollback() } + defer func() { + if ctx.Err() != nil { + log.WithError(ctx.Err()).WithFields(log.Fields{ + "req": req, + "id": id, + "dirtyRead": atomic.LoadUint32(&s.hasSchemaChange) != 1, + }).Warning("read query canceled") + } + }() + for i, v := range req.Payload.Queries { - if cnames, ctypes, data, ierr = readSingle(querier, &v); ierr != nil { + if cnames, ctypes, data, ierr = readSingle(ctx, querier, &v); ierr != nil { err = errors.Wrapf(ierr, "query at #%d failed", i) // Add to failed pool list s.pool.setFailed(req) @@ -343,7 +373,9 @@ func (s *State) readTx(req *types.Request) (ref *QueryTracker, resp *types.Respo return } -func (s *State) writeSingle(q *types.Query) (res sql.Result, err error) { +func (s *State) writeSingle( + ctx context.Context, q *types.Query) (res sql.Result, err error, +) { var ( containsDDL bool pattern string @@ -353,7 +385,7 @@ func (s *State) writeSingle(q *types.Query) (res sql.Result, err error) { if containsDDL, pattern, args, err = convertQueryAndBuildArgs(q.Pattern, q.Args); err != nil { return } - if res, err = s.unc.Exec(pattern, args...); err == nil { + if res, err = s.unc.ExecContext(ctx, pattern, args...); err == nil { if containsDDL { atomic.StoreUint32(&s.hasSchemaChange, 1) } @@ -373,7 +405,9 @@ func (s *State) rollbackTo(savepoint uint64) { s.unc.Exec("ROLLBACK TO \"?\"", savepoint) } -func (s *State) write(req *types.Request) (ref *QueryTracker, resp *types.Response, err error) { +func (s *State) write( + ctx context.Context, req *types.Request) (ref *QueryTracker, resp *types.Response, err error, +) { var ( savepoint uint64 query = &QueryTracker{Req: req} @@ -382,6 +416,12 @@ func (s *State) write(req *types.Request) (ref *QueryTracker, resp *types.Respon lastInsertID int64 ) + defer func() { + if ctx.Err() != nil { + log.WithError(err).WithField("req", req).Warning("write query canceled") + } + }() + // TODO(leventeliu): savepoint is a sqlite-specified solution for nested transaction. if err = func() (err error) { var ierr error @@ -390,7 +430,7 @@ func (s *State) write(req *types.Request) (ref *QueryTracker, resp *types.Respon savepoint = s.getID() for i, v := range req.Payload.Queries { var res sql.Result - if res, ierr = s.writeSingle(&v); ierr != nil { + if res, ierr = s.writeSingle(ctx, &v); ierr != nil { err = errors.Wrapf(ierr, "execute at #%d failed", i) // Add to failed pool list s.pool.setFailed(req) @@ -426,7 +466,7 @@ func (s *State) write(req *types.Request) (ref *QueryTracker, resp *types.Respon return } -func (s *State) replay(req *types.Request, resp *types.Response) (err error) { +func (s *State) replay(ctx context.Context, req *types.Request, resp *types.Response) (err error) { var ( ierr error savepoint uint64 @@ -443,7 +483,7 @@ func (s *State) replay(req *types.Request, resp *types.Response) (err error) { return } for i, v := range req.Payload.Queries { - if _, ierr = s.writeSingle(&v); ierr != nil { + if _, ierr = s.writeSingle(ctx, &v); ierr != nil { err = errors.Wrapf(ierr, "execute at #%d failed", i) s.rollbackTo(savepoint) return @@ -457,6 +497,12 @@ func (s *State) replay(req *types.Request, resp *types.Response) (err error) { // ReplayBlock replays the queries from block. It also checks and skips some preceding pooled // queries. func (s *State) ReplayBlock(block *types.Block) (err error) { + return s.ReplayBlockWithContext(context.Background(), block) +} + +// ReplayBlockWithContext replays the queries from block with context. It also checks and +// skips some preceding pooled queries. +func (s *State) ReplayBlockWithContext(ctx context.Context, block *types.Block) (err error) { var ( ierr error lastsp uint64 // Last savepoint @@ -488,7 +534,7 @@ func (s *State) ReplayBlock(block *types.Block) (err error) { s.rollbackTo(lastsp) return } - if _, ierr = s.writeSingle(&v); ierr != nil { + if _, ierr = s.writeSingle(ctx, &v); ierr != nil { err = errors.Wrapf(ierr, "execute at %d:%d failed", i, j) s.rollbackTo(lastsp) return @@ -501,7 +547,7 @@ func (s *State) ReplayBlock(block *types.Block) (err error) { for _, r := range block.FailedReqs { s.pool.removeFailed(r) } - // Check if the current transaction is ok to commit + // Check if the current transaction is OK to commit if s.pool.matchLast(lastsp) { if err = s.uncCommit(); err != nil { // FATAL ERROR @@ -541,13 +587,21 @@ func (s *State) commit() (err error) { // CommitEx commits the current transaction and returns all the pooled queries. func (s *State) CommitEx() (failed []*types.Request, queries []*QueryTracker, err error) { + return s.CommitExWithContext(context.Background()) +} + +// CommitExWithContext commits the current transaction and returns all the pooled queries +// with context. +func (s *State) CommitExWithContext( + ctx context.Context) (failed []*types.Request, queries []*QueryTracker, err error, +) { s.Lock() defer s.Unlock() if err = s.uncCommit(); err != nil { // FATAL ERROR return } - if s.unc, err = s.strg.Writer().Begin(); err != nil { + if s.unc, err = s.strg.Writer().BeginTx(ctx, nil); err != nil { // FATAL ERROR return } @@ -575,7 +629,6 @@ func (s *State) rollback() (err error) { s.Lock() defer s.Unlock() s.rollbackTo(s.cmpoint) - s.current = s.cmpoint return } @@ -586,11 +639,19 @@ func (s *State) getLocalTime() time.Time { // Query does the query(ies) in req, pools the request and persists any change to // the underlying storage. func (s *State) Query(req *types.Request) (ref *QueryTracker, resp *types.Response, err error) { + return s.QueryWithContext(context.Background(), req) +} + +// QueryWithContext does the query(ies) in req, pools the request and persists any change to +// the underlying storage. +func (s *State) QueryWithContext( + ctx context.Context, req *types.Request) (ref *QueryTracker, resp *types.Response, err error, +) { switch req.Header.QueryType { case types.ReadQuery: - return s.readTx(req) + return s.readTx(ctx, req) case types.WriteQuery: - return s.write(req) + return s.write(ctx, req) default: err = ErrInvalidRequest } @@ -599,6 +660,13 @@ func (s *State) Query(req *types.Request) (ref *QueryTracker, resp *types.Respon // Replay replays a write log from other peer to replicate storage state. func (s *State) Replay(req *types.Request, resp *types.Response) (err error) { + return s.ReplayWithContext(context.Background(), req, resp) +} + +// ReplayWithContext replays a write log from other peer to replicate storage state with context. +func (s *State) ReplayWithContext( + ctx context.Context, req *types.Request, resp *types.Response) (err error, +) { // NOTE(leventeliu): in the current implementation, failed requests are not tracked in remote // nodes (while replaying via Replay calls). Because we don't want to actually replay read // queries in all synchronized nodes, meanwhile, whether a request will fail or not @@ -609,7 +677,7 @@ func (s *State) Replay(req *types.Request, resp *types.Response) (err error) { case types.ReadQuery: return case types.WriteQuery: - return s.replay(req, resp) + return s.replay(ctx, req, resp) default: err = ErrInvalidRequest }