Skip to content

Commit e6f3261

Browse files
authored
Merge pull request pkg#283 from pkg/open-on-open
Initialize opening of files/directories upon receiving open packets
2 parents a741a7f + 05f7f0a commit e6f3261

File tree

5 files changed

+142
-120
lines changed

5 files changed

+142
-120
lines changed

request-example.go

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"path/filepath"
1313
"sort"
1414
"sync"
15+
"syscall"
1516
"time"
1617
)
1718

@@ -29,6 +30,7 @@ func (fs *root) Fileread(r *Request) (io.ReaderAt, error) {
2930
if fs.mockErr != nil {
3031
return nil, fs.mockErr
3132
}
33+
_ = r.WithContext(r.Context()) // initialize context for deadlock testing
3234
fs.filesLock.Lock()
3335
defer fs.filesLock.Unlock()
3436
file, err := fs.fetch(r.Filepath)
@@ -48,6 +50,7 @@ func (fs *root) Filewrite(r *Request) (io.WriterAt, error) {
4850
if fs.mockErr != nil {
4951
return nil, fs.mockErr
5052
}
53+
_ = r.WithContext(r.Context()) // initialize context for deadlock testing
5154
fs.filesLock.Lock()
5255
defer fs.filesLock.Unlock()
5356
file, err := fs.fetch(r.Filepath)
@@ -69,6 +72,7 @@ func (fs *root) Filecmd(r *Request) error {
6972
if fs.mockErr != nil {
7073
return fs.mockErr
7174
}
75+
_ = r.WithContext(r.Context()) // initialize context for deadlock testing
7276
fs.filesLock.Lock()
7377
defer fs.filesLock.Unlock()
7478
switch r.Method {
@@ -129,11 +133,20 @@ func (fs *root) Filelist(r *Request) (ListerAt, error) {
129133
if fs.mockErr != nil {
130134
return nil, fs.mockErr
131135
}
136+
_ = r.WithContext(r.Context()) // initialize context for deadlock testing
132137
fs.filesLock.Lock()
133138
defer fs.filesLock.Unlock()
134139

140+
file, err := fs.fetch(r.Filepath)
141+
if err != nil {
142+
return nil, err
143+
}
144+
135145
switch r.Method {
136146
case "List":
147+
if !file.IsDir() {
148+
return nil, syscall.ENOTDIR
149+
}
137150
ordered_names := []string{}
138151
for fn, _ := range fs.files {
139152
if filepath.Dir(fn) == r.Filepath {
@@ -147,16 +160,8 @@ func (fs *root) Filelist(r *Request) (ListerAt, error) {
147160
}
148161
return listerat(list), nil
149162
case "Stat":
150-
file, err := fs.fetch(r.Filepath)
151-
if err != nil {
152-
return nil, err
153-
}
154163
return listerat([]os.FileInfo{file}), nil
155164
case "Readlink":
156-
file, err := fs.fetch(r.Filepath)
157-
if err != nil {
158-
return nil, err
159-
}
160165
if file.symlink != "" {
161166
file, err = fs.fetch(file.symlink)
162167
if err != nil {

request-server.go

Lines changed: 19 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package sftp
33
import (
44
"context"
55
"io"
6-
"os"
76
"path"
87
"path/filepath"
98
"strconv"
@@ -56,37 +55,24 @@ func (rs *RequestServer) nextRequest(r *Request) string {
5655
defer rs.openRequestLock.Unlock()
5756
rs.handleCount++
5857
handle := strconv.Itoa(rs.handleCount)
58+
r.handle = handle
5959
rs.openRequests[handle] = r
6060
return handle
6161
}
6262

63-
// Returns Request from openRequests, bool is false if it is missing
64-
// If the method is different, save/return a new Request w/ that Method.
63+
// Returns Request from openRequests, bool is false if it is missing.
6564
//
6665
// The Requests in openRequests work essentially as open file descriptors that
6766
// you can do different things with. What you are doing with it are denoted by
68-
// the first packet of that type (read/write/etc). We create a new Request when
69-
// it changes to set the request.Method attribute in a thread safe way.
70-
func (rs *RequestServer) getRequest(handle, method string) (*Request, bool) {
67+
// the first packet of that type (read/write/etc).
68+
func (rs *RequestServer) getRequest(handle string) (*Request, bool) {
7169
rs.openRequestLock.RLock()
70+
defer rs.openRequestLock.RUnlock()
7271
r, ok := rs.openRequests[handle]
73-
rs.openRequestLock.RUnlock()
74-
if !ok || r.Method == method {
75-
return r, ok
76-
}
77-
// if we make it here we need to replace the request
78-
rs.openRequestLock.Lock()
79-
defer rs.openRequestLock.Unlock()
80-
r, ok = rs.openRequests[handle]
81-
if !ok || r.Method == method { // re-check needed b/c lock race
82-
return r, ok
83-
}
84-
r = r.copy()
85-
r.Method = method
86-
rs.openRequests[handle] = r
8772
return r, ok
8873
}
8974

75+
// Close the Request and clear from openRequests map
9076
func (rs *RequestServer) closeRequest(handle string) error {
9177
rs.openRequestLock.Lock()
9278
defer rs.openRequestLock.Unlock()
@@ -173,28 +159,24 @@ func (rs *RequestServer) packetWorker(
173159
rpkt = cleanPacketPath(pkt)
174160
case *sshFxpOpendirPacket:
175161
request := requestFromPacket(ctx, pkt)
176-
rpkt = request.call(rs.Handlers, pkt)
177-
if stat, ok := rpkt.(*sshFxpStatResponse); ok {
178-
if stat.info.IsDir() {
179-
handle := rs.nextRequest(request)
180-
rpkt = sshFxpHandlePacket{ID: pkt.id(), Handle: handle}
181-
} else {
182-
rpkt = statusFromError(pkt, &os.PathError{
183-
Path: request.Filepath, Err: syscall.ENOTDIR})
184-
}
185-
}
162+
rs.nextRequest(request)
163+
rpkt = request.opendir(rs.Handlers, pkt)
186164
case *sshFxpOpenPacket:
187165
request := requestFromPacket(ctx, pkt)
188-
handle := rs.nextRequest(request)
189-
rpkt = sshFxpHandlePacket{ID: pkt.id(), Handle: handle}
190-
if pkt.hasPflags(ssh_FXF_CREAT) {
191-
if p := request.call(rs.Handlers, pkt); !statusOk(p) {
192-
rpkt = p // if error in write, return it
193-
}
166+
rs.nextRequest(request)
167+
rpkt = request.open(rs.Handlers, pkt)
168+
case *sshFxpFstatPacket:
169+
handle := pkt.getHandle()
170+
request, ok := rs.getRequest(handle)
171+
if !ok {
172+
rpkt = statusFromError(pkt, syscall.EBADF)
173+
} else {
174+
request = NewRequest("Stat", request.Filepath)
175+
rpkt = request.call(rs.Handlers, pkt)
194176
}
195177
case hasHandle:
196178
handle := pkt.getHandle()
197-
request, ok := rs.getRequest(handle, requestMethod(pkt))
179+
request, ok := rs.getRequest(handle)
198180
if !ok {
199181
rpkt = statusFromError(pkt, syscall.EBADF)
200182
} else {
@@ -214,12 +196,6 @@ func (rs *RequestServer) packetWorker(
214196
return nil
215197
}
216198

217-
// True is responsePacket is an OK status packet
218-
func statusOk(rpkt responsePacket) bool {
219-
p, ok := rpkt.(sshFxpStatusPacket)
220-
return ok && p.StatusError.Code == ssh_FX_OK
221-
}
222-
223199
// clean and return name packet for file
224200
func cleanPacketPath(pkt *sshFxpRealpathPacket) responsePacket {
225201
path := cleanPath(pkt.getPath())

request-server_test.go

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func TestRequestCache(t *testing.T) {
8484
fh := p.svr.nextRequest(foo)
8585
bh := p.svr.nextRequest(bar)
8686
assert.Len(t, p.svr.openRequests, 2)
87-
_foo, ok := p.svr.getRequest(fh, "")
87+
_foo, ok := p.svr.getRequest(fh)
8888
assert.Equal(t, foo.Method, _foo.Method)
8989
assert.Equal(t, foo.Filepath, _foo.Filepath)
9090
assert.Equal(t, foo.Target, _foo.Target)
@@ -94,7 +94,7 @@ func TestRequestCache(t *testing.T) {
9494
assert.NotNil(t, _foo.ctx)
9595
assert.Equal(t, _foo.Context().Err(), nil, "context is still valid")
9696
assert.True(t, ok)
97-
_, ok = p.svr.getRequest("zed", "")
97+
_, ok = p.svr.getRequest("zed")
9898
assert.False(t, ok)
9999
p.svr.closeRequest(fh)
100100
assert.Equal(t, _foo.Context().Err(), context.Canceled, "context is now canceled")
@@ -147,7 +147,7 @@ func TestRequestWriteEmpty(t *testing.T) {
147147
f, err := r.fetch("/foo")
148148
if assert.Nil(t, err) {
149149
assert.False(t, f.isdir)
150-
assert.Equal(t, f.content, []byte(""))
150+
assert.Len(t, f.content, 0)
151151
}
152152
// lets test with an error
153153
r.returnErr(os.ErrInvalid)
@@ -170,7 +170,7 @@ func TestRequestFilename(t *testing.T) {
170170
assert.Error(t, err)
171171
}
172172

173-
func TestRequestRead(t *testing.T) {
173+
func TestRequestJustRead(t *testing.T) {
174174
p := clientRequestServerPair(t)
175175
defer p.Close()
176176
_, err := putTestFile(p.cli, "/foo", "hello")
@@ -187,21 +187,18 @@ func TestRequestRead(t *testing.T) {
187187
assert.Equal(t, "hello", string(contents[0:5]))
188188
}
189189

190-
func TestRequestReadFail(t *testing.T) {
190+
func TestRequestOpenFail(t *testing.T) {
191191
p := clientRequestServerPair(t)
192192
defer p.Close()
193193
rf, err := p.cli.Open("/foo")
194-
assert.Nil(t, err)
195-
contents := make([]byte, 5)
196-
n, err := rf.Read(contents)
197-
assert.Equal(t, n, 0)
198194
assert.Exactly(t, os.ErrNotExist, err)
195+
assert.Nil(t, rf)
199196
}
200197

201-
func TestRequestOpen(t *testing.T) {
198+
func TestRequestCreate(t *testing.T) {
202199
p := clientRequestServerPair(t)
203200
defer p.Close()
204-
fh, err := p.cli.Open("foo")
201+
fh, err := p.cli.Create("foo")
205202
assert.Nil(t, err)
206203
err = fh.Close()
207204
assert.Nil(t, err)
@@ -354,7 +351,9 @@ func TestRequestReaddir(t *testing.T) {
354351
for i := 0; i < 100; i++ {
355352
fname := fmt.Sprintf("/foo_%02d", i)
356353
_, err := putTestFile(p.cli, fname, fname)
357-
assert.Nil(t, err)
354+
if err != nil {
355+
t.Fatal("expected no error, got:", err)
356+
}
358357
}
359358
_, err := p.cli.ReadDir("/foo_01")
360359
assert.Equal(t, &StatusError{Code: ssh_FX_FAILURE,

0 commit comments

Comments
 (0)