diff --git a/.github/workflows/.editorconfig b/.github/workflows/.editorconfig new file mode 100644 index 00000000..7bd3346f --- /dev/null +++ b/.github/workflows/.editorconfig @@ -0,0 +1,2 @@ +[*.yml] +indent_size = 2 diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml new file mode 100644 index 00000000..2671aaa1 --- /dev/null +++ b/.github/workflows/CI.yml @@ -0,0 +1,38 @@ +name: CI + +on: + push: + branches: [master] + pull_request: + +jobs: + run-tests: + name: Run test cases + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest] + go: ['1.24', '1.23'] + exclude: + - os: macos-latest + go: '1.24' + + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go }} + + - name: Run tests + run: | + make integration + make integration_w_race + + - name: Run tests on 32-bit arch + if: startsWith(matrix.os, 'ubuntu-') + run: | + make integration + env: + GOARCH: 386 diff --git a/.github/workflows/cifuzz.yml b/.github/workflows/cifuzz.yml new file mode 100644 index 00000000..e3e91c87 --- /dev/null +++ b/.github/workflows/cifuzz.yml @@ -0,0 +1,26 @@ +name: CIFuzz +on: [pull_request] +jobs: + Fuzzing: + runs-on: ubuntu-latest + steps: + - name: Build Fuzzers + id: build + uses: google/oss-fuzz/infra/cifuzz/actions/build_fuzzers@master + with: + oss-fuzz-project-name: 'go-sftp' + dry-run: false + language: go + - name: Run Fuzzers + uses: google/oss-fuzz/infra/cifuzz/actions/run_fuzzers@master + with: + oss-fuzz-project-name: 'go-sftp' + fuzz-seconds: 300 + dry-run: false + language: go + - name: Upload Crash + uses: actions/upload-artifact@v4 + if: failure() && steps.build.outcome == 'success' + with: + name: artifacts + path: ./out/artifacts diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..caf2dca2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +.*.swo +.*.swp + +server_standalone/server_standalone + +examples/*/id_rsa +examples/*/id_rsa.pub + +memprofile.out +memprofile.svg diff --git a/CONTRIBUTORS b/CONTRIBUTORS index 7eff8236..5c7196ae 100644 --- a/CONTRIBUTORS +++ b/CONTRIBUTORS @@ -1,2 +1,3 @@ Dave Cheney Saulius Gurklys +John Eikenberry diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..4d3a0079 --- /dev/null +++ b/Makefile @@ -0,0 +1,27 @@ +.PHONY: integration integration_w_race benchmark + +integration: + go test -integration -v ./... + go test -testserver -v ./... + go test -integration -testserver -v ./... + go test -integration -allocator -v ./... + go test -testserver -allocator -v ./... + go test -integration -testserver -allocator -v ./... + +integration_w_race: + go test -race -integration -v ./... + go test -race -testserver -v ./... + go test -race -integration -testserver -v ./... + go test -race -integration -allocator -v ./... + go test -race -testserver -allocator -v ./... + go test -race -integration -allocator -testserver -v ./... + +COUNT ?= 1 +BENCHMARK_PATTERN ?= "." + +benchmark: + go test -integration -run=NONE -bench=$(BENCHMARK_PATTERN) -benchmem -count=$(COUNT) + +benchmark_w_memprofile: + go test -integration -run=NONE -bench=$(BENCHMARK_PATTERN) -benchmem -count=$(COUNT) -memprofile memprofile.out + go tool pprof -svg -output=memprofile.svg memprofile.out diff --git a/README.md b/README.md index d7058eb6..5e78cd39 100644 --- a/README.md +++ b/README.md @@ -1,27 +1,44 @@ sftp ---- -The `sftp` package provides support for file system operations on remote ssh servers using the SFTP subsystem. +The `sftp` package provides support for file system operations on remote ssh +servers using the SFTP subsystem. It also implements an SFTP server for serving +files from the filesystem. -[![Build Status](https://drone.io/github.com/pkg/sftp/status.png)](https://drone.io/github.com/pkg/sftp/latest) +![CI Status](https://github.com/pkg/sftp/workflows/CI/badge.svg?branch=master&event=push) [![Go Reference](https://pkg.go.dev/badge/github.com/pkg/sftp.svg)](https://pkg.go.dev/github.com/pkg/sftp) usage and examples ------------------ -See [godoc.org/github.com/pkg/sftp](http://godoc.org/github.com/pkg/sftp) for examples and usage. +See [https://pkg.go.dev/github.com/pkg/sftp](https://pkg.go.dev/github.com/pkg/sftp) for +examples and usage. -The basic operation of the package mirrors the facilities of the [os](http://golang.org/pkg/os) package. +The basic operation of the package mirrors the facilities of the +[os](http://golang.org/pkg/os) package. -The Walker interface for directory traversal is heavily inspired by Keith Rarick's [fs](http://godoc.org/github.com/kr/fs) package. +The Walker interface for directory traversal is heavily inspired by Keith +Rarick's [fs](https://pkg.go.dev/github.com/kr/fs) package. roadmap ------- - * Currently all traffic with the server is serialized, this can be improved by allowing overlapping requests/responses. - * There is way too much duplication in the Client methods. If there was an unmarshal(interface{}) method this would reduce a heap of the duplication. - * Implement integration tests by talking directly to a real opensftp-server process. This shouldn't be too difficult to implement with a small refactoring to the sftp.NewClient method. These tests should be gated on an -sftp.integration test flag. _in progress_ +* There is way too much duplication in the Client methods. If there was an + unmarshal(interface{}) method this would reduce a heap of the duplication. contributing ------------ -Features, Issues, and Pull Requests are always welcome. +We welcome pull requests, bug fixes and issue reports. + +Before proposing a large change, first please discuss your change by raising an +issue. + +For API/code bugs, please include a small, self contained code example to +reproduce the issue. For pull requests, remember test coverage. + +We try to handle issues and pull requests with a 0 open philosophy. That means +we will try to address the submission as soon as possible and will work toward +a resolution. If progress can no longer be made (eg. unreproducible bug) or +stops (eg. unresponsive submitter), we will close the bug. + +Thanks. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 00000000..9ebc2631 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,13 @@ +# Security Policy + +## Supported Versions + +Security updates are provided for the latest released version of this package. +We also welcome vulnerability reports for the development version to help us ensure it is secure before the next release. + +## Reporting a Vulnerability + +If you believe you’ve found a security vulnerability in this project, we strongly encourage you to report it privately using GitHub’s [security advisory system](https://github.com/pkg/sftp/security/advisories/new). +This will allow us to review and address the issue before public disclosure. + +Thank you for helping us keep the project secure. diff --git a/allocator.go b/allocator.go new file mode 100644 index 00000000..3e67e543 --- /dev/null +++ b/allocator.go @@ -0,0 +1,96 @@ +package sftp + +import ( + "sync" +) + +type allocator struct { + sync.Mutex + available [][]byte + // map key is the request order + used map[uint32][][]byte +} + +func newAllocator() *allocator { + return &allocator{ + // micro optimization: initialize available pages with an initial capacity + available: make([][]byte, 0, SftpServerWorkerCount*2), + used: make(map[uint32][][]byte), + } +} + +// GetPage returns a previously allocated and unused []byte or create a new one. +// The slice have a fixed size = maxMsgLength, this value is suitable for both +// receiving new packets and reading the files to serve +func (a *allocator) GetPage(requestOrderID uint32) []byte { + a.Lock() + defer a.Unlock() + + var result []byte + + // get an available page and remove it from the available ones. + if len(a.available) > 0 { + truncLength := len(a.available) - 1 + result = a.available[truncLength] + + a.available[truncLength] = nil // clear out the internal pointer + a.available = a.available[:truncLength] // truncate the slice + } + + // no preallocated slice found, just allocate a new one + if result == nil { + result = make([]byte, maxMsgLength) + } + + // put result in used pages + a.used[requestOrderID] = append(a.used[requestOrderID], result) + + return result +} + +// ReleasePages marks unused all pages in use for the given requestID +func (a *allocator) ReleasePages(requestOrderID uint32) { + a.Lock() + defer a.Unlock() + + if used := a.used[requestOrderID]; len(used) > 0 { + a.available = append(a.available, used...) + } + delete(a.used, requestOrderID) +} + +// Free removes all the used and available pages. +// Call this method when the allocator is not needed anymore +func (a *allocator) Free() { + a.Lock() + defer a.Unlock() + + a.available = nil + a.used = make(map[uint32][][]byte) +} + +func (a *allocator) countUsedPages() int { + a.Lock() + defer a.Unlock() + + num := 0 + for _, p := range a.used { + num += len(p) + } + return num +} + +func (a *allocator) countAvailablePages() int { + a.Lock() + defer a.Unlock() + + return len(a.available) +} + +func (a *allocator) isRequestOrderIDUsed(requestOrderID uint32) bool { + a.Lock() + defer a.Unlock() + + _, ok := a.used[requestOrderID] + return ok +} diff --git a/allocator_test.go b/allocator_test.go new file mode 100644 index 00000000..74f4da1a --- /dev/null +++ b/allocator_test.go @@ -0,0 +1,135 @@ +package sftp + +import ( + "strconv" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAllocator(t *testing.T) { + allocator := newAllocator() + // get a page for request order id 1 + page := allocator.GetPage(1) + page[1] = uint8(1) + assert.Equal(t, maxMsgLength, len(page)) + assert.Equal(t, 1, allocator.countUsedPages()) + // get another page for request order id 1, we now have 2 used pages + page = allocator.GetPage(1) + page[0] = uint8(2) + assert.Equal(t, 2, allocator.countUsedPages()) + // get another page for request order id 1, we now have 3 used pages + page = allocator.GetPage(1) + page[2] = uint8(3) + assert.Equal(t, 3, allocator.countUsedPages()) + // release the page for request order id 1, we now have 3 available pages + allocator.ReleasePages(1) + assert.NotContains(t, allocator.used, 1) + assert.Equal(t, 3, allocator.countAvailablePages()) + // get a page for request order id 2 + // we get the latest released page, let's verify that by checking the previously written values + // so we are sure we are reusing a previously allocated page + page = allocator.GetPage(2) + assert.Equal(t, uint8(3), page[2]) + assert.Equal(t, 2, allocator.countAvailablePages()) + assert.Equal(t, 1, allocator.countUsedPages()) + page = allocator.GetPage(2) + assert.Equal(t, uint8(2), page[0]) + assert.Equal(t, 1, allocator.countAvailablePages()) + assert.Equal(t, 2, allocator.countUsedPages()) + page = allocator.GetPage(2) + assert.Equal(t, uint8(1), page[1]) + // we now have 3 used pages for request order id 2 and no available pages + assert.Equal(t, 0, allocator.countAvailablePages()) + assert.Equal(t, 3, allocator.countUsedPages()) + assert.True(t, allocator.isRequestOrderIDUsed(2), "page with request order id 2 must be used") + assert.False(t, allocator.isRequestOrderIDUsed(1), "page with request order id 1 must be not used") + // release some request order id with no allocated pages, should have no effect + allocator.ReleasePages(1) + allocator.ReleasePages(3) + assert.Equal(t, 0, allocator.countAvailablePages()) + assert.Equal(t, 3, allocator.countUsedPages()) + assert.True(t, allocator.isRequestOrderIDUsed(2), "page with request order id 2 must be used") + assert.False(t, allocator.isRequestOrderIDUsed(1), "page with request order id 1 must be not used") + // now get some pages for another request order id + allocator.GetPage(3) + // we now must have 3 used pages for request order id 2 and 1 used page for request order id 3 + assert.Equal(t, 0, allocator.countAvailablePages()) + assert.Equal(t, 4, allocator.countUsedPages()) + assert.True(t, allocator.isRequestOrderIDUsed(2), "page with request order id 2 must be used") + assert.True(t, allocator.isRequestOrderIDUsed(3), "page with request order id 3 must be used") + assert.False(t, allocator.isRequestOrderIDUsed(1), "page with request order id 1 must be not used") + // get another page for request order id 3 + allocator.GetPage(3) + assert.Equal(t, 0, allocator.countAvailablePages()) + assert.Equal(t, 5, allocator.countUsedPages()) + assert.True(t, allocator.isRequestOrderIDUsed(2), "page with request order id 2 must be used") + assert.True(t, allocator.isRequestOrderIDUsed(3), "page with request order id 3 must be used") + assert.False(t, allocator.isRequestOrderIDUsed(1), "page with request order id 1 must be not used") + // now release the pages for request order id 3 + allocator.ReleasePages(3) + assert.Equal(t, 2, allocator.countAvailablePages()) + assert.Equal(t, 3, allocator.countUsedPages()) + assert.True(t, allocator.isRequestOrderIDUsed(2), "page with request order id 2 must be used") + assert.False(t, allocator.isRequestOrderIDUsed(1), "page with request order id 1 must be not used") + assert.False(t, allocator.isRequestOrderIDUsed(3), "page with request order id 3 must be not used") + // again check we are reusing previously allocated pages. + // We have written nothing to the 2 last requested page so release them and get the third one + allocator.ReleasePages(2) + assert.Equal(t, 5, allocator.countAvailablePages()) + assert.Equal(t, 0, allocator.countUsedPages()) + assert.False(t, allocator.isRequestOrderIDUsed(2), "page with request order id 2 must be not used") + allocator.GetPage(4) + allocator.GetPage(4) + page = allocator.GetPage(4) + assert.Equal(t, uint8(3), page[2]) + assert.Equal(t, 2, allocator.countAvailablePages()) + assert.Equal(t, 3, allocator.countUsedPages()) + assert.True(t, allocator.isRequestOrderIDUsed(4), "page with request order id 4 must be used") + // free the allocator + allocator.Free() + assert.Equal(t, 0, allocator.countAvailablePages()) + assert.Equal(t, 0, allocator.countUsedPages()) +} + +func BenchmarkAllocatorSerial(b *testing.B) { + allocator := newAllocator() + for i := 0; i < b.N; i++ { + benchAllocator(allocator, uint32(i)) + } +} + +func BenchmarkAllocatorParallel(b *testing.B) { + var counter uint32 + allocator := newAllocator() + for i := 1; i <= 8; i *= 2 { + b.Run(strconv.Itoa(i), func(b *testing.B) { + b.SetParallelism(i) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + benchAllocator(allocator, atomic.AddUint32(&counter, 1)) + } + }) + }) + } +} + +func benchAllocator(allocator *allocator, requestOrderID uint32) { + // simulates the page requested in recvPacket + allocator.GetPage(requestOrderID) + // simulates the page requested in fileget for downloads + allocator.GetPage(requestOrderID) + // release the allocated pages + allocator.ReleasePages(requestOrderID) +} + +// useful for debug +func printAllocatorContents(allocator *allocator) { + for o, u := range allocator.used { + debug("used order id: %v, values: %+v", o, u) + } + for _, v := range allocator.available { + debug("available, values: %+v", v) + } +} diff --git a/attrs.go b/attrs.go index 0d37db08..74ac03b7 100644 --- a/attrs.go +++ b/attrs.go @@ -1,138 +1,136 @@ package sftp // ssh_FXP_ATTRS support -// see http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02#section-5 +// see https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt#section-5 import ( "os" - "syscall" "time" ) const ( - ssh_FILEXFER_ATTR_SIZE = 0x00000001 - ssh_FILEXFER_ATTR_UIDGID = 0x00000002 - ssh_FILEXFER_ATTR_PERMISSIONS = 0x00000004 - ssh_FILEXFER_ATTR_ACMODTIME = 0x00000008 - ssh_FILEXFER_ATTR_EXTENDED = 0x80000000 + sshFileXferAttrSize = 0x00000001 + sshFileXferAttrUIDGID = 0x00000002 + sshFileXferAttrPermissions = 0x00000004 + sshFileXferAttrACmodTime = 0x00000008 + sshFileXferAttrExtended = 0x80000000 + + sshFileXferAttrAll = sshFileXferAttrSize | sshFileXferAttrUIDGID | sshFileXferAttrPermissions | + sshFileXferAttrACmodTime | sshFileXferAttrExtended ) // fileInfo is an artificial type designed to satisfy os.FileInfo. type fileInfo struct { - name string - size int64 - mode os.FileMode - mtime time.Time - sys interface{} + name string + stat *FileStat } // Name returns the base name of the file. func (fi *fileInfo) Name() string { return fi.name } // Size returns the length in bytes for regular files; system-dependent for others. -func (fi *fileInfo) Size() int64 { return fi.size } +func (fi *fileInfo) Size() int64 { return int64(fi.stat.Size) } // Mode returns file mode bits. -func (fi *fileInfo) Mode() os.FileMode { return fi.mode } +func (fi *fileInfo) Mode() os.FileMode { return fi.stat.FileMode() } // ModTime returns the last modification time of the file. -func (fi *fileInfo) ModTime() time.Time { return fi.mtime } +func (fi *fileInfo) ModTime() time.Time { return fi.stat.ModTime() } // IsDir returns true if the file is a directory. func (fi *fileInfo) IsDir() bool { return fi.Mode().IsDir() } -func (fi *fileInfo) Sys() interface{} { return fi.sys } +func (fi *fileInfo) Sys() interface{} { return fi.stat } -// FileStat holds the original unmarshalled values from a call to READDIR or *STAT. -// It is exported for the purposes of accessing the raw values via os.FileInfo.Sys() +// FileStat holds the original unmarshalled values from a call to READDIR or +// *STAT. It is exported for the purposes of accessing the raw values via +// os.FileInfo.Sys(). It is also used server side to store the unmarshalled +// values for SetStat. type FileStat struct { Size uint64 Mode uint32 Mtime uint32 Atime uint32 - Uid uint32 - Gid uint32 + UID uint32 + GID uint32 Extended []StatExtended } +// ModTime returns the Mtime SFTP file attribute converted to a time.Time +func (fs *FileStat) ModTime() time.Time { + return time.Unix(int64(fs.Mtime), 0) +} + +// AccessTime returns the Atime SFTP file attribute converted to a time.Time +func (fs *FileStat) AccessTime() time.Time { + return time.Unix(int64(fs.Atime), 0) +} + +// FileMode returns the Mode SFTP file attribute converted to an os.FileMode +func (fs *FileStat) FileMode() os.FileMode { + return toFileMode(fs.Mode) +} + +// StatExtended contains additional, extended information for a FileStat. type StatExtended struct { ExtType string ExtData string } -func fileInfoFromStat(st *FileStat, name string) os.FileInfo { - fs := &fileInfo{ - name: name, - size: int64(st.Size), - mode: toFileMode(st.Mode), - mtime: time.Unix(int64(st.Mtime), 0), - sys: st, +func fileInfoFromStat(stat *FileStat, name string) os.FileInfo { + return &fileInfo{ + name: name, + stat: stat, } - return fs } -func unmarshalAttrs(b []byte) (*FileStat, []byte) { - flags, b := unmarshalUint32(b) - var fs FileStat - if flags&ssh_FILEXFER_ATTR_SIZE == ssh_FILEXFER_ATTR_SIZE { - fs.Size, b = unmarshalUint64(b) - } - if flags&ssh_FILEXFER_ATTR_UIDGID == ssh_FILEXFER_ATTR_UIDGID { - fs.Uid, b = unmarshalUint32(b) - } - if flags&ssh_FILEXFER_ATTR_UIDGID == ssh_FILEXFER_ATTR_UIDGID { - fs.Gid, b = unmarshalUint32(b) - } - if flags&ssh_FILEXFER_ATTR_PERMISSIONS == ssh_FILEXFER_ATTR_PERMISSIONS { - fs.Mode, b = unmarshalUint32(b) - } - if flags&ssh_FILEXFER_ATTR_ACMODTIME == ssh_FILEXFER_ATTR_ACMODTIME { - fs.Atime, b = unmarshalUint32(b) - fs.Mtime, b = unmarshalUint32(b) - } - if flags&ssh_FILEXFER_ATTR_EXTENDED == ssh_FILEXFER_ATTR_EXTENDED { - var count uint32 - count, b = unmarshalUint32(b) - ext := make([]StatExtended, count, count) - for i := uint32(0); i < count; i++ { - var typ string - var data string - typ, b = unmarshalString(b) - data, b = unmarshalString(b) - ext[i] = StatExtended{typ, data} - } - fs.Extended = ext - } - return &fs, b +// FileInfoUidGid extends os.FileInfo and adds callbacks for Uid and Gid retrieval, +// as an alternative to *syscall.Stat_t objects on unix systems. +type FileInfoUidGid interface { + os.FileInfo + Uid() uint32 + Gid() uint32 } -// toFileMode converts sftp filemode bits to the os.FileMode specification -func toFileMode(mode uint32) os.FileMode { - var fm = os.FileMode(mode & 0777) - switch mode & syscall.S_IFMT { - case syscall.S_IFBLK: - fm |= os.ModeDevice - case syscall.S_IFCHR: - fm |= os.ModeDevice | os.ModeCharDevice - case syscall.S_IFDIR: - fm |= os.ModeDir - case syscall.S_IFIFO: - fm |= os.ModeNamedPipe - case syscall.S_IFLNK: - fm |= os.ModeSymlink - case syscall.S_IFREG: - // nothing to do - case syscall.S_IFSOCK: - fm |= os.ModeSocket - } - if mode&syscall.S_ISGID != 0 { - fm |= os.ModeSetgid +// FileInfoUidGid extends os.FileInfo and adds a callbacks for extended data retrieval. +type FileInfoExtendedData interface { + os.FileInfo + Extended() []StatExtended +} + +func fileStatFromInfo(fi os.FileInfo) (uint32, *FileStat) { + mtime := fi.ModTime().Unix() + atime := mtime + var flags uint32 = sshFileXferAttrSize | + sshFileXferAttrPermissions | + sshFileXferAttrACmodTime + + fileStat := &FileStat{ + Size: uint64(fi.Size()), + Mode: fromFileMode(fi.Mode()), + Mtime: uint32(mtime), + Atime: uint32(atime), } - if mode&syscall.S_ISUID != 0 { - fm |= os.ModeSetuid + + // os specific file stat decoding + fileStatFromInfoOs(fi, &flags, fileStat) + + // The call above will include the sshFileXferAttrUIDGID in case + // the os.FileInfo can be casted to *syscall.Stat_t on unix. + // If fi implements FileInfoUidGid, retrieve Uid, Gid from it instead. + if fiExt, ok := fi.(FileInfoUidGid); ok { + flags |= sshFileXferAttrUIDGID + fileStat.UID = fiExt.Uid() + fileStat.GID = fiExt.Gid() } - if mode&syscall.S_ISVTX != 0 { - fm |= os.ModeSticky + + // if fi implements FileInfoExtendedData, retrieve extended data from it + if fiExt, ok := fi.(FileInfoExtendedData); ok { + fileStat.Extended = fiExt.Extended() + if len(fileStat.Extended) > 0 { + flags |= sshFileXferAttrExtended + } } - return fm + + return flags, fileStat } diff --git a/attrs_stubs.go b/attrs_stubs.go new file mode 100644 index 00000000..d20348c1 --- /dev/null +++ b/attrs_stubs.go @@ -0,0 +1,12 @@ +//go:build plan9 || windows || android +// +build plan9 windows android + +package sftp + +import ( + "os" +) + +func fileStatFromInfoOs(fi os.FileInfo, flags *uint32, fileStat *FileStat) { + // todo +} diff --git a/attrs_test.go b/attrs_test.go index d5529059..a755df62 100644 --- a/attrs_test.go +++ b/attrs_test.go @@ -1,45 +1,8 @@ package sftp import ( - "bytes" "os" - "reflect" - "testing" - "time" ) // ensure that attrs implemenst os.FileInfo var _ os.FileInfo = new(fileInfo) - -var unmarshalAttrsTests = []struct { - b []byte - want *fileInfo - rest []byte -}{ - {marshal(nil, struct{ Flags uint32 }{}), &fileInfo{mtime: time.Unix(int64(0), 0)}, nil}, - {marshal(nil, struct { - Flags uint32 - Size uint64 - }{ssh_FILEXFER_ATTR_SIZE, 20}), &fileInfo{size: 20, mtime: time.Unix(int64(0), 0)}, nil}, - {marshal(nil, struct { - Flags uint32 - Size uint64 - Permissions uint32 - }{ssh_FILEXFER_ATTR_SIZE | ssh_FILEXFER_ATTR_PERMISSIONS, 20, 0644}), &fileInfo{size: 20, mode: os.FileMode(0644), mtime: time.Unix(int64(0), 0)}, nil}, - {marshal(nil, struct { - Flags uint32 - Size uint64 - Uid, Gid, Permissions uint32 - }{ssh_FILEXFER_ATTR_SIZE | ssh_FILEXFER_ATTR_UIDGID | ssh_FILEXFER_ATTR_UIDGID | ssh_FILEXFER_ATTR_PERMISSIONS, 20, 1000, 1000, 0644}), &fileInfo{size: 20, mode: os.FileMode(0644), mtime: time.Unix(int64(0), 0)}, nil}, -} - -func TestUnmarshalAttrs(t *testing.T) { - for _, tt := range unmarshalAttrsTests { - stat, rest := unmarshalAttrs(tt.b) - got := fileInfoFromStat(stat, "") - tt.want.sys = got.Sys() - if !reflect.DeepEqual(got, tt.want) || !bytes.Equal(tt.rest, rest) { - t.Errorf("unmarshalAttrs(%#v): want %#v, %#v, got: %#v, %#v", tt.b, tt.want, tt.rest, got, rest) - } - } -} diff --git a/attrs_unix.go b/attrs_unix.go new file mode 100644 index 00000000..96ffc03d --- /dev/null +++ b/attrs_unix.go @@ -0,0 +1,17 @@ +//go:build darwin || dragonfly || freebsd || (!android && linux) || netbsd || openbsd || solaris || aix || js || zos +// +build darwin dragonfly freebsd !android,linux netbsd openbsd solaris aix js zos + +package sftp + +import ( + "os" + "syscall" +) + +func fileStatFromInfoOs(fi os.FileInfo, flags *uint32, fileStat *FileStat) { + if statt, ok := fi.Sys().(*syscall.Stat_t); ok { + *flags |= sshFileXferAttrUIDGID + fileStat.UID = statt.Uid + fileStat.GID = statt.Gid + } +} diff --git a/client.go b/client.go index f48b0be7..307a35ea 100644 --- a/client.go +++ b/client.go @@ -1,27 +1,205 @@ package sftp import ( - "encoding" + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" "io" + "math" "os" "path" "sync" + "sync/atomic" + "syscall" "time" "github.com/kr/fs" - "golang.org/x/crypto/ssh" + + "github.com/pkg/sftp/internal/encoding/ssh/filexfer/openssh" +) + +var ( + // ErrInternalInconsistency indicates the packets sent and the data queued to be + // written to the file don't match up. It is an unusual error and usually is + // caused by bad behavior server side or connection issues. The error is + // limited in scope to the call where it happened, the client object is still + // OK to use as long as the connection is still open. + ErrInternalInconsistency = errors.New("internal inconsistency") + // InternalInconsistency alias for ErrInternalInconsistency. + // + // Deprecated: please use ErrInternalInconsistency + InternalInconsistency = ErrInternalInconsistency ) -// New creates a new SFTP client on conn. -func NewClient(conn *ssh.Client) (*Client, error) { +// A ClientOption is a function which applies configuration to a Client. +type ClientOption func(*Client) error + +// MaxPacketChecked sets the maximum size of the payload, measured in bytes. +// This option only accepts sizes servers should support, ie. <= 32768 bytes. +// +// If you get the error "failed to send packet header: EOF" when copying a +// large file, try lowering this number. +// +// The default packet size is 32768 bytes. +func MaxPacketChecked(size int) ClientOption { + return func(c *Client) error { + if size < 1 { + return errors.New("size must be greater or equal to 1") + } + if size > 32768 { + return errors.New("sizes larger than 32KB might not work with all servers") + } + c.maxPacket = size + return nil + } +} + +// MaxPacketUnchecked sets the maximum size of the payload, measured in bytes. +// It accepts sizes larger than the 32768 bytes all servers should support. +// Only use a setting higher than 32768 if your application always connects to +// the same server or after sufficiently broad testing. +// +// If you get the error "failed to send packet header: EOF" when copying a +// large file, try lowering this number. +// +// The default packet size is 32768 bytes. +func MaxPacketUnchecked(size int) ClientOption { + return func(c *Client) error { + if size < 1 { + return errors.New("size must be greater or equal to 1") + } + c.maxPacket = size + return nil + } +} + +// MaxPacket sets the maximum size of the payload, measured in bytes. +// This option only accepts sizes servers should support, ie. <= 32768 bytes. +// This is a synonym for MaxPacketChecked that provides backward compatibility. +// +// If you get the error "failed to send packet header: EOF" when copying a +// large file, try lowering this number. +// +// The default packet size is 32768 bytes. +func MaxPacket(size int) ClientOption { + return MaxPacketChecked(size) +} + +// MaxConcurrentRequestsPerFile sets the maximum concurrent requests allowed for a single file. +// +// The default maximum concurrent requests is 64. +func MaxConcurrentRequestsPerFile(n int) ClientOption { + return func(c *Client) error { + if n < 1 { + return errors.New("n must be greater or equal to 1") + } + c.maxConcurrentRequests = n + return nil + } +} + +// UseConcurrentWrites allows the Client to perform concurrent Writes. +// +// Using concurrency while doing writes, requires special consideration. +// A write to a later offset in a file after an error, +// could end up with a file length longer than what was successfully written. +// +// When using this option, if you receive an error during `io.Copy` or `io.WriteTo`, +// you may need to `Truncate` the target Writer to avoid “holes” in the data written. +func UseConcurrentWrites(value bool) ClientOption { + return func(c *Client) error { + c.useConcurrentWrites = value + return nil + } +} + +// UseConcurrentReads allows the Client to perform concurrent Reads. +// +// Concurrent reads are generally safe to use and not using them will degrade +// performance, so this option is enabled by default. +// +// When enabled, WriteTo will use Stat/Fstat to get the file size and determines +// how many concurrent workers to use. +// Some "read once" servers will delete the file if they receive a stat call on an +// open file and then the download will fail. +// Disabling concurrent reads you will be able to download files from these servers. +// If concurrent reads are disabled, the UseFstat option is ignored. +func UseConcurrentReads(value bool) ClientOption { + return func(c *Client) error { + c.disableConcurrentReads = !value + return nil + } +} + +// UseFstat sets whether to use Fstat or Stat when File.WriteTo is called +// (usually when copying files). +// Some servers limit the amount of open files and calling Stat after opening +// the file will throw an error From the server. Setting this flag will call +// Fstat instead of Stat which is suppose to be called on an open file handle. +// +// It has been found that that with IBM Sterling SFTP servers which have +// "extractability" level set to 1 which means only 1 file can be opened at +// any given time. +// +// If the server you are working with still has an issue with both Stat and +// Fstat calls you can always open a file and read it until the end. +// +// Another reason to read the file until its end and Fstat doesn't work is +// that in some servers, reading a full file will automatically delete the +// file as some of these mainframes map the file to a message in a queue. +// Once the file has been read it will get deleted. +func UseFstat(value bool) ClientOption { + return func(c *Client) error { + c.useFstat = value + return nil + } +} + +// CopyStderrTo specifies a writer to which the standard error of the remote sftp-server command should be written. +// +// The writer passed in will not be automatically closed. +// It is the responsibility of the caller to coordinate closure of any writers. +func CopyStderrTo(wr io.Writer) ClientOption { + return func(c *Client) error { + c.stderrTo = wr + return nil + } +} + +// Client represents an SFTP session on a *ssh.ClientConn SSH connection. +// Multiple Clients can be active on a single SSH connection, and a Client +// may be called concurrently from multiple Goroutines. +// +// Client implements the github.com/kr/fs.FileSystem interface. +type Client struct { + clientConn + + stderrTo io.Writer + + ext map[string]string // Extensions (name -> data). + + maxPacket int // max packet size read or written. + maxConcurrentRequests int + nextid uint32 + + // write concurrency is… error prone. + // Default behavior should be to not use it. + useConcurrentWrites bool + useFstat bool + disableConcurrentReads bool +} + +// NewClient creates a new SFTP client on conn, using zero or more option +// functions. +func NewClient(conn *ssh.Client, opts ...ClientOption) (*Client, error) { s, err := conn.NewSession() if err != nil { return nil, err } - if err := s.RequestSubsystem("sftp"); err != nil { - return nil, err - } + pw, err := s.StdinPipe() if err != nil { return nil, err @@ -30,100 +208,186 @@ func NewClient(conn *ssh.Client) (*Client, error) { if err != nil { return nil, err } + perr, err := s.StderrPipe() + if err != nil { + return nil, err + } + + if err := s.RequestSubsystem("sftp"); err != nil { + return nil, err + } - return NewClientPipe(pr, pw) + return newClientPipe(pr, perr, pw, s.Wait, opts...) } // NewClientPipe creates a new SFTP client given a Reader and a WriteCloser. // This can be used for connecting to an SFTP server over TCP/TLS or by using // the system's ssh client program (e.g. via exec.Command). -func NewClientPipe(rd io.Reader, wr io.WriteCloser) (*Client, error) { - sftp := &Client{ - w: wr, - r: rd, +func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Client, error) { + return newClientPipe(rd, nil, wr, nil, opts...) +} + +func newClientPipe(rd, stderr io.Reader, wr io.WriteCloser, wait func() error, opts ...ClientOption) (*Client, error) { + c := &Client{ + clientConn: clientConn{ + conn: conn{ + Reader: rd, + WriteCloser: wr, + }, + inflight: make(map[uint32]chan<- result), + closed: make(chan struct{}), + wait: wait, + }, + + ext: make(map[string]string), + + maxPacket: 1 << 15, + maxConcurrentRequests: 64, } - if err := sftp.sendInit(); err != nil { - return nil, err + + for _, opt := range opts { + if err := opt(c); err != nil { + wr.Close() + return nil, err + } } - return sftp, sftp.recvVersion() -} -// Client represents an SFTP session on a *ssh.ClientConn SSH connection. -// Multiple Clients can be active on a single SSH connection, and a Client -// may be called concurrently from multiple Goroutines. -// -// Client implements the github.com/kr/fs.FileSystem interface. -type Client struct { - w io.WriteCloser - r io.Reader - mu sync.Mutex // locks mu and seralises commands to the server - nextid uint32 -} + if stderr != nil { + wr := io.Discard + if c.stderrTo != nil { + wr = c.stderrTo + } + + go func() { + // DO NOT close the writer! + // Programs may pass in `os.Stderr` to write the remote stderr to, + // and the program may continue after disconnect by reconnecting. + // But if we've closed their stderr, then we just messed everything up. + + if _, err := io.Copy(wr, stderr); err != nil { + debug("error copying stderr: %v", err) + } + }() + } + + if err := c.sendInit(); err != nil { + wr.Close() + return nil, fmt.Errorf("error sending init packet to server: %w", err) + } + + if err := c.recvVersion(); err != nil { + wr.Close() + return nil, fmt.Errorf("error receiving version packet from server: %w", err) + } + + c.clientConn.wg.Add(1) + go func() { + defer c.clientConn.wg.Done() + + if err := c.clientConn.recv(); err != nil { + c.clientConn.broadcastErr(err) + } + }() -// Close closes the SFTP session. -func (c *Client) Close() error { return c.w.Close() } + return c, nil +} -// Create creates the named file mode 0666 (before umask), truncating it if -// it already exists. If successful, methods on the returned File can be -// used for I/O; the associated file descriptor has mode O_RDWR. +// Create creates the named file mode 0666 (before umask), truncating it if it +// already exists. If successful, methods on the returned File can be used for +// I/O; the associated file descriptor has mode O_RDWR. If you need more +// control over the flags/mode used to open the file see client.OpenFile. +// +// Note that some SFTP servers (eg. AWS Transfer) do not support opening files +// read/write at the same time. For those services you will need to use +// `client.OpenFile(os.O_WRONLY|os.O_CREATE|os.O_TRUNC)`. func (c *Client) Create(path string) (*File, error) { - return c.open(path, flags(os.O_RDWR|os.O_CREATE|os.O_TRUNC)) + return c.open(path, toPflags(os.O_RDWR|os.O_CREATE|os.O_TRUNC)) } -const sftpProtocolVersion = 3 // http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02 +const sftpProtocolVersion = 3 // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt func (c *Client) sendInit() error { - return sendPacket(c.w, sshFxInitPacket{ - Version: sftpProtocolVersion, // http://tools.ietf.org/html/draft-ietf-secsh-filexfer-02 + return c.clientConn.conn.sendPacket(&sshFxInitPacket{ + Version: sftpProtocolVersion, // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt }) } -// returns the current value of c.nextid and increments it -// callers is expected to hold c.mu -func (c *Client) nextId() uint32 { - v := c.nextid - c.nextid++ - return v +// returns the next value of c.nextid +func (c *Client) nextID() uint32 { + return atomic.AddUint32(&c.nextid, 1) } func (c *Client) recvVersion() error { - typ, data, err := recvPacket(c.r) + typ, data, err := c.recvPacket(0) if err != nil { + if err == io.EOF { + return fmt.Errorf("server unexpectedly closed connection: %w", io.ErrUnexpectedEOF) + } + return err } - if typ != ssh_FXP_VERSION { - return &unexpectedPacketErr{ssh_FXP_VERSION, typ} + + if typ != sshFxpVersion { + return &unexpectedPacketErr{sshFxpVersion, typ} + } + + version, data, err := unmarshalUint32Safe(data) + if err != nil { + return err } - version, _ := unmarshalUint32(data) if version != sftpProtocolVersion { return &unexpectedVersionErr{sftpProtocolVersion, version} } + for len(data) > 0 { + var ext extensionPair + ext, data, err = unmarshalExtensionPair(data) + if err != nil { + return err + } + c.ext[ext.Name] = ext.Data + } + return nil } +// HasExtension checks whether the server supports a named extension. +// +// The first return value is the extension data reported by the server +// (typically a version number). +func (c *Client) HasExtension(name string) (string, bool) { + data, ok := c.ext[name] + return data, ok +} + // Walk returns a new Walker rooted at root. func (c *Client) Walk(root string) *fs.Walker { return fs.WalkFS(root, c) } -// ReadDir reads the directory named by dirname and returns a list of -// directory entries. +// ReadDir reads the directory named by p +// and returns a list of directory entries. func (c *Client) ReadDir(p string) ([]os.FileInfo, error) { - handle, err := c.opendir(p) + return c.ReadDirContext(context.Background(), p) +} + +// ReadDirContext reads the directory named by p +// and returns a list of directory entries. +// The passed context can be used to cancel the operation +// returning all entries listed up to the cancellation. +func (c *Client) ReadDirContext(ctx context.Context, p string) ([]os.FileInfo, error) { + handle, err := c.opendir(ctx, p) if err != nil { return nil, err } defer c.close(handle) // this has to defer earlier than the lock below - var attrs []os.FileInfo - c.mu.Lock() - defer c.mu.Unlock() + var entries []os.FileInfo var done = false for !done { - id := c.nextId() - typ, data, err1 := c.sendRequest(sshFxpReaddirPacket{ - Id: id, + id := c.nextID() + typ, data, err1 := c.sendPacket(ctx, nil, &sshFxpReaddirPacket{ + ID: id, Handle: handle, }) if err1 != nil { @@ -132,10 +396,10 @@ func (c *Client) ReadDir(p string) ([]os.FileInfo, error) { break } switch typ { - case ssh_FXP_NAME: + case sshFxpName: sid, data := unmarshalUint32(data) if sid != id { - return nil, &unexpectedIdErr{id, sid} + return nil, &unexpectedIDErr{id, sid} } count, data := unmarshalUint32(data) for i := uint32(0); i < count; i++ { @@ -143,15 +407,18 @@ func (c *Client) ReadDir(p string) ([]os.FileInfo, error) { filename, data = unmarshalString(data) _, data = unmarshalString(data) // discard longname var attr *FileStat - attr, data = unmarshalAttrs(data) + attr, data, err = unmarshalAttrs(data) + if err != nil { + return nil, err + } if filename == "." || filename == ".." { continue } - attrs = append(attrs, fileInfoFromStat(attr, path.Base(filename))) + entries = append(entries, fileInfoFromStat(attr, path.Base(filename))) } - case ssh_FXP_STATUS: + case sshFxpStatus: // TODO(dfc) scope warning! - err = eofOrErr(unmarshalStatus(id, data)) + err = normaliseError(unmarshalStatus(id, data)) done = true default: return nil, unimplementedPacketErr(typ) @@ -160,55 +427,68 @@ func (c *Client) ReadDir(p string) ([]os.FileInfo, error) { if err == io.EOF { err = nil } - return attrs, err + return entries, err } -func (c *Client) opendir(path string) (string, error) { - c.mu.Lock() - defer c.mu.Unlock() - id := c.nextId() - typ, data, err := c.sendRequest(sshFxpOpendirPacket{ - Id: id, + +func (c *Client) opendir(ctx context.Context, path string) (string, error) { + id := c.nextID() + typ, data, err := c.sendPacket(ctx, nil, &sshFxpOpendirPacket{ + ID: id, Path: path, }) if err != nil { return "", err } switch typ { - case ssh_FXP_HANDLE: + case sshFxpHandle: sid, data := unmarshalUint32(data) if sid != id { - return "", &unexpectedIdErr{id, sid} + return "", &unexpectedIDErr{id, sid} } handle, _ := unmarshalString(data) return handle, nil - case ssh_FXP_STATUS: - return "", unmarshalStatus(id, data) + case sshFxpStatus: + return "", normaliseError(unmarshalStatus(id, data)) default: return "", unimplementedPacketErr(typ) } } +// Stat returns a FileInfo structure describing the file specified by path 'p'. +// If 'p' is a symbolic link, the returned FileInfo structure describes the referent file. +func (c *Client) Stat(p string) (os.FileInfo, error) { + fs, err := c.stat(p) + if err != nil { + return nil, err + } + return fileInfoFromStat(fs, path.Base(p)), nil +} + +// Lstat returns a FileInfo structure describing the file specified by path 'p'. +// If 'p' is a symbolic link, the returned FileInfo structure describes the symbolic link. func (c *Client) Lstat(p string) (os.FileInfo, error) { - c.mu.Lock() - defer c.mu.Unlock() - id := c.nextId() - typ, data, err := c.sendRequest(sshFxpLstatPacket{ - Id: id, + id := c.nextID() + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpLstatPacket{ + ID: id, Path: p, }) if err != nil { return nil, err } switch typ { - case ssh_FXP_ATTRS: + case sshFxpAttrs: sid, data := unmarshalUint32(data) if sid != id { - return nil, &unexpectedIdErr{id, sid} + return nil, &unexpectedIDErr{id, sid} + } + attr, _, err := unmarshalAttrs(data) + if err != nil { + // avoid returning a valid value from fileInfoFromStats if err != nil. + return nil, err } - attr, _ := unmarshalAttrs(data) return fileInfoFromStat(attr, path.Base(p)), nil - case ssh_FXP_STATUS: - return nil, unmarshalStatus(id, data) + case sshFxpStatus: + return nil, normaliseError(unmarshalStatus(id, data)) default: return nil, unimplementedPacketErr(typ) } @@ -216,21 +496,19 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) { // ReadLink reads the target of a symbolic link. func (c *Client) ReadLink(p string) (string, error) { - c.mu.Lock() - defer c.mu.Unlock() - id := c.nextId() - typ, data, err := c.sendRequest(sshFxpReadlinkPacket{ - Id: id, + id := c.nextID() + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpReadlinkPacket{ + ID: id, Path: p, }) if err != nil { return "", err } switch typ { - case ssh_FXP_NAME: + case sshFxpName: sid, data := unmarshalUint32(data) if sid != id { - return "", &unexpectedIdErr{id, sid} + return "", &unexpectedIDErr{id, sid} } count, data := unmarshalUint32(data) if count != 1 { @@ -238,20 +516,75 @@ func (c *Client) ReadLink(p string) (string, error) { } filename, _ := unmarshalString(data) // ignore dummy attributes return filename, nil - case ssh_FXP_STATUS: - return "", unmarshalStatus(id, data) + case sshFxpStatus: + return "", normaliseError(unmarshalStatus(id, data)) default: return "", unimplementedPacketErr(typ) } } +// Link creates a hard link at 'newname', pointing at the same inode as 'oldname' +func (c *Client) Link(oldname, newname string) error { + id := c.nextID() + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpHardlinkPacket{ + ID: id, + Oldpath: oldname, + Newpath: newname, + }) + if err != nil { + return err + } + switch typ { + case sshFxpStatus: + return normaliseError(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +// Symlink creates a symbolic link at 'newname', pointing at target 'oldname' +func (c *Client) Symlink(oldname, newname string) error { + id := c.nextID() + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpSymlinkPacket{ + ID: id, + Linkpath: newname, + Targetpath: oldname, + }) + if err != nil { + return err + } + switch typ { + case sshFxpStatus: + return normaliseError(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +func (c *Client) fsetstat(handle string, flags uint32, attrs interface{}) error { + id := c.nextID() + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFsetstatPacket{ + ID: id, + Handle: handle, + Flags: flags, + Attrs: attrs, + }) + if err != nil { + return err + } + switch typ { + case sshFxpStatus: + return normaliseError(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + // setstat is a convience wrapper to allow for changing of various parts of the file descriptor. func (c *Client) setstat(path string, flags uint32, attrs interface{}) error { - c.mu.Lock() - defer c.mu.Unlock() - id := c.nextId() - typ, data, err := c.sendRequest(sshFxpSetstatPacket{ - Id: id, + id := c.nextID() + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpSetstatPacket{ + ID: id, Path: path, Flags: flags, Attrs: attrs, @@ -260,8 +593,8 @@ func (c *Client) setstat(path string, flags uint32, attrs interface{}) error { return err } switch typ { - case ssh_FXP_STATUS: - return okOrErr(unmarshalStatus(id, data)) + case sshFxpStatus: + return normaliseError(unmarshalStatus(id, data)) default: return unimplementedPacketErr(typ) } @@ -274,22 +607,26 @@ func (c *Client) Chtimes(path string, atime time.Time, mtime time.Time) error { Mtime uint32 } attrs := times{uint32(atime.Unix()), uint32(mtime.Unix())} - return c.setstat(path, ssh_FILEXFER_ATTR_ACMODTIME, attrs) + return c.setstat(path, sshFileXferAttrACmodTime, attrs) } // Chown changes the user and group owners of the named file. func (c *Client) Chown(path string, uid, gid int) error { type owner struct { - Uid uint32 - Gid uint32 + UID uint32 + GID uint32 } attrs := owner{uint32(uid), uint32(gid)} - return c.setstat(path, ssh_FILEXFER_ATTR_UIDGID, attrs) + return c.setstat(path, sshFileXferAttrUIDGID, attrs) } // Chmod changes the permissions of the named file. +// +// Chmod does not apply a umask, because even retrieving the umask is not +// possible in a portable way without causing a race condition. Callers +// should mask off umask bits, if desired. func (c *Client) Chmod(path string, mode os.FileMode) error { - return c.setstat(path, ssh_FILEXFER_ATTR_PERMISSIONS, uint32(mode)) + return c.setstat(path, sshFileXferAttrPermissions, toChmodPerm(mode)) } // Truncate sets the size of the named file. Although it may be safely assumed @@ -297,29 +634,41 @@ func (c *Client) Chmod(path string, mode os.FileMode) error { // the SFTP protocol does not specify what behavior the server should do when setting // size greater than the current size. func (c *Client) Truncate(path string, size int64) error { - return c.setstat(path, ssh_FILEXFER_ATTR_SIZE, uint64(size)) + return c.setstat(path, sshFileXferAttrSize, uint64(size)) +} + +// SetExtendedData sets extended attributes of the named file. It uses the +// SSH_FILEXFER_ATTR_EXTENDED flag in the setstat request. +// +// This flag provides a general extension mechanism for vendor-specific extensions. +// Names of the attributes should be a string of the format "name@domain", where "domain" +// is a valid, registered domain name and "name" identifies the method. Server +// implementations SHOULD ignore extended data fields that they do not understand. +func (c *Client) SetExtendedData(path string, extended []StatExtended) error { + attrs := &FileStat{ + Extended: extended, + } + return c.setstat(path, sshFileXferAttrExtended, attrs) } // Open opens the named file for reading. If successful, methods on the // returned file can be used for reading; the associated file descriptor // has mode O_RDONLY. func (c *Client) Open(path string) (*File, error) { - return c.open(path, flags(os.O_RDONLY)) + return c.open(path, toPflags(os.O_RDONLY)) } // OpenFile is the generalized open call; most users will use Open or // Create instead. It opens the named file with specified flag (O_RDONLY // etc.). If successful, methods on the returned File can be used for I/O. func (c *Client) OpenFile(path string, f int) (*File, error) { - return c.open(path, flags(f)) + return c.open(path, toPflags(f)) } func (c *Client) open(path string, pflags uint32) (*File, error) { - c.mu.Lock() - defer c.mu.Unlock() - id := c.nextId() - typ, data, err := c.sendRequest(sshFxpOpenPacket{ - Id: id, + id := c.nextID() + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpOpenPacket{ + ID: id, Path: path, Pflags: pflags, }) @@ -327,94 +676,118 @@ func (c *Client) open(path string, pflags uint32) (*File, error) { return nil, err } switch typ { - case ssh_FXP_HANDLE: + case sshFxpHandle: sid, data := unmarshalUint32(data) if sid != id { - return nil, &unexpectedIdErr{id, sid} + return nil, &unexpectedIDErr{id, sid} } handle, _ := unmarshalString(data) return &File{c: c, path: path, handle: handle}, nil - case ssh_FXP_STATUS: - return nil, unmarshalStatus(id, data) + case sshFxpStatus: + return nil, normaliseError(unmarshalStatus(id, data)) default: return nil, unimplementedPacketErr(typ) } } -// readAt reads len(buf) bytes from the remote file indicated by handle starting -// from offset. -func (c *Client) readAt(handle string, offset uint64, buf []byte) (uint32, error) { - c.mu.Lock() - defer c.mu.Unlock() - id := c.nextId() - typ, data, err := c.sendRequest(sshFxpReadPacket{ - Id: id, +// close closes a handle handle previously returned in the response +// to SSH_FXP_OPEN or SSH_FXP_OPENDIR. The handle becomes invalid +// immediately after this request has been sent. +func (c *Client) close(handle string) error { + id := c.nextID() + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpClosePacket{ + ID: id, Handle: handle, - Offset: offset, - Len: uint32(len(buf)), }) if err != nil { - return 0, err + return err } switch typ { - case ssh_FXP_DATA: - sid, data := unmarshalUint32(data) - if sid != id { - return 0, &unexpectedIdErr{id, sid} - } - l, data := unmarshalUint32(data) - n := copy(buf, data[:l]) - return uint32(n), nil - case ssh_FXP_STATUS: - return 0, eofOrErr(unmarshalStatus(id, data)) + case sshFxpStatus: + return normaliseError(unmarshalStatus(id, data)) default: - return 0, unimplementedPacketErr(typ) + return unimplementedPacketErr(typ) } } -// close closes a handle handle previously returned in the response -// to SSH_FXP_OPEN or SSH_FXP_OPENDIR. The handle becomes invalid -// immediately after this request has been sent. -func (c *Client) close(handle string) error { - c.mu.Lock() - defer c.mu.Unlock() - id := c.nextId() - typ, data, err := c.sendRequest(sshFxpClosePacket{ - Id: id, - Handle: handle, +func (c *Client) stat(path string) (*FileStat, error) { + id := c.nextID() + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpStatPacket{ + ID: id, + Path: path, }) if err != nil { - return err + return nil, err } switch typ { - case ssh_FXP_STATUS: - return okOrErr(unmarshalStatus(id, data)) + case sshFxpAttrs: + sid, data := unmarshalUint32(data) + if sid != id { + return nil, &unexpectedIDErr{id, sid} + } + attr, _, err := unmarshalAttrs(data) + return attr, err + case sshFxpStatus: + return nil, normaliseError(unmarshalStatus(id, data)) default: - return unimplementedPacketErr(typ) + return nil, unimplementedPacketErr(typ) } } func (c *Client) fstat(handle string) (*FileStat, error) { - c.mu.Lock() - defer c.mu.Unlock() - id := c.nextId() - typ, data, err := c.sendRequest(sshFxpFstatPacket{ - Id: id, + id := c.nextID() + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFstatPacket{ + ID: id, Handle: handle, }) if err != nil { return nil, err } switch typ { - case ssh_FXP_ATTRS: + case sshFxpAttrs: sid, data := unmarshalUint32(data) if sid != id { - return nil, &unexpectedIdErr{id, sid} + return nil, &unexpectedIDErr{id, sid} + } + attr, _, err := unmarshalAttrs(data) + return attr, err + case sshFxpStatus: + return nil, normaliseError(unmarshalStatus(id, data)) + default: + return nil, unimplementedPacketErr(typ) + } +} + +// StatVFS retrieves VFS statistics from a remote host. +// +// It implements the statvfs@openssh.com SSH_FXP_EXTENDED feature +// from http://www.opensource.apple.com/source/OpenSSH/OpenSSH-175/openssh/PROTOCOL?txt. +func (c *Client) StatVFS(path string) (*StatVFS, error) { + // send the StatVFS packet to the server + id := c.nextID() + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpStatvfsPacket{ + ID: id, + Path: path, + }) + if err != nil { + return nil, err + } + + switch typ { + // server responded with valid data + case sshFxpExtendedReply: + var response StatVFS + err = binary.Read(bytes.NewReader(data), binary.BigEndian, &response) + if err != nil { + return nil, errors.New("can not parse reply") } - attr, _ := unmarshalAttrs(data) - return attr, nil - case ssh_FXP_STATUS: - return nil, unmarshalStatus(id, data) + + return &response, nil + + // the resquest failed + case sshFxpStatus: + return nil, normaliseError(unmarshalStatus(id, data)) + default: return nil, unimplementedPacketErr(typ) } @@ -429,46 +802,87 @@ func (c *Client) Join(elem ...string) string { return path.Join(elem...) } // file or directory with the specified path exists, or if the specified directory // is not empty. func (c *Client) Remove(path string) error { - err := c.removeFile(path) - if status, ok := err.(*StatusError); ok && status.Code == ssh_FX_FAILURE { - err = c.removeDirectory(path) + errF := c.removeFile(path) + if errF == nil { + return nil + } + + errD := c.RemoveDirectory(path) + if errD == nil { + return nil + } + + // Both failed: figure out which error to return. + + if errF, ok := errF.(*os.PathError); ok { + // The only time it makes sense to compare errors, is when both are `*os.PathError`. + // We cannot test these directly with errF == errD, as that would be a pointer comparison. + + if errD, ok := errD.(*os.PathError); ok && errors.Is(errF.Err, errD.Err) { + // If they are both pointers to PathError, + // and the same underlying error, then return that. + return errF + } + } + + fi, err := c.Stat(path) + if err != nil { + return err } - return err + + if fi.IsDir() { + return errD + } + + return errF } func (c *Client) removeFile(path string) error { - c.mu.Lock() - defer c.mu.Unlock() - id := c.nextId() - typ, data, err := c.sendRequest(sshFxpRemovePacket{ - Id: id, + id := c.nextID() + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRemovePacket{ + ID: id, Filename: path, }) if err != nil { return err } switch typ { - case ssh_FXP_STATUS: - return okOrErr(unmarshalStatus(id, data)) + case sshFxpStatus: + err = normaliseError(unmarshalStatus(id, data)) + if err == nil { + return nil + } + return &os.PathError{ + Op: "remove", + Path: path, + Err: err, + } default: return unimplementedPacketErr(typ) } } -func (c *Client) removeDirectory(path string) error { - c.mu.Lock() - defer c.mu.Unlock() - id := c.nextId() - typ, data, err := c.sendRequest(sshFxpRmdirPacket{ - Id: id, +// RemoveDirectory removes a directory path. +func (c *Client) RemoveDirectory(path string) error { + id := c.nextID() + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRmdirPacket{ + ID: id, Path: path, }) if err != nil { return err } switch typ { - case ssh_FXP_STATUS: - return okOrErr(unmarshalStatus(id, data)) + case sshFxpStatus: + err = normaliseError(unmarshalStatus(id, data)) + if err == nil { + return nil + } + return &os.PathError{ + Op: "remove", + Path: path, + Err: err, + } default: return unimplementedPacketErr(typ) } @@ -476,11 +890,9 @@ func (c *Client) removeDirectory(path string) error { // Rename renames a file. func (c *Client) Rename(oldname, newname string) error { - c.mu.Lock() - defer c.mu.Unlock() - id := c.nextId() - typ, data, err := c.sendRequest(sshFxpRenamePacket{ - Id: id, + id := c.nextID() + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRenamePacket{ + ID: id, Oldpath: oldname, Newpath: newname, }) @@ -488,103 +900,709 @@ func (c *Client) Rename(oldname, newname string) error { return err } switch typ { - case ssh_FXP_STATUS: - return okOrErr(unmarshalStatus(id, data)) + case sshFxpStatus: + return normaliseError(unmarshalStatus(id, data)) default: return unimplementedPacketErr(typ) } } -func (c *Client) sendRequest(p encoding.BinaryMarshaler) (byte, []byte, error) { - if err := sendPacket(c.w, p); err != nil { - return 0, nil, err - } - return recvPacket(c.r) -} - -// writeAt writes len(buf) bytes from the remote file indicated by handle starting -// from offset. -func (c *Client) writeAt(handle string, offset uint64, buf []byte) (uint32, error) { - c.mu.Lock() - defer c.mu.Unlock() - id := c.nextId() - typ, data, err := c.sendRequest(sshFxpWritePacket{ - Id: id, - Handle: handle, - Offset: offset, - Length: uint32(len(buf)), - Data: buf, +// PosixRename renames a file using the posix-rename@openssh.com extension +// which will replace newname if it already exists. +func (c *Client) PosixRename(oldname, newname string) error { + id := c.nextID() + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpPosixRenamePacket{ + ID: id, + Oldpath: oldname, + Newpath: newname, }) if err != nil { - return 0, err + return err } switch typ { - case ssh_FXP_STATUS: - if err := okOrErr(unmarshalStatus(id, data)); err != nil { - return 0, err - } - return uint32(len(buf)), nil + case sshFxpStatus: + return normaliseError(unmarshalStatus(id, data)) default: - return 0, unimplementedPacketErr(typ) + return unimplementedPacketErr(typ) } } -// Creates the specified directory. An error will be returned if a file or -// directory with the specified path already exists, or if the directory's -// parent folder does not exist (the method cannot create complete paths). -func (c *Client) Mkdir(path string) error { - c.mu.Lock() - defer c.mu.Unlock() - id := c.nextId() - typ, data, err := c.sendRequest(sshFxpMkdirPacket{ - Id: id, +// RealPath can be used to have the server canonicalize any given path name to an absolute path. +// +// This is useful for converting path names containing ".." components, +// or relative pathnames without a leading slash into absolute paths. +func (c *Client) RealPath(path string) (string, error) { + id := c.nextID() + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRealpathPacket{ + ID: id, Path: path, }) if err != nil { - return err + return "", err } switch typ { - case ssh_FXP_STATUS: - return okOrErr(unmarshalStatus(id, data)) - default: - return unimplementedPacketErr(typ) - } -} - -// File represents a remote file. -type File struct { - c *Client - path string + case sshFxpName: + sid, data := unmarshalUint32(data) + if sid != id { + return "", &unexpectedIDErr{id, sid} + } + count, data := unmarshalUint32(data) + if count != 1 { + return "", unexpectedCount(1, count) + } + filename, _ := unmarshalString(data) // ignore attributes + return filename, nil + case sshFxpStatus: + return "", normaliseError(unmarshalStatus(id, data)) + default: + return "", unimplementedPacketErr(typ) + } +} + +// Getwd returns the current working directory of the server. Operations +// involving relative paths will be based at this location. +func (c *Client) Getwd() (string, error) { + return c.RealPath(".") +} + +// Mkdir creates the specified directory. An error will be returned if a file or +// directory with the specified path already exists, or if the directory's +// parent folder does not exist (the method cannot create complete paths). +func (c *Client) Mkdir(path string) error { + id := c.nextID() + typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpMkdirPacket{ + ID: id, + Path: path, + }) + if err != nil { + return err + } + switch typ { + case sshFxpStatus: + return normaliseError(unmarshalStatus(id, data)) + default: + return unimplementedPacketErr(typ) + } +} + +// MkdirAll creates a directory named path, along with any necessary parents, +// and returns nil, or else returns an error. +// If path is already a directory, MkdirAll does nothing and returns nil. +// If, while making any directory, that path is found to already be a regular file, an error is returned. +func (c *Client) MkdirAll(path string) error { + // Most of this code mimics https://golang.org/src/os/path.go?s=514:561#L13 + // Fast path: if we can tell whether path is a directory or file, stop with success or error. + dir, err := c.Stat(path) + if err == nil { + if dir.IsDir() { + return nil + } + return &os.PathError{Op: "mkdir", Path: path, Err: syscall.ENOTDIR} + } + + // Slow path: make sure parent exists and then call Mkdir for path. + i := len(path) + for i > 0 && path[i-1] == '/' { // Skip trailing path separator. + i-- + } + + j := i + for j > 0 && path[j-1] != '/' { // Scan backward over element. + j-- + } + + if j > 1 { + // Create parent + err = c.MkdirAll(path[0 : j-1]) + if err != nil { + return err + } + } + + // Parent now exists; invoke Mkdir and use its result. + err = c.Mkdir(path) + if err != nil { + // Handle arguments like "foo/." by + // double-checking that directory doesn't exist. + dir, err1 := c.Lstat(path) + if err1 == nil && dir.IsDir() { + return nil + } + return err + } + return nil +} + +// RemoveAll delete files recursively in the directory and Recursively delete subdirectories. +// An error will be returned if no file or directory with the specified path exists +func (c *Client) RemoveAll(path string) error { + + // Get the file/directory information + fi, err := c.Stat(path) + if err != nil { + return err + } + + if fi.IsDir() { + // Delete files recursively in the directory + files, err := c.ReadDir(path) + if err != nil { + return err + } + + for _, file := range files { + if file.IsDir() { + // Recursively delete subdirectories + err = c.RemoveAll(path + "/" + file.Name()) + if err != nil { + return err + } + } else { + // Delete individual files + err = c.Remove(path + "/" + file.Name()) + if err != nil { + return err + } + } + } + + } + + return c.Remove(path) + +} + +// File represents a remote file. +type File struct { + c *Client + path string + + mu sync.RWMutex handle string - offset uint64 // current offset within remote file + offset int64 // current offset within remote file } // Close closes the File, rendering it unusable for I/O. It returns an // error, if any. func (f *File) Close() error { - return f.c.close(f.handle) + f.mu.Lock() + defer f.mu.Unlock() + + if f.handle == "" { + return os.ErrClosed + } + + // The design principle here is that when `openssh-portable/sftp-server.c` is doing `handle_close`, + // it will unconditionally mark the handle as unused, + // so we need to also unconditionally mark this handle as invalid. + // By invalidating our local copy of the handle, + // we ensure that there cannot be any erroneous use-after-close requests sent after Close. + + handle := f.handle + f.handle = "" + + return f.c.close(handle) } -// Read reads up to len(b) bytes from the File. It returns the number of -// bytes read and an error, if any. EOF is signaled by a zero count with -// err set to io.EOF. +// Name returns the name of the file as presented to Open or Create. +func (f *File) Name() string { + return f.path +} + +// Read reads up to len(b) bytes from the File. It returns the number of bytes +// read and an error, if any. Read follows io.Reader semantics, so when Read +// encounters an error or EOF condition after successfully reading n > 0 bytes, +// it returns the number of bytes read. +// +// To maximise throughput for transferring the entire file (especially +// over high latency links) it is recommended to use WriteTo rather +// than calling Read multiple times. io.Copy will do this +// automatically. func (f *File) Read(b []byte) (int, error) { - var read int - for len(b) > 0 { - n, err := f.c.readAt(f.handle, f.offset, b[:min(len(b), maxWritePacket)]) - f.offset += uint64(n) - read += int(n) + f.mu.Lock() + defer f.mu.Unlock() + + n, err := f.readAt(b, f.offset) + f.offset += int64(n) + return n, err +} + +// readChunkAt attempts to read the whole entire length of the buffer from the file starting at the offset. +// It will continue progressively reading into the buffer until it fills the whole buffer, or an error occurs. +func (f *File) readChunkAt(ch chan result, b []byte, off int64) (n int, err error) { + for err == nil && n < len(b) { + id := f.c.nextID() + typ, data, err := f.c.sendPacket(context.Background(), ch, &sshFxpReadPacket{ + ID: id, + Handle: f.handle, + Offset: uint64(off) + uint64(n), + Len: uint32(len(b) - n), + }) + if err != nil { + return n, err + } + + switch typ { + case sshFxpStatus: + return n, normaliseError(unmarshalStatus(id, data)) + + case sshFxpData: + sid, data := unmarshalUint32(data) + if id != sid { + return n, &unexpectedIDErr{id, sid} + } + + l, data := unmarshalUint32(data) + n += copy(b[n:], data[:l]) + + default: + return n, unimplementedPacketErr(typ) + } + } + + return +} + +func (f *File) readAtSequential(b []byte, off int64) (read int, err error) { + for read < len(b) { + rb := b[read:] + if len(rb) > f.c.maxPacket { + rb = rb[:f.c.maxPacket] + } + n, err := f.readChunkAt(nil, rb, off+int64(read)) + if n < 0 { + panic("sftp.File: returned negative count from readChunkAt") + } + if n > 0 { + read += n + } if err != nil { return read, err } - b = b[n:] } return read, nil } +// ReadAt reads up to len(b) byte from the File at a given offset `off`. It returns +// the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics, +// so the file offset is not altered during the read. +func (f *File) ReadAt(b []byte, off int64) (int, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + return f.readAt(b, off) +} + +// readAt must be called while holding either the Read or Write mutex in File. +// This code is concurrent safe with itself, but not with Close. +func (f *File) readAt(b []byte, off int64) (int, error) { + if f.handle == "" { + return 0, os.ErrClosed + } + + if len(b) <= f.c.maxPacket { + // This should be able to be serviced with 1/2 requests. + // So, just do it directly. + return f.readChunkAt(nil, b, off) + } + + if f.c.disableConcurrentReads { + return f.readAtSequential(b, off) + } + + // Split the read into multiple maxPacket-sized concurrent reads bounded by maxConcurrentRequests. + // This allows writes with a suitably large buffer to transfer data at a much faster rate + // by overlapping round trip times. + + cancel := make(chan struct{}) + + concurrency := len(b)/f.c.maxPacket + 1 + if concurrency > f.c.maxConcurrentRequests || concurrency < 1 { + concurrency = f.c.maxConcurrentRequests + } + + resPool := newResChanPool(concurrency) + + type work struct { + id uint32 + res chan result + + b []byte + off int64 + } + workCh := make(chan work) + + // Slice: cut up the Read into any number of buffers of length <= f.c.maxPacket, and at appropriate offsets. + go func() { + defer close(workCh) + + b := b + offset := off + chunkSize := f.c.maxPacket + + for len(b) > 0 { + rb := b + if len(rb) > chunkSize { + rb = rb[:chunkSize] + } + + id := f.c.nextID() + res := resPool.Get() + + f.c.dispatchRequest(res, &sshFxpReadPacket{ + ID: id, + Handle: f.handle, + Offset: uint64(offset), + Len: uint32(len(rb)), + }) + + select { + case workCh <- work{id, res, rb, offset}: + case <-cancel: + return + } + + offset += int64(len(rb)) + b = b[len(rb):] + } + }() + + type rErr struct { + off int64 + err error + } + errCh := make(chan rErr) + + var wg sync.WaitGroup + wg.Add(concurrency) + for i := 0; i < concurrency; i++ { + // Map_i: each worker gets work, and then performs the Read into its buffer from its respective offset. + go func() { + defer wg.Done() + + for packet := range workCh { + var n int + + s := <-packet.res + resPool.Put(packet.res) + + err := s.err + if err == nil { + switch s.typ { + case sshFxpStatus: + err = normaliseError(unmarshalStatus(packet.id, s.data)) + + case sshFxpData: + sid, data := unmarshalUint32(s.data) + if packet.id != sid { + err = &unexpectedIDErr{packet.id, sid} + + } else { + l, data := unmarshalUint32(data) + n = copy(packet.b, data[:l]) + + // For normal disk files, it is guaranteed that this will read + // the specified number of bytes, or up to end of file. + // This implies, if we have a short read, that means EOF. + if n < len(packet.b) { + err = io.EOF + } + } + + default: + err = unimplementedPacketErr(s.typ) + } + } + + if err != nil { + // return the offset as the start + how much we read before the error. + errCh <- rErr{packet.off + int64(n), err} + + // DO NOT return. + // We want to ensure that workCh is drained before wg.Wait returns. + } + } + }() + } + + // Wait for long tail, before closing results. + go func() { + wg.Wait() + close(errCh) + }() + + // Reduce: collect all the results into a relevant return: the earliest offset to return an error. + firstErr := rErr{math.MaxInt64, nil} + for rErr := range errCh { + if rErr.off <= firstErr.off { + firstErr = rErr + } + + select { + case <-cancel: + default: + // stop any more work from being distributed. (Just in case.) + close(cancel) + } + } + + if firstErr.err != nil { + // firstErr.err != nil if and only if firstErr.off > our starting offset. + return int(firstErr.off - off), firstErr.err + } + + // As per spec for io.ReaderAt, we return nil error if and only if we read everything. + return len(b), nil +} + +// writeToSequential implements WriteTo, but works sequentially with no parallelism. +func (f *File) writeToSequential(w io.Writer) (written int64, err error) { + b := make([]byte, f.c.maxPacket) + ch := make(chan result, 1) // reusable channel + + for { + n, err := f.readChunkAt(ch, b, f.offset) + if n < 0 { + panic("sftp.File: returned negative count from readChunkAt") + } + + if n > 0 { + f.offset += int64(n) + + m, err := w.Write(b[:n]) + written += int64(m) + + if err != nil { + return written, err + } + } + + if err != nil { + if err == io.EOF { + return written, nil // return nil explicitly. + } + + return written, err + } + } +} + +// WriteTo writes the file to the given Writer. +// The return value is the number of bytes written. +// Any error encountered during the write is also returned. +// +// This method is preferred over calling Read multiple times +// to maximise throughput for transferring the entire file, +// especially over high latency links. +func (f *File) WriteTo(w io.Writer) (written int64, err error) { + f.mu.Lock() + defer f.mu.Unlock() + + if f.handle == "" { + return 0, os.ErrClosed + } + + if f.c.disableConcurrentReads { + return f.writeToSequential(w) + } + + // For concurrency, we want to guess how many concurrent workers we should use. + var fileStat *FileStat + if f.c.useFstat { + fileStat, err = f.c.fstat(f.handle) + } else { + fileStat, err = f.c.stat(f.path) + } + if err != nil { + return 0, err + } + + fileSize := fileStat.Size + if fileSize <= uint64(f.c.maxPacket) || !isRegular(fileStat.Mode) { + // only regular files are guaranteed to return (full read) xor (partial read, next error) + return f.writeToSequential(w) + } + + concurrency64 := fileSize/uint64(f.c.maxPacket) + 1 // a bad guess, but better than no guess + if concurrency64 > uint64(f.c.maxConcurrentRequests) || concurrency64 < 1 { + concurrency64 = uint64(f.c.maxConcurrentRequests) + } + // Now that concurrency64 is saturated to an int value, we know this assignment cannot possibly overflow. + concurrency := int(concurrency64) + + chunkSize := f.c.maxPacket + pool := newBufPool(concurrency, chunkSize) + resPool := newResChanPool(concurrency) + + cancel := make(chan struct{}) + var wg sync.WaitGroup + defer func() { + // Once the writing Reduce phase has ended, all the feed work needs to unconditionally stop. + close(cancel) + + // We want to wait until all outstanding goroutines with an `f` or `f.c` reference have completed. + // Just to be sure we don’t orphan any goroutines any hanging references. + wg.Wait() + }() + + type writeWork struct { + b []byte + off int64 + err error + + next chan writeWork + } + writeCh := make(chan writeWork) + + type readWork struct { + id uint32 + res chan result + off int64 + + cur, next chan writeWork + } + readCh := make(chan readWork) + + // Slice: hand out chunks of work on demand, with a `cur` and `next` channel built-in for sequencing. + go func() { + defer close(readCh) + + off := f.offset + + cur := writeCh + for { + id := f.c.nextID() + res := resPool.Get() + + next := make(chan writeWork) + readWork := readWork{ + id: id, + res: res, + off: off, + + cur: cur, + next: next, + } + + f.c.dispatchRequest(res, &sshFxpReadPacket{ + ID: id, + Handle: f.handle, + Offset: uint64(off), + Len: uint32(chunkSize), + }) + + select { + case readCh <- readWork: + case <-cancel: + return + } + + off += int64(chunkSize) + cur = next + } + }() + + wg.Add(concurrency) + for i := 0; i < concurrency; i++ { + // Map_i: each worker gets readWork, and does the Read into a buffer at the given offset. + go func() { + defer wg.Done() + + for readWork := range readCh { + var b []byte + var n int + + s := <-readWork.res + resPool.Put(readWork.res) + + err := s.err + if err == nil { + switch s.typ { + case sshFxpStatus: + err = normaliseError(unmarshalStatus(readWork.id, s.data)) + + case sshFxpData: + sid, data := unmarshalUint32(s.data) + if readWork.id != sid { + err = &unexpectedIDErr{readWork.id, sid} + + } else { + l, data := unmarshalUint32(data) + b = pool.Get()[:l] + n = copy(b, data[:l]) + b = b[:n] + } + + default: + err = unimplementedPacketErr(s.typ) + } + } + + writeWork := writeWork{ + b: b, + off: readWork.off, + err: err, + + next: readWork.next, + } + + select { + case readWork.cur <- writeWork: + case <-cancel: + } + + // DO NOT return. + // We want to ensure that readCh is drained before wg.Wait returns. + } + }() + } + + // Reduce: serialize the results from the reads into sequential writes. + cur := writeCh + for { + packet, ok := <-cur + if !ok { + return written, errors.New("sftp.File.WriteTo: unexpectedly closed channel") + } + + // Because writes are serialized, this will always be the last successfully read byte. + f.offset = packet.off + int64(len(packet.b)) + + if len(packet.b) > 0 { + n, err := w.Write(packet.b) + written += int64(n) + if err != nil { + return written, err + } + } + + if packet.err != nil { + if packet.err == io.EOF { + return written, nil + } + + return written, packet.err + } + + pool.Put(packet.b) + cur = packet.next + } +} + // Stat returns the FileInfo structure describing file. If there is an // error. func (f *File) Stat() (os.FileInfo, error) { + f.mu.RLock() + defer f.mu.RUnlock() + + if f.handle == "" { + return nil, os.ErrClosed + } + + return f.stat() +} + +func (f *File) stat() (os.FileInfo, error) { fs, err := f.c.fstat(f.handle) if err != nil { return nil, err @@ -592,126 +1610,689 @@ func (f *File) Stat() (os.FileInfo, error) { return fileInfoFromStat(fs, path.Base(f.path)), nil } -// clamp writes to less than 32k -const maxWritePacket = 1 << 15 - // Write writes len(b) bytes to the File. It returns the number of bytes // written and an error, if any. Write returns a non-nil error when n != // len(b). +// +// To maximise throughput for transferring the entire file (especially +// over high latency links) it is recommended to use ReadFrom rather +// than calling Write multiple times. io.Copy will do this +// automatically. func (f *File) Write(b []byte) (int, error) { - var written int - for len(b) > 0 { - n, err := f.c.writeAt(f.handle, f.offset, b[:min(len(b), maxWritePacket)]) - f.offset += uint64(n) - written += int(n) + f.mu.Lock() + defer f.mu.Unlock() + + if f.handle == "" { + return 0, os.ErrClosed + } + + n, err := f.writeAt(b, f.offset) + f.offset += int64(n) + return n, err +} + +func (f *File) writeChunkAt(ch chan result, b []byte, off int64) (int, error) { + typ, data, err := f.c.sendPacket(context.Background(), ch, &sshFxpWritePacket{ + ID: f.c.nextID(), + Handle: f.handle, + Offset: uint64(off), + Length: uint32(len(b)), + Data: b, + }) + if err != nil { + return 0, err + } + + switch typ { + case sshFxpStatus: + id, _ := unmarshalUint32(data) + err := normaliseError(unmarshalStatus(id, data)) + if err != nil { + return 0, err + } + + default: + return 0, unimplementedPacketErr(typ) + } + + return len(b), nil +} + +// writeAtConcurrent implements WriterAt, but works concurrently rather than sequentially. +func (f *File) writeAtConcurrent(b []byte, off int64) (int, error) { + // Split the write into multiple maxPacket sized concurrent writes + // bounded by maxConcurrentRequests. This allows writes with a suitably + // large buffer to transfer data at a much faster rate due to + // overlapping round trip times. + + cancel := make(chan struct{}) + + type work struct { + id uint32 + res chan result + + off int64 + } + workCh := make(chan work) + + concurrency := len(b)/f.c.maxPacket + 1 + if concurrency > f.c.maxConcurrentRequests || concurrency < 1 { + concurrency = f.c.maxConcurrentRequests + } + + pool := newResChanPool(concurrency) + + // Slice: cut up the Read into any number of buffers of length <= f.c.maxPacket, and at appropriate offsets. + go func() { + defer close(workCh) + + var read int + chunkSize := f.c.maxPacket + + for read < len(b) { + wb := b[read:] + if len(wb) > chunkSize { + wb = wb[:chunkSize] + } + + id := f.c.nextID() + res := pool.Get() + off := off + int64(read) + + f.c.dispatchRequest(res, &sshFxpWritePacket{ + ID: id, + Handle: f.handle, + Offset: uint64(off), + Length: uint32(len(wb)), + Data: wb, + }) + + select { + case workCh <- work{id, res, off}: + case <-cancel: + return + } + + read += len(wb) + } + }() + + type wErr struct { + off int64 + err error + } + errCh := make(chan wErr) + + var wg sync.WaitGroup + wg.Add(concurrency) + for i := 0; i < concurrency; i++ { + // Map_i: each worker gets work, and does the Write from each buffer to its respective offset. + go func() { + defer wg.Done() + + for work := range workCh { + s := <-work.res + pool.Put(work.res) + + err := s.err + if err == nil { + switch s.typ { + case sshFxpStatus: + err = normaliseError(unmarshalStatus(work.id, s.data)) + default: + err = unimplementedPacketErr(s.typ) + } + } + + if err != nil { + errCh <- wErr{work.off, err} + } + } + }() + } + + // Wait for long tail, before closing results. + go func() { + wg.Wait() + close(errCh) + }() + + // Reduce: collect all the results into a relevant return: the earliest offset to return an error. + firstErr := wErr{math.MaxInt64, nil} + for wErr := range errCh { + if wErr.off <= firstErr.off { + firstErr = wErr + } + + select { + case <-cancel: + default: + // stop any more work from being distributed. (Just in case.) + close(cancel) + } + } + + if firstErr.err != nil { + // firstErr.err != nil if and only if firstErr.off >= our starting offset. + return int(firstErr.off - off), firstErr.err + } + + return len(b), nil +} + +// WriteAt writes up to len(b) byte to the File at a given offset `off`. It returns +// the number of bytes written and an error, if any. WriteAt follows io.WriterAt semantics, +// so the file offset is not altered during the write. +func (f *File) WriteAt(b []byte, off int64) (written int, err error) { + f.mu.RLock() + defer f.mu.RUnlock() + + if f.handle == "" { + return 0, os.ErrClosed + } + + return f.writeAt(b, off) +} + +// writeAt must be called while holding either the Read or Write mutex in File. +// This code is concurrent safe with itself, but not with Close. +func (f *File) writeAt(b []byte, off int64) (written int, err error) { + if len(b) <= f.c.maxPacket { + // We can do this in one write. + return f.writeChunkAt(nil, b, off) + } + + if f.c.useConcurrentWrites { + return f.writeAtConcurrent(b, off) + } + + ch := make(chan result, 1) // reusable channel + + chunkSize := f.c.maxPacket + + for written < len(b) { + wb := b[written:] + if len(wb) > chunkSize { + wb = wb[:chunkSize] + } + + n, err := f.writeChunkAt(ch, wb, off+int64(written)) + if n > 0 { + written += n + } + if err != nil { return written, err } - b = b[n:] } - return written, nil + + return len(b), nil +} + +// ReadFromWithConcurrency implements ReaderFrom, +// but uses the given concurrency to issue multiple requests at the same time. +// +// Giving a concurrency of less than one will default to the Client’s max concurrency. +// +// Otherwise, the given concurrency will be capped by the Client's max concurrency. +// +// When one needs to guarantee concurrent reads/writes, this method is preferred +// over ReadFrom. +func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) { + f.mu.Lock() + defer f.mu.Unlock() + + return f.readFromWithConcurrency(r, concurrency) +} + +func (f *File) readFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) { + if f.handle == "" { + return 0, os.ErrClosed + } + + // Split the write into multiple maxPacket sized concurrent writes. + // This allows writes with a suitably large reader + // to transfer data at a much faster rate due to overlapping round trip times. + + cancel := make(chan struct{}) + + type work struct { + id uint32 + res chan result + + off int64 + } + workCh := make(chan work) + + type rwErr struct { + off int64 + err error + } + errCh := make(chan rwErr) + + if concurrency > f.c.maxConcurrentRequests || concurrency < 1 { + concurrency = f.c.maxConcurrentRequests + } + + pool := newResChanPool(concurrency) + + // Slice: cut up the Read into any number of buffers of length <= f.c.maxPacket, and at appropriate offsets. + go func() { + defer close(workCh) + + b := make([]byte, f.c.maxPacket) + off := f.offset + + for { + // Fill the entire buffer. + n, err := io.ReadFull(r, b) + + if n > 0 { + read += int64(n) + + id := f.c.nextID() + res := pool.Get() + + f.c.dispatchRequest(res, &sshFxpWritePacket{ + ID: id, + Handle: f.handle, + Offset: uint64(off), + Length: uint32(n), + Data: b[:n], + }) + + select { + case workCh <- work{id, res, off}: + case <-cancel: + return + } + + off += int64(n) + } + + if err != nil { + if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) { + errCh <- rwErr{off, err} + } + return + } + } + }() + + var wg sync.WaitGroup + wg.Add(concurrency) + for i := 0; i < concurrency; i++ { + // Map_i: each worker gets work, and does the Write from each buffer to its respective offset. + go func() { + defer wg.Done() + + for work := range workCh { + s := <-work.res + pool.Put(work.res) + + err := s.err + if err == nil { + switch s.typ { + case sshFxpStatus: + err = normaliseError(unmarshalStatus(work.id, s.data)) + default: + err = unimplementedPacketErr(s.typ) + } + } + + if err != nil { + errCh <- rwErr{work.off, err} + + // DO NOT return. + // We want to ensure that workCh is drained before wg.Wait returns. + } + } + }() + } + + // Wait for long tail, before closing results. + go func() { + wg.Wait() + close(errCh) + }() + + // Reduce: Collect all the results into a relevant return: the earliest offset to return an error. + firstErr := rwErr{math.MaxInt64, nil} + for rwErr := range errCh { + if rwErr.off <= firstErr.off { + firstErr = rwErr + } + + select { + case <-cancel: + default: + // stop any more work from being distributed. + close(cancel) + } + } + + if firstErr.err != nil { + // firstErr.err != nil if and only if firstErr.off is a valid offset. + // + // firstErr.off will then be the lesser of: + // * the offset of the first error from writing, + // * the last successfully read offset. + // + // This could be less than the last successfully written offset, + // which is the whole reason for the UseConcurrentWrites() ClientOption. + // + // Callers are responsible for truncating any SFTP files to a safe length. + f.offset = firstErr.off + + // ReadFrom is defined to return the read bytes, regardless of any writer errors. + return read, firstErr.err + } + + f.offset += read + return read, nil +} + +// ReadFrom reads data from r until EOF and writes it to the file. The return +// value is the number of bytes read. Any error except io.EOF encountered +// during the read is also returned. +// +// This method is preferred over calling Write multiple times +// to maximise throughput for transferring the entire file, +// especially over high-latency links. +// +// To ensure concurrent writes, the given r needs to implement one of +// the following receiver methods: +// +// Len() int +// Size() int64 +// Stat() (os.FileInfo, error) +// +// or be an instance of [io.LimitedReader] to determine the number of possible +// concurrent requests. Otherwise, reads/writes are performed sequentially. +// ReadFromWithConcurrency can be used explicitly to guarantee concurrent +// processing of the reader. +func (f *File) ReadFrom(r io.Reader) (int64, error) { + f.mu.Lock() + defer f.mu.Unlock() + + if f.handle == "" { + return 0, os.ErrClosed + } + + if f.c.useConcurrentWrites { + var remain int64 + switch r := r.(type) { + case interface{ Len() int }: + remain = int64(r.Len()) + + case interface{ Size() int64 }: + remain = r.Size() + + case *io.LimitedReader: + remain = r.N + + case interface{ Stat() (os.FileInfo, error) }: + info, err := r.Stat() + if err == nil { + remain = info.Size() + } + } + + if remain < 0 { + // We can strongly assert that we want default max concurrency here. + return f.readFromWithConcurrency(r, f.c.maxConcurrentRequests) + } + + if remain > int64(f.c.maxPacket) { + // Otherwise, only use concurrency, if it would be at least two packets. + + // This is the best reasonable guess we can make. + concurrency64 := remain/int64(f.c.maxPacket) + 1 + + // We need to cap this value to an `int` size value to avoid overflow on 32-bit machines. + // So, we may as well pre-cap it to `f.c.maxConcurrentRequests`. + if concurrency64 > int64(f.c.maxConcurrentRequests) { + concurrency64 = int64(f.c.maxConcurrentRequests) + } + + return f.readFromWithConcurrency(r, int(concurrency64)) + } + } + + ch := make(chan result, 1) // reusable channel + + b := make([]byte, f.c.maxPacket) + + var read int64 + for { + // Fill the entire buffer. + n, err := io.ReadFull(r, b) + if n < 0 { + panic("sftp.File: reader returned negative count from Read") + } + + if n > 0 { + read += int64(n) + + m, err2 := f.writeChunkAt(ch, b[:n], f.offset) + f.offset += int64(m) + + if err == nil { + err = err2 + } + } + + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + return read, nil // return nil explicitly. + } + + return read, err + } + } } // Seek implements io.Seeker by setting the client offset for the next Read or // Write. It returns the next offset read. Seeking before or after the end of // the file is undefined. Seeking relative to the end calls Stat. func (f *File) Seek(offset int64, whence int) (int64, error) { + f.mu.Lock() + defer f.mu.Unlock() + + if f.handle == "" { + return 0, os.ErrClosed + } + switch whence { - case os.SEEK_SET: - f.offset = uint64(offset) - case os.SEEK_CUR: - f.offset = uint64(int64(f.offset) + offset) - case os.SEEK_END: - fi, err := f.Stat() + case io.SeekStart: + case io.SeekCurrent: + offset += f.offset + case io.SeekEnd: + fi, err := f.stat() if err != nil { - return int64(f.offset), err + return f.offset, err } - f.offset = uint64(fi.Size() + offset) + offset += fi.Size() default: - return int64(f.offset), unimplementedSeekWhence(whence) + return f.offset, unimplementedSeekWhence(whence) + } + + if offset < 0 { + return f.offset, os.ErrInvalid } - return int64(f.offset), nil + + f.offset = offset + return f.offset, nil } // Chown changes the uid/gid of the current file. func (f *File) Chown(uid, gid int) error { - return f.c.Chown(f.path, uid, gid) + f.mu.RLock() + defer f.mu.RUnlock() + + if f.handle == "" { + return os.ErrClosed + } + + return f.c.fsetstat(f.handle, sshFileXferAttrUIDGID, &FileStat{ + UID: uint32(uid), + GID: uint32(gid), + }) } // Chmod changes the permissions of the current file. +// +// See Client.Chmod for details. func (f *File) Chmod(mode os.FileMode) error { - return f.c.Chmod(f.path, mode) + f.mu.RLock() + defer f.mu.RUnlock() + + if f.handle == "" { + return os.ErrClosed + } + + return f.c.fsetstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode)) +} + +// SetExtendedData sets extended attributes of the current file. It uses the +// SSH_FILEXFER_ATTR_EXTENDED flag in the setstat request. +// +// This flag provides a general extension mechanism for vendor-specific extensions. +// Names of the attributes should be a string of the format "name@domain", where "domain" +// is a valid, registered domain name and "name" identifies the method. Server +// implementations SHOULD ignore extended data fields that they do not understand. +func (f *File) SetExtendedData(path string, extended []StatExtended) error { + f.mu.RLock() + defer f.mu.RUnlock() + + if f.handle == "" { + return os.ErrClosed + } + + attrs := &FileStat{ + Extended: extended, + } + + return f.c.fsetstat(f.handle, sshFileXferAttrExtended, attrs) } // Truncate sets the size of the current file. Although it may be safely assumed // that if the size is less than its current size it will be truncated to fit, // the SFTP protocol does not specify what behavior the server should do when setting // size greater than the current size. +// We send a SSH_FXP_FSETSTAT here since we have a file handle func (f *File) Truncate(size int64) error { - return f.c.Truncate(f.path, size) -} + f.mu.RLock() + defer f.mu.RUnlock() -func min(a, b int) int { - if a > b { - return b + if f.handle == "" { + return os.ErrClosed } - return a + + return f.c.fsetstat(f.handle, sshFileXferAttrSize, uint64(size)) } -// okOrErr returns nil if Err.Code is SSH_FX_OK, otherwise it returns the error. -func okOrErr(err error) error { - if err, ok := err.(*StatusError); ok && err.Code == ssh_FX_OK { - return nil +// Sync requests a flush of the contents of a File to stable storage. +// +// Sync requires the server to support the fsync@openssh.com extension. +func (f *File) Sync() error { + f.mu.Lock() + defer f.mu.Unlock() + + if f.handle == "" { + return os.ErrClosed } - return err -} -func eofOrErr(err error) error { - if err, ok := err.(*StatusError); ok && err.Code == ssh_FX_EOF { - return io.EOF + if data, ok := f.c.HasExtension(openssh.ExtensionFSync().Name); !ok || data != "1" { + return &StatusError{ + Code: sshFxOPUnsupported, + msg: "fsync not supported", + } } - return err -} -func unmarshalStatus(id uint32, data []byte) error { - sid, data := unmarshalUint32(data) - if sid != id { - return &unexpectedIdErr{id, sid} + id := f.c.nextID() + typ, data, err := f.c.sendPacket(context.Background(), nil, &sshFxpFsyncPacket{ + ID: id, + Handle: f.handle, + }) + + switch { + case err != nil: + return err + case typ == sshFxpStatus: + return normaliseError(unmarshalStatus(id, data)) + default: + return &unexpectedPacketErr{want: sshFxpStatus, got: typ} } - code, data := unmarshalUint32(data) - msg, data := unmarshalString(data) - lang, _ := unmarshalString(data) - return &StatusError{ - Code: code, - msg: msg, - lang: lang, +} + +// normaliseError normalises an error into a more standard form that can be +// checked against stdlib errors like io.EOF or os.ErrNotExist. +func normaliseError(err error) error { + switch err := err.(type) { + case *StatusError: + switch err.Code { + case sshFxEOF: + return io.EOF + case sshFxNoSuchFile: + return os.ErrNotExist + case sshFxPermissionDenied: + return os.ErrPermission + case sshFxOk: + return nil + default: + return err + } + default: + return err } } // flags converts the flags passed to OpenFile into ssh flags. // Unsupported flags are ignored. -func flags(f int) uint32 { +func toPflags(f int) uint32 { var out uint32 - switch f & os.O_WRONLY { - case os.O_WRONLY: - out |= ssh_FXF_WRITE + switch f & (os.O_RDONLY | os.O_WRONLY | os.O_RDWR) { case os.O_RDONLY: - out |= ssh_FXF_READ - } - if f&os.O_RDWR == os.O_RDWR { - out |= ssh_FXF_READ | ssh_FXF_WRITE + out |= sshFxfRead + case os.O_WRONLY: + out |= sshFxfWrite + case os.O_RDWR: + out |= sshFxfRead | sshFxfWrite } if f&os.O_APPEND == os.O_APPEND { - out |= ssh_FXF_APPEND + out |= sshFxfAppend } if f&os.O_CREATE == os.O_CREATE { - out |= ssh_FXF_CREAT + out |= sshFxfCreat } if f&os.O_TRUNC == os.O_TRUNC { - out |= ssh_FXF_TRUNC + out |= sshFxfTrunc } if f&os.O_EXCL == os.O_EXCL { - out |= ssh_FXF_EXCL + out |= sshFxfExcl } return out } + +// toChmodPerm converts Go permission bits to POSIX permission bits. +// +// This differs from fromFileMode in that we preserve the POSIX versions of +// setuid, setgid and sticky in m, because we've historically supported those +// bits, and we mask off any non-permission bits. +func toChmodPerm(m os.FileMode) (perm uint32) { + const mask = os.ModePerm | os.FileMode(s_ISUID|s_ISGID|s_ISVTX) + perm = uint32(m & mask) + + if m&os.ModeSetuid != 0 { + perm |= s_ISUID + } + if m&os.ModeSetgid != 0 { + perm |= s_ISGID + } + if m&os.ModeSticky != 0 { + perm |= s_ISVTX + } + + return perm +} diff --git a/client_integration_darwin_test.go b/client_integration_darwin_test.go new file mode 100644 index 00000000..826a1b0b --- /dev/null +++ b/client_integration_darwin_test.go @@ -0,0 +1,40 @@ +package sftp + +import ( + "syscall" + "testing" +) + +func TestClientStatVFS(t *testing.T) { + if *testServerImpl { + t.Skipf("go server does not support FXP_EXTENDED") + } + sftp, cmd := testClient(t, READWRITE, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + vfs, err := sftp.StatVFS("/") + if err != nil { + t.Fatal(err) + } + + // get system stats + s := syscall.Statfs_t{} + err = syscall.Statfs("/", &s) + if err != nil { + t.Fatal(err) + } + + // check some stats + if vfs.Files != uint64(s.Files) { + t.Fatal("fr_size does not match") + } + + if vfs.Bfree != uint64(s.Bfree) { + t.Fatal("f_bsize does not match") + } + + if vfs.Favail != uint64(s.Ffree) { + t.Fatal("f_namemax does not match") + } +} diff --git a/client_integration_linux_test.go b/client_integration_linux_test.go new file mode 100644 index 00000000..58d4d84a --- /dev/null +++ b/client_integration_linux_test.go @@ -0,0 +1,45 @@ +package sftp + +import ( + "syscall" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestClientStatVFS(t *testing.T) { + if *testServerImpl { + t.Skipf("go server does not support FXP_EXTENDED") + } + sftp, cmd := testClient(t, READWRITE, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + _, ok := sftp.HasExtension("statvfs@openssh.com") + require.True(t, ok, "server doesn't list statvfs extension") + + vfs, err := sftp.StatVFS("/") + if err != nil { + t.Fatal(err) + } + + // get system stats + s := syscall.Statfs_t{} + err = syscall.Statfs("/", &s) + if err != nil { + t.Fatal(err) + } + + // check some stats + if vfs.Frsize != uint64(s.Frsize) { + t.Fatalf("fr_size does not match, expected: %v, got: %v", s.Frsize, vfs.Frsize) + } + + if vfs.Bsize != uint64(s.Bsize) { + t.Fatalf("f_bsize does not match, expected: %v, got: %v", s.Bsize, vfs.Bsize) + } + + if vfs.Namemax != uint64(s.Namelen) { + t.Fatalf("f_namemax does not match, expected: %v, got: %v", s.Namelen, vfs.Namemax) + } +} diff --git a/client_integration_test.go b/client_integration_test.go index 828deea3..546ba211 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -4,73 +4,242 @@ package sftp // enable with -integration import ( + "bytes" "crypto/sha1" - "flag" + "errors" + "fmt" "io" "io/ioutil" "math/rand" + "net" "os" "os/exec" + "os/user" "path" "path/filepath" "reflect" + "regexp" + "runtime" + "sort" + "strconv" + "sync" "testing" "testing/quick" + "time" "github.com/kr/fs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const ( - READONLY = true - READWRITE = false + READONLY = true + READWRITE = false + NODELAY time.Duration = 0 debuglevel = "ERROR" // set to "DEBUG" for debugging ) -var testIntegration = flag.Bool("integration", false, "perform integration tests against sftp server process") -var testSftp = flag.String("sftp", "/usr/lib/openssh/sftp-server", "location of the sftp server binary") +type delayedWrite struct { + t time.Time + b []byte +} + +// delayedWriter wraps a writer and artificially delays the write. This is +// meant to mimic connections with various latencies. Error's returned from the +// underlying writer will panic so this should only be used over reliable +// connections. +type delayedWriter struct { + closed chan struct{} + + mu sync.Mutex + ch chan delayedWrite + closing chan struct{} +} + +func newDelayedWriter(w io.WriteCloser, delay time.Duration) io.WriteCloser { + dw := &delayedWriter{ + ch: make(chan delayedWrite, 128), + closed: make(chan struct{}), + closing: make(chan struct{}), + } + + go func() { + defer close(dw.closed) + defer w.Close() + + for writeMsg := range dw.ch { + time.Sleep(time.Until(writeMsg.t.Add(delay))) + + n, err := w.Write(writeMsg.b) + if err != nil { + panic("write error") + } + + if n < len(writeMsg.b) { + panic("showrt write") + } + } + }() + + return dw +} + +func (dw *delayedWriter) Write(b []byte) (int, error) { + dw.mu.Lock() + defer dw.mu.Unlock() + + write := delayedWrite{ + t: time.Now(), + b: append([]byte(nil), b...), + } + + select { + case <-dw.closing: + return 0, errors.New("delayedWriter is closing") + case dw.ch <- write: + } + + return len(b), nil +} + +func (dw *delayedWriter) Close() error { + dw.mu.Lock() + defer dw.mu.Unlock() + + select { + case <-dw.closing: + default: + close(dw.ch) + close(dw.closing) + } + + <-dw.closed + return nil +} + +// netPipe provides a pair of io.ReadWriteClosers connected to each other. +// The functions is identical to os.Pipe with the exception that netPipe +// provides the Read/Close guarantees that os.File derrived pipes do not. +func netPipe(t testing.TB) (io.ReadWriteCloser, io.ReadWriteCloser) { + type result struct { + net.Conn + error + } + + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + closeListener := make(chan struct{}, 1) + closeListener <- struct{}{} + + ch := make(chan result, 1) + go func() { + conn, err := l.Accept() + ch <- result{conn, err} + + if _, ok := <-closeListener; ok { + err = l.Close() + if err != nil { + t.Error(err) + } + close(closeListener) + } + }() + + c1, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + if _, ok := <-closeListener; ok { + l.Close() + close(closeListener) + } + t.Fatal(err) + } + + r := <-ch + if r.error != nil { + t.Fatal(err) + } + + return c1, r.Conn +} + +func testClientGoSvr(t testing.TB, readonly bool, delay time.Duration, opts ...ClientOption) (*Client, *exec.Cmd) { + c1, c2 := netPipe(t) + + options := []ServerOption{WithDebug(os.Stderr)} + if readonly { + options = append(options, ReadOnly()) + } + + server, err := NewServer(c1, options...) + if err != nil { + t.Fatal(err) + } + go server.Serve() + + var wr io.WriteCloser = c2 + if delay > NODELAY { + wr = newDelayedWriter(wr, delay) + } + + client, err := NewClientPipe(c2, wr, opts...) + if err != nil { + t.Fatal(err) + } -// testClient returns a *Client connected to a localy running sftp-server + // dummy command... + return client, exec.Command("true") +} + +// testClient returns a *Client connected to a locally running sftp-server // the *exec.Cmd returned must be defer Wait'd. -func testClient(t testing.TB, readonly bool) (*Client, *exec.Cmd) { +func testClient(t testing.TB, readonly bool, delay time.Duration, opts ...ClientOption) (*Client, *exec.Cmd) { if !*testIntegration { - t.Skip("skipping intergration test") + t.Skip("skipping integration test") + } + + if *testServerImpl { + return testClientGoSvr(t, readonly, delay, opts...) } + cmd := exec.Command(*testSftp, "-e", "-R", "-l", debuglevel) // log to stderr, read only if !readonly { cmd = exec.Command(*testSftp, "-e", "-l", debuglevel) // log to stderr } + cmd.Stderr = os.Stdout + pw, err := cmd.StdinPipe() if err != nil { t.Fatal(err) } + + if delay > NODELAY { + pw = newDelayedWriter(pw, delay) + } + pr, err := cmd.StdoutPipe() if err != nil { t.Fatal(err) } + if err := cmd.Start(); err != nil { t.Skipf("could not start sftp-server process: %v", err) } - sftp, err := NewClientPipe(pr, pw) + sftp, err := NewClientPipe(pr, pw, opts...) if err != nil { t.Fatal(err) } - if err := sftp.sendInit(); err != nil { - defer cmd.Wait() - t.Fatal(err) - } - if err := sftp.recvVersion(); err != nil { - defer cmd.Wait() - t.Fatal(err) - } return sftp, cmd } func TestNewClient(t *testing.T) { - sftp, cmd := testClient(t, READONLY) + sftp, cmd := testClient(t, READONLY, NODELAY) defer cmd.Wait() if err := sftp.Close(); err != nil { @@ -79,14 +248,15 @@ func TestNewClient(t *testing.T) { } func TestClientLstat(t *testing.T) { - sftp, cmd := testClient(t, READONLY) + sftp, cmd := testClient(t, READONLY, NODELAY) defer cmd.Wait() defer sftp.Close() - f, err := ioutil.TempFile("", "sftptest") + f, err := ioutil.TempFile("", "sftptest-lstat") if err != nil { t.Fatal(err) } + f.Close() defer os.Remove(f.Name()) want, err := os.Lstat(f.Name()) @@ -104,32 +274,34 @@ func TestClientLstat(t *testing.T) { } } -func TestClientLstatMissing(t *testing.T) { - sftp, cmd := testClient(t, READONLY) +func TestClientLstatIsNotExist(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NODELAY) defer cmd.Wait() defer sftp.Close() - f, err := ioutil.TempFile("", "sftptest") + f, err := ioutil.TempFile("", "sftptest-lstatisnotexist") if err != nil { t.Fatal(err) } + f.Close() os.Remove(f.Name()) - _, err = sftp.Lstat(f.Name()) - if err1, ok := err.(*StatusError); !ok || err1.Code != ssh_FX_NO_SUCH_FILE { - t.Fatalf("Lstat: want: %v, got %#v", ssh_FX_NO_SUCH_FILE, err) + if _, err := sftp.Lstat(f.Name()); !os.IsNotExist(err) { + t.Errorf("os.IsNotExist(%v) = false, want true", err) } } func TestClientMkdir(t *testing.T) { - sftp, cmd := testClient(t, READWRITE) + sftp, cmd := testClient(t, READWRITE, NODELAY) defer cmd.Wait() defer sftp.Close() - dir, err := ioutil.TempDir("", "sftptest") + dir, err := ioutil.TempDir("", "sftptest-mkdir") if err != nil { t.Fatal(err) } + defer os.RemoveAll(dir) + sub := path.Join(dir, "mkdir1") if err := sftp.Mkdir(sub); err != nil { t.Fatal(err) @@ -138,16 +310,40 @@ func TestClientMkdir(t *testing.T) { t.Fatal(err) } } +func TestClientMkdirAll(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + dir, err := ioutil.TempDir("", "sftptest-mkdirall") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + sub := path.Join(dir, "mkdir1", "mkdir2", "mkdir3") + if err := sftp.MkdirAll(sub); err != nil { + t.Fatal(err) + } + info, err := os.Lstat(sub) + if err != nil { + t.Fatal(err) + } + if !info.IsDir() { + t.Fatalf("Expected mkdirall to create dir at: %s", sub) + } +} func TestClientOpen(t *testing.T) { - sftp, cmd := testClient(t, READONLY) + sftp, cmd := testClient(t, READONLY, NODELAY) defer cmd.Wait() defer sftp.Close() - f, err := ioutil.TempFile("", "sftptest") + f, err := ioutil.TempFile("", "sftptest-open") if err != nil { t.Fatal(err) } + f.Close() defer os.Remove(f.Name()) got, err := sftp.Open(f.Name()) @@ -159,6 +355,26 @@ func TestClientOpen(t *testing.T) { } } +func TestClientOpenIsNotExist(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + if _, err := sftp.Open("/doesnt/exist"); !os.IsNotExist(err) { + t.Errorf("os.IsNotExist(%v) = false, want true", err) + } +} + +func TestClientStatIsNotExist(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + if _, err := sftp.Stat("/doesnt/exist"); !os.IsNotExist(err) { + t.Errorf("os.IsNotExist(%v) = false, want true", err) + } +} + const seekBytes = 128 * 1024 type seek struct { @@ -171,7 +387,7 @@ func (s seek) Generate(r *rand.Rand, _ int) reflect.Value { } func (s seek) set(t *testing.T, r io.ReadSeeker) { - if _, err := r.Seek(s.offset, os.SEEK_SET); err != nil { + if _, err := r.Seek(s.offset, io.SeekStart); err != nil { t.Fatalf("error while seeking with %+v: %v", s, err) } } @@ -184,29 +400,30 @@ func (s seek) current(t *testing.T, r io.ReadSeeker) { skip = -skip } - if _, err := r.Seek(mid, os.SEEK_SET); err != nil { + if _, err := r.Seek(mid, io.SeekStart); err != nil { t.Fatalf("error seeking to midpoint with %+v: %v", s, err) } - if _, err := r.Seek(skip, os.SEEK_CUR); err != nil { + if _, err := r.Seek(skip, io.SeekCurrent); err != nil { t.Fatalf("error seeking from %d with %+v: %v", mid, s, err) } } func (s seek) end(t *testing.T, r io.ReadSeeker) { - if _, err := r.Seek(-s.offset, os.SEEK_END); err != nil { + if _, err := r.Seek(-s.offset, io.SeekEnd); err != nil { t.Fatalf("error seeking from end with %+v: %v", s, err) } } func TestClientSeek(t *testing.T) { - sftp, cmd := testClient(t, READONLY) + sftp, cmd := testClient(t, READONLY, NODELAY) defer cmd.Wait() defer sftp.Close() - fOS, err := ioutil.TempFile("", "seek-test") + fOS, err := ioutil.TempFile("", "sftptest-seek") if err != nil { t.Fatal(err) } + defer os.Remove(fOS.Name()) defer fOS.Close() fSFTP, err := sftp.Open(fOS.Name()) @@ -243,11 +460,11 @@ func TestClientSeek(t *testing.T) { } func TestClientCreate(t *testing.T) { - sftp, cmd := testClient(t, READWRITE) + sftp, cmd := testClient(t, READWRITE, NODELAY) defer cmd.Wait() defer sftp.Close() - f, err := ioutil.TempFile("", "sftptest") + f, err := ioutil.TempFile("", "sftptest-create") if err != nil { t.Fatal(err) } @@ -262,11 +479,11 @@ func TestClientCreate(t *testing.T) { } func TestClientAppend(t *testing.T) { - sftp, cmd := testClient(t, READWRITE) + sftp, cmd := testClient(t, READWRITE, NODELAY) defer cmd.Wait() defer sftp.Close() - f, err := ioutil.TempFile("", "sftptest") + f, err := ioutil.TempFile("", "sftptest-append") if err != nil { t.Fatal(err) } @@ -281,32 +498,51 @@ func TestClientAppend(t *testing.T) { } func TestClientCreateFailed(t *testing.T) { - sftp, cmd := testClient(t, READONLY) + sftp, cmd := testClient(t, READONLY, NODELAY) defer cmd.Wait() defer sftp.Close() - f, err := ioutil.TempFile("", "sftptest") - if err != nil { - t.Fatal(err) - } + f, err := ioutil.TempFile("", "sftptest-createfailed") + require.NoError(t, err) + defer f.Close() defer os.Remove(f.Name()) f2, err := sftp.Create(f.Name()) - if err1, ok := err.(*StatusError); !ok || err1.Code != ssh_FX_PERMISSION_DENIED { - t.Fatalf("Create: want: %v, got %#v", ssh_FX_PERMISSION_DENIED, err) - } + require.True(t, os.IsPermission(err)) if err == nil { f2.Close() } } +func TestClientFileName(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest-filename") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + + f2, err := sftp.Open(f.Name()) + if err != nil { + t.Fatal(err) + } + defer f2.Close() + + if got, want := f2.Name(), f.Name(); got != want { + t.Fatalf("Name: got %q want %q", want, got) + } +} + func TestClientFileStat(t *testing.T) { - sftp, cmd := testClient(t, READONLY) + sftp, cmd := testClient(t, READONLY, NODELAY) defer cmd.Wait() defer sftp.Close() - f, err := ioutil.TempFile("", "sftptest") + f, err := ioutil.TempFile("", "sftptest-filestat") if err != nil { t.Fatal(err) } @@ -321,6 +557,7 @@ func TestClientFileStat(t *testing.T) { if err != nil { t.Fatal(err) } + defer f2.Close() got, err := f2.Stat() if err != nil { @@ -332,15 +569,80 @@ func TestClientFileStat(t *testing.T) { } } +func TestClientStatLink(t *testing.T) { + skipIfWindows(t) // Windows does not support links. + + sftp, cmd := testClient(t, READONLY, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest-statlink") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + + realName := f.Name() + linkName := f.Name() + ".softlink" + + // create a symlink that points at sftptest + if err := os.Symlink(realName, linkName); err != nil { + t.Fatal(err) + } + defer os.Remove(linkName) + + // compare Lstat of links + wantLstat, err := os.Lstat(linkName) + if err != nil { + t.Fatal(err) + } + wantStat, err := os.Stat(linkName) + if err != nil { + t.Fatal(err) + } + + gotLstat, err := sftp.Lstat(linkName) + if err != nil { + t.Fatal(err) + } + gotStat, err := sftp.Stat(linkName) + if err != nil { + t.Fatal(err) + } + + // check that stat is not lstat from os package + if sameFile(wantLstat, wantStat) { + t.Fatalf("Lstat / Stat(%q): both %#v %#v", f.Name(), wantLstat, wantStat) + } + + // compare Lstat of links + if !sameFile(wantLstat, gotLstat) { + t.Fatalf("Lstat(%q): want %#v, got %#v", f.Name(), wantLstat, gotLstat) + } + + // compare Stat of links + if !sameFile(wantStat, gotStat) { + t.Fatalf("Stat(%q): want %#v, got %#v", f.Name(), wantStat, gotStat) + } + + // check that stat is not lstat + if sameFile(gotLstat, gotStat) { + t.Fatalf("Lstat / Stat(%q): both %#v %#v", f.Name(), gotLstat, gotStat) + } +} + func TestClientRemove(t *testing.T) { - sftp, cmd := testClient(t, READWRITE) + sftp, cmd := testClient(t, READWRITE, NODELAY) defer cmd.Wait() defer sftp.Close() - f, err := ioutil.TempFile("", "sftptest") + f, err := ioutil.TempFile("", "sftptest-remove") if err != nil { t.Fatal(err) } + defer os.Remove(f.Name()) + f.Close() + if err := sftp.Remove(f.Name()); err != nil { t.Fatal(err) } @@ -349,15 +651,80 @@ func TestClientRemove(t *testing.T) { } } +func TestClientRemoveAll(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + // Create a temporary directory for testing + tempDir, err := ioutil.TempDir("", "sftptest-removeAll") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + // Create a directory tree + dir1, err := ioutil.TempDir(tempDir, "foo") + if err != nil { + t.Fatal(err) + } + dir2, err := ioutil.TempDir(dir1, "bar") + if err != nil { + t.Fatal(err) + } + + // Create some files within the directory tree + file1 := tempDir + "/file1.txt" + file2 := dir1 + "/file2.txt" + file3 := dir2 + "/file3.txt" + err = ioutil.WriteFile(file1, []byte("File 1"), 0644) + if err != nil { + t.Fatalf("Failed to create file: %v", err) + } + err = ioutil.WriteFile(file2, []byte("File 2"), 0644) + if err != nil { + t.Fatalf("Failed to create file: %v", err) + } + err = ioutil.WriteFile(file3, []byte("File 3"), 0644) + if err != nil { + t.Fatalf("Failed to create file: %v", err) + } + + // Call the function to delete the files recursively + err = sftp.RemoveAll(tempDir) + if err != nil { + t.Fatalf("Failed to delete files recursively: %v", err) + } + + // Check if the directories and files have been deleted + if _, err := os.Stat(dir1); !os.IsNotExist(err) { + t.Errorf("Directory %s still exists", dir1) + } + if _, err := os.Stat(dir2); !os.IsNotExist(err) { + t.Errorf("Directory %s still exists", dir2) + } + if _, err := os.Stat(file1); !os.IsNotExist(err) { + t.Errorf("File %s still exists", file1) + } + if _, err := os.Stat(file2); !os.IsNotExist(err) { + t.Errorf("File %s still exists", file2) + } + if _, err := os.Stat(file3); !os.IsNotExist(err) { + t.Errorf("File %s still exists", file3) + } +} + func TestClientRemoveDir(t *testing.T) { - sftp, cmd := testClient(t, READWRITE) + sftp, cmd := testClient(t, READWRITE, NODELAY) defer cmd.Wait() defer sftp.Close() - dir, err := ioutil.TempDir("", "sftptest") + dir, err := ioutil.TempDir("", "sftptest-removedir") if err != nil { t.Fatal(err) } + defer os.RemoveAll(dir) + if err := sftp.Remove(dir); err != nil { t.Fatal(err) } @@ -367,14 +734,16 @@ func TestClientRemoveDir(t *testing.T) { } func TestClientRemoveFailed(t *testing.T) { - sftp, cmd := testClient(t, READONLY) + sftp, cmd := testClient(t, READONLY, NODELAY) defer cmd.Wait() defer sftp.Close() - f, err := ioutil.TempFile("", "sftptest") + f, err := ioutil.TempFile("", "sftptest-removefailed") if err != nil { t.Fatal(err) } + defer os.Remove(f.Name()) + if err := sftp.Remove(f.Name()); err == nil { t.Fatalf("Remove(%v): want: permission denied, got %v", f.Name(), err) } @@ -384,15 +753,18 @@ func TestClientRemoveFailed(t *testing.T) { } func TestClientRename(t *testing.T) { - sftp, cmd := testClient(t, READWRITE) + sftp, cmd := testClient(t, READWRITE, NODELAY) defer cmd.Wait() defer sftp.Close() - f, err := ioutil.TempFile("", "sftptest") - if err != nil { - t.Fatal(err) - } - f2 := f.Name() + ".new" + dir, err := ioutil.TempDir("", "sftptest-rename") + require.NoError(t, err) + defer os.RemoveAll(dir) + f, err := os.Create(filepath.Join(dir, "old")) + require.NoError(t, err) + f.Close() + + f2 := filepath.Join(dir, "new") if err := sftp.Rename(f.Name(), f2); err != nil { t.Fatal(err) } @@ -404,180 +776,1100 @@ func TestClientRename(t *testing.T) { } } -func TestClientReadLine(t *testing.T) { - sftp, cmd := testClient(t, READWRITE) +func TestClientPosixRename(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NODELAY) defer cmd.Wait() defer sftp.Close() - f, err := ioutil.TempFile("", "sftptest") - if err != nil { + dir, err := ioutil.TempDir("", "sftptest-posixrename") + require.NoError(t, err) + defer os.RemoveAll(dir) + f, err := os.Create(filepath.Join(dir, "old")) + require.NoError(t, err) + f.Close() + + f2 := filepath.Join(dir, "new") + if err := sftp.PosixRename(f.Name(), f2); err != nil { t.Fatal(err) } - f2 := f.Name() + ".sym" - if err := os.Symlink(f.Name(), f2); err != nil { + if _, err := os.Lstat(f.Name()); !os.IsNotExist(err) { t.Fatal(err) } - if _, err := sftp.ReadLink(f2); err != nil { + if _, err := os.Lstat(f2); err != nil { t.Fatal(err) } } -func sameFile(want, got os.FileInfo) bool { - return want.Name() == got.Name() && - want.Size() == got.Size() +func TestClientGetwd(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + lwd, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + rwd, err := sftp.Getwd() + if err != nil { + t.Fatal(err) + } + if !filepath.IsAbs(rwd) { + t.Fatalf("Getwd: wanted absolute path, got %q", rwd) + } + if filepath.ToSlash(lwd) != filepath.ToSlash(rwd) { + t.Fatalf("Getwd: want %q, got %q", lwd, rwd) + } } -var clientReadTests = []struct { - n int64 -}{ - {0}, - {1}, - {1000}, - {1024}, - {1025}, - {2048}, - {4096}, - {1 << 12}, - {1 << 13}, - {1 << 14}, - {1 << 15}, - {1 << 16}, - {1 << 17}, - {1 << 18}, - {1 << 19}, - {1 << 20}, +func TestClientReadLink(t *testing.T) { + if runtime.GOOS == "windows" && *testServerImpl { + // os.Symlink requires privilege escalation. + t.Skip() + } + + sftp, cmd := testClient(t, READWRITE, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + dir, err := ioutil.TempDir("", "sftptest-readlink") + require.NoError(t, err) + defer os.RemoveAll(dir) + f, err := os.Create(filepath.Join(dir, "file")) + require.NoError(t, err) + f.Close() + + f2 := filepath.Join(dir, "symlink") + if err := os.Symlink(f.Name(), f2); err != nil { + t.Fatal(err) + } + if rl, err := sftp.ReadLink(f2); err != nil { + t.Fatal(err) + } else if rl != f.Name() { + t.Fatalf("unexpected link target: %v, not %v", rl, f.Name()) + } } -func TestClientRead(t *testing.T) { - sftp, cmd := testClient(t, READONLY) +func TestClientLink(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NODELAY) defer cmd.Wait() defer sftp.Close() - d, err := ioutil.TempDir("", "sftptest") + dir, err := ioutil.TempDir("", "sftptest-link") + require.NoError(t, err) + defer os.RemoveAll(dir) + + f, err := os.Create(filepath.Join(dir, "file")) + require.NoError(t, err) + data := []byte("linktest") + _, err = f.Write(data) + f.Close() if err != nil { t.Fatal(err) } - defer os.RemoveAll(d) - for _, tt := range clientReadTests { - f, err := ioutil.TempFile(d, "read-test") + f2 := filepath.Join(dir, "link") + if err := sftp.Link(f.Name(), f2); err != nil { + t.Fatal(err) + } + if st2, err := sftp.Stat(f2); err != nil { + t.Fatal(err) + } else if int(st2.Size()) != len(data) { + t.Fatalf("unexpected link size: %v, not %v", st2.Size(), len(data)) + } +} + +func TestClientSymlink(t *testing.T) { + if runtime.GOOS == "windows" && *testServerImpl { + // os.Symlink requires privilege escalation. + t.Skip() + } + + sftp, cmd := testClient(t, READWRITE, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + dir, err := ioutil.TempDir("", "sftptest-symlink") + require.NoError(t, err) + defer os.RemoveAll(dir) + f, err := os.Create(filepath.Join(dir, "file")) + require.NoError(t, err) + f.Close() + + f2 := filepath.Join(dir, "symlink") + if err := sftp.Symlink(f.Name(), f2); err != nil { + t.Fatal(err) + } + if rl, err := sftp.ReadLink(f2); err != nil { + t.Fatal(err) + } else if rl != f.Name() { + t.Fatalf("unexpected link target: %v, not %v", rl, f.Name()) + } +} + +func TestClientChmod(t *testing.T) { + skipIfWindows(t) // No UNIX permissions. + sftp, cmd := testClient(t, READWRITE, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest-chmod") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + f.Close() + + if err := sftp.Chmod(f.Name(), 0531); err != nil { + t.Fatal(err) + } + if stat, err := os.Stat(f.Name()); err != nil { + t.Fatal(err) + } else if stat.Mode()&os.ModePerm != 0531 { + t.Fatalf("invalid perm %o\n", stat.Mode()) + } + + sf, err := sftp.Open(f.Name()) + require.NoError(t, err) + require.NoError(t, sf.Chmod(0500)) + sf.Close() + + stat, err := os.Stat(f.Name()) + require.NoError(t, err) + require.EqualValues(t, 0500, stat.Mode()) +} + +func TestClientChmodReadonly(t *testing.T) { + skipIfWindows(t) // No UNIX permissions. + sftp, cmd := testClient(t, READONLY, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest-chmodreadonly") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + f.Close() + + if err := sftp.Chmod(f.Name(), 0531); err == nil { + t.Fatal("expected error") + } +} + +func TestClientSetuid(t *testing.T) { + skipIfWindows(t) // No UNIX permissions. + if *testServerImpl { + t.Skipf("skipping with -testserver") + } + + sftp, cmd := testClient(t, READWRITE, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest-setuid") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + f.Close() + + const allPerm = os.ModePerm | os.ModeSetuid | os.ModeSetgid | os.ModeSticky | + os.FileMode(s_ISUID|s_ISGID|s_ISVTX) + + for _, c := range []struct { + goPerm os.FileMode + posixPerm uint32 + }{ + {os.ModeSetuid, s_ISUID}, + {os.ModeSetgid, s_ISGID}, + {os.ModeSticky, s_ISVTX}, + {os.ModeSetuid | os.ModeSticky, s_ISUID | s_ISVTX}, + } { + goPerm := 0700 | c.goPerm + posixPerm := 0700 | c.posixPerm + + err = sftp.Chmod(f.Name(), goPerm) + require.NoError(t, err) + + info, err := sftp.Stat(f.Name()) + require.NoError(t, err) + require.Equal(t, goPerm, info.Mode()&allPerm) + + err = sftp.Chmod(f.Name(), 0700) // Reset funny bits. + require.NoError(t, err) + + // For historical reasons, we also support literal POSIX mode bits in + // Chmod. Stat should still translate these to Go os.FileMode bits. + err = sftp.Chmod(f.Name(), os.FileMode(posixPerm)) + require.NoError(t, err) + + info, err = sftp.Stat(f.Name()) + require.NoError(t, err) + require.Equal(t, goPerm, info.Mode()&allPerm) + } +} + +func TestClientChown(t *testing.T) { + skipIfWindows(t) // No UNIX permissions. + sftp, cmd := testClient(t, READWRITE, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + usr, err := user.Current() + if err != nil { + t.Fatal(err) + } + if usr.Uid != "0" { + t.Log("must be root to run chown tests") + t.Skip() + } + + chownto, err := user.Lookup("daemon") // seems common-ish... + if err != nil { + t.Fatal(err) + } + toUID, err := strconv.Atoi(chownto.Uid) + if err != nil { + t.Fatal(err) + } + toGID, err := strconv.Atoi(chownto.Gid) + if err != nil { + t.Fatal(err) + } + + f, err := ioutil.TempFile("", "sftptest-chown") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + f.Close() + + before, err := exec.Command("ls", "-nl", f.Name()).Output() + if err != nil { + t.Fatal(err) + } + if err := sftp.Chown(f.Name(), toUID, toGID); err != nil { + t.Fatal(err) + } + after, err := exec.Command("ls", "-nl", f.Name()).Output() + if err != nil { + t.Fatal(err) + } + + spaceRegex := regexp.MustCompile(`\s+`) + + beforeWords := spaceRegex.Split(string(before), -1) + if beforeWords[2] != "0" { + t.Fatalf("bad previous user? should be root") + } + afterWords := spaceRegex.Split(string(after), -1) + if afterWords[2] != chownto.Uid || afterWords[3] != chownto.Gid { + t.Fatalf("bad chown: %#v", afterWords) + } + t.Logf("before: %v", string(before)) + t.Logf(" after: %v", string(after)) +} + +func TestClientChownReadonly(t *testing.T) { + skipIfWindows(t) // No UNIX permissions. + sftp, cmd := testClient(t, READONLY, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + usr, err := user.Current() + if err != nil { + t.Fatal(err) + } + if usr.Uid != "0" { + t.Log("must be root to run chown tests") + t.Skip() + } + + chownto, err := user.Lookup("daemon") // seems common-ish... + if err != nil { + t.Fatal(err) + } + toUID, err := strconv.Atoi(chownto.Uid) + if err != nil { + t.Fatal(err) + } + toGID, err := strconv.Atoi(chownto.Gid) + if err != nil { + t.Fatal(err) + } + + f, err := ioutil.TempFile("", "sftptest-chownreadonly") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + f.Close() + + if err := sftp.Chown(f.Name(), toUID, toGID); err == nil { + t.Fatal("expected error") + } +} + +func TestClientChtimes(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest-chtimes") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + f.Close() + + atime := time.Date(2013, 2, 23, 13, 24, 35, 0, time.UTC) + mtime := time.Date(1985, 6, 12, 6, 6, 6, 0, time.UTC) + if err := sftp.Chtimes(f.Name(), atime, mtime); err != nil { + t.Fatal(err) + } + if stat, err := os.Stat(f.Name()); err != nil { + t.Fatal(err) + } else if stat.ModTime().Sub(mtime) != 0 { + t.Fatalf("incorrect mtime: %v vs %v", stat.ModTime(), mtime) + } +} + +func TestClientChtimesReadonly(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest-chtimesreadonly") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + f.Close() + + atime := time.Date(2013, 2, 23, 13, 24, 35, 0, time.UTC) + mtime := time.Date(1985, 6, 12, 6, 6, 6, 0, time.UTC) + if err := sftp.Chtimes(f.Name(), atime, mtime); err == nil { + t.Fatal("expected error") + } +} + +func TestClientTruncate(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest-truncate") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + fname := f.Name() + + if n, err := f.Write([]byte("hello world")); n != 11 || err != nil { + t.Fatal(err) + } + f.Close() + + if err := sftp.Truncate(fname, 5); err != nil { + t.Fatal(err) + } + if stat, err := os.Stat(fname); err != nil { + t.Fatal(err) + } else if stat.Size() != 5 { + t.Fatalf("unexpected size: %d", stat.Size()) + } +} + +func TestClientTruncateReadonly(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest-truncreadonly") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + fname := f.Name() + + if n, err := f.Write([]byte("hello world")); n != 11 || err != nil { + t.Fatal(err) + } + f.Close() + + if err := sftp.Truncate(fname, 5); err == nil { + t.Fatal("expected error") + } + if stat, err := os.Stat(fname); err != nil { + t.Fatal(err) + } else if stat.Size() != 11 { + t.Fatalf("unexpected size: %d", stat.Size()) + } +} + +func sameFile(want, got os.FileInfo) bool { + _, wantName := filepath.Split(want.Name()) + _, gotName := filepath.Split(got.Name()) + return wantName == gotName && + want.Size() == got.Size() +} + +func TestClientReadSimple(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + d, err := ioutil.TempDir("", "sftptest-readsimple") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(d) + + f, err := ioutil.TempFile(d, "read-test") + if err != nil { + t.Fatal(err) + } + fname := f.Name() + f.Write([]byte("hello")) + f.Close() + + f2, err := sftp.Open(fname) + if err != nil { + t.Fatal(err) + } + defer f2.Close() + stuff := make([]byte, 32) + n, err := f2.Read(stuff) + if err != nil && err != io.EOF { + t.Fatalf("err: %v", err) + } + if n != 5 { + t.Fatalf("n should be 5, is %v", n) + } + if string(stuff[0:5]) != "hello" { + t.Fatalf("invalid contents") + } +} + +func TestClientReadSequential(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + sftp.disableConcurrentReads = true + d, err := ioutil.TempDir("", "sftptest-readsequential") + require.NoError(t, err) + + defer os.RemoveAll(d) + + f, err := ioutil.TempFile(d, "read-sequential-test") + require.NoError(t, err) + fname := f.Name() + content := []byte("hello world") + f.Write(content) + f.Close() + + for _, maxPktSize := range []int{1, 2, 3, 4} { + sftp.maxPacket = maxPktSize + + sftpFile, err := sftp.Open(fname) + require.NoError(t, err) + + stuff := make([]byte, 32) + n, err := sftpFile.Read(stuff) + require.ErrorIs(t, err, io.EOF) + require.Equal(t, len(content), n) + require.Equal(t, content, stuff[0:len(content)]) + + err = sftpFile.Close() + require.NoError(t, err) + + sftpFile, err = sftp.Open(fname) + require.NoError(t, err) + + stuff = make([]byte, 5) + n, err = sftpFile.Read(stuff) + require.NoError(t, err) + require.Equal(t, len(stuff), n) + require.Equal(t, content[:len(stuff)], stuff) + + err = sftpFile.Close() + require.NoError(t, err) + + // now read from a offset + off := int64(3) + sftpFile, err = sftp.Open(fname) + require.NoError(t, err) + + stuff = make([]byte, 5) + n, err = sftpFile.ReadAt(stuff, off) + require.NoError(t, err) + require.Equal(t, len(stuff), n) + require.Equal(t, content[off:off+int64(len(stuff))], stuff) + + err = sftpFile.Close() + require.NoError(t, err) + } +} + +// this writer requires maxPacket = 3 and always returns an error for the second write call +type lastChunkErrSequentialWriter struct { + counter int +} + +func (w *lastChunkErrSequentialWriter) Write(b []byte) (int, error) { + w.counter++ + if w.counter == 1 { + if len(b) != 3 { + return 0, errors.New("this writer requires maxPacket = 3, please set MaxPacketChecked(3)") + } + return len(b), nil + } + return 1, errors.New("this writer fails after the first write") +} + +func TestClientWriteSequentialWriterErr(t *testing.T) { + client, cmd := testClient(t, READONLY, NODELAY, MaxPacketChecked(3)) + defer cmd.Wait() + defer client.Close() + + d, err := ioutil.TempDir("", "sftptest-writesequential-writeerr") + require.NoError(t, err) + + defer os.RemoveAll(d) + + f, err := ioutil.TempFile(d, "write-sequential-writeerr-test") + require.NoError(t, err) + fname := f.Name() + _, err = f.Write([]byte("12345")) + require.NoError(t, err) + require.NoError(t, f.Close()) + + sftpFile, err := client.Open(fname) + require.NoError(t, err) + defer sftpFile.Close() + + w := &lastChunkErrSequentialWriter{} + written, err := sftpFile.writeToSequential(w) + assert.Error(t, err) + expected := int64(4) + if written != expected { + t.Errorf("sftpFile.Write() = %d, but expected %d", written, expected) + } + assert.Equal(t, 2, w.counter) +} + +func TestClientReadDir(t *testing.T) { + sftp1, cmd1 := testClient(t, READONLY, NODELAY) + sftp2, cmd2 := testClientGoSvr(t, READONLY, NODELAY) + defer cmd1.Wait() + defer cmd2.Wait() + defer sftp1.Close() + defer sftp2.Close() + + dir := os.TempDir() + + d, err := os.Open(dir) + if err != nil { + t.Fatal(err) + } + defer d.Close() + osfiles, err := d.Readdir(4096) + if err != nil { + t.Fatal(err) + } + + sftp1Files, err := sftp1.ReadDir(dir) + if err != nil { + t.Fatal(err) + } + sftp2Files, err := sftp2.ReadDir(dir) + if err != nil { + t.Fatal(err) + } + + osFilesByName := map[string]os.FileInfo{} + for _, f := range osfiles { + osFilesByName[f.Name()] = f + } + sftp1FilesByName := map[string]os.FileInfo{} + for _, f := range sftp1Files { + sftp1FilesByName[f.Name()] = f + } + sftp2FilesByName := map[string]os.FileInfo{} + for _, f := range sftp2Files { + sftp2FilesByName[f.Name()] = f + } + + if len(osFilesByName) != len(sftp1FilesByName) || len(sftp1FilesByName) != len(sftp2FilesByName) { + t.Fatalf("os gives %v, sftp1 gives %v, sftp2 gives %v", len(osFilesByName), len(sftp1FilesByName), len(sftp2FilesByName)) + } + + for name, osF := range osFilesByName { + sftp1F, ok := sftp1FilesByName[name] + if !ok { + t.Fatalf("%v present in os but not sftp1", name) + } + sftp2F, ok := sftp2FilesByName[name] + if !ok { + t.Fatalf("%v present in os but not sftp2", name) + } + + //t.Logf("%v: %v %v %v", name, osF, sftp1F, sftp2F) + if osF.Size() != sftp1F.Size() || sftp1F.Size() != sftp2F.Size() { + t.Fatalf("size %v %v %v", osF.Size(), sftp1F.Size(), sftp2F.Size()) + } + if osF.IsDir() != sftp1F.IsDir() || sftp1F.IsDir() != sftp2F.IsDir() { + t.Fatalf("isdir %v %v %v", osF.IsDir(), sftp1F.IsDir(), sftp2F.IsDir()) + } + if osF.ModTime().Sub(sftp1F.ModTime()) > time.Second || sftp1F.ModTime() != sftp2F.ModTime() { + t.Fatalf("modtime %v %v %v", osF.ModTime(), sftp1F.ModTime(), sftp2F.ModTime()) + } + if osF.Mode() != sftp1F.Mode() || sftp1F.Mode() != sftp2F.Mode() { + t.Fatalf("mode %x %x %x", osF.Mode(), sftp1F.Mode(), sftp2F.Mode()) + } + } +} + +var clientReadTests = []struct { + n int64 +}{ + {0}, + {1}, + {1000}, + {1024}, + {1025}, + {2048}, + {4096}, + {1 << 12}, + {1 << 13}, + {1 << 14}, + {1 << 15}, + {1 << 16}, + {1 << 17}, + {1 << 18}, + {1 << 19}, + {1 << 20}, +} + +func TestClientRead(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + d, err := ioutil.TempDir("", "sftptest-read") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(d) + + for _, disableConcurrentReads := range []bool{true, false} { + for _, tt := range clientReadTests { + f, err := ioutil.TempFile(d, "read-test") + if err != nil { + t.Fatal(err) + } + defer f.Close() + hash := writeN(t, f, tt.n) + sftp.disableConcurrentReads = disableConcurrentReads + f2, err := sftp.Open(f.Name()) + if err != nil { + t.Fatal(err) + } + defer f2.Close() + hash2, n := readHash(t, f2) + if hash != hash2 || tt.n != n { + t.Errorf("Read: hash: want: %q, got %q, read: want: %v, got %v", hash, hash2, tt.n, n) + } + } + } +} + +// readHash reads r until EOF returning the number of bytes read +// and the hash of the contents. +func readHash(t *testing.T, r io.Reader) (string, int64) { + h := sha1.New() + read, err := io.Copy(h, r) + if err != nil { + t.Fatal(err) + } + return string(h.Sum(nil)), read +} + +// writeN writes n bytes of random data to w and returns the +// hash of that data. +func writeN(t *testing.T, w io.Writer, n int64) string { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + + h := sha1.New() + + mw := io.MultiWriter(w, h) + + written, err := io.CopyN(mw, r, n) + if err != nil { + t.Fatal(err) + } + if written != n { + t.Fatalf("CopyN(%v): wrote: %v", n, written) + } + return string(h.Sum(nil)) +} + +var clientWriteTests = []struct { + n int + total int64 // cumulative file size +}{ + {0, 0}, + {1, 1}, + {0, 1}, + {999, 1000}, + {24, 1024}, + {1023, 2047}, + {2048, 4095}, + {1 << 12, 8191}, + {1 << 13, 16383}, + {1 << 14, 32767}, + {1 << 15, 65535}, + {1 << 16, 131071}, + {1 << 17, 262143}, + {1 << 18, 524287}, + {1 << 19, 1048575}, + {1 << 20, 2097151}, + {1 << 21, 4194303}, +} + +func TestClientWrite(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + d, err := ioutil.TempDir("", "sftptest-write") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(d) + + f := path.Join(d, "writeTest") + w, err := sftp.Create(f) + if err != nil { + t.Fatal(err) + } + defer w.Close() + + for _, tt := range clientWriteTests { + got, err := w.Write(make([]byte, tt.n)) if err != nil { t.Fatal(err) } - defer f.Close() - hash := writeN(t, f, tt.n) - f2, err := sftp.Open(f.Name()) + if got != tt.n { + t.Errorf("Write(%v): wrote: want: %v, got %v", tt.n, tt.n, got) + } + fi, err := os.Stat(f) if err != nil { t.Fatal(err) } - defer f2.Close() - hash2, n := readHash(t, f2) - if hash != hash2 || tt.n != n { - t.Errorf("Read: hash: want: %q, got %q, read: want: %v, got %v", hash, hash2, tt.n, n) + if total := fi.Size(); total != tt.total { + t.Errorf("Write(%v): size: want: %v, got %v", tt.n, tt.total, total) + } + } +} + +// ReadFrom is basically Write with io.Reader as the arg +func TestClientReadFrom(t *testing.T) { + sftp, cmd := testClient(t, READWRITE, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + d, err := ioutil.TempDir("", "sftptest-readfrom") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(d) + + f := path.Join(d, "writeTest") + w, err := sftp.Create(f) + if err != nil { + t.Fatal(err) + } + defer w.Close() + + for _, tt := range clientWriteTests { + got, err := w.ReadFrom(bytes.NewReader(make([]byte, tt.n))) + if err != nil { + t.Fatal(err) + } + if got != int64(tt.n) { + t.Errorf("Write(%v): wrote: want: %v, got %v", tt.n, tt.n, got) + } + fi, err := os.Stat(f) + if err != nil { + t.Fatal(err) } + if total := fi.Size(); total != tt.total { + t.Errorf("Write(%v): size: want: %v, got %v", tt.n, tt.total, total) + } + } +} + +// A sizedReader is a Reader with a completely arbitrary Size. +type sizedReader struct { + io.Reader + size int +} + +func (r *sizedReader) Size() int { return r.size } + +// Test File.ReadFrom's handling of a Reader's Size: +// it should be used as a heuristic for determining concurrency only. +func TestClientReadFromSizeMismatch(t *testing.T) { + const ( + packetSize = 1024 + filesize = 4 * packetSize + ) + + sftp, cmd := testClient(t, READWRITE, NODELAY, MaxPacketChecked(packetSize), UseConcurrentWrites(true)) + defer cmd.Wait() + defer sftp.Close() + + d, err := ioutil.TempDir("", "sftptest-readfrom-size-mismatch") + if err != nil { + t.Fatal("cannot create temp dir:", err) + } + defer os.RemoveAll(d) + + buf := make([]byte, filesize) + + for i, reportedSize := range []int{ + -1, filesize - 100, filesize, filesize + 100, + } { + t.Run(fmt.Sprint(i), func(t *testing.T) { + r := &sizedReader{Reader: bytes.NewReader(buf), size: reportedSize} + + f := path.Join(d, fmt.Sprint(i)) + w, err := sftp.Create(f) + if err != nil { + t.Fatal("unexpected error:", err) + } + defer w.Close() + + n, err := w.ReadFrom(r) + assert.EqualValues(t, filesize, n) + + fi, err := os.Stat(f) + if err != nil { + t.Fatal("unexpected error:", err) + } + assert.EqualValues(t, filesize, fi.Size()) + }) } } -// readHash reads r until EOF returning the number of bytes read -// and the hash of the contents. -func readHash(t *testing.T, r io.Reader) (string, int64) { - h := sha1.New() - tr := io.TeeReader(r, h) - read, err := io.Copy(ioutil.Discard, tr) - if err != nil { - t.Fatal(err) +// Issue #145 in github +// Deadlock in ReadFrom when network drops after 1 good packet. +// Deadlock would occur anytime desiredInFlight-inFlight==2 and 2 errors +// occurred in a row. The channel to report the errors only had a buffer +// of 1 and 2 would be sent. +var errFakeNet = errors.New("Fake network issue") + +func TestClientReadFromDeadlock(t *testing.T) { + for i := 0; i < 5; i++ { + clientWriteDeadlock(t, i, func(f *File) { + b := make([]byte, 32768*4) + content := bytes.NewReader(b) + _, err := f.ReadFrom(content) + if !errors.Is(err, errFakeNet) { + t.Fatal("Didn't receive correct error:", err) + } + }) } - return string(h.Sum(nil)), read } -// writeN writes n bytes of random data to w and returns the -// hash of that data. -func writeN(t *testing.T, w io.Writer, n int64) string { - rand, err := os.Open("/dev/urandom") - if err != nil { - t.Fatal(err) +// Write has exact same problem +func TestClientWriteDeadlock(t *testing.T) { + for i := 0; i < 5; i++ { + clientWriteDeadlock(t, i, func(f *File) { + b := make([]byte, 32768*4) + + _, err := f.Write(b) + if !errors.Is(err, errFakeNet) { + t.Fatal("Didn't receive correct error:", err) + } + }) } - defer rand.Close() +} - h := sha1.New() +type timeBombWriter struct { + count int + w io.WriteCloser +} - mw := io.MultiWriter(w, h) +func (w *timeBombWriter) Write(b []byte) (int, error) { + if w.count < 1 { + return 0, errFakeNet + } + + w.count-- + return w.w.Write(b) +} + +func (w *timeBombWriter) Close() error { + return w.w.Close() +} + +// shared body for both previous tests +func clientWriteDeadlock(t *testing.T, N int, badfunc func(*File)) { + if !*testServerImpl { + t.Skipf("skipping without -testserver") + } + + sftp, cmd := testClient(t, READWRITE, NODELAY) + defer cmd.Wait() + defer sftp.Close() - written, err := io.CopyN(mw, rand, n) + d, err := ioutil.TempDir("", "sftptest-writedeadlock") if err != nil { t.Fatal(err) } - if written != n { - t.Fatalf("CopyN(%v): wrote: %v", n, written) + defer os.RemoveAll(d) + + f := path.Join(d, "writeTest") + w, err := sftp.Create(f) + if err != nil { + t.Fatal(err) } - return string(h.Sum(nil)) + defer w.Close() + + // Override the clienConn Writer with a failing version + // Replicates network error/drop part way through (after N good writes) + wrap := sftp.clientConn.conn.WriteCloser + sftp.clientConn.conn.WriteCloser = &timeBombWriter{ + count: N, + w: wrap, + } + + // this locked (before the fix) + badfunc(w) } -var clientWriteTests = []struct { - n int - total int64 // cumulative file size -}{ - {0, 0}, - {1, 1}, - {0, 1}, - {999, 1000}, - {24, 1024}, - {1023, 2047}, - {2048, 4095}, - {1 << 12, 8191}, - {1 << 13, 16383}, - {1 << 14, 32767}, - {1 << 15, 65535}, - {1 << 16, 131071}, - {1 << 17, 262143}, - {1 << 18, 524287}, - {1 << 19, 1048575}, - {1 << 20, 2097151}, - {1 << 21, 4194303}, +// Read/WriteTo has this issue as well +func TestClientReadDeadlock(t *testing.T) { + for i := 0; i < 3; i++ { + clientReadDeadlock(t, i, func(f *File) { + b := make([]byte, 32768*4) + + _, err := f.Read(b) + if !errors.Is(err, errFakeNet) { + t.Fatal("Didn't receive correct error:", err) + } + }) + } } -func TestClientWrite(t *testing.T) { - sftp, cmd := testClient(t, READWRITE) +func TestClientWriteToDeadlock(t *testing.T) { + for i := 0; i < 3; i++ { + clientReadDeadlock(t, i, func(f *File) { + b := make([]byte, 32768*4) + + buf := bytes.NewBuffer(b) + + _, err := f.WriteTo(buf) + if !errors.Is(err, errFakeNet) { + t.Fatal("Didn't receive correct error:", err) + } + }) + } +} + +func clientReadDeadlock(t *testing.T, N int, badfunc func(*File)) { + if !*testServerImpl { + t.Skipf("skipping without -testserver") + } + sftp, cmd := testClient(t, READWRITE, NODELAY) defer cmd.Wait() defer sftp.Close() - d, err := ioutil.TempDir("", "sftptest") + d, err := ioutil.TempDir("", "sftptest-readdeadlock") if err != nil { t.Fatal(err) } defer os.RemoveAll(d) f := path.Join(d, "writeTest") + w, err := sftp.Create(f) if err != nil { t.Fatal(err) } defer w.Close() - for _, tt := range clientWriteTests { - got, err := w.Write(make([]byte, tt.n)) - if err != nil { - t.Fatal(err) - } - if got != tt.n { - t.Errorf("Write(%v): wrote: want: %v, got %v", tt.n, tt.n, got) - } - fi, err := os.Stat(f) - if err != nil { - t.Fatal(err) - } - if total := fi.Size(); total != tt.total { - t.Errorf("Write(%v): size: want: %v, got %v", tt.n, tt.total, total) - } + // write the data for the read tests + b := make([]byte, 32768*4) + w.Write(b) + + // open new copy of file for read tests + r, err := sftp.Open(f) + if err != nil { + t.Fatal(err) } + defer r.Close() + + // Override the clienConn Writer with a failing version + // Replicates network error/drop part way through (after N good writes) + wrap := sftp.clientConn.conn.WriteCloser + sftp.clientConn.conn.WriteCloser = &timeBombWriter{ + count: N, + w: wrap, + } + + // this locked (before the fix) + badfunc(r) } -// taken from github.com/kr/fs/walk_test.go +func TestClientSyncGo(t *testing.T) { + if !*testServerImpl { + t.Skipf("skipping without -testserver") + } + err := testClientSync(t) + + // Since Server does not support the fsync extension, we can only + // check that we get the right error. + require.Error(t, err) + + switch err := err.(type) { + case *StatusError: + assert.Equal(t, ErrSSHFxOpUnsupported, err.FxCode()) + default: + t.Error(err) + } +} + +func TestClientSyncSFTP(t *testing.T) { + if *testServerImpl { + t.Skipf("skipping with -testserver") + } + err := testClientSync(t) + assert.NoError(t, err) +} + +func testClientSync(t *testing.T) error { + sftp, cmd := testClient(t, READWRITE, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + d, err := ioutil.TempDir("", "sftptest.sync") + require.NoError(t, err) + defer os.RemoveAll(d) + + f := path.Join(d, "syncTest") + w, err := sftp.Create(f) + require.NoError(t, err) + defer w.Close() -type PathTest struct { - path, result string + return w.Sync() } +// taken from github.com/kr/fs/walk_test.go + type Node struct { name string entries []*Node // nil if the entry is a file @@ -664,21 +1956,21 @@ func mark(path string, info os.FileInfo, err error, errors *[]error, clear bool) } func TestClientWalk(t *testing.T) { - sftp, cmd := testClient(t, READONLY) + sftp, cmd := testClient(t, READONLY, NODELAY) defer cmd.Wait() defer sftp.Close() makeTree(t) errors := make([]error, 0, 10) clear := true - markFn := func(walker *fs.Walker) (err error) { + markFn := func(walker *fs.Walker) error { for walker.Step() { - err = mark(walker.Path(), walker.Stat(), walker.Err(), &errors, clear) + err := mark(walker.Path(), walker.Stat(), walker.Err(), &errors, clear) if err != nil { - break + return err } } - return err + return nil } // Expect no errors. err := markFn(sftp.Walk(tree.name)) @@ -717,41 +2009,389 @@ func TestClientWalk(t *testing.T) { checkMarks(t, true) errors = errors[0:0] - // 4) capture errors, stop after first error. - // mark respective subtrees manually - markTree(tree.entries[1]) - markTree(tree.entries[3]) - // correct double-marking of directory itself - tree.entries[1].mark-- - tree.entries[3].mark-- - clear = false // error will stop processing - err = markFn(sftp.Walk(tree.name)) - if err == nil { - t.Fatalf("expected error return from Walk") - } - if len(errors) != 1 { - t.Errorf("expected 1 error, got %d: %s", len(errors), errors) + // 4) capture errors, stop after first error. + // mark respective subtrees manually + markTree(tree.entries[1]) + markTree(tree.entries[3]) + // correct double-marking of directory itself + tree.entries[1].mark-- + tree.entries[3].mark-- + clear = false // error will stop processing + err = markFn(sftp.Walk(tree.name)) + if err == nil { + t.Fatalf("expected error return from Walk") + } + if len(errors) != 1 { + t.Errorf("expected 1 error, got %d: %s", len(errors), errors) + } + // the inaccessible subtrees were marked manually + checkMarks(t, false) + errors = errors[0:0] + + // restore permissions + os.Chmod(filepath.Join(tree.name, tree.entries[1].name), 0770) + os.Chmod(filepath.Join(tree.name, tree.entries[3].name), 0770) + } + + // cleanup + if err := os.RemoveAll(tree.name); err != nil { + t.Errorf("removeTree: %v", err) + } +} + +type MatchTest struct { + pattern, s string + match bool + err error +} + +var matchTests = []MatchTest{ + {"abc", "abc", true, nil}, + {"*", "abc", true, nil}, + {"*c", "abc", true, nil}, + {"a*", "a", true, nil}, + {"a*", "abc", true, nil}, + {"a*", "ab/c", false, nil}, + {"a*/b", "abc/b", true, nil}, + {"a*/b", "a/c/b", false, nil}, + {"a*b*c*d*e*/f", "axbxcxdxe/f", true, nil}, + {"a*b*c*d*e*/f", "axbxcxdxexxx/f", true, nil}, + {"a*b*c*d*e*/f", "axbxcxdxe/xxx/f", false, nil}, + {"a*b*c*d*e*/f", "axbxcxdxexxx/fff", false, nil}, + {"a*b?c*x", "abxbbxdbxebxczzx", true, nil}, + {"a*b?c*x", "abxbbxdbxebxczzy", false, nil}, + {"ab[c]", "abc", true, nil}, + {"ab[b-d]", "abc", true, nil}, + {"ab[e-g]", "abc", false, nil}, + {"ab[^c]", "abc", false, nil}, + {"ab[^b-d]", "abc", false, nil}, + {"ab[^e-g]", "abc", true, nil}, + {"a\\*b", "a*b", true, nil}, + {"a\\*b", "ab", false, nil}, + {"a?b", "a☺b", true, nil}, + {"a[^a]b", "a☺b", true, nil}, + {"a???b", "a☺b", false, nil}, + {"a[^a][^a][^a]b", "a☺b", false, nil}, + {"[a-ζ]*", "α", true, nil}, + {"*[a-ζ]", "A", false, nil}, + {"a?b", "a/b", false, nil}, + {"a*b", "a/b", false, nil}, + {"[\\]a]", "]", true, nil}, + {"[\\-]", "-", true, nil}, + {"[x\\-]", "x", true, nil}, + {"[x\\-]", "-", true, nil}, + {"[x\\-]", "z", false, nil}, + {"[\\-x]", "x", true, nil}, + {"[\\-x]", "-", true, nil}, + {"[\\-x]", "a", false, nil}, + {"[]a]", "]", false, ErrBadPattern}, + {"[-]", "-", false, ErrBadPattern}, + {"[x-]", "x", false, ErrBadPattern}, + {"[x-]", "-", false, ErrBadPattern}, + {"[x-]", "z", false, ErrBadPattern}, + {"[-x]", "x", false, ErrBadPattern}, + {"[-x]", "-", false, ErrBadPattern}, + {"[-x]", "a", false, ErrBadPattern}, + {"\\", "a", false, ErrBadPattern}, + {"[a-b-c]", "a", false, ErrBadPattern}, + {"[", "a", false, ErrBadPattern}, + {"[^", "a", false, ErrBadPattern}, + {"[^bc", "a", false, ErrBadPattern}, + {"a[", "ab", false, ErrBadPattern}, + {"*x", "xxx", true, nil}, + + // The following test behaves differently on Go 1.15.3 and Go tip as + // https://github.com/golang/go/commit/b5ddc42b465dd5b9532ee336d98343d81a6d35b2 + // (pre-Go 1.16). TODO: reevaluate when Go 1.16 is released. + //{"a[", "a", false, nil}, +} + +func errp(e error) string { + if e == nil { + return "" + } + return e.Error() +} + +// contains returns true if vector contains the string s. +func contains(vector []string, s string) bool { + for _, elem := range vector { + if elem == s { + return true + } + } + return false +} + +var globTests = []struct { + pattern, result string +}{ + {"match.go", "match.go"}, + {"mat?h.go", "match.go"}, + {"ma*ch.go", "match.go"}, + {`\m\a\t\c\h\.\g\o`, "match.go"}, + {"../*/match.go", "../sftp/match.go"}, +} + +type globTest struct { + pattern string + matches []string +} + +func (test *globTest) buildWant(root string) []string { + var want []string + for _, m := range test.matches { + want = append(want, root+filepath.FromSlash(m)) + } + sort.Strings(want) + return want +} + +func TestMatch(t *testing.T) { + for _, tt := range matchTests { + pattern := tt.pattern + s := tt.s + ok, err := Match(pattern, s) + if ok != tt.match || err != tt.err { + t.Errorf("Match(%#q, %#q) = %v, %q want %v, %q", pattern, s, ok, errp(err), tt.match, errp(tt.err)) + } + } +} + +func TestGlob(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + for _, tt := range globTests { + pattern := tt.pattern + result := tt.result + matches, err := sftp.Glob(pattern) + if err != nil { + t.Errorf("Glob error for %q: %s", pattern, err) + continue + } + if !contains(matches, result) { + t.Errorf("Glob(%#q) = %#v want %v", pattern, matches, result) + } + } + for _, pattern := range []string{"no_match", "../*/no_match"} { + matches, err := sftp.Glob(pattern) + if err != nil { + t.Errorf("Glob error for %q: %s", pattern, err) + continue + } + if len(matches) != 0 { + t.Errorf("Glob(%#q) = %#v want []", pattern, matches) + } + } +} + +func TestGlobError(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + _, err := sftp.Glob("[7]") + if err != nil { + t.Error("expected error for bad pattern; got none") + } +} + +func TestGlobUNC(t *testing.T) { + sftp, cmd := testClient(t, READONLY, NODELAY) + defer cmd.Wait() + defer sftp.Close() + // Just make sure this runs without crashing for now. + // See issue 15879. + sftp.Glob(`\\?\C:\*`) +} + +// sftp/issue/42, abrupt server hangup would result in client hangs. +func TestServerRoughDisconnect(t *testing.T) { + skipIfWindows(t) + if *testServerImpl { + t.Skipf("skipping with -testserver") + } + sftp, cmd := testClient(t, READONLY, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := sftp.Open("/dev/zero") + if err != nil { + t.Fatal(err) + } + defer f.Close() + go func() { + time.Sleep(100 * time.Millisecond) + cmd.Process.Kill() + }() + + _, err = io.Copy(ioutil.Discard, f) + assert.Error(t, err) +} + +// sftp/issue/181, abrupt server hangup would result in client hangs. +// due to broadcastErr filling up the request channel +// this reproduces it about 50% of the time +func TestServerRoughDisconnect2(t *testing.T) { + skipIfWindows(t) + if *testServerImpl { + t.Skipf("skipping with -testserver") + } + sftp, cmd := testClient(t, READONLY, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + f, err := sftp.Open("/dev/zero") + if err != nil { + t.Fatal(err) + } + defer f.Close() + b := make([]byte, 32768*100) + go func() { + time.Sleep(1 * time.Millisecond) + cmd.Process.Kill() + }() + for { + _, err = f.Read(b) + if err != nil { + break + } + } +} + +// sftp/issue/234 - abrupt shutdown during ReadFrom hangs client +func TestServerRoughDisconnect3(t *testing.T) { + skipIfWindows(t) + if *testServerImpl { + t.Skipf("skipping with -testserver") + } + + sftp, cmd := testClient(t, READWRITE, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + dest, err := sftp.OpenFile("/dev/null", os.O_RDWR) + if err != nil { + t.Fatal(err) + } + defer dest.Close() + + src, err := os.Open("/dev/zero") + if err != nil { + t.Fatal(err) + } + defer src.Close() + + go func() { + time.Sleep(10 * time.Millisecond) + cmd.Process.Kill() + }() + + _, err = io.Copy(dest, src) + assert.Error(t, err) +} + +// sftp/issue/234 - also affected Write +func TestServerRoughDisconnect4(t *testing.T) { + skipIfWindows(t) + if *testServerImpl { + t.Skipf("skipping with -testserver") + } + sftp, cmd := testClient(t, READWRITE, NODELAY) + defer cmd.Wait() + defer sftp.Close() + + dest, err := sftp.OpenFile("/dev/null", os.O_RDWR) + if err != nil { + t.Fatal(err) + } + defer dest.Close() + + src, err := os.Open("/dev/zero") + if err != nil { + t.Fatal(err) + } + defer src.Close() + + go func() { + time.Sleep(10 * time.Millisecond) + cmd.Process.Kill() + }() + + b := make([]byte, 32768*200) + src.Read(b) + for { + _, err = dest.Write(b) + if err != nil { + assert.NotEqual(t, io.EOF, err) + break } - // the inaccessible subtrees were marked manually - checkMarks(t, false) - errors = errors[0:0] + } - // restore permissions - os.Chmod(filepath.Join(tree.name, tree.entries[1].name), 0770) - os.Chmod(filepath.Join(tree.name, tree.entries[3].name), 0770) + _, err = io.Copy(dest, src) + assert.Error(t, err) +} + +// sftp/issue/390 - server disconnect should not cause io.EOF or +// io.ErrUnexpectedEOF in sftp.File.Read, because those confuse io.ReadFull. +func TestServerRoughDisconnectEOF(t *testing.T) { + skipIfWindows(t) + if *testServerImpl { + t.Skipf("skipping with -testserver") } + sftp, cmd := testClient(t, READONLY, NODELAY) + defer cmd.Wait() + defer sftp.Close() - // cleanup - if err := os.RemoveAll(tree.name); err != nil { - t.Errorf("removeTree: %v", err) + f, err := sftp.Open("/dev/null") + if err != nil { + t.Fatal(err) + } + defer f.Close() + go func() { + time.Sleep(100 * time.Millisecond) + cmd.Process.Kill() + }() + + _, err = io.ReadFull(f, make([]byte, 10)) + assert.Error(t, err) + assert.NotEqual(t, io.ErrUnexpectedEOF, err) +} + +// sftp/issue/26 writing to a read only file caused client to loop. +func TestClientWriteToROFile(t *testing.T) { + skipIfWindows(t) + + sftp, cmd := testClient(t, READWRITE, NODELAY) + defer cmd.Wait() + + defer func() { + err := sftp.Close() + assert.NoError(t, err) + }() + + // TODO (puellanivis): /dev/zero is not actually a read-only file. + // So, this test works purely by accident. + f, err := sftp.Open("/dev/zero") + if err != nil { + t.Fatal(err) + } + defer f.Close() + + _, err = f.Write([]byte("hello")) + if err == nil { + t.Fatal("expected error, got", err) } } -func benchmarkRead(b *testing.B, bufsize int) { +func benchmarkRead(b *testing.B, bufsize int, delay time.Duration) { + skipIfWindows(b) size := 10*1024*1024 + 123 // ~10MiB // open sftp client - sftp, cmd := testClient(b, READONLY) + sftp, cmd := testClient(b, READONLY, delay) defer cmd.Wait() defer sftp.Close() @@ -767,7 +2407,6 @@ func benchmarkRead(b *testing.B, bufsize int) { if err != nil { b.Fatal(err) } - defer f2.Close() for offset < size { n, err := io.ReadFull(f2, buf) @@ -782,42 +2421,56 @@ func benchmarkRead(b *testing.B, bufsize int) { offset += n } + + f2.Close() } } func BenchmarkRead1k(b *testing.B) { - benchmarkRead(b, 1*1024) + benchmarkRead(b, 1*1024, NODELAY) } func BenchmarkRead16k(b *testing.B) { - benchmarkRead(b, 16*1024) + benchmarkRead(b, 16*1024, NODELAY) } func BenchmarkRead32k(b *testing.B) { - benchmarkRead(b, 32*1024) + benchmarkRead(b, 32*1024, NODELAY) } func BenchmarkRead128k(b *testing.B) { - benchmarkRead(b, 128*1024) + benchmarkRead(b, 128*1024, NODELAY) } func BenchmarkRead512k(b *testing.B) { - benchmarkRead(b, 512*1024) + benchmarkRead(b, 512*1024, NODELAY) } func BenchmarkRead1MiB(b *testing.B) { - benchmarkRead(b, 1024*1024) + benchmarkRead(b, 1024*1024, NODELAY) } func BenchmarkRead4MiB(b *testing.B) { - benchmarkRead(b, 4*1024*1024) + benchmarkRead(b, 4*1024*1024, NODELAY) +} + +func BenchmarkRead4MiBDelay10Msec(b *testing.B) { + benchmarkRead(b, 4*1024*1024, 10*time.Millisecond) +} + +func BenchmarkRead4MiBDelay50Msec(b *testing.B) { + benchmarkRead(b, 4*1024*1024, 50*time.Millisecond) +} + +func BenchmarkRead4MiBDelay150Msec(b *testing.B) { + benchmarkRead(b, 4*1024*1024, 150*time.Millisecond) } -func benchmarkWrite(b *testing.B, bufsize int) { +func benchmarkWrite(b *testing.B, bufsize int, delay time.Duration) { size := 10*1024*1024 + 123 // ~10MiB // open sftp client - sftp, cmd := testClient(b, false) + sftp, cmd := testClient(b, false, delay) defer cmd.Wait() defer sftp.Close() @@ -829,20 +2482,24 @@ func benchmarkWrite(b *testing.B, bufsize int) { for i := 0; i < b.N; i++ { offset := 0 - f, err := ioutil.TempFile("", "sftptest") + f, err := ioutil.TempFile("", "sftptest-benchwrite") if err != nil { b.Fatal(err) } - defer os.Remove(f.Name()) + defer os.Remove(f.Name()) // actually queue up a series of removes for these files f2, err := sftp.Create(f.Name()) if err != nil { b.Fatal(err) } - defer f2.Close() for offset < size { - n, err := f2.Write(data[offset:min(len(data), offset+bufsize)]) + buf := data[offset:] + if len(buf) > bufsize { + buf = buf[:bufsize] + } + + n, err := f2.Write(buf) if err != nil { b.Fatal(err) } @@ -870,29 +2527,359 @@ func benchmarkWrite(b *testing.B, bufsize int) { } func BenchmarkWrite1k(b *testing.B) { - benchmarkWrite(b, 1*1024) + benchmarkWrite(b, 1*1024, NODELAY) } func BenchmarkWrite16k(b *testing.B) { - benchmarkWrite(b, 16*1024) + benchmarkWrite(b, 16*1024, NODELAY) } func BenchmarkWrite32k(b *testing.B) { - benchmarkWrite(b, 32*1024) + benchmarkWrite(b, 32*1024, NODELAY) } func BenchmarkWrite128k(b *testing.B) { - benchmarkWrite(b, 128*1024) + benchmarkWrite(b, 128*1024, NODELAY) } func BenchmarkWrite512k(b *testing.B) { - benchmarkWrite(b, 512*1024) + benchmarkWrite(b, 512*1024, NODELAY) } func BenchmarkWrite1MiB(b *testing.B) { - benchmarkWrite(b, 1024*1024) + benchmarkWrite(b, 1024*1024, NODELAY) } func BenchmarkWrite4MiB(b *testing.B) { - benchmarkWrite(b, 4*1024*1024) + benchmarkWrite(b, 4*1024*1024, NODELAY) +} + +func BenchmarkWrite4MiBDelay10Msec(b *testing.B) { + benchmarkWrite(b, 4*1024*1024, 10*time.Millisecond) +} + +func BenchmarkWrite4MiBDelay50Msec(b *testing.B) { + benchmarkWrite(b, 4*1024*1024, 50*time.Millisecond) +} + +func BenchmarkWrite4MiBDelay150Msec(b *testing.B) { + benchmarkWrite(b, 4*1024*1024, 150*time.Millisecond) +} + +func benchmarkReadFrom(b *testing.B, bufsize int, delay time.Duration) { + size := 10*1024*1024 + 123 // ~10MiB + + // open sftp client + sftp, cmd := testClient(b, false, delay) + defer cmd.Wait() + defer sftp.Close() + + data := make([]byte, size) + + b.ResetTimer() + b.SetBytes(int64(size)) + + for i := 0; i < b.N; i++ { + f, err := ioutil.TempFile("", "sftptest-benchreadfrom") + if err != nil { + b.Fatal(err) + } + defer os.Remove(f.Name()) + + f2, err := sftp.Create(f.Name()) + if err != nil { + b.Fatal(err) + } + defer f2.Close() + + f2.ReadFrom(bytes.NewReader(data)) + f2.Close() + + fi, err := os.Stat(f.Name()) + if err != nil { + b.Fatal(err) + } + + if fi.Size() != int64(size) { + b.Fatalf("wrong file size: want %d, got %d", size, fi.Size()) + } + + os.Remove(f.Name()) + } +} + +func BenchmarkReadFrom1k(b *testing.B) { + benchmarkReadFrom(b, 1*1024, NODELAY) +} + +func BenchmarkReadFrom16k(b *testing.B) { + benchmarkReadFrom(b, 16*1024, NODELAY) +} + +func BenchmarkReadFrom32k(b *testing.B) { + benchmarkReadFrom(b, 32*1024, NODELAY) +} + +func BenchmarkReadFrom128k(b *testing.B) { + benchmarkReadFrom(b, 128*1024, NODELAY) +} + +func BenchmarkReadFrom512k(b *testing.B) { + benchmarkReadFrom(b, 512*1024, NODELAY) +} + +func BenchmarkReadFrom1MiB(b *testing.B) { + benchmarkReadFrom(b, 1024*1024, NODELAY) +} + +func BenchmarkReadFrom4MiB(b *testing.B) { + benchmarkReadFrom(b, 4*1024*1024, NODELAY) +} + +func BenchmarkReadFrom4MiBDelay10Msec(b *testing.B) { + benchmarkReadFrom(b, 4*1024*1024, 10*time.Millisecond) +} + +func BenchmarkReadFrom4MiBDelay50Msec(b *testing.B) { + benchmarkReadFrom(b, 4*1024*1024, 50*time.Millisecond) +} + +func BenchmarkReadFrom4MiBDelay150Msec(b *testing.B) { + benchmarkReadFrom(b, 4*1024*1024, 150*time.Millisecond) +} + +func benchmarkWriteTo(b *testing.B, bufsize int, delay time.Duration) { + size := 10*1024*1024 + 123 // ~10MiB + + // open sftp client + sftp, cmd := testClient(b, false, delay) + defer cmd.Wait() + defer sftp.Close() + + f, err := ioutil.TempFile("", "sftptest-benchwriteto") + if err != nil { + b.Fatal(err) + } + defer os.Remove(f.Name()) + + data := make([]byte, size) + + f.Write(data) + f.Close() + + buf := bytes.NewBuffer(make([]byte, 0, size)) + + b.ResetTimer() + b.SetBytes(int64(size)) + + for i := 0; i < b.N; i++ { + buf.Reset() + + f2, err := sftp.Open(f.Name()) + if err != nil { + b.Fatal(err) + } + + f2.WriteTo(buf) + f2.Close() + + if buf.Len() != size { + b.Fatalf("wrote buffer size: want %d, got %d", size, buf.Len()) + } + } +} + +func BenchmarkWriteTo1k(b *testing.B) { + benchmarkWriteTo(b, 1*1024, NODELAY) +} + +func BenchmarkWriteTo16k(b *testing.B) { + benchmarkWriteTo(b, 16*1024, NODELAY) +} + +func BenchmarkWriteTo32k(b *testing.B) { + benchmarkWriteTo(b, 32*1024, NODELAY) +} + +func BenchmarkWriteTo128k(b *testing.B) { + benchmarkWriteTo(b, 128*1024, NODELAY) +} + +func BenchmarkWriteTo512k(b *testing.B) { + benchmarkWriteTo(b, 512*1024, NODELAY) +} + +func BenchmarkWriteTo1MiB(b *testing.B) { + benchmarkWriteTo(b, 1024*1024, NODELAY) +} + +func BenchmarkWriteTo4MiB(b *testing.B) { + benchmarkWriteTo(b, 4*1024*1024, NODELAY) +} + +func BenchmarkWriteTo4MiBDelay10Msec(b *testing.B) { + benchmarkWriteTo(b, 4*1024*1024, 10*time.Millisecond) +} + +func BenchmarkWriteTo4MiBDelay50Msec(b *testing.B) { + benchmarkWriteTo(b, 4*1024*1024, 50*time.Millisecond) +} + +func BenchmarkWriteTo4MiBDelay150Msec(b *testing.B) { + benchmarkWriteTo(b, 4*1024*1024, 150*time.Millisecond) +} + +func benchmarkCopyDown(b *testing.B, fileSize int64, delay time.Duration) { + skipIfWindows(b) + // Create a temp file and fill it with zero's. + src, err := ioutil.TempFile("", "sftptest-benchcopydown") + if err != nil { + b.Fatal(err) + } + defer src.Close() + srcFilename := src.Name() + defer os.Remove(srcFilename) + zero, err := os.Open("/dev/zero") + if err != nil { + b.Fatal(err) + } + n, err := io.Copy(src, io.LimitReader(zero, fileSize)) + if err != nil { + b.Fatal(err) + } + if n < fileSize { + b.Fatal("short copy") + } + zero.Close() + src.Close() + + sftp, cmd := testClient(b, READONLY, delay) + defer cmd.Wait() + defer sftp.Close() + b.ResetTimer() + b.SetBytes(fileSize) + + for i := 0; i < b.N; i++ { + dst, err := ioutil.TempFile("", "sftptest-benchcopydown") + if err != nil { + b.Fatal(err) + } + defer os.Remove(dst.Name()) + + src, err := sftp.Open(srcFilename) + if err != nil { + b.Fatal(err) + } + defer src.Close() + n, err := io.Copy(dst, src) + if err != nil { + b.Fatalf("copy error: %v", err) + } + if n < fileSize { + b.Fatal("unable to copy all bytes") + } + dst.Close() + fi, err := os.Stat(dst.Name()) + if err != nil { + b.Fatal(err) + } + + if fi.Size() != fileSize { + b.Fatalf("wrong file size: want %d, got %d", fileSize, fi.Size()) + } + os.Remove(dst.Name()) + } +} + +func BenchmarkCopyDown10MiBDelay10Msec(b *testing.B) { + benchmarkCopyDown(b, 10*1024*1024, 10*time.Millisecond) +} + +func BenchmarkCopyDown10MiBDelay50Msec(b *testing.B) { + benchmarkCopyDown(b, 10*1024*1024, 50*time.Millisecond) +} + +func BenchmarkCopyDown10MiBDelay150Msec(b *testing.B) { + benchmarkCopyDown(b, 10*1024*1024, 150*time.Millisecond) +} + +func benchmarkCopyUp(b *testing.B, fileSize int64, delay time.Duration) { + skipIfWindows(b) + // Create a temp file and fill it with zero's. + src, err := ioutil.TempFile("", "sftptest-benchcopyup") + if err != nil { + b.Fatal(err) + } + defer src.Close() + srcFilename := src.Name() + defer os.Remove(srcFilename) + zero, err := os.Open("/dev/zero") + if err != nil { + b.Fatal(err) + } + n, err := io.Copy(src, io.LimitReader(zero, fileSize)) + if err != nil { + b.Fatal(err) + } + if n < fileSize { + b.Fatal("short copy") + } + zero.Close() + src.Close() + + sftp, cmd := testClient(b, false, delay) + defer cmd.Wait() + defer sftp.Close() + + b.ResetTimer() + b.SetBytes(fileSize) + + for i := 0; i < b.N; i++ { + tmp, err := ioutil.TempFile("", "sftptest-benchcopyup") + if err != nil { + b.Fatal(err) + } + tmp.Close() + defer os.Remove(tmp.Name()) + + dst, err := sftp.Create(tmp.Name()) + if err != nil { + b.Fatal(err) + } + defer dst.Close() + src, err := os.Open(srcFilename) + if err != nil { + b.Fatal(err) + } + defer src.Close() + n, err := io.Copy(dst, src) + if err != nil { + b.Fatalf("copy error: %v", err) + } + if n < fileSize { + b.Fatal("unable to copy all bytes") + } + + fi, err := os.Stat(tmp.Name()) + if err != nil { + b.Fatal(err) + } + + if fi.Size() != fileSize { + b.Fatalf("wrong file size: want %d, got %d", fileSize, fi.Size()) + } + os.Remove(tmp.Name()) + } +} + +func BenchmarkCopyUp10MiBDelay10Msec(b *testing.B) { + benchmarkCopyUp(b, 10*1024*1024, 10*time.Millisecond) +} + +func BenchmarkCopyUp10MiBDelay50Msec(b *testing.B) { + benchmarkCopyUp(b, 10*1024*1024, 50*time.Millisecond) +} + +func BenchmarkCopyUp10MiBDelay150Msec(b *testing.B) { + benchmarkCopyUp(b, 10*1024*1024, 150*time.Millisecond) } diff --git a/client_test.go b/client_test.go index 9ade6d1a..dda8af2b 100644 --- a/client_test.go +++ b/client_test.go @@ -1,6 +1,8 @@ package sftp import ( + "bytes" + "errors" "io" "os" "testing" @@ -14,42 +16,54 @@ var _ fs.FileSystem = new(Client) // assert that *File implements io.ReadWriteCloser var _ io.ReadWriteCloser = new(File) -var ok = &StatusError{Code: ssh_FX_OK} -var eof = &StatusError{Code: ssh_FX_EOF} -var fail = &StatusError{Code: ssh_FX_FAILURE} +func TestNormaliseError(t *testing.T) { + var ( + ok = &StatusError{Code: sshFxOk} + eof = &StatusError{Code: sshFxEOF} + fail = &StatusError{Code: sshFxFailure} + noSuchFile = &StatusError{Code: sshFxNoSuchFile} + foo = errors.New("foo") + ) -var eofOrErrTests = []struct { - err, want error -}{ - {nil, nil}, - {eof, io.EOF}, - {ok, ok}, - {io.EOF, io.EOF}, -} - -func TestEofOrErr(t *testing.T) { - for _, tt := range eofOrErrTests { - got := eofOrErr(tt.err) - if got != tt.want { - t.Errorf("eofOrErr(%#v): want: %#v, got: %#v", tt.err, tt.want, got) - } + var tests = []struct { + desc string + err error + want error + }{ + { + desc: "nil error", + }, + { + desc: "not *StatusError", + err: foo, + want: foo, + }, + { + desc: "*StatusError with ssh_FX_EOF", + err: eof, + want: io.EOF, + }, + { + desc: "*StatusError with ssh_FX_NO_SUCH_FILE", + err: noSuchFile, + want: os.ErrNotExist, + }, + { + desc: "*StatusError with ssh_FX_OK", + err: ok, + }, + { + desc: "*StatusError with ssh_FX_FAILURE", + err: fail, + want: fail, + }, } -} -var okOrErrTests = []struct { - err, want error -}{ - {nil, nil}, - {eof, eof}, - {ok, nil}, - {io.EOF, io.EOF}, -} - -func TestOkOrErr(t *testing.T) { - for _, tt := range okOrErrTests { - got := okOrErr(tt.err) + for _, tt := range tests { + got := normaliseError(tt.err) if got != tt.want { - t.Errorf("okOrErr(%#v): want: %#v, got: %#v", tt.err, tt.want, got) + t.Errorf("normaliseError(%#v), test %q\n- want: %#v\n- got: %#v", + tt.err, tt.desc, tt.want, got) } } } @@ -58,18 +72,129 @@ var flagsTests = []struct { flags int want uint32 }{ - {os.O_RDONLY, ssh_FXF_READ}, - {os.O_WRONLY, ssh_FXF_WRITE}, - {os.O_RDWR, ssh_FXF_READ | ssh_FXF_WRITE}, - {os.O_RDWR | os.O_CREATE | os.O_TRUNC, ssh_FXF_READ | ssh_FXF_WRITE | ssh_FXF_CREAT | ssh_FXF_TRUNC}, - {os.O_WRONLY | os.O_APPEND, ssh_FXF_WRITE | ssh_FXF_APPEND}, + {os.O_RDONLY, sshFxfRead}, + {os.O_WRONLY, sshFxfWrite}, + {os.O_RDWR, sshFxfRead | sshFxfWrite}, + {os.O_RDWR | os.O_CREATE | os.O_TRUNC, sshFxfRead | sshFxfWrite | sshFxfCreat | sshFxfTrunc}, + {os.O_WRONLY | os.O_APPEND, sshFxfWrite | sshFxfAppend}, } func TestFlags(t *testing.T) { for i, tt := range flagsTests { - got := flags(tt.flags) + got := toPflags(tt.flags) if got != tt.want { t.Errorf("test %v: flags(%x): want: %x, got: %x", i, tt.flags, tt.want, got) } } } + +type packetSizeTest struct { + size int + valid bool +} + +var maxPacketCheckedTests = []packetSizeTest{ + {size: 0, valid: false}, + {size: 1, valid: true}, + {size: 32768, valid: true}, + {size: 32769, valid: false}, +} + +var maxPacketUncheckedTests = []packetSizeTest{ + {size: 0, valid: false}, + {size: 1, valid: true}, + {size: 32768, valid: true}, + {size: 32769, valid: true}, +} + +func TestMaxPacketChecked(t *testing.T) { + for _, tt := range maxPacketCheckedTests { + testMaxPacketOption(t, MaxPacketChecked(tt.size), tt) + } +} + +func TestMaxPacketUnchecked(t *testing.T) { + for _, tt := range maxPacketUncheckedTests { + testMaxPacketOption(t, MaxPacketUnchecked(tt.size), tt) + } +} + +func TestMaxPacket(t *testing.T) { + for _, tt := range maxPacketCheckedTests { + testMaxPacketOption(t, MaxPacket(tt.size), tt) + } +} + +func testMaxPacketOption(t *testing.T, o ClientOption, tt packetSizeTest) { + var c Client + + err := o(&c) + if (err == nil) != tt.valid { + t.Errorf("MaxPacketChecked(%v)\n- want: %v\n- got: %v", tt.size, tt.valid, err == nil) + } + if c.maxPacket != tt.size && tt.valid { + t.Errorf("MaxPacketChecked(%v)\n- want: %v\n- got: %v", tt.size, tt.size, c.maxPacket) + } +} + +func testFstatOption(t *testing.T, o ClientOption, value bool) { + var c Client + + err := o(&c) + if err == nil && c.useFstat != value { + t.Errorf("UseFStat(%v)\n- want: %v\n- got: %v", value, value, c.useFstat) + } +} + +func TestUseFstatChecked(t *testing.T) { + testFstatOption(t, UseFstat(true), true) + testFstatOption(t, UseFstat(false), false) +} + +type sink struct{} + +func (*sink) Close() error { return nil } +func (*sink) Write(p []byte) (int, error) { return len(p), nil } + +func TestClientZeroLengthPacket(t *testing.T) { + // Packet length zero (never valid). This used to crash the client. + packet := []byte{0, 0, 0, 0} + + r := bytes.NewReader(packet) + c, err := NewClientPipe(r, &sink{}) + if err == nil { + t.Error("expected an error, got nil") + } + if c != nil { + c.Close() + } +} + +func TestClientShortPacket(t *testing.T) { + // init packet too short. + packet := []byte{0, 0, 0, 1, 2} + + r := bytes.NewReader(packet) + _, err := NewClientPipe(r, &sink{}) + if !errors.Is(err, errShortPacket) { + t.Fatalf("expected error: %v, got: %v", errShortPacket, err) + } +} + +// Issue #418: panic in clientConn.recv when the sid is incomplete. +func TestClientNoSid(t *testing.T) { + stream := new(bytes.Buffer) + sendPacket(stream, &sshFxVersionPacket{Version: sftpProtocolVersion}) + // Next packet has the sid cut short after two bytes. + stream.Write([]byte{0, 0, 0, 10, 0, 0}) + + c, err := NewClientPipe(stream, &sink{}) + if err != nil { + t.Fatal(err) + } + + _, err = c.Stat("anything") + if !errors.Is(err, ErrSSHFxConnectionLost) { + t.Fatal("expected ErrSSHFxConnectionLost, got", err) + } +} diff --git a/conn.go b/conn.go new file mode 100644 index 00000000..e68a2bd0 --- /dev/null +++ b/conn.go @@ -0,0 +1,212 @@ +package sftp + +import ( + "context" + "encoding" + "fmt" + "io" + "sync" +) + +// conn implements a bidirectional channel on which client and server +// connections are multiplexed. +type conn struct { + io.Reader + io.WriteCloser + // this is the same allocator used in packet manager + alloc *allocator + sync.Mutex // used to serialise writes to sendPacket +} + +// the orderID is used in server mode if the allocator is enabled. +// For the client mode just pass 0. +// It returns io.EOF if the connection is closed and +// there are no more packets to read. +func (c *conn) recvPacket(orderID uint32) (fxp, []byte, error) { + return recvPacket(c, c.alloc, orderID) +} + +func (c *conn) sendPacket(m encoding.BinaryMarshaler) error { + c.Lock() + defer c.Unlock() + + return sendPacket(c, m) +} + +func (c *conn) Close() error { + c.Lock() + defer c.Unlock() + return c.WriteCloser.Close() +} + +type clientConn struct { + conn + wg sync.WaitGroup + + wait func() error // if non-nil, call this during Wait() to get a possible remote status error. + + sync.Mutex // protects inflight + inflight map[uint32]chan<- result // outstanding requests + + closed chan struct{} + err error +} + +// Wait blocks until the conn has shut down, and return the error +// causing the shutdown. It can be called concurrently from multiple +// goroutines. +func (c *clientConn) Wait() error { + <-c.closed + + if c.wait == nil { + // Only return this error if c.wait won't return something more useful. + return c.err + } + + if err := c.wait(); err != nil { + + // TODO: when https://github.com/golang/go/issues/35025 is fixed, + // we can remove this if block entirely. + // Right now, it’s always going to return this, so it is not useful. + // But we have this code here so that as soon as the ssh library is updated, + // we can return a possibly more useful error. + if err.Error() == "ssh: session not started" { + return c.err + } + + return err + } + + // c.wait returned no error; so, let's return something maybe more useful. + return c.err +} + +// Close closes the SFTP session. +func (c *clientConn) Close() error { + defer c.wg.Wait() + return c.conn.Close() +} + +// recv continuously reads from the server and forwards responses to the +// appropriate channel. +func (c *clientConn) recv() error { + defer c.conn.Close() + + for { + typ, data, err := c.recvPacket(0) + if err != nil { + return err + } + sid, _, err := unmarshalUint32Safe(data) + if err != nil { + return err + } + + ch, ok := c.getChannel(sid) + if !ok { + // This is an unexpected occurrence. Send the error + // back to all listeners so that they terminate + // gracefully. + return fmt.Errorf("sid not found: %d", sid) + } + + ch <- result{typ: typ, data: data} + } +} + +func (c *clientConn) putChannel(ch chan<- result, sid uint32) bool { + c.Lock() + defer c.Unlock() + + select { + case <-c.closed: + // already closed with broadcastErr, return error on chan. + ch <- result{err: ErrSSHFxConnectionLost} + return false + default: + } + + c.inflight[sid] = ch + return true +} + +func (c *clientConn) getChannel(sid uint32) (chan<- result, bool) { + c.Lock() + defer c.Unlock() + + ch, ok := c.inflight[sid] + delete(c.inflight, sid) + + return ch, ok +} + +// result captures the result of receiving the a packet from the server +type result struct { + typ fxp + data []byte + err error +} + +type idmarshaler interface { + id() uint32 + encoding.BinaryMarshaler +} + +func (c *clientConn) sendPacket(ctx context.Context, ch chan result, p idmarshaler) (fxp, []byte, error) { + if cap(ch) < 1 { + ch = make(chan result, 1) + } + + c.dispatchRequest(ch, p) + + select { + case <-ctx.Done(): + return 0, nil, ctx.Err() + case s := <-ch: + return s.typ, s.data, s.err + } +} + +// dispatchRequest should ideally only be called by race-detection tests outside of this file, +// where you have to ensure two packets are in flight sequentially after each other. +func (c *clientConn) dispatchRequest(ch chan<- result, p idmarshaler) { + sid := p.id() + + if !c.putChannel(ch, sid) { + // already closed. + return + } + + if err := c.conn.sendPacket(p); err != nil { + if ch, ok := c.getChannel(sid); ok { + ch <- result{err: err} + } + } +} + +// broadcastErr sends an error to all goroutines waiting for a response. +func (c *clientConn) broadcastErr(err error) { + c.Lock() + defer c.Unlock() + + bcastRes := result{err: ErrSSHFxConnectionLost} + for sid, ch := range c.inflight { + ch <- bcastRes + + // Replace the chan in inflight, + // we have hijacked this chan, + // and this guarantees always-only-once sending. + c.inflight[sid] = make(chan<- result, 1) + } + + c.err = err + close(c.closed) +} + +type serverConn struct { + conn +} + +func (s *serverConn) sendError(id uint32, err error) error { + return s.sendPacket(statusFromError(id, err)) +} diff --git a/debug.go b/debug.go index 3e264abe..f0db14d3 100644 --- a/debug.go +++ b/debug.go @@ -1,3 +1,4 @@ +//go:build debug // +build debug package sftp diff --git a/errno_plan9.go b/errno_plan9.go new file mode 100644 index 00000000..cf9d3902 --- /dev/null +++ b/errno_plan9.go @@ -0,0 +1,42 @@ +package sftp + +import ( + "os" + "syscall" +) + +var EBADF = syscall.NewError("fd out of range or not open") + +func wrapPathError(filepath string, err error) error { + if errno, ok := err.(syscall.ErrorString); ok { + return &os.PathError{Path: filepath, Err: errno} + } + return err +} + +// translateErrno translates a syscall error number to a SFTP error code. +func translateErrno(errno syscall.ErrorString) uint32 { + switch errno { + case "": + return sshFxOk + case syscall.ENOENT: + return sshFxNoSuchFile + case syscall.EPERM: + return sshFxPermissionDenied + } + + return sshFxFailure +} + +func translateSyscallError(err error) (uint32, bool) { + switch e := err.(type) { + case syscall.ErrorString: + return translateErrno(e), true + case *os.PathError: + debug("statusFromError,pathError: error is %T %#v", e.Err, e.Err) + if errno, ok := e.Err.(syscall.ErrorString); ok { + return translateErrno(errno), true + } + } + return 0, false +} diff --git a/errno_posix.go b/errno_posix.go new file mode 100644 index 00000000..cd87e1b5 --- /dev/null +++ b/errno_posix.go @@ -0,0 +1,45 @@ +//go:build !plan9 +// +build !plan9 + +package sftp + +import ( + "os" + "syscall" +) + +const EBADF = syscall.EBADF + +func wrapPathError(filepath string, err error) error { + if errno, ok := err.(syscall.Errno); ok { + return &os.PathError{Path: filepath, Err: errno} + } + return err +} + +// translateErrno translates a syscall error number to a SFTP error code. +func translateErrno(errno syscall.Errno) uint32 { + switch errno { + case 0: + return sshFxOk + case syscall.ENOENT: + return sshFxNoSuchFile + case syscall.EACCES, syscall.EPERM: + return sshFxPermissionDenied + } + + return sshFxFailure +} + +func translateSyscallError(err error) (uint32, bool) { + switch e := err.(type) { + case syscall.Errno: + return translateErrno(e), true + case *os.PathError: + debug("statusFromError,pathError: error is %T %#v", e.Err, e.Err) + if errno, ok := e.Err.(syscall.Errno); ok { + return translateErrno(errno), true + } + } + return 0, false +} diff --git a/example_test.go b/example_test.go index 3f73726d..0093eb0c 100644 --- a/example_test.go +++ b/example_test.go @@ -1,26 +1,31 @@ package sftp_test import ( + "bufio" "fmt" + "io" "log" "os" "os/exec" - - "golang.org/x/crypto/ssh" + "path" + "strings" "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" ) -func Example(conn *ssh.Client) { +func Example() { + var conn *ssh.Client + // open an SFTP session over an existing ssh connection. - sftp, err := sftp.NewClient(conn) + client, err := sftp.NewClient(conn) if err != nil { log.Fatal(err) } - defer sftp.Close() + defer client.Close() // walk a directory - w := sftp.Walk("/home/user") + w := client.Walk("/home/user") for w.Step() { if w.Err() != nil { continue @@ -29,16 +34,17 @@ func Example(conn *ssh.Client) { } // leave your mark - f, err := sftp.Create("hello.txt") + f, err := client.Create("hello.txt") if err != nil { log.Fatal(err) } if _, err := f.Write([]byte("Hello world!")); err != nil { log.Fatal(err) } + f.Close() // check it's there - fi, err := sftp.Lstat("hello.txt") + fi, err := client.Lstat("hello.txt") if err != nil { log.Fatal(err) } @@ -89,3 +95,71 @@ func ExampleNewClientPipe() { // close the connection client.Close() } + +func ExampleClient_Mkdir_parents() { + // Example of mimicing 'mkdir --parents'; I.E. recursively create + // directoryies and don't error if any directories already exists. + var conn *ssh.Client + + client, err := sftp.NewClient(conn) + if err != nil { + log.Fatal(err) + } + defer client.Close() + + sshFxFailure := uint32(4) + mkdirParents := func(client *sftp.Client, dir string) (err error) { + var parents string + + if path.IsAbs(dir) { + // Otherwise, an absolute path given below would be turned in to a relative one + // by splitting on "/" + parents = "/" + } + + for _, name := range strings.Split(dir, "/") { + if name == "" { + // Paths with double-/ in them should just move along + // this will also catch the case of the first character being a "/", i.e. an absolute path + continue + } + parents = path.Join(parents, name) + err = client.Mkdir(parents) + if status, ok := err.(*sftp.StatusError); ok { + if status.Code == sshFxFailure { + var fi os.FileInfo + fi, err = client.Stat(parents) + if err == nil { + if !fi.IsDir() { + return fmt.Errorf("file exists: %s", parents) + } + } + } + } + if err != nil { + break + } + } + return err + } + + err = mkdirParents(client, "/tmp/foo/bar") + if err != nil { + log.Fatal(err) + } +} + +func ExampleFile_ReadFrom_bufio() { + // Using Bufio to buffer writes going to an sftp.File won't buffer as it + // skips buffering if the underlying writer support ReadFrom. The + // workaround is to wrap your writer in a struct that only implements + // io.Writer. + // + // For background see github.com/pkg/sftp/issues/125 + + var data_source io.Reader + var f *sftp.File + type writerOnly struct{ io.Writer } + bw := bufio.NewWriter(writerOnly{f}) // no ReadFrom() + bw.ReadFrom(data_source) +} diff --git a/examples/buffered-read-benchmark/main.go b/examples/buffered-read-benchmark/main.go index 5d63f18c..7f2adc4a 100644 --- a/examples/buffered-read-benchmark/main.go +++ b/examples/buffered-read-benchmark/main.go @@ -22,6 +22,7 @@ var ( HOST = flag.String("host", "localhost", "ssh server hostname") PORT = flag.Int("port", 22, "ssh server port") PASS = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password") + SIZE = flag.Int("s", 1<<15, "set max packet size") ) func init() { @@ -39,8 +40,9 @@ func main() { } config := ssh.ClientConfig{ - User: *USER, - Auth: auths, + User: *USER, + Auth: auths, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), } addr := fmt.Sprintf("%s:%d", *HOST, *PORT) conn, err := ssh.Dial("tcp", addr, &config) @@ -49,7 +51,7 @@ func main() { } defer conn.Close() - c, err := sftp.NewClient(conn) + c, err := sftp.NewClient(conn, sftp.MaxPacket(*SIZE)) if err != nil { log.Fatalf("unable to start sftp subsytem: %v", err) } diff --git a/examples/buffered-write-benchmark/main.go b/examples/buffered-write-benchmark/main.go index c0ea133a..7a0594cf 100644 --- a/examples/buffered-write-benchmark/main.go +++ b/examples/buffered-write-benchmark/main.go @@ -22,6 +22,7 @@ var ( HOST = flag.String("host", "localhost", "ssh server hostname") PORT = flag.Int("port", 22, "ssh server port") PASS = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password") + SIZE = flag.Int("s", 1<<15, "set max packet size") ) func init() { @@ -39,8 +40,9 @@ func main() { } config := ssh.ClientConfig{ - User: *USER, - Auth: auths, + User: *USER, + Auth: auths, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), } addr := fmt.Sprintf("%s:%d", *HOST, *PORT) conn, err := ssh.Dial("tcp", addr, &config) @@ -49,7 +51,7 @@ func main() { } defer conn.Close() - c, err := sftp.NewClient(conn) + c, err := sftp.NewClient(conn, sftp.MaxPacket(*SIZE)) if err != nil { log.Fatalf("unable to start sftp subsytem: %v", err) } diff --git a/examples/go-sftp-server/README.md b/examples/go-sftp-server/README.md new file mode 100644 index 00000000..bd96f2d8 --- /dev/null +++ b/examples/go-sftp-server/README.md @@ -0,0 +1,12 @@ +Example SFTP server implementation +=== + +In order to use this example you will need an RSA key. + +On linux-like systems with openssh installed, you can use the command: + +``` +ssh-keygen -t rsa -f id_rsa +``` + +Then you will be able to run the sftp-server command in the current directory. diff --git a/examples/go-sftp-server/main.go b/examples/go-sftp-server/main.go new file mode 100644 index 00000000..aef436cb --- /dev/null +++ b/examples/go-sftp-server/main.go @@ -0,0 +1,154 @@ +// An example SFTP server implementation using the golang SSH package. +// Serves the whole filesystem visible to the user, and has a hard-coded username and password, +// so not for real use! +package main + +import ( + "flag" + "fmt" + "io" + "log" + "net" + "os" + + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +// Based on example server code from golang.org/x/crypto/ssh and server_standalone +func main() { + var ( + readOnly bool + debugStderr bool + winRoot bool + ) + + flag.BoolVar(&readOnly, "R", false, "read-only server") + flag.BoolVar(&debugStderr, "e", false, "debug to stderr") + flag.BoolVar(&winRoot, "wr", false, "windows root") + + flag.Parse() + + debugStream := io.Discard + if debugStderr { + debugStream = os.Stderr + } + + // An SSH server is represented by a ServerConfig, which holds + // certificate details and handles authentication of ServerConns. + config := &ssh.ServerConfig{ + PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + // Should use constant-time compare (or better, salt+hash) in + // a production setting. + fmt.Fprintf(debugStream, "Login: %s\n", c.User()) + if c.User() == "testuser" && string(pass) == "tiger" { + return nil, nil + } + return nil, fmt.Errorf("password rejected for %q", c.User()) + }, + } + + privateBytes, err := os.ReadFile("id_rsa") + if err != nil { + log.Fatal("Failed to load private key", err) + } + + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + log.Fatal("Failed to parse private key", err) + } + + config.AddHostKey(private) + + // Once a ServerConfig has been configured, connections can be + // accepted. + listener, err := net.Listen("tcp", "0.0.0.0:2022") + if err != nil { + log.Fatal("failed to listen for connection", err) + } + fmt.Printf("Listening on %v\n", listener.Addr()) + + nConn, err := listener.Accept() + if err != nil { + log.Fatal("failed to accept incoming connection", err) + } + + // Before use, a handshake must be performed on the incoming + // net.Conn. + _, chans, reqs, err := ssh.NewServerConn(nConn, config) + if err != nil { + log.Fatal("failed to handshake", err) + } + fmt.Fprintf(debugStream, "SSH server established\n") + + // The incoming Request channel must be serviced. + go ssh.DiscardRequests(reqs) + + // Service the incoming Channel channel. + for newChannel := range chans { + // Channels have a type, depending on the application level + // protocol intended. In the case of an SFTP session, this is "subsystem" + // with a payload string of "sftp" + fmt.Fprintf(debugStream, "Incoming channel: %s\n", newChannel.ChannelType()) + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + fmt.Fprintf(debugStream, "Unknown channel type: %s\n", newChannel.ChannelType()) + continue + } + channel, requests, err := newChannel.Accept() + if err != nil { + log.Fatal("could not accept channel.", err) + } + fmt.Fprintf(debugStream, "Channel accepted\n") + + // Sessions have out-of-band requests such as "shell", + // "pty-req" and "env". Here we handle only the + // "subsystem" request. + go func(in <-chan *ssh.Request) { + for req := range in { + fmt.Fprintf(debugStream, "Request: %v\n", req.Type) + ok := false + switch req.Type { + case "subsystem": + fmt.Fprintf(debugStream, "Subsystem: %s\n", req.Payload[4:]) + if string(req.Payload[4:]) == "sftp" { + ok = true + } + } + fmt.Fprintf(debugStream, " - accepted: %v\n", ok) + req.Reply(ok, nil) + } + }(requests) + + serverOptions := []sftp.ServerOption{ + sftp.WithDebug(debugStream), + } + + if readOnly { + serverOptions = append(serverOptions, sftp.ReadOnly()) + fmt.Fprintf(debugStream, "Read-only server\n") + } else { + fmt.Fprintf(debugStream, "Read write server\n") + } + + if winRoot { + serverOptions = append(serverOptions, sftp.WindowsRootEnumeratesDrives()) + fmt.Fprintf(debugStream, "Windows root enabled\n") + } + + server, err := sftp.NewServer( + channel, + serverOptions..., + ) + if err != nil { + log.Fatal(err) + } + if err := server.Serve(); err != nil { + if err != io.EOF { + log.Fatal("sftp server completed with error:", err) + } + } + server.Close() + log.Print("sftp client exited session.") + } +} diff --git a/examples/gsftp/main.go b/examples/gsftp/main.go deleted file mode 100644 index a1a8824b..00000000 --- a/examples/gsftp/main.go +++ /dev/null @@ -1,147 +0,0 @@ -// gsftp implements a simple sftp client. -// -// gsftp understands the following commands: -// -// List a directory (and its subdirectories) -// gsftp ls DIR -// -// Fetch a remote file -// gsftp fetch FILE -// -// Put the contents of stdin to a remote file -// cat LOCALFILE | gsftp put REMOTEFILE -// -// Print the details of a remote file -// gsftp stat FILE -// -// Remove a remote file -// gsftp rm FILE -// -// Rename a file -// gsftp mv OLD NEW -// -package main - -import ( - "flag" - "fmt" - "io" - "log" - "net" - "os" - - "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/agent" - - "github.com/pkg/sftp" -) - -var ( - USER = flag.String("user", os.Getenv("USER"), "ssh username") - HOST = flag.String("host", "localhost", "ssh server hostname") - PORT = flag.Int("port", 22, "ssh server port") - PASS = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password") -) - -func init() { - flag.Parse() - if len(flag.Args()) < 1 { - log.Fatal("subcommand required") - } -} - -func main() { - var auths []ssh.AuthMethod - if aconn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { - auths = append(auths, ssh.PublicKeysCallback(agent.NewClient(aconn).Signers)) - - } - if *PASS != "" { - auths = append(auths, ssh.Password(*PASS)) - } - - config := ssh.ClientConfig{ - User: *USER, - Auth: auths, - } - addr := fmt.Sprintf("%s:%d", *HOST, *PORT) - conn, err := ssh.Dial("tcp", addr, &config) - if err != nil { - log.Fatalf("unable to connect to [%s]: %v", addr, err) - } - defer conn.Close() - - client, err := sftp.NewClient(conn) - if err != nil { - log.Fatalf("unable to start sftp subsytem: %v", err) - } - defer client.Close() - switch cmd := flag.Args()[0]; cmd { - case "ls": - if len(flag.Args()) < 2 { - log.Fatalf("%s %s: remote path required", cmd, os.Args[0]) - } - walker := client.Walk(flag.Args()[1]) - for walker.Step() { - if err := walker.Err(); err != nil { - log.Println(err) - continue - } - fmt.Println(walker.Path()) - } - case "fetch": - if len(flag.Args()) < 2 { - log.Fatalf("%s %s: remote path required", cmd, os.Args[0]) - } - f, err := client.Open(flag.Args()[1]) - if err != nil { - log.Fatal(err) - } - defer f.Close() - if _, err := io.Copy(os.Stdout, f); err != nil { - log.Fatal(err) - } - case "put": - if len(flag.Args()) < 2 { - log.Fatalf("%s %s: remote path required", cmd, os.Args[0]) - } - f, err := client.Create(flag.Args()[1]) - if err != nil { - log.Fatal(err) - } - defer f.Close() - if _, err := io.Copy(f, os.Stdin); err != nil { - log.Fatal(err) - } - case "stat": - if len(flag.Args()) < 2 { - log.Fatalf("%s %s: remote path required", cmd, os.Args[0]) - } - f, err := client.Open(flag.Args()[1]) - if err != nil { - log.Fatal(err) - } - defer f.Close() - fi, err := f.Stat() - if err != nil { - log.Fatalf("unable to stat file: %v", err) - } - fmt.Printf("%s %d %v\n", fi.Name(), fi.Size(), fi.Mode()) - case "rm": - if len(flag.Args()) < 2 { - log.Fatalf("%s %s: remote path required", cmd, os.Args[0]) - } - if err := client.Remove(flag.Args()[1]); err != nil { - log.Fatalf("unable to remove file: %v", err) - } - case "mv": - if len(flag.Args()) < 3 { - log.Fatalf("%s %s: old and new name required", cmd, os.Args[0]) - } - if err := client.Rename(flag.Args()[1], flag.Args()[2]); err != nil { - log.Fatalf("unable to rename file: %v", err) - } - default: - log.Fatalf("unknown subcommand: %v", cmd) - } -} diff --git a/examples/request-server/main.go b/examples/request-server/main.go new file mode 100644 index 00000000..37a3a55f --- /dev/null +++ b/examples/request-server/main.go @@ -0,0 +1,130 @@ +// An example SFTP server implementation using the golang SSH package. +// Serves the whole filesystem visible to the user, and has a hard-coded username and password, +// so not for real use! +package main + +import ( + "flag" + "fmt" + "io" + "log" + "net" + "os" + + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +// Based on example server code from golang.org/x/crypto/ssh and server_standalone +func main() { + var ( + readOnly bool + debugStderr bool + ) + + flag.BoolVar(&readOnly, "R", false, "read-only server") + flag.BoolVar(&debugStderr, "e", false, "debug to stderr") + flag.Parse() + + debugStream := io.Discard + if debugStderr { + debugStream = os.Stderr + } + + // An SSH server is represented by a ServerConfig, which holds + // certificate details and handles authentication of ServerConns. + config := &ssh.ServerConfig{ + PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + // Should use constant-time compare (or better, salt+hash) in + // a production setting. + fmt.Fprintf(debugStream, "Login: %s\n", c.User()) + if c.User() == "testuser" && string(pass) == "tiger" { + return nil, nil + } + return nil, fmt.Errorf("password rejected for %q", c.User()) + }, + } + + privateBytes, err := os.ReadFile("id_rsa") + if err != nil { + log.Fatal("Failed to load private key", err) + } + + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + log.Fatal("Failed to parse private key", err) + } + + config.AddHostKey(private) + + // Once a ServerConfig has been configured, connections can be + // accepted. + listener, err := net.Listen("tcp", "0.0.0.0:2022") + if err != nil { + log.Fatal("failed to listen for connection", err) + } + fmt.Printf("Listening on %v\n", listener.Addr()) + + nConn, err := listener.Accept() + if err != nil { + log.Fatal("failed to accept incoming connection", err) + } + + // Before use, a handshake must be performed on the incoming net.Conn. + sconn, chans, reqs, err := ssh.NewServerConn(nConn, config) + if err != nil { + log.Fatal("failed to handshake", err) + } + log.Println("login detected:", sconn.User()) + fmt.Fprintf(debugStream, "SSH server established\n") + + // The incoming Request channel must be serviced. + go ssh.DiscardRequests(reqs) + + // Service the incoming Channel channel. + for newChannel := range chans { + // Channels have a type, depending on the application level + // protocol intended. In the case of an SFTP session, this is "subsystem" + // with a payload string of "sftp" + fmt.Fprintf(debugStream, "Incoming channel: %s\n", newChannel.ChannelType()) + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + fmt.Fprintf(debugStream, "Unknown channel type: %s\n", newChannel.ChannelType()) + continue + } + channel, requests, err := newChannel.Accept() + if err != nil { + log.Fatal("could not accept channel.", err) + } + fmt.Fprintf(debugStream, "Channel accepted\n") + + // Sessions have out-of-band requests such as "shell", + // "pty-req" and "env". Here we handle only the + // "subsystem" request. + go func(in <-chan *ssh.Request) { + for req := range in { + fmt.Fprintf(debugStream, "Request: %v\n", req.Type) + ok := false + switch req.Type { + case "subsystem": + fmt.Fprintf(debugStream, "Subsystem: %s\n", req.Payload[4:]) + if string(req.Payload[4:]) == "sftp" { + ok = true + } + } + fmt.Fprintf(debugStream, " - accepted: %v\n", ok) + req.Reply(ok, nil) + } + }(requests) + + root := sftp.InMemHandler() + server := sftp.NewRequestServer(channel, root) + if err := server.Serve(); err != nil { + if err != io.EOF { + log.Fatal("sftp server completed with error:", err) + } + } + server.Close() + log.Print("sftp client exited session.") + } +} diff --git a/examples/streaming-read-benchmark/main.go b/examples/streaming-read-benchmark/main.go index 7ebdbd46..3f0f4893 100644 --- a/examples/streaming-read-benchmark/main.go +++ b/examples/streaming-read-benchmark/main.go @@ -23,6 +23,7 @@ var ( HOST = flag.String("host", "localhost", "ssh server hostname") PORT = flag.Int("port", 22, "ssh server port") PASS = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password") + SIZE = flag.Int("s", 1<<15, "set max packet size") ) func init() { @@ -40,8 +41,9 @@ func main() { } config := ssh.ClientConfig{ - User: *USER, - Auth: auths, + User: *USER, + Auth: auths, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), } addr := fmt.Sprintf("%s:%d", *HOST, *PORT) conn, err := ssh.Dial("tcp", addr, &config) @@ -50,7 +52,7 @@ func main() { } defer conn.Close() - c, err := sftp.NewClient(conn) + c, err := sftp.NewClient(conn, sftp.MaxPacket(*SIZE)) if err != nil { log.Fatalf("unable to start sftp subsytem: %v", err) } diff --git a/examples/streaming-write-benchmark/main.go b/examples/streaming-write-benchmark/main.go index 63b27b55..2139d97a 100644 --- a/examples/streaming-write-benchmark/main.go +++ b/examples/streaming-write-benchmark/main.go @@ -23,6 +23,7 @@ var ( HOST = flag.String("host", "localhost", "ssh server hostname") PORT = flag.Int("port", 22, "ssh server port") PASS = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password") + SIZE = flag.Int("s", 1<<15, "set max packet size") ) func init() { @@ -40,8 +41,9 @@ func main() { } config := ssh.ClientConfig{ - User: *USER, - Auth: auths, + User: *USER, + Auth: auths, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), } addr := fmt.Sprintf("%s:%d", *HOST, *PORT) conn, err := ssh.Dial("tcp", addr, &config) @@ -50,7 +52,7 @@ func main() { } defer conn.Close() - c, err := sftp.NewClient(conn) + c, err := sftp.NewClient(conn, sftp.MaxPacket(*SIZE)) if err != nil { log.Fatalf("unable to start sftp subsytem: %v", err) } diff --git a/fuzz.go b/fuzz.go new file mode 100644 index 00000000..f2f1fc31 --- /dev/null +++ b/fuzz.go @@ -0,0 +1,23 @@ +//go:build gofuzz +// +build gofuzz + +package sftp + +import "bytes" + +type sinkfuzz struct{} + +func (*sinkfuzz) Close() error { return nil } +func (*sinkfuzz) Write(p []byte) (int, error) { return len(p), nil } + +var devnull = &sinkfuzz{} + +// To run: go-fuzz-build && go-fuzz +func Fuzz(data []byte) int { + c, err := NewClientPipe(bytes.NewReader(data), devnull) + if err != nil { + return 0 + } + c.Close() + return 1 +} diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..ab298e17 --- /dev/null +++ b/go.mod @@ -0,0 +1,16 @@ +module github.com/pkg/sftp + +go 1.23.0 + +require ( + github.com/kr/fs v0.1.0 + github.com/stretchr/testify v1.10.0 + golang.org/x/crypto v0.41.0 + golang.org/x/sys v0.35.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..7ffe0ee1 --- /dev/null +++ b/go.sum @@ -0,0 +1,18 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= +github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= +golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4= +golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/encoding/ssh/filexfer/attrs.go b/internal/encoding/ssh/filexfer/attrs.go new file mode 100644 index 00000000..3aec937f --- /dev/null +++ b/internal/encoding/ssh/filexfer/attrs.go @@ -0,0 +1,296 @@ +package sshfx + +// Attributes related flags. +const ( + AttrSize = 1 << iota // SSH_FILEXFER_ATTR_SIZE + AttrUIDGID // SSH_FILEXFER_ATTR_UIDGID + AttrPermissions // SSH_FILEXFER_ATTR_PERMISSIONS + AttrACModTime // SSH_FILEXFER_ACMODTIME + + AttrExtended = 1 << 31 // SSH_FILEXFER_ATTR_EXTENDED +) + +// Attributes defines the file attributes type defined in draft-ietf-secsh-filexfer-02 +// +// Defined in: https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt#section-5 +type Attributes struct { + Flags uint32 + + // AttrSize + Size uint64 + + // AttrUIDGID + UID uint32 + GID uint32 + + // AttrPermissions + Permissions FileMode + + // AttrACmodTime + ATime uint32 + MTime uint32 + + // AttrExtended + ExtendedAttributes []ExtendedAttribute +} + +// GetSize returns the Size field and a bool that is true if and only if the value is valid/defined. +func (a *Attributes) GetSize() (size uint64, ok bool) { + return a.Size, a.Flags&AttrSize != 0 +} + +// SetSize is a convenience function that sets the Size field, +// and marks the field as valid/defined in Flags. +func (a *Attributes) SetSize(size uint64) { + a.Flags |= AttrSize + a.Size = size +} + +// GetUIDGID returns the UID and GID fields and a bool that is true if and only if the values are valid/defined. +func (a *Attributes) GetUIDGID() (uid, gid uint32, ok bool) { + return a.UID, a.GID, a.Flags&AttrUIDGID != 0 +} + +// SetUIDGID is a convenience function that sets the UID and GID fields, +// and marks the fields as valid/defined in Flags. +func (a *Attributes) SetUIDGID(uid, gid uint32) { + a.Flags |= AttrUIDGID + a.UID = uid + a.GID = gid +} + +// GetPermissions returns the Permissions field and a bool that is true if and only if the value is valid/defined. +func (a *Attributes) GetPermissions() (perms FileMode, ok bool) { + return a.Permissions, a.Flags&AttrPermissions != 0 +} + +// SetPermissions is a convenience function that sets the Permissions field, +// and marks the field as valid/defined in Flags. +func (a *Attributes) SetPermissions(perms FileMode) { + a.Flags |= AttrPermissions + a.Permissions = perms +} + +// GetACModTime returns the ATime and MTime fields and a bool that is true if and only if the values are valid/defined. +func (a *Attributes) GetACModTime() (atime, mtime uint32, ok bool) { + return a.ATime, a.MTime, a.Flags&AttrACModTime != 0 +} + +// SetACModTime is a convenience function that sets the ATime and MTime fields, +// and marks the fields as valid/defined in Flags. +func (a *Attributes) SetACModTime(atime, mtime uint32) { + a.Flags |= AttrACModTime + a.ATime = atime + a.MTime = mtime +} + +// Len returns the number of bytes a would marshal into. +func (a *Attributes) Len() int { + length := 4 + + if a.Flags&AttrSize != 0 { + length += 8 + } + + if a.Flags&AttrUIDGID != 0 { + length += 4 + 4 + } + + if a.Flags&AttrPermissions != 0 { + length += 4 + } + + if a.Flags&AttrACModTime != 0 { + length += 4 + 4 + } + + if a.Flags&AttrExtended != 0 { + length += 4 + + for _, ext := range a.ExtendedAttributes { + length += ext.Len() + } + } + + return length +} + +// MarshalInto marshals e onto the end of the given Buffer. +func (a *Attributes) MarshalInto(buf *Buffer) { + buf.AppendUint32(a.Flags) + + if a.Flags&AttrSize != 0 { + buf.AppendUint64(a.Size) + } + + if a.Flags&AttrUIDGID != 0 { + buf.AppendUint32(a.UID) + buf.AppendUint32(a.GID) + } + + if a.Flags&AttrPermissions != 0 { + buf.AppendUint32(uint32(a.Permissions)) + } + + if a.Flags&AttrACModTime != 0 { + buf.AppendUint32(a.ATime) + buf.AppendUint32(a.MTime) + } + + if a.Flags&AttrExtended != 0 { + buf.AppendUint32(uint32(len(a.ExtendedAttributes))) + + for _, ext := range a.ExtendedAttributes { + ext.MarshalInto(buf) + } + } +} + +// MarshalBinary returns a as the binary encoding of a. +func (a *Attributes) MarshalBinary() ([]byte, error) { + buf := NewBuffer(make([]byte, 0, a.Len())) + a.MarshalInto(buf) + return buf.Bytes(), nil +} + +// UnmarshalFrom unmarshals an Attributes from the given Buffer into e. +// +// NOTE: The values of fields not covered in the a.Flags are explicitly undefined. +func (a *Attributes) UnmarshalFrom(buf *Buffer) (err error) { + flags := buf.ConsumeUint32() + + return a.XXX_UnmarshalByFlags(flags, buf) +} + +// XXX_UnmarshalByFlags uses the pre-existing a.Flags field to determine which fields to decode. +// DO NOT USE THIS: it is an anti-corruption function to implement existing internal usage in pkg/sftp. +// This function is not a part of any compatibility promise. +func (a *Attributes) XXX_UnmarshalByFlags(flags uint32, buf *Buffer) (err error) { + a.Flags = flags + + // Short-circuit dummy attributes. + if a.Flags == 0 { + return buf.Err + } + + if a.Flags&AttrSize != 0 { + a.Size = buf.ConsumeUint64() + } + + if a.Flags&AttrUIDGID != 0 { + a.UID = buf.ConsumeUint32() + a.GID = buf.ConsumeUint32() + } + + if a.Flags&AttrPermissions != 0 { + a.Permissions = FileMode(buf.ConsumeUint32()) + } + + if a.Flags&AttrACModTime != 0 { + a.ATime = buf.ConsumeUint32() + a.MTime = buf.ConsumeUint32() + } + + if a.Flags&AttrExtended != 0 { + count := buf.ConsumeCount() + + a.ExtendedAttributes = make([]ExtendedAttribute, count) + for i := range a.ExtendedAttributes { + a.ExtendedAttributes[i].UnmarshalFrom(buf) + } + } + + return buf.Err +} + +// UnmarshalBinary decodes the binary encoding of Attributes into e. +func (a *Attributes) UnmarshalBinary(data []byte) error { + return a.UnmarshalFrom(NewBuffer(data)) +} + +// ExtendedAttribute defines the extended file attribute type defined in draft-ietf-secsh-filexfer-02 +// +// Defined in: https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt#section-5 +type ExtendedAttribute struct { + Type string + Data string +} + +// Len returns the number of bytes e would marshal into. +func (e *ExtendedAttribute) Len() int { + return 4 + len(e.Type) + 4 + len(e.Data) +} + +// MarshalInto marshals e onto the end of the given Buffer. +func (e *ExtendedAttribute) MarshalInto(buf *Buffer) { + buf.AppendString(e.Type) + buf.AppendString(e.Data) +} + +// MarshalBinary returns e as the binary encoding of e. +func (e *ExtendedAttribute) MarshalBinary() ([]byte, error) { + buf := NewBuffer(make([]byte, 0, e.Len())) + e.MarshalInto(buf) + return buf.Bytes(), nil +} + +// UnmarshalFrom unmarshals an ExtendedAattribute from the given Buffer into e. +func (e *ExtendedAttribute) UnmarshalFrom(buf *Buffer) (err error) { + *e = ExtendedAttribute{ + Type: buf.ConsumeString(), + Data: buf.ConsumeString(), + } + + return buf.Err +} + +// UnmarshalBinary decodes the binary encoding of ExtendedAttribute into e. +func (e *ExtendedAttribute) UnmarshalBinary(data []byte) error { + return e.UnmarshalFrom(NewBuffer(data)) +} + +// NameEntry implements the SSH_FXP_NAME repeated data type from draft-ietf-secsh-filexfer-02 +// +// This type is incompatible with versions 4 or higher. +type NameEntry struct { + Filename string + Longname string + Attrs Attributes +} + +// Len returns the number of bytes e would marshal into. +func (e *NameEntry) Len() int { + return 4 + len(e.Filename) + 4 + len(e.Longname) + e.Attrs.Len() +} + +// MarshalInto marshals e onto the end of the given Buffer. +func (e *NameEntry) MarshalInto(buf *Buffer) { + buf.AppendString(e.Filename) + buf.AppendString(e.Longname) + + e.Attrs.MarshalInto(buf) +} + +// MarshalBinary returns e as the binary encoding of e. +func (e *NameEntry) MarshalBinary() ([]byte, error) { + buf := NewBuffer(make([]byte, 0, e.Len())) + e.MarshalInto(buf) + return buf.Bytes(), nil +} + +// UnmarshalFrom unmarshals an NameEntry from the given Buffer into e. +// +// NOTE: The values of fields not covered in the a.Flags are explicitly undefined. +func (e *NameEntry) UnmarshalFrom(buf *Buffer) (err error) { + *e = NameEntry{ + Filename: buf.ConsumeString(), + Longname: buf.ConsumeString(), + } + + return e.Attrs.UnmarshalFrom(buf) +} + +// UnmarshalBinary decodes the binary encoding of NameEntry into e. +func (e *NameEntry) UnmarshalBinary(data []byte) error { + return e.UnmarshalFrom(NewBuffer(data)) +} diff --git a/internal/encoding/ssh/filexfer/attrs_test.go b/internal/encoding/ssh/filexfer/attrs_test.go new file mode 100644 index 00000000..1e620faa --- /dev/null +++ b/internal/encoding/ssh/filexfer/attrs_test.go @@ -0,0 +1,231 @@ +package sshfx + +import ( + "bytes" + "testing" +) + +func TestAttributes(t *testing.T) { + const ( + size uint64 = 0x123456789ABCDEF0 + uid = 1000 + gid = 100 + perms FileMode = 0x87654321 + atime = 0x2A2B2C2D + mtime = 0x42434445 + ) + + extAttr := ExtendedAttribute{ + Type: "foo", + Data: "bar", + } + + attr := &Attributes{ + Size: size, + UID: uid, + GID: gid, + Permissions: perms, + ATime: atime, + MTime: mtime, + ExtendedAttributes: []ExtendedAttribute{ + extAttr, + }, + } + + type test struct { + name string + flags uint32 + encoded []byte + } + + tests := []test{ + { + name: "empty", + encoded: []byte{ + 0x00, 0x00, 0x00, 0x00, + }, + }, + { + name: "size", + flags: AttrSize, + encoded: []byte{ + 0x00, 0x00, 0x00, 0x01, + 0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xF0, + }, + }, + { + name: "uidgid", + flags: AttrUIDGID, + encoded: []byte{ + 0x00, 0x00, 0x00, 0x02, + 0x00, 0x00, 0x03, 0xE8, + 0x00, 0x00, 0x00, 100, + }, + }, + { + name: "permissions", + flags: AttrPermissions, + encoded: []byte{ + 0x00, 0x00, 0x00, 0x04, + 0x87, 0x65, 0x43, 0x21, + }, + }, + { + name: "acmodtime", + flags: AttrACModTime, + encoded: []byte{ + 0x00, 0x00, 0x00, 0x08, + 0x2A, 0x2B, 0x2C, 0x2D, + 0x42, 0x43, 0x44, 0x45, + }, + }, + { + name: "extended", + flags: AttrExtended, + encoded: []byte{ + 0x80, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x03, 'f', 'o', 'o', + 0x00, 0x00, 0x00, 0x03, 'b', 'a', 'r', + }, + }, + { + name: "size uidgid permisssions acmodtime extended", + flags: AttrSize | AttrUIDGID | AttrPermissions | AttrACModTime | AttrExtended, + encoded: []byte{ + 0x80, 0x00, 0x00, 0x0F, + 0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xF0, + 0x00, 0x00, 0x03, 0xE8, + 0x00, 0x00, 0x00, 100, + 0x87, 0x65, 0x43, 0x21, + 0x2A, 0x2B, 0x2C, 0x2D, + 0x42, 0x43, 0x44, 0x45, + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x03, 'f', 'o', 'o', + 0x00, 0x00, 0x00, 0x03, 'b', 'a', 'r', + }, + }, + } + + for _, tt := range tests { + attr := *attr + + t.Run(tt.name, func(t *testing.T) { + attr.Flags = tt.flags + + buf, err := attr.MarshalBinary() + if err != nil { + t.Fatal("unexpected error:", err) + } + + if !bytes.Equal(buf, tt.encoded) { + t.Fatalf("MarshalBinary() = %X, but wanted %X", buf, tt.encoded) + } + + attr = Attributes{} + + if err := attr.UnmarshalBinary(buf); err != nil { + t.Fatal("unexpected error:", err) + } + + if attr.Flags != tt.flags { + t.Errorf("UnmarshalBinary(): Flags was %x, but wanted %x", attr.Flags, tt.flags) + } + + if attr.Flags&AttrSize != 0 && attr.Size != size { + t.Errorf("UnmarshalBinary(): Size was %x, but wanted %x", attr.Size, size) + } + + if attr.Flags&AttrUIDGID != 0 { + if attr.UID != uid { + t.Errorf("UnmarshalBinary(): UID was %x, but wanted %x", attr.UID, uid) + } + + if attr.GID != gid { + t.Errorf("UnmarshalBinary(): GID was %x, but wanted %x", attr.GID, gid) + } + } + + if attr.Flags&AttrPermissions != 0 && attr.Permissions != perms { + t.Errorf("UnmarshalBinary(): Permissions was %#v, but wanted %#v", attr.Permissions, perms) + } + + if attr.Flags&AttrACModTime != 0 { + if attr.ATime != atime { + t.Errorf("UnmarshalBinary(): ATime was %x, but wanted %x", attr.ATime, atime) + } + + if attr.MTime != mtime { + t.Errorf("UnmarshalBinary(): MTime was %x, but wanted %x", attr.MTime, mtime) + } + } + + if attr.Flags&AttrExtended != 0 { + extAttrs := attr.ExtendedAttributes + + if count := len(extAttrs); count != 1 { + t.Fatalf("UnmarshalBinary(): len(ExtendedAttributes) was %d, but wanted %d", count, 1) + } + + if got := extAttrs[0]; got != extAttr { + t.Errorf("UnmarshalBinary(): ExtendedAttributes[0] was %#v, but wanted %#v", got, extAttr) + } + } + }) + } +} + +func TestNameEntry(t *testing.T) { + const ( + filename = "foo" + longname = "bar" + perms FileMode = 0x87654321 + ) + + e := &NameEntry{ + Filename: filename, + Longname: longname, + Attrs: Attributes{ + Flags: AttrPermissions, + Permissions: perms, + }, + } + + buf, err := e.MarshalBinary() + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 0x03, 'f', 'o', 'o', + 0x00, 0x00, 0x00, 0x03, 'b', 'a', 'r', + 0x00, 0x00, 0x00, 0x04, + 0x87, 0x65, 0x43, 0x21, + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalBinary() = %X, but wanted %X", buf, want) + } + + *e = NameEntry{} + + if err := e.UnmarshalBinary(buf); err != nil { + t.Fatal("unexpected error:", err) + } + + if e.Filename != filename { + t.Errorf("UnmarhsalFrom(): Filename was %q, but expected %q", e.Filename, filename) + } + + if e.Longname != longname { + t.Errorf("UnmarhsalFrom(): Longname was %q, but expected %q", e.Longname, longname) + } + + if e.Attrs.Flags != AttrPermissions { + t.Errorf("UnmarshalBinary(): Attrs.Flag was %#x, but expected %#x", e.Attrs.Flags, AttrPermissions) + } + + if e.Attrs.Permissions != perms { + t.Errorf("UnmarshalBinary(): Attrs.Permissions was %#v, but expected %#v", e.Attrs.Permissions, perms) + } +} diff --git a/internal/encoding/ssh/filexfer/buffer.go b/internal/encoding/ssh/filexfer/buffer.go new file mode 100644 index 00000000..bd4783bb --- /dev/null +++ b/internal/encoding/ssh/filexfer/buffer.go @@ -0,0 +1,340 @@ +package sshfx + +import ( + "encoding/binary" + "errors" +) + +// Various encoding errors. +var ( + ErrShortPacket = errors.New("packet too short") + ErrLongPacket = errors.New("packet too long") +) + +// Buffer wraps up the various encoding details of the SSH format. +// +// Data types are encoded as per section 4 from https://tools.ietf.org/html/draft-ietf-secsh-architecture-09#page-8 +type Buffer struct { + b []byte + off int + Err error +} + +// NewBuffer creates and initializes a new buffer using buf as its initial contents. +// The new buffer takes ownership of buf, and the caller should not use buf after this call. +// +// In most cases, new(Buffer) (or just declaring a Buffer variable) is sufficient to initialize a Buffer. +func NewBuffer(buf []byte) *Buffer { + return &Buffer{ + b: buf, + } +} + +// NewMarshalBuffer creates a new Buffer ready to start marshaling a Packet into. +// It preallocates enough space for uint32(length), uint8(type), uint32(request-id) and size more bytes. +func NewMarshalBuffer(size int) *Buffer { + return NewBuffer(make([]byte, 4+1+4+size)) +} + +// Bytes returns a slice of length b.Len() holding the unconsumed bytes in the Buffer. +// The slice is valid for use only until the next buffer modification +// (that is, only until the next call to an Append or Consume method). +func (b *Buffer) Bytes() []byte { + return b.b[b.off:] +} + +// Len returns the number of unconsumed bytes in the buffer. +func (b *Buffer) Len() int { return len(b.b) - b.off } + +// Cap returns the capacity of the buffer’s underlying byte slice, +// that is, the total space allocated for the buffer’s data. +func (b *Buffer) Cap() int { return cap(b.b) } + +// Reset resets the buffer to be empty, but it retains the underlying storage for use by future Appends. +func (b *Buffer) Reset() { + *b = Buffer{ + b: b.b[:0], + } +} + +// StartPacket resets and initializes the buffer to be ready to start marshaling a packet into. +// It truncates the buffer, reserves space for uint32(length), then appends the given packetType and requestID. +func (b *Buffer) StartPacket(packetType PacketType, requestID uint32) { + *b = Buffer{ + b: append(b.b[:0], make([]byte, 4)...), + } + + b.AppendUint8(uint8(packetType)) + b.AppendUint32(requestID) +} + +// Packet finalizes the packet started from StartPacket. +// It is expected that this will end the ownership of the underlying byte-slice, +// and so the returned byte-slices may be reused the same as any other byte-slice, +// the caller should not use this buffer after this call. +// +// It writes the packet body length into the first four bytes of the buffer in network byte order (big endian). +// The packet body length is the length of this buffer less the 4-byte length itself, plus the length of payload. +// +// It is assumed that no Consume methods have been called on this buffer, +// and so it returns the whole underlying slice. +func (b *Buffer) Packet(payload []byte) (header, payloadPassThru []byte, err error) { + b.PutLength(len(b.b) - 4 + len(payload)) + + return b.b, payload, nil +} + +// ConsumeUint8 consumes a single byte from the buffer. +// If the buffer does not have enough data, it will set Err to ErrShortPacket. +func (b *Buffer) ConsumeUint8() uint8 { + if b.Err != nil { + return 0 + } + + if b.Len() < 1 { + b.off = len(b.b) + b.Err = ErrShortPacket + return 0 + } + + var v uint8 + v, b.off = b.b[b.off], b.off+1 + return v +} + +// AppendUint8 appends a single byte into the buffer. +func (b *Buffer) AppendUint8(v uint8) { + b.b = append(b.b, v) +} + +// ConsumeBool consumes a single byte from the buffer, and returns true if that byte is non-zero. +// If the buffer does not have enough data, it will set Err to ErrShortPacket. +func (b *Buffer) ConsumeBool() bool { + return b.ConsumeUint8() != 0 +} + +// AppendBool appends a single bool into the buffer. +// It encodes it as a single byte, with false as 0, and true as 1. +func (b *Buffer) AppendBool(v bool) { + if v { + b.AppendUint8(1) + } else { + b.AppendUint8(0) + } +} + +// ConsumeUint16 consumes a single uint16 from the buffer, in network byte order (big-endian). +// If the buffer does not have enough data, it will set Err to ErrShortPacket. +func (b *Buffer) ConsumeUint16() uint16 { + if b.Err != nil { + return 0 + } + + if b.Len() < 2 { + b.off = len(b.b) + b.Err = ErrShortPacket + return 0 + } + + v := binary.BigEndian.Uint16(b.b[b.off:]) + b.off += 2 + return v +} + +// AppendUint16 appends single uint16 into the buffer, in network byte order (big-endian). +func (b *Buffer) AppendUint16(v uint16) { + b.b = append(b.b, + byte(v>>8), + byte(v>>0), + ) +} + +// unmarshalUint32 is used internally to read the packet length. +// It is unsafe, and so not exported. +// Even within this package, its use should be avoided. +func unmarshalUint32(b []byte) uint32 { + return binary.BigEndian.Uint32(b[:4]) +} + +// ConsumeUint32 consumes a single uint32 from the buffer, in network byte order (big-endian). +// If the buffer does not have enough data, it will set Err to ErrShortPacket. +func (b *Buffer) ConsumeUint32() uint32 { + if b.Err != nil { + return 0 + } + + if b.Len() < 4 { + b.off = len(b.b) + b.Err = ErrShortPacket + return 0 + } + + v := binary.BigEndian.Uint32(b.b[b.off:]) + b.off += 4 + return v +} + +// AppendUint32 appends a single uint32 into the buffer, in network byte order (big-endian). +func (b *Buffer) AppendUint32(v uint32) { + b.b = append(b.b, + byte(v>>24), + byte(v>>16), + byte(v>>8), + byte(v>>0), + ) +} + +// ConsumeCount consumes a single uint32 count from the buffer, in network byte order (big-endian) as an int. +// If the buffer does not have enough data, it will set Err to ErrShortPacket. +func (b *Buffer) ConsumeCount() int { + return int(b.ConsumeUint32()) +} + +// AppendCount appends a single int length as a uint32 into the buffer, in network byte order (big-endian). +func (b *Buffer) AppendCount(v int) { + b.AppendUint32(uint32(v)) +} + +// ConsumeUint64 consumes a single uint64 from the buffer, in network byte order (big-endian). +// If the buffer does not have enough data, it will set Err to ErrShortPacket. +func (b *Buffer) ConsumeUint64() uint64 { + if b.Err != nil { + return 0 + } + + if b.Len() < 8 { + b.off = len(b.b) + b.Err = ErrShortPacket + return 0 + } + + v := binary.BigEndian.Uint64(b.b[b.off:]) + b.off += 8 + return v +} + +// AppendUint64 appends a single uint64 into the buffer, in network byte order (big-endian). +func (b *Buffer) AppendUint64(v uint64) { + b.b = append(b.b, + byte(v>>56), + byte(v>>48), + byte(v>>40), + byte(v>>32), + byte(v>>24), + byte(v>>16), + byte(v>>8), + byte(v>>0), + ) +} + +// ConsumeInt64 consumes a single int64 from the buffer, in network byte order (big-endian) with two’s complement. +// If the buffer does not have enough data, it will set Err to ErrShortPacket. +func (b *Buffer) ConsumeInt64() int64 { + return int64(b.ConsumeUint64()) +} + +// AppendInt64 appends a single int64 into the buffer, in network byte order (big-endian) with two’s complement. +func (b *Buffer) AppendInt64(v int64) { + b.AppendUint64(uint64(v)) +} + +// ConsumeByteSlice consumes a single string of raw binary data from the buffer. +// A string is a uint32 length, followed by that number of raw bytes. +// If the buffer does not have enough data, or defines a length larger than available, it will set Err to ErrShortPacket. +// +// The returned slice aliases the buffer contents, and is valid only as long as the buffer is not reused +// (that is, only until the next call to Reset, PutLength, StartPacket, or UnmarshalBinary). +// +// In no case will any Consume calls return overlapping slice aliases, +// and Append calls are guaranteed to not disturb this slice alias. +func (b *Buffer) ConsumeByteSlice() []byte { + length := int(b.ConsumeUint32()) + if b.Err != nil { + return nil + } + + if b.Len() < length || length < 0 { + b.off = len(b.b) + b.Err = ErrShortPacket + return nil + } + + v := b.b[b.off:] + if len(v) > length || cap(v) > length { + v = v[:length:length] + } + b.off += int(length) + return v +} + +// ConsumeByteSliceCopy consumes a single string of raw binary data as a copy from the buffer. +// A string is a uint32 length, followed by that number of raw bytes. +// If the buffer does not have enough data, or defines a length larger than available, it will set Err to ErrShortPacket. +// +// The returned slice does not alias any buffer contents, +// and will therefore be valid even if the buffer is later reused. +// +// If hint has sufficient capacity to hold the data, it will be reused and overwritten, +// otherwise a new backing slice will be allocated and returned. +func (b *Buffer) ConsumeByteSliceCopy(hint []byte) []byte { + data := b.ConsumeByteSlice() + + if grow := len(data) - len(hint); grow > 0 { + hint = append(hint, make([]byte, grow)...) + } + + n := copy(hint, data) + hint = hint[:n] + return hint +} + +// AppendByteSlice appends a single string of raw binary data into the buffer. +// A string is a uint32 length, followed by that number of raw bytes. +func (b *Buffer) AppendByteSlice(v []byte) { + b.AppendUint32(uint32(len(v))) + b.b = append(b.b, v...) +} + +// ConsumeString consumes a single string of binary data from the buffer. +// A string is a uint32 length, followed by that number of raw bytes. +// If the buffer does not have enough data, or defines a length larger than available, it will set Err to ErrShortPacket. +// +// NOTE: Go implicitly assumes that strings contain UTF-8 encoded data. +// All caveats on using arbitrary binary data in Go strings applies. +func (b *Buffer) ConsumeString() string { + return string(b.ConsumeByteSlice()) +} + +// AppendString appends a single string of binary data into the buffer. +// A string is a uint32 length, followed by that number of raw bytes. +func (b *Buffer) AppendString(v string) { + b.AppendByteSlice([]byte(v)) +} + +// PutLength writes the given size into the first four bytes of the buffer in network byte order (big endian). +func (b *Buffer) PutLength(size int) { + if len(b.b) < 4 { + b.b = append(b.b, make([]byte, 4-len(b.b))...) + } + + binary.BigEndian.PutUint32(b.b, uint32(size)) +} + +// MarshalBinary returns a clone of the full internal buffer. +func (b *Buffer) MarshalBinary() ([]byte, error) { + clone := make([]byte, len(b.b)) + n := copy(clone, b.b) + return clone[:n], nil +} + +// UnmarshalBinary sets the internal buffer of b to be a clone of data, and zeros the internal offset. +func (b *Buffer) UnmarshalBinary(data []byte) error { + if grow := len(data) - len(b.b); grow > 0 { + b.b = append(b.b, make([]byte, grow)...) + } + + n := copy(b.b, data) + b.b = b.b[:n] + b.off = 0 + return nil +} diff --git a/internal/encoding/ssh/filexfer/extended_packets.go b/internal/encoding/ssh/filexfer/extended_packets.go new file mode 100644 index 00000000..f7174253 --- /dev/null +++ b/internal/encoding/ssh/filexfer/extended_packets.go @@ -0,0 +1,143 @@ +package sshfx + +import ( + "encoding" + "sync" +) + +// ExtendedData aliases the untyped interface composition of encoding.BinaryMarshaler and encoding.BinaryUnmarshaler. +type ExtendedData = interface { + encoding.BinaryMarshaler + encoding.BinaryUnmarshaler +} + +// ExtendedDataConstructor defines a function that returns a new(ArbitraryExtendedPacket). +type ExtendedDataConstructor func() ExtendedData + +var extendedPacketTypes = struct { + mu sync.RWMutex + constructors map[string]ExtendedDataConstructor +}{ + constructors: make(map[string]ExtendedDataConstructor), +} + +// RegisterExtendedPacketType defines a specific ExtendedDataConstructor for the given extension string. +func RegisterExtendedPacketType(extension string, constructor ExtendedDataConstructor) { + extendedPacketTypes.mu.Lock() + defer extendedPacketTypes.mu.Unlock() + + if _, exist := extendedPacketTypes.constructors[extension]; exist { + panic("encoding/ssh/filexfer: multiple registration of extended packet type " + extension) + } + + extendedPacketTypes.constructors[extension] = constructor +} + +func newExtendedPacket(extension string) ExtendedData { + extendedPacketTypes.mu.RLock() + defer extendedPacketTypes.mu.RUnlock() + + if f := extendedPacketTypes.constructors[extension]; f != nil { + return f() + } + + return new(Buffer) +} + +// ExtendedPacket defines the SSH_FXP_CLOSE packet. +type ExtendedPacket struct { + ExtendedRequest string + + Data ExtendedData +} + +// Type returns the SSH_FXP_xy value associated with this packet type. +func (p *ExtendedPacket) Type() PacketType { + return PacketTypeExtended +} + +// MarshalPacket returns p as a two-part binary encoding of p. +// +// The Data is marshaled into binary, and returned as the payload. +func (p *ExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.ExtendedRequest) // string(extended-request) + buf = NewMarshalBuffer(size) + } + + buf.StartPacket(PacketTypeExtended, reqid) + buf.AppendString(p.ExtendedRequest) + + if p.Data != nil { + payload, err = p.Data.MarshalBinary() + if err != nil { + return nil, nil, err + } + } + + return buf.Packet(payload) +} + +// UnmarshalPacketBody unmarshals the packet body from the given Buffer. +// It is assumed that the uint32(request-id) has already been consumed. +// +// If p.Data is nil, and the extension has been registered, a new type will be made from the registration. +// If the extension has not been registered, then a new Buffer will be allocated. +// Then the request-specific-data will be unmarshaled from the rest of the buffer. +func (p *ExtendedPacket) UnmarshalPacketBody(buf *Buffer) (err error) { + p.ExtendedRequest = buf.ConsumeString() + if buf.Err != nil { + return buf.Err + } + + if p.Data == nil { + p.Data = newExtendedPacket(p.ExtendedRequest) + } + + return p.Data.UnmarshalBinary(buf.Bytes()) +} + +// ExtendedReplyPacket defines the SSH_FXP_CLOSE packet. +type ExtendedReplyPacket struct { + Data ExtendedData +} + +// Type returns the SSH_FXP_xy value associated with this packet type. +func (p *ExtendedReplyPacket) Type() PacketType { + return PacketTypeExtendedReply +} + +// MarshalPacket returns p as a two-part binary encoding of p. +// +// The Data is marshaled into binary, and returned as the payload. +func (p *ExtendedReplyPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + buf = NewMarshalBuffer(0) + } + + buf.StartPacket(PacketTypeExtendedReply, reqid) + + if p.Data != nil { + payload, err = p.Data.MarshalBinary() + if err != nil { + return nil, nil, err + } + } + + return buf.Packet(payload) +} + +// UnmarshalPacketBody unmarshals the packet body from the given Buffer. +// It is assumed that the uint32(request-id) has already been consumed. +// +// If p.Data is nil, and there is request-specific-data, +// then the request-specific-data will be wrapped in a Buffer and assigned to p.Data. +func (p *ExtendedReplyPacket) UnmarshalPacketBody(buf *Buffer) (err error) { + if p.Data == nil { + p.Data = new(Buffer) + } + + return p.Data.UnmarshalBinary(buf.Bytes()) +} diff --git a/internal/encoding/ssh/filexfer/extended_packets_test.go b/internal/encoding/ssh/filexfer/extended_packets_test.go new file mode 100644 index 00000000..d7a54d3c --- /dev/null +++ b/internal/encoding/ssh/filexfer/extended_packets_test.go @@ -0,0 +1,240 @@ +package sshfx + +import ( + "bytes" + "testing" +) + +type testExtendedData struct { + value uint8 +} + +func (d *testExtendedData) MarshalBinary() ([]byte, error) { + buf := NewBuffer(make([]byte, 0, 4)) + + buf.AppendUint8(d.value ^ 0x2a) + + return buf.Bytes(), nil +} + +func (d *testExtendedData) UnmarshalBinary(data []byte) error { + buf := NewBuffer(data) + + v := buf.ConsumeUint8() + if buf.Err != nil { + return buf.Err + } + + d.value = v ^ 0x2a + + return nil +} + +var _ Packet = &ExtendedPacket{} + +func TestExtendedPacketNoData(t *testing.T) { + const ( + id = 42 + extendedRequest = "foo@example" + ) + + p := &ExtendedPacket{ + ExtendedRequest: extendedRequest, + } + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 20, + 200, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 11, 'f', 'o', 'o', '@', 'e', 'x', 'a', 'm', 'p', 'l', 'e', + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = ExtendedPacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.ExtendedRequest != extendedRequest { + t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extendedRequest) + } +} + +func TestExtendedPacketTestData(t *testing.T) { + const ( + id = 42 + extendedRequest = "foo@example" + textValue = 13 + ) + + const value = 13 + + p := &ExtendedPacket{ + ExtendedRequest: extendedRequest, + Data: &testExtendedData{ + value: textValue, + }, + } + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 21, + 200, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 11, 'f', 'o', 'o', '@', 'e', 'x', 'a', 'm', 'p', 'l', 'e', + 0x27, + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = ExtendedPacket{ + Data: new(testExtendedData), + } + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.ExtendedRequest != extendedRequest { + t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extendedRequest) + } + + if buf, ok := p.Data.(*testExtendedData); !ok { + t.Errorf("UnmarshalPacketBody(): Data was type %T, but expected %T", p.Data, buf) + + } else if buf.value != value { + t.Errorf("UnmarshalPacketBody(): Data.value was %#x, but expected %#x", buf.value, value) + } + + *p = ExtendedPacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.ExtendedRequest != extendedRequest { + t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extendedRequest) + } + + wantBuffer := []byte{0x27} + + if buf, ok := p.Data.(*Buffer); !ok { + t.Errorf("UnmarshalPacketBody(): Data was type %T, but expected %T", p.Data, buf) + + } else if !bytes.Equal(buf.b, wantBuffer) { + t.Errorf("UnmarshalPacketBody(): Data was %X, but expected %X", buf.b, wantBuffer) + } +} + +var _ Packet = &ExtendedReplyPacket{} + +func TestExtendedReplyNoData(t *testing.T) { + const ( + id = 42 + ) + + p := &ExtendedReplyPacket{} + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 5, + 201, + 0x00, 0x00, 0x00, 42, + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = ExtendedReplyPacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } +} + +func TestExtendedReplyPacketTestData(t *testing.T) { + const ( + id = 42 + textValue = 13 + ) + + const value = 13 + + p := &ExtendedReplyPacket{ + Data: &testExtendedData{ + value: textValue, + }, + } + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 6, + 201, + 0x00, 0x00, 0x00, 42, + 0x27, + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = ExtendedReplyPacket{ + Data: new(testExtendedData), + } + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if buf, ok := p.Data.(*testExtendedData); !ok { + t.Errorf("UnmarshalPacketBody(): Data was type %T, but expected %T", p.Data, buf) + + } else if buf.value != value { + t.Errorf("UnmarshalPacketBody(): Data.value was %#x, but expected %#x", buf.value, value) + } + + *p = ExtendedReplyPacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + wantBuffer := []byte{0x27} + + if buf, ok := p.Data.(*Buffer); !ok { + t.Errorf("UnmarshalPacketBody(): Data was type %T, but expected %T", p.Data, buf) + + } else if !bytes.Equal(buf.b, wantBuffer) { + t.Errorf("UnmarshalPacketBody(): Data was %X, but expected %X", buf.b, wantBuffer) + } +} diff --git a/internal/encoding/ssh/filexfer/extensions.go b/internal/encoding/ssh/filexfer/extensions.go new file mode 100644 index 00000000..c425780c --- /dev/null +++ b/internal/encoding/ssh/filexfer/extensions.go @@ -0,0 +1,43 @@ +package sshfx + +// ExtensionPair defines the extension-pair type defined in draft-ietf-secsh-filexfer-13. +// This type is backwards-compatible with how draft-ietf-secsh-filexfer-02 defines extensions. +// +// Defined in: https://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-4.2 +type ExtensionPair struct { + Name string + Data string +} + +// Len returns the number of bytes e would marshal into. +func (e *ExtensionPair) Len() int { + return 4 + len(e.Name) + 4 + len(e.Data) +} + +// MarshalInto marshals e onto the end of the given Buffer. +func (e *ExtensionPair) MarshalInto(buf *Buffer) { + buf.AppendString(e.Name) + buf.AppendString(e.Data) +} + +// MarshalBinary returns e as the binary encoding of e. +func (e *ExtensionPair) MarshalBinary() ([]byte, error) { + buf := NewBuffer(make([]byte, 0, e.Len())) + e.MarshalInto(buf) + return buf.Bytes(), nil +} + +// UnmarshalFrom unmarshals an ExtensionPair from the given Buffer into e. +func (e *ExtensionPair) UnmarshalFrom(buf *Buffer) (err error) { + *e = ExtensionPair{ + Name: buf.ConsumeString(), + Data: buf.ConsumeString(), + } + + return buf.Err +} + +// UnmarshalBinary decodes the binary encoding of ExtensionPair into e. +func (e *ExtensionPair) UnmarshalBinary(data []byte) error { + return e.UnmarshalFrom(NewBuffer(data)) +} diff --git a/internal/encoding/ssh/filexfer/extensions_test.go b/internal/encoding/ssh/filexfer/extensions_test.go new file mode 100644 index 00000000..15f2a792 --- /dev/null +++ b/internal/encoding/ssh/filexfer/extensions_test.go @@ -0,0 +1,49 @@ +package sshfx + +import ( + "bytes" + "testing" +) + +func TestExtensionPair(t *testing.T) { + const ( + name = "foo" + data = "1" + ) + + pair := &ExtensionPair{ + Name: name, + Data: data, + } + + buf, err := pair.MarshalBinary() + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 3, + 'f', 'o', 'o', + 0x00, 0x00, 0x00, 1, + '1', + } + + if !bytes.Equal(buf, want) { + t.Errorf("ExtensionPair.MarshalBinary() = %X, but wanted %X", buf, want) + } + + *pair = ExtensionPair{} + + if err := pair.UnmarshalBinary(buf); err != nil { + t.Fatal("unexpected error:", err) + } + + if pair.Name != name { + t.Errorf("ExtensionPair.UnmarshalBinary(): Name was %q, but expected %q", pair.Name, name) + } + + if pair.Data != data { + t.Errorf("RawPacket.UnmarshalBinary(): Data was %q, but expected %q", pair.Data, data) + } + +} diff --git a/internal/encoding/ssh/filexfer/filexfer.go b/internal/encoding/ssh/filexfer/filexfer.go new file mode 100644 index 00000000..d3009994 --- /dev/null +++ b/internal/encoding/ssh/filexfer/filexfer.go @@ -0,0 +1,54 @@ +// Package sshfx implements the wire encoding for secsh-filexfer as described in https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt +package sshfx + +// PacketMarshaller narrowly defines packets that will only be transmitted. +// +// ExtendedPacket types will often only implement this interface, +// since decoding the whole packet body of an ExtendedPacket can only be done dependent on the ExtendedRequest field. +type PacketMarshaller interface { + // MarshalPacket is the primary intended way to encode a packet. + // The request-id for the packet is set from reqid. + // + // An optional buffer may be given in b. + // If the buffer has a minimum capacity, it shall be truncated and used to marshal the header into. + // The minimum capacity for the packet must be a constant expression, and should be at least 9. + // + // It shall return the main body of the encoded packet in header, + // and may optionally return an additional payload to be written immediately after the header. + // + // It shall encode in the first 4-bytes of the header the proper length of the rest of the header+payload. + MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) +} + +// Packet defines the behavior of a full generic SFTP packet. +// +// InitPacket, and VersionPacket are not generic SFTP packets, and instead implement (Un)MarshalBinary. +// +// ExtendedPacket types should not iplement this interface, +// since decoding the whole packet body of an ExtendedPacket can only be done dependent on the ExtendedRequest field. +type Packet interface { + PacketMarshaller + + // Type returns the SSH_FXP_xy value associated with the specific packet. + Type() PacketType + + // UnmarshalPacketBody decodes a packet body from the given Buffer. + // It is assumed that the common header values of the length, type and request-id have already been consumed. + // + // Implementations should not alias the given Buffer, + // instead they can consider prepopulating an internal buffer as a hint, + // and copying into that buffer if it has sufficient length. + UnmarshalPacketBody(buf *Buffer) error +} + +// ComposePacket converts returns from MarshalPacket into an equivalent call to MarshalBinary. +func ComposePacket(header, payload []byte, err error) ([]byte, error) { + return append(header, payload...), err +} + +// Default length values, +// Defined in draft-ietf-secsh-filexfer-02 section 3. +const ( + DefaultMaxPacketLength = 34000 + DefaultMaxDataLength = 32768 +) diff --git a/internal/encoding/ssh/filexfer/fx.go b/internal/encoding/ssh/filexfer/fx.go new file mode 100644 index 00000000..9abcbafc --- /dev/null +++ b/internal/encoding/ssh/filexfer/fx.go @@ -0,0 +1,147 @@ +package sshfx + +import ( + "fmt" +) + +// Status defines the SFTP error codes used in SSH_FXP_STATUS response packets. +type Status uint32 + +// Defines the various SSH_FX_* values. +const ( + // see draft-ietf-secsh-filexfer-02 + // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt#section-7 + StatusOK = Status(iota) + StatusEOF + StatusNoSuchFile + StatusPermissionDenied + StatusFailure + StatusBadMessage + StatusNoConnection + StatusConnectionLost + StatusOPUnsupported + + // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-03.txt#section-7 + StatusV4InvalidHandle + StatusV4NoSuchPath + StatusV4FileAlreadyExists + StatusV4WriteProtect + + // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-04.txt#section-7 + StatusV4NoMedia + + // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-05.txt#section-7 + StatusV5NoSpaceOnFilesystem + StatusV5QuotaExceeded + StatusV5UnknownPrincipal + StatusV5LockConflict + + // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-06.txt#section-8 + StatusV6DirNotEmpty + StatusV6NotADirectory + StatusV6InvalidFilename + StatusV6LinkLoop + + // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-07.txt#section-8 + StatusV6CannotDelete + StatusV6InvalidParameter + StatusV6FileIsADirectory + StatusV6ByteRangeLockConflict + StatusV6ByteRangeLockRefused + StatusV6DeletePending + + // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-08.txt#section-8.1 + StatusV6FileCorrupt + + // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-10.txt#section-9.1 + StatusV6OwnerInvalid + StatusV6GroupInvalid + + // https://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1 + StatusV6NoMatchingByteRangeLock +) + +func (s Status) Error() string { + return s.String() +} + +// Is returns true if the target is the same Status code, +// or target is a StatusPacket with the same Status code. +func (s Status) Is(target error) bool { + if target, ok := target.(*StatusPacket); ok { + return target.StatusCode == s + } + + return s == target +} + +func (s Status) String() string { + switch s { + case StatusOK: + return "SSH_FX_OK" + case StatusEOF: + return "SSH_FX_EOF" + case StatusNoSuchFile: + return "SSH_FX_NO_SUCH_FILE" + case StatusPermissionDenied: + return "SSH_FX_PERMISSION_DENIED" + case StatusFailure: + return "SSH_FX_FAILURE" + case StatusBadMessage: + return "SSH_FX_BAD_MESSAGE" + case StatusNoConnection: + return "SSH_FX_NO_CONNECTION" + case StatusConnectionLost: + return "SSH_FX_CONNECTION_LOST" + case StatusOPUnsupported: + return "SSH_FX_OP_UNSUPPORTED" + case StatusV4InvalidHandle: + return "SSH_FX_INVALID_HANDLE" + case StatusV4NoSuchPath: + return "SSH_FX_NO_SUCH_PATH" + case StatusV4FileAlreadyExists: + return "SSH_FX_FILE_ALREADY_EXISTS" + case StatusV4WriteProtect: + return "SSH_FX_WRITE_PROTECT" + case StatusV4NoMedia: + return "SSH_FX_NO_MEDIA" + case StatusV5NoSpaceOnFilesystem: + return "SSH_FX_NO_SPACE_ON_FILESYSTEM" + case StatusV5QuotaExceeded: + return "SSH_FX_QUOTA_EXCEEDED" + case StatusV5UnknownPrincipal: + return "SSH_FX_UNKNOWN_PRINCIPAL" + case StatusV5LockConflict: + return "SSH_FX_LOCK_CONFLICT" + case StatusV6DirNotEmpty: + return "SSH_FX_DIR_NOT_EMPTY" + case StatusV6NotADirectory: + return "SSH_FX_NOT_A_DIRECTORY" + case StatusV6InvalidFilename: + return "SSH_FX_INVALID_FILENAME" + case StatusV6LinkLoop: + return "SSH_FX_LINK_LOOP" + case StatusV6CannotDelete: + return "SSH_FX_CANNOT_DELETE" + case StatusV6InvalidParameter: + return "SSH_FX_INVALID_PARAMETER" + case StatusV6FileIsADirectory: + return "SSH_FX_FILE_IS_A_DIRECTORY" + case StatusV6ByteRangeLockConflict: + return "SSH_FX_BYTE_RANGE_LOCK_CONFLICT" + case StatusV6ByteRangeLockRefused: + return "SSH_FX_BYTE_RANGE_LOCK_REFUSED" + case StatusV6DeletePending: + return "SSH_FX_DELETE_PENDING" + case StatusV6FileCorrupt: + return "SSH_FX_FILE_CORRUPT" + case StatusV6OwnerInvalid: + return "SSH_FX_OWNER_INVALID" + case StatusV6GroupInvalid: + return "SSH_FX_GROUP_INVALID" + case StatusV6NoMatchingByteRangeLock: + return "SSH_FX_NO_MATCHING_BYTE_RANGE_LOCK" + default: + return fmt.Sprintf("SSH_FX_UNKNOWN(%d)", s) + } +} diff --git a/internal/encoding/ssh/filexfer/fx_test.go b/internal/encoding/ssh/filexfer/fx_test.go new file mode 100644 index 00000000..95480591 --- /dev/null +++ b/internal/encoding/ssh/filexfer/fx_test.go @@ -0,0 +1,102 @@ +package sshfx + +import ( + "bufio" + "errors" + "regexp" + "strconv" + "strings" + "testing" +) + +// This string data is copied verbatim from https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-13.txt +var fxStandardsText = ` +SSH_FX_OK 0 +SSH_FX_EOF 1 +SSH_FX_NO_SUCH_FILE 2 +SSH_FX_PERMISSION_DENIED 3 +SSH_FX_FAILURE 4 +SSH_FX_BAD_MESSAGE 5 +SSH_FX_NO_CONNECTION 6 +SSH_FX_CONNECTION_LOST 7 +SSH_FX_OP_UNSUPPORTED 8 +SSH_FX_INVALID_HANDLE 9 +SSH_FX_NO_SUCH_PATH 10 +SSH_FX_FILE_ALREADY_EXISTS 11 +SSH_FX_WRITE_PROTECT 12 +SSH_FX_NO_MEDIA 13 +SSH_FX_NO_SPACE_ON_FILESYSTEM 14 +SSH_FX_QUOTA_EXCEEDED 15 +SSH_FX_UNKNOWN_PRINCIPAL 16 +SSH_FX_LOCK_CONFLICT 17 +SSH_FX_DIR_NOT_EMPTY 18 +SSH_FX_NOT_A_DIRECTORY 19 +SSH_FX_INVALID_FILENAME 20 +SSH_FX_LINK_LOOP 21 +SSH_FX_CANNOT_DELETE 22 +SSH_FX_INVALID_PARAMETER 23 +SSH_FX_FILE_IS_A_DIRECTORY 24 +SSH_FX_BYTE_RANGE_LOCK_CONFLICT 25 +SSH_FX_BYTE_RANGE_LOCK_REFUSED 26 +SSH_FX_DELETE_PENDING 27 +SSH_FX_FILE_CORRUPT 28 +SSH_FX_OWNER_INVALID 29 +SSH_FX_GROUP_INVALID 30 +SSH_FX_NO_MATCHING_BYTE_RANGE_LOCK 31 +` + +func TestFxNames(t *testing.T) { + whitespace := regexp.MustCompile(`[[:space:]]+`) + + scan := bufio.NewScanner(strings.NewReader(fxStandardsText)) + + for scan.Scan() { + line := scan.Text() + if i := strings.Index(line, "//"); i >= 0 { + line = line[:i] + } + + line = strings.TrimSpace(line) + if line == "" { + continue + } + + fields := whitespace.Split(line, 2) + if len(fields) < 2 { + t.Fatalf("unexpected standards text line: %q", line) + } + + name, value := fields[0], fields[1] + n, err := strconv.Atoi(value) + if err != nil { + t.Fatal("unexpected error:", err) + } + + fx := Status(n) + + if got := fx.String(); got != name { + t.Errorf("fx name mismatch for %d: got %q, but want %q", n, got, name) + } + } + + if err := scan.Err(); err != nil { + t.Fatal("unexpected error:", err) + } +} + +func TestStatusIs(t *testing.T) { + status := StatusFailure + + if !errors.Is(status, StatusFailure) { + t.Error("errors.Is(StatusFailure, StatusFailure) != true") + } + if !errors.Is(status, &StatusPacket{StatusCode: StatusFailure}) { + t.Error("errors.Is(StatusFailure, StatusPacket{StatusFailure}) != true") + } + if errors.Is(status, StatusOK) { + t.Error("errors.Is(StatusFailure, StatusFailure) == true") + } + if errors.Is(status, &StatusPacket{StatusCode: StatusOK}) { + t.Error("errors.Is(StatusFailure, StatusPacket{StatusFailure}) == true") + } +} diff --git a/internal/encoding/ssh/filexfer/fxp.go b/internal/encoding/ssh/filexfer/fxp.go new file mode 100644 index 00000000..78080021 --- /dev/null +++ b/internal/encoding/ssh/filexfer/fxp.go @@ -0,0 +1,169 @@ +package sshfx + +import ( + "fmt" +) + +// PacketType defines the various SFTP packet types. +type PacketType uint8 + +// Request packet types. +const ( + // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt#section-3 + PacketTypeInit = PacketType(iota + 1) + PacketTypeVersion + PacketTypeOpen + PacketTypeClose + PacketTypeRead + PacketTypeWrite + PacketTypeLStat + PacketTypeFStat + PacketTypeSetstat + PacketTypeFSetstat + PacketTypeOpenDir + PacketTypeReadDir + PacketTypeRemove + PacketTypeMkdir + PacketTypeRmdir + PacketTypeRealPath + PacketTypeStat + PacketTypeRename + PacketTypeReadLink + PacketTypeSymlink + + // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-07.txt#section-3.3 + PacketTypeV6Link + + // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-08.txt#section-3.3 + PacketTypeV6Block + PacketTypeV6Unblock +) + +// Response packet types. +const ( + // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt#section-3 + PacketTypeStatus = PacketType(iota + 101) + PacketTypeHandle + PacketTypeData + PacketTypeName + PacketTypeAttrs +) + +// Extended packet types. +const ( + // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt#section-3 + PacketTypeExtended = PacketType(iota + 200) + PacketTypeExtendedReply +) + +func (f PacketType) String() string { + switch f { + case PacketTypeInit: + return "SSH_FXP_INIT" + case PacketTypeVersion: + return "SSH_FXP_VERSION" + case PacketTypeOpen: + return "SSH_FXP_OPEN" + case PacketTypeClose: + return "SSH_FXP_CLOSE" + case PacketTypeRead: + return "SSH_FXP_READ" + case PacketTypeWrite: + return "SSH_FXP_WRITE" + case PacketTypeLStat: + return "SSH_FXP_LSTAT" + case PacketTypeFStat: + return "SSH_FXP_FSTAT" + case PacketTypeSetstat: + return "SSH_FXP_SETSTAT" + case PacketTypeFSetstat: + return "SSH_FXP_FSETSTAT" + case PacketTypeOpenDir: + return "SSH_FXP_OPENDIR" + case PacketTypeReadDir: + return "SSH_FXP_READDIR" + case PacketTypeRemove: + return "SSH_FXP_REMOVE" + case PacketTypeMkdir: + return "SSH_FXP_MKDIR" + case PacketTypeRmdir: + return "SSH_FXP_RMDIR" + case PacketTypeRealPath: + return "SSH_FXP_REALPATH" + case PacketTypeStat: + return "SSH_FXP_STAT" + case PacketTypeRename: + return "SSH_FXP_RENAME" + case PacketTypeReadLink: + return "SSH_FXP_READLINK" + case PacketTypeSymlink: + return "SSH_FXP_SYMLINK" + case PacketTypeV6Link: + return "SSH_FXP_LINK" + case PacketTypeV6Block: + return "SSH_FXP_BLOCK" + case PacketTypeV6Unblock: + return "SSH_FXP_UNBLOCK" + case PacketTypeStatus: + return "SSH_FXP_STATUS" + case PacketTypeHandle: + return "SSH_FXP_HANDLE" + case PacketTypeData: + return "SSH_FXP_DATA" + case PacketTypeName: + return "SSH_FXP_NAME" + case PacketTypeAttrs: + return "SSH_FXP_ATTRS" + case PacketTypeExtended: + return "SSH_FXP_EXTENDED" + case PacketTypeExtendedReply: + return "SSH_FXP_EXTENDED_REPLY" + default: + return fmt.Sprintf("SSH_FXP_UNKNOWN(%d)", f) + } +} + +func newPacketFromType(typ PacketType) (Packet, error) { + switch typ { + case PacketTypeOpen: + return new(OpenPacket), nil + case PacketTypeClose: + return new(ClosePacket), nil + case PacketTypeRead: + return new(ReadPacket), nil + case PacketTypeWrite: + return new(WritePacket), nil + case PacketTypeLStat: + return new(LStatPacket), nil + case PacketTypeFStat: + return new(FStatPacket), nil + case PacketTypeSetstat: + return new(SetstatPacket), nil + case PacketTypeFSetstat: + return new(FSetstatPacket), nil + case PacketTypeOpenDir: + return new(OpenDirPacket), nil + case PacketTypeReadDir: + return new(ReadDirPacket), nil + case PacketTypeRemove: + return new(RemovePacket), nil + case PacketTypeMkdir: + return new(MkdirPacket), nil + case PacketTypeRmdir: + return new(RmdirPacket), nil + case PacketTypeRealPath: + return new(RealPathPacket), nil + case PacketTypeStat: + return new(StatPacket), nil + case PacketTypeRename: + return new(RenamePacket), nil + case PacketTypeReadLink: + return new(ReadLinkPacket), nil + case PacketTypeSymlink: + return new(SymlinkPacket), nil + case PacketTypeExtended: + return new(ExtendedPacket), nil + default: + return nil, fmt.Errorf("unexpected request packet type: %v", typ) + } +} diff --git a/internal/encoding/ssh/filexfer/fxp_test.go b/internal/encoding/ssh/filexfer/fxp_test.go new file mode 100644 index 00000000..b0a3a830 --- /dev/null +++ b/internal/encoding/ssh/filexfer/fxp_test.go @@ -0,0 +1,85 @@ +package sshfx + +import ( + "bufio" + "regexp" + "strconv" + "strings" + "testing" +) + +// This string data is copied verbatim from https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-13.txt +// except where commented that it was taken from a different source. +var fxpStandardsText = ` +SSH_FXP_INIT 1 +SSH_FXP_VERSION 2 +SSH_FXP_OPEN 3 +SSH_FXP_CLOSE 4 +SSH_FXP_READ 5 +SSH_FXP_WRITE 6 +SSH_FXP_LSTAT 7 +SSH_FXP_FSTAT 8 +SSH_FXP_SETSTAT 9 +SSH_FXP_FSETSTAT 10 +SSH_FXP_OPENDIR 11 +SSH_FXP_READDIR 12 +SSH_FXP_REMOVE 13 +SSH_FXP_MKDIR 14 +SSH_FXP_RMDIR 15 +SSH_FXP_REALPATH 16 +SSH_FXP_STAT 17 +SSH_FXP_RENAME 18 +SSH_FXP_READLINK 19 +SSH_FXP_SYMLINK 20 // Deprecated in filexfer-13 added from filexfer-02 +SSH_FXP_LINK 21 +SSH_FXP_BLOCK 22 +SSH_FXP_UNBLOCK 23 + +SSH_FXP_STATUS 101 +SSH_FXP_HANDLE 102 +SSH_FXP_DATA 103 +SSH_FXP_NAME 104 +SSH_FXP_ATTRS 105 + +SSH_FXP_EXTENDED 200 +SSH_FXP_EXTENDED_REPLY 201 +` + +func TestFxpNames(t *testing.T) { + whitespace := regexp.MustCompile(`[[:space:]]+`) + + scan := bufio.NewScanner(strings.NewReader(fxpStandardsText)) + + for scan.Scan() { + line := scan.Text() + if i := strings.Index(line, "//"); i >= 0 { + line = line[:i] + } + + line = strings.TrimSpace(line) + if line == "" { + continue + } + + fields := whitespace.Split(line, 2) + if len(fields) < 2 { + t.Fatalf("unexpected standards text line: %q", line) + } + + name, value := fields[0], fields[1] + n, err := strconv.Atoi(value) + if err != nil { + t.Fatal("unexpected error:", err) + } + + fxp := PacketType(n) + + if got := fxp.String(); got != name { + t.Errorf("fxp name mismatch for %d: got %q, but want %q", n, got, name) + } + } + + if err := scan.Err(); err != nil { + t.Fatal("unexpected error:", err) + } +} diff --git a/internal/encoding/ssh/filexfer/handle_packets.go b/internal/encoding/ssh/filexfer/handle_packets.go new file mode 100644 index 00000000..44594acf --- /dev/null +++ b/internal/encoding/ssh/filexfer/handle_packets.go @@ -0,0 +1,230 @@ +package sshfx + +// ClosePacket defines the SSH_FXP_CLOSE packet. +type ClosePacket struct { + Handle string +} + +// Type returns the SSH_FXP_xy value associated with this packet type. +func (p *ClosePacket) Type() PacketType { + return PacketTypeClose +} + +// MarshalPacket returns p as a two-part binary encoding of p. +func (p *ClosePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Handle) // string(handle) + buf = NewMarshalBuffer(size) + } + + buf.StartPacket(PacketTypeClose, reqid) + buf.AppendString(p.Handle) + + return buf.Packet(payload) +} + +// UnmarshalPacketBody unmarshals the packet body from the given Buffer. +// It is assumed that the uint32(request-id) has already been consumed. +func (p *ClosePacket) UnmarshalPacketBody(buf *Buffer) (err error) { + *p = ClosePacket{ + Handle: buf.ConsumeString(), + } + + return buf.Err +} + +// ReadPacket defines the SSH_FXP_READ packet. +type ReadPacket struct { + Handle string + Offset uint64 + Length uint32 +} + +// Type returns the SSH_FXP_xy value associated with this packet type. +func (p *ReadPacket) Type() PacketType { + return PacketTypeRead +} + +// MarshalPacket returns p as a two-part binary encoding of p. +func (p *ReadPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + // string(handle) + uint64(offset) + uint32(len) + size := 4 + len(p.Handle) + 8 + 4 + buf = NewMarshalBuffer(size) + } + + buf.StartPacket(PacketTypeRead, reqid) + buf.AppendString(p.Handle) + buf.AppendUint64(p.Offset) + buf.AppendUint32(p.Length) + + return buf.Packet(payload) +} + +// UnmarshalPacketBody unmarshals the packet body from the given Buffer. +// It is assumed that the uint32(request-id) has already been consumed. +func (p *ReadPacket) UnmarshalPacketBody(buf *Buffer) (err error) { + *p = ReadPacket{ + Handle: buf.ConsumeString(), + Offset: buf.ConsumeUint64(), + Length: buf.ConsumeUint32(), + } + + return buf.Err +} + +// WritePacket defines the SSH_FXP_WRITE packet. +type WritePacket struct { + Handle string + Offset uint64 + Data []byte +} + +// Type returns the SSH_FXP_xy value associated with this packet type. +func (p *WritePacket) Type() PacketType { + return PacketTypeWrite +} + +// MarshalPacket returns p as a two-part binary encoding of p. +func (p *WritePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + // string(handle) + uint64(offset) + uint32(len(data)); data content in payload + size := 4 + len(p.Handle) + 8 + 4 + buf = NewMarshalBuffer(size) + } + + buf.StartPacket(PacketTypeWrite, reqid) + buf.AppendString(p.Handle) + buf.AppendUint64(p.Offset) + buf.AppendUint32(uint32(len(p.Data))) + + return buf.Packet(p.Data) +} + +// UnmarshalPacketBody unmarshals the packet body from the given Buffer. +// It is assumed that the uint32(request-id) has already been consumed. +// +// If p.Data is already populated, and of sufficient length to hold the data, +// then this will copy the data into that byte slice. +// +// If p.Data has a length insufficient to hold the data, +// then this will make a new slice of sufficient length, and copy the data into that. +// +// This means this _does not_ alias any of the data buffer that is passed in. +func (p *WritePacket) UnmarshalPacketBody(buf *Buffer) (err error) { + *p = WritePacket{ + Handle: buf.ConsumeString(), + Offset: buf.ConsumeUint64(), + Data: buf.ConsumeByteSliceCopy(p.Data), + } + + return buf.Err +} + +// FStatPacket defines the SSH_FXP_FSTAT packet. +type FStatPacket struct { + Handle string +} + +// Type returns the SSH_FXP_xy value associated with this packet type. +func (p *FStatPacket) Type() PacketType { + return PacketTypeFStat +} + +// MarshalPacket returns p as a two-part binary encoding of p. +func (p *FStatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Handle) // string(handle) + buf = NewMarshalBuffer(size) + } + + buf.StartPacket(PacketTypeFStat, reqid) + buf.AppendString(p.Handle) + + return buf.Packet(payload) +} + +// UnmarshalPacketBody unmarshals the packet body from the given Buffer. +// It is assumed that the uint32(request-id) has already been consumed. +func (p *FStatPacket) UnmarshalPacketBody(buf *Buffer) (err error) { + *p = FStatPacket{ + Handle: buf.ConsumeString(), + } + + return buf.Err +} + +// FSetstatPacket defines the SSH_FXP_FSETSTAT packet. +type FSetstatPacket struct { + Handle string + Attrs Attributes +} + +// Type returns the SSH_FXP_xy value associated with this packet type. +func (p *FSetstatPacket) Type() PacketType { + return PacketTypeFSetstat +} + +// MarshalPacket returns p as a two-part binary encoding of p. +func (p *FSetstatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Handle) + p.Attrs.Len() // string(handle) + ATTRS(attrs) + buf = NewMarshalBuffer(size) + } + + buf.StartPacket(PacketTypeFSetstat, reqid) + buf.AppendString(p.Handle) + + p.Attrs.MarshalInto(buf) + + return buf.Packet(payload) +} + +// UnmarshalPacketBody unmarshals the packet body from the given Buffer. +// It is assumed that the uint32(request-id) has already been consumed. +func (p *FSetstatPacket) UnmarshalPacketBody(buf *Buffer) (err error) { + *p = FSetstatPacket{ + Handle: buf.ConsumeString(), + } + + return p.Attrs.UnmarshalFrom(buf) +} + +// ReadDirPacket defines the SSH_FXP_READDIR packet. +type ReadDirPacket struct { + Handle string +} + +// Type returns the SSH_FXP_xy value associated with this packet type. +func (p *ReadDirPacket) Type() PacketType { + return PacketTypeReadDir +} + +// MarshalPacket returns p as a two-part binary encoding of p. +func (p *ReadDirPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Handle) // string(handle) + buf = NewMarshalBuffer(size) + } + + buf.StartPacket(PacketTypeReadDir, reqid) + buf.AppendString(p.Handle) + + return buf.Packet(payload) +} + +// UnmarshalPacketBody unmarshals the packet body from the given Buffer. +// It is assumed that the uint32(request-id) has already been consumed. +func (p *ReadDirPacket) UnmarshalPacketBody(buf *Buffer) (err error) { + *p = ReadDirPacket{ + Handle: buf.ConsumeString(), + } + + return buf.Err +} diff --git a/internal/encoding/ssh/filexfer/handle_packets_test.go b/internal/encoding/ssh/filexfer/handle_packets_test.go new file mode 100644 index 00000000..249ff323 --- /dev/null +++ b/internal/encoding/ssh/filexfer/handle_packets_test.go @@ -0,0 +1,282 @@ +package sshfx + +import ( + "bytes" + "testing" +) + +var _ Packet = &ClosePacket{} + +func TestClosePacket(t *testing.T) { + const ( + id = 42 + handle = "somehandle" + ) + + p := &ClosePacket{ + Handle: "somehandle", + } + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 19, + 4, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 10, 's', 'o', 'm', 'e', 'h', 'a', 'n', 'd', 'l', 'e', + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = ClosePacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.Handle != handle { + t.Errorf("UnmarshalPacketBody(): Handle was %q, but expected %q", p.Handle, handle) + } +} + +var _ Packet = &ReadPacket{} + +func TestReadPacket(t *testing.T) { + const ( + id = 42 + handle = "somehandle" + offset uint64 = 0x123456789ABCDEF0 + length uint32 = 0xFEDCBA98 + ) + + p := &ReadPacket{ + Handle: "somehandle", + Offset: offset, + Length: length, + } + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 31, + 5, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 10, 's', 'o', 'm', 'e', 'h', 'a', 'n', 'd', 'l', 'e', + 0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xF0, + 0xFE, 0xDC, 0xBA, 0x98, + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = ReadPacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.Handle != handle { + t.Errorf("UnmarshalPacketBody(): Handle was %q, but expected %q", p.Handle, handle) + } + + if p.Offset != offset { + t.Errorf("UnmarshalPacketBody(): Offset was %x, but expected %x", p.Offset, offset) + } + + if p.Length != length { + t.Errorf("UnmarshalPacketBody(): Length was %x, but expected %x", p.Length, length) + } +} + +var _ Packet = &WritePacket{} + +func TestWritePacket(t *testing.T) { + const ( + id = 42 + handle = "somehandle" + offset uint64 = 0x123456789ABCDEF0 + ) + + var payload = []byte(`foobar`) + + p := &WritePacket{ + Handle: "somehandle", + Offset: offset, + Data: payload, + } + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 37, + 6, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 10, 's', 'o', 'm', 'e', 'h', 'a', 'n', 'd', 'l', 'e', + 0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xF0, + 0x00, 0x00, 0x00, 0x06, 'f', 'o', 'o', 'b', 'a', 'r', + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = WritePacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.Handle != handle { + t.Errorf("UnmarshalPacketBody(): Handle was %q, but expected %q", p.Handle, handle) + } + + if p.Offset != offset { + t.Errorf("UnmarshalPacketBody(): Offset was %x, but expected %x", p.Offset, offset) + } + + if !bytes.Equal(p.Data, payload) { + t.Errorf("UnmarshalPacketBody(): Data was %X, but expected %X", p.Data, payload) + } +} + +var _ Packet = &FStatPacket{} + +func TestFStatPacket(t *testing.T) { + const ( + id = 42 + handle = "somehandle" + ) + + p := &FStatPacket{ + Handle: "somehandle", + } + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 19, + 8, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 10, 's', 'o', 'm', 'e', 'h', 'a', 'n', 'd', 'l', 'e', + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = FStatPacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.Handle != handle { + t.Errorf("UnmarshalPacketBody(): Handle was %q, but expected %q", p.Handle, handle) + } +} + +var _ Packet = &FSetstatPacket{} + +func TestFSetstatPacket(t *testing.T) { + const ( + id = 42 + handle = "somehandle" + perms = 0x87654321 + ) + + p := &FSetstatPacket{ + Handle: "somehandle", + Attrs: Attributes{ + Flags: AttrPermissions, + Permissions: perms, + }, + } + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 27, + 10, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 10, 's', 'o', 'm', 'e', 'h', 'a', 'n', 'd', 'l', 'e', + 0x00, 0x00, 0x00, 0x04, + 0x87, 0x65, 0x43, 0x21, + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = FSetstatPacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.Handle != handle { + t.Errorf("UnmarshalPacketBody(): Handle was %q, but expected %q", p.Handle, handle) + } +} + +var _ Packet = &ReadDirPacket{} + +func TestReadDirPacket(t *testing.T) { + const ( + id = 42 + handle = "somehandle" + ) + + p := &ReadDirPacket{ + Handle: "somehandle", + } + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 19, + 12, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 10, 's', 'o', 'm', 'e', 'h', 'a', 'n', 'd', 'l', 'e', + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = ReadDirPacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.Handle != handle { + t.Errorf("UnmarshalPacketBody(): Handle was %q, but expected %q", p.Handle, handle) + } +} diff --git a/internal/encoding/ssh/filexfer/init_packets.go b/internal/encoding/ssh/filexfer/init_packets.go new file mode 100644 index 00000000..c553ee2e --- /dev/null +++ b/internal/encoding/ssh/filexfer/init_packets.go @@ -0,0 +1,99 @@ +package sshfx + +// InitPacket defines the SSH_FXP_INIT packet. +type InitPacket struct { + Version uint32 + Extensions []*ExtensionPair +} + +// MarshalBinary returns p as the binary encoding of p. +func (p *InitPacket) MarshalBinary() ([]byte, error) { + size := 1 + 4 // byte(type) + uint32(version) + + for _, ext := range p.Extensions { + size += ext.Len() + } + + b := NewBuffer(make([]byte, 4, 4+size)) + b.AppendUint8(uint8(PacketTypeInit)) + b.AppendUint32(p.Version) + + for _, ext := range p.Extensions { + ext.MarshalInto(b) + } + + b.PutLength(size) + + return b.Bytes(), nil +} + +// UnmarshalBinary unmarshals a full raw packet out of the given data. +// It is assumed that the uint32(length) has already been consumed to receive the data. +// It is also assumed that the uint8(type) has already been consumed to which packet to unmarshal into. +func (p *InitPacket) UnmarshalBinary(data []byte) (err error) { + buf := NewBuffer(data) + + *p = InitPacket{ + Version: buf.ConsumeUint32(), + } + + for buf.Len() > 0 { + var ext ExtensionPair + if err := ext.UnmarshalFrom(buf); err != nil { + return err + } + + p.Extensions = append(p.Extensions, &ext) + } + + return buf.Err +} + +// VersionPacket defines the SSH_FXP_VERSION packet. +type VersionPacket struct { + Version uint32 + Extensions []*ExtensionPair +} + +// MarshalBinary returns p as the binary encoding of p. +func (p *VersionPacket) MarshalBinary() ([]byte, error) { + size := 1 + 4 // byte(type) + uint32(version) + + for _, ext := range p.Extensions { + size += ext.Len() + } + + b := NewBuffer(make([]byte, 4, 4+size)) + b.AppendUint8(uint8(PacketTypeVersion)) + b.AppendUint32(p.Version) + + for _, ext := range p.Extensions { + ext.MarshalInto(b) + } + + b.PutLength(size) + + return b.Bytes(), nil +} + +// UnmarshalBinary unmarshals a full raw packet out of the given data. +// It is assumed that the uint32(length) has already been consumed to receive the data. +// It is also assumed that the uint8(type) has already been consumed to which packet to unmarshal into. +func (p *VersionPacket) UnmarshalBinary(data []byte) (err error) { + buf := NewBuffer(data) + + *p = VersionPacket{ + Version: buf.ConsumeUint32(), + } + + for buf.Len() > 0 { + var ext ExtensionPair + if err := ext.UnmarshalFrom(buf); err != nil { + return err + } + + p.Extensions = append(p.Extensions, &ext) + } + + return nil +} diff --git a/internal/encoding/ssh/filexfer/init_packets_test.go b/internal/encoding/ssh/filexfer/init_packets_test.go new file mode 100644 index 00000000..0c466bf1 --- /dev/null +++ b/internal/encoding/ssh/filexfer/init_packets_test.go @@ -0,0 +1,114 @@ +package sshfx + +import ( + "bytes" + "testing" +) + +func TestInitPacket(t *testing.T) { + var version uint8 = 3 + + p := &InitPacket{ + Version: uint32(version), + Extensions: []*ExtensionPair{ + { + Name: "foo", + Data: "1", + }, + }, + } + + buf, err := p.MarshalBinary() + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 17, + 1, + 0x00, 0x00, 0x00, version, + 0x00, 0x00, 0x00, 3, 'f', 'o', 'o', + 0x00, 0x00, 0x00, 1, '1', + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalBinary() = %X, but wanted %X", buf, want) + } + + *p = InitPacket{} + + // UnmarshalBinary assumes the uint32(length) + uint8(type) have already been consumed. + if err := p.UnmarshalBinary(buf[5:]); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.Version != uint32(version) { + t.Errorf("UnmarshalBinary(): Version was %d, but expected %d", p.Version, version) + } + + if len(p.Extensions) != 1 { + t.Fatalf("UnmarshalBinary(): len(p.Extensions) was %d, but expected %d", len(p.Extensions), 1) + } + + if got, want := p.Extensions[0].Name, "foo"; got != want { + t.Errorf("UnmarshalBinary(): p.Extensions[0].Name was %q, but expected %q", got, want) + } + + if got, want := p.Extensions[0].Data, "1"; got != want { + t.Errorf("UnmarshalBinary(): p.Extensions[0].Data was %q, but expected %q", got, want) + } +} + +func TestVersionPacket(t *testing.T) { + var version uint8 = 3 + + p := &VersionPacket{ + Version: uint32(version), + Extensions: []*ExtensionPair{ + { + Name: "foo", + Data: "1", + }, + }, + } + + buf, err := p.MarshalBinary() + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 17, + 2, + 0x00, 0x00, 0x00, version, + 0x00, 0x00, 0x00, 3, 'f', 'o', 'o', + 0x00, 0x00, 0x00, 1, '1', + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalBinary() = %X, but wanted %X", buf, want) + } + + *p = VersionPacket{} + + // UnmarshalBinary assumes the uint32(length) + uint8(type) have already been consumed. + if err := p.UnmarshalBinary(buf[5:]); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.Version != uint32(version) { + t.Errorf("UnmarshalBinary(): Version was %d, but expected %d", p.Version, version) + } + + if len(p.Extensions) != 1 { + t.Fatalf("UnmarshalBinary(): len(p.Extensions) was %d, but expected %d", len(p.Extensions), 1) + } + + if got, want := p.Extensions[0].Name, "foo"; got != want { + t.Errorf("UnmarshalBinary(): p.Extensions[0].Name was %q, but expected %q", got, want) + } + + if got, want := p.Extensions[0].Data, "1"; got != want { + t.Errorf("UnmarshalBinary(): p.Extensions[0].Data was %q, but expected %q", got, want) + } +} diff --git a/internal/encoding/ssh/filexfer/open_packets.go b/internal/encoding/ssh/filexfer/open_packets.go new file mode 100644 index 00000000..896ba16e --- /dev/null +++ b/internal/encoding/ssh/filexfer/open_packets.go @@ -0,0 +1,86 @@ +package sshfx + +// SSH_FXF_* flags. +const ( + FlagRead = 1 << iota // SSH_FXF_READ + FlagWrite // SSH_FXF_WRITE + FlagAppend // SSH_FXF_APPEND + FlagCreate // SSH_FXF_CREAT + FlagTruncate // SSH_FXF_TRUNC + FlagExclusive // SSH_FXF_EXCL +) + +// OpenPacket defines the SSH_FXP_OPEN packet. +type OpenPacket struct { + Filename string + PFlags uint32 + Attrs Attributes +} + +// Type returns the SSH_FXP_xy value associated with this packet type. +func (p *OpenPacket) Type() PacketType { + return PacketTypeOpen +} + +// MarshalPacket returns p as a two-part binary encoding of p. +func (p *OpenPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + // string(filename) + uint32(pflags) + ATTRS(attrs) + size := 4 + len(p.Filename) + 4 + p.Attrs.Len() + buf = NewMarshalBuffer(size) + } + + buf.StartPacket(PacketTypeOpen, reqid) + buf.AppendString(p.Filename) + buf.AppendUint32(p.PFlags) + + p.Attrs.MarshalInto(buf) + + return buf.Packet(payload) +} + +// UnmarshalPacketBody unmarshals the packet body from the given Buffer. +// It is assumed that the uint32(request-id) has already been consumed. +func (p *OpenPacket) UnmarshalPacketBody(buf *Buffer) (err error) { + *p = OpenPacket{ + Filename: buf.ConsumeString(), + PFlags: buf.ConsumeUint32(), + } + + return p.Attrs.UnmarshalFrom(buf) +} + +// OpenDirPacket defines the SSH_FXP_OPENDIR packet. +type OpenDirPacket struct { + Path string +} + +// Type returns the SSH_FXP_xy value associated with this packet type. +func (p *OpenDirPacket) Type() PacketType { + return PacketTypeOpenDir +} + +// MarshalPacket returns p as a two-part binary encoding of p. +func (p *OpenDirPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Path) // string(path) + buf = NewMarshalBuffer(size) + } + + buf.StartPacket(PacketTypeOpenDir, reqid) + buf.AppendString(p.Path) + + return buf.Packet(payload) +} + +// UnmarshalPacketBody unmarshals the packet body from the given Buffer. +// It is assumed that the uint32(request-id) has already been consumed. +func (p *OpenDirPacket) UnmarshalPacketBody(buf *Buffer) (err error) { + *p = OpenDirPacket{ + Path: buf.ConsumeString(), + } + + return buf.Err +} diff --git a/internal/encoding/ssh/filexfer/open_packets_test.go b/internal/encoding/ssh/filexfer/open_packets_test.go new file mode 100644 index 00000000..a792bf45 --- /dev/null +++ b/internal/encoding/ssh/filexfer/open_packets_test.go @@ -0,0 +1,107 @@ +package sshfx + +import ( + "bytes" + "testing" +) + +var _ Packet = &OpenPacket{} + +func TestOpenPacket(t *testing.T) { + const ( + id = 42 + filename = "/foo" + perms FileMode = 0x87654321 + ) + + p := &OpenPacket{ + Filename: "/foo", + PFlags: FlagRead, + Attrs: Attributes{ + Flags: AttrPermissions, + Permissions: perms, + }, + } + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 25, + 3, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o', + 0x00, 0x00, 0x00, 1, + 0x00, 0x00, 0x00, 0x04, + 0x87, 0x65, 0x43, 0x21, + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = OpenPacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.Filename != filename { + t.Errorf("UnmarshalPacketBody(): Filename was %q, but expected %q", p.Filename, filename) + } + + if p.PFlags != FlagRead { + t.Errorf("UnmarshalPacketBody(): PFlags was %#x, but expected %#x", p.PFlags, FlagRead) + } + + if p.Attrs.Flags != AttrPermissions { + t.Errorf("UnmarshalPacketBody(): Attrs.Flags was %#x, but expected %#x", p.Attrs.Flags, AttrPermissions) + } + + if p.Attrs.Permissions != perms { + t.Errorf("UnmarshalPacketBody(): Attrs.Permissions was %#v, but expected %#v", p.Attrs.Permissions, perms) + } +} + +var _ Packet = &OpenDirPacket{} + +func TestOpenDirPacket(t *testing.T) { + const ( + id = 42 + path = "/foo" + ) + + p := &OpenDirPacket{ + Path: path, + } + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 13, + 11, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o', + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = OpenDirPacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.Path != path { + t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", p.Path, path) + } +} diff --git a/internal/encoding/ssh/filexfer/openssh/fsync.go b/internal/encoding/ssh/filexfer/openssh/fsync.go new file mode 100644 index 00000000..708a4ba7 --- /dev/null +++ b/internal/encoding/ssh/filexfer/openssh/fsync.go @@ -0,0 +1,73 @@ +package openssh + +import ( + sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer" +) + +const extensionFSync = "fsync@openssh.com" + +// RegisterExtensionFSync registers the "fsync@openssh.com" extended packet with the encoding/ssh/filexfer package. +func RegisterExtensionFSync() { + sshfx.RegisterExtendedPacketType(extensionFSync, func() sshfx.ExtendedData { + return new(FSyncExtendedPacket) + }) +} + +// ExtensionFSync returns an ExtensionPair suitable to append into an sshfx.InitPacket or sshfx.VersionPacket. +func ExtensionFSync() *sshfx.ExtensionPair { + return &sshfx.ExtensionPair{ + Name: extensionFSync, + Data: "1", + } +} + +// FSyncExtendedPacket defines the fsync@openssh.com extend packet. +type FSyncExtendedPacket struct { + Handle string +} + +// Type returns the SSH_FXP_EXTENDED packet type. +func (ep *FSyncExtendedPacket) Type() sshfx.PacketType { + return sshfx.PacketTypeExtended +} + +// MarshalPacket returns ep as a two-part binary encoding of the full extended packet. +func (ep *FSyncExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + p := &sshfx.ExtendedPacket{ + ExtendedRequest: extensionFSync, + + Data: ep, + } + return p.MarshalPacket(reqid, b) +} + +// MarshalInto encodes ep into the binary encoding of the fsync@openssh.com extended packet-specific data. +func (ep *FSyncExtendedPacket) MarshalInto(buf *sshfx.Buffer) { + buf.AppendString(ep.Handle) +} + +// MarshalBinary encodes ep into the binary encoding of the fsync@openssh.com extended packet-specific data. +// +// NOTE: This _only_ encodes the packet-specific data, it does not encode the full extended packet. +func (ep *FSyncExtendedPacket) MarshalBinary() ([]byte, error) { + // string(handle) + size := 4 + len(ep.Handle) + + buf := sshfx.NewBuffer(make([]byte, 0, size)) + ep.MarshalInto(buf) + return buf.Bytes(), nil +} + +// UnmarshalFrom decodes the fsync@openssh.com extended packet-specific data from buf. +func (ep *FSyncExtendedPacket) UnmarshalFrom(buf *sshfx.Buffer) (err error) { + *ep = FSyncExtendedPacket{ + Handle: buf.ConsumeString(), + } + + return buf.Err +} + +// UnmarshalBinary decodes the fsync@openssh.com extended packet-specific data into ep. +func (ep *FSyncExtendedPacket) UnmarshalBinary(data []byte) (err error) { + return ep.UnmarshalFrom(sshfx.NewBuffer(data)) +} diff --git a/internal/encoding/ssh/filexfer/openssh/fsync_test.go b/internal/encoding/ssh/filexfer/openssh/fsync_test.go new file mode 100644 index 00000000..f9e878fb --- /dev/null +++ b/internal/encoding/ssh/filexfer/openssh/fsync_test.go @@ -0,0 +1,62 @@ +package openssh + +import ( + "bytes" + "testing" + + sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer" +) + +var _ sshfx.PacketMarshaller = &FSyncExtendedPacket{} + +func init() { + RegisterExtensionFSync() +} + +func TestFSyncExtendedPacket(t *testing.T) { + const ( + id = 42 + handle = "somehandle" + ) + + ep := &FSyncExtendedPacket{ + Handle: handle, + } + + data, err := sshfx.ComposePacket(ep.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 40, + 200, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 17, 'f', 's', 'y', 'n', 'c', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm', + 0x00, 0x00, 0x00, 10, 's', 'o', 'm', 'e', 'h', 'a', 'n', 'd', 'l', 'e', + } + + if !bytes.Equal(data, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", data, want) + } + + var p sshfx.ExtendedPacket + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(sshfx.NewBuffer(data[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.ExtendedRequest != extensionFSync { + t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extensionFSync) + } + + ep, ok := p.Data.(*FSyncExtendedPacket) + if !ok { + t.Fatalf("UnmarshaledPacketBody(): Data was type %T, but expected *FSyncExtendedPacket", p.Data) + } + + if ep.Handle != handle { + t.Errorf("UnmarshalPacketBody(): Handle was %q, but expected %q", ep.Handle, handle) + } +} diff --git a/internal/encoding/ssh/filexfer/openssh/hardlink.go b/internal/encoding/ssh/filexfer/openssh/hardlink.go new file mode 100644 index 00000000..f48d25a2 --- /dev/null +++ b/internal/encoding/ssh/filexfer/openssh/hardlink.go @@ -0,0 +1,76 @@ +package openssh + +import ( + sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer" +) + +const extensionHardlink = "hardlink@openssh.com" + +// RegisterExtensionHardlink registers the "hardlink@openssh.com" extended packet with the encoding/ssh/filexfer package. +func RegisterExtensionHardlink() { + sshfx.RegisterExtendedPacketType(extensionHardlink, func() sshfx.ExtendedData { + return new(HardlinkExtendedPacket) + }) +} + +// ExtensionHardlink returns an ExtensionPair suitable to append into an sshfx.InitPacket or sshfx.VersionPacket. +func ExtensionHardlink() *sshfx.ExtensionPair { + return &sshfx.ExtensionPair{ + Name: extensionHardlink, + Data: "1", + } +} + +// HardlinkExtendedPacket defines the hardlink@openssh.com extend packet. +type HardlinkExtendedPacket struct { + OldPath string + NewPath string +} + +// Type returns the SSH_FXP_EXTENDED packet type. +func (ep *HardlinkExtendedPacket) Type() sshfx.PacketType { + return sshfx.PacketTypeExtended +} + +// MarshalPacket returns ep as a two-part binary encoding of the full extended packet. +func (ep *HardlinkExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + p := &sshfx.ExtendedPacket{ + ExtendedRequest: extensionHardlink, + + Data: ep, + } + return p.MarshalPacket(reqid, b) +} + +// MarshalInto encodes ep into the binary encoding of the hardlink@openssh.com extended packet-specific data. +func (ep *HardlinkExtendedPacket) MarshalInto(buf *sshfx.Buffer) { + buf.AppendString(ep.OldPath) + buf.AppendString(ep.NewPath) +} + +// MarshalBinary encodes ep into the binary encoding of the hardlink@openssh.com extended packet-specific data. +// +// NOTE: This _only_ encodes the packet-specific data, it does not encode the full extended packet. +func (ep *HardlinkExtendedPacket) MarshalBinary() ([]byte, error) { + // string(oldpath) + string(newpath) + size := 4 + len(ep.OldPath) + 4 + len(ep.NewPath) + + buf := sshfx.NewBuffer(make([]byte, 0, size)) + ep.MarshalInto(buf) + return buf.Bytes(), nil +} + +// UnmarshalFrom decodes the hardlink@openssh.com extended packet-specific data from buf. +func (ep *HardlinkExtendedPacket) UnmarshalFrom(buf *sshfx.Buffer) (err error) { + *ep = HardlinkExtendedPacket{ + OldPath: buf.ConsumeString(), + NewPath: buf.ConsumeString(), + } + + return buf.Err +} + +// UnmarshalBinary decodes the hardlink@openssh.com extended packet-specific data into ep. +func (ep *HardlinkExtendedPacket) UnmarshalBinary(data []byte) (err error) { + return ep.UnmarshalFrom(sshfx.NewBuffer(data)) +} diff --git a/internal/encoding/ssh/filexfer/openssh/hardlink_test.go b/internal/encoding/ssh/filexfer/openssh/hardlink_test.go new file mode 100644 index 00000000..5d3be06b --- /dev/null +++ b/internal/encoding/ssh/filexfer/openssh/hardlink_test.go @@ -0,0 +1,69 @@ +package openssh + +import ( + "bytes" + "testing" + + sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer" +) + +var _ sshfx.PacketMarshaller = &HardlinkExtendedPacket{} + +func init() { + RegisterExtensionHardlink() +} + +func TestHardlinkExtendedPacket(t *testing.T) { + const ( + id = 42 + oldpath = "/foo" + newpath = "/bar" + ) + + ep := &HardlinkExtendedPacket{ + OldPath: oldpath, + NewPath: newpath, + } + + data, err := sshfx.ComposePacket(ep.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 45, + 200, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 20, 'h', 'a', 'r', 'd', 'l', 'i', 'n', 'k', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm', + 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o', + 0x00, 0x00, 0x00, 4, '/', 'b', 'a', 'r', + } + + if !bytes.Equal(data, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", data, want) + } + + var p sshfx.ExtendedPacket + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(sshfx.NewBuffer(data[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.ExtendedRequest != extensionHardlink { + t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extensionHardlink) + } + + ep, ok := p.Data.(*HardlinkExtendedPacket) + if !ok { + t.Fatalf("UnmarshaledPacketBody(): Data was type %T, but expected *HardlinkExtendedPacket", p.Data) + } + + if ep.OldPath != oldpath { + t.Errorf("UnmarshalPacketBody(): OldPath was %q, but expected %q", ep.OldPath, oldpath) + } + + if ep.NewPath != newpath { + t.Errorf("UnmarshalPacketBody(): NewPath was %q, but expected %q", ep.NewPath, newpath) + } +} diff --git a/internal/encoding/ssh/filexfer/openssh/openssh.go b/internal/encoding/ssh/filexfer/openssh/openssh.go new file mode 100644 index 00000000..f93ff177 --- /dev/null +++ b/internal/encoding/ssh/filexfer/openssh/openssh.go @@ -0,0 +1,2 @@ +// Package openssh implements the openssh secsh-filexfer extensions as described in https://github.com/openssh/openssh-portable/blob/master/PROTOCOL +package openssh diff --git a/internal/encoding/ssh/filexfer/openssh/posix-rename.go b/internal/encoding/ssh/filexfer/openssh/posix-rename.go new file mode 100644 index 00000000..5166489c --- /dev/null +++ b/internal/encoding/ssh/filexfer/openssh/posix-rename.go @@ -0,0 +1,76 @@ +package openssh + +import ( + sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer" +) + +const extensionPOSIXRename = "posix-rename@openssh.com" + +// RegisterExtensionPOSIXRename registers the "posix-rename@openssh.com" extended packet with the encoding/ssh/filexfer package. +func RegisterExtensionPOSIXRename() { + sshfx.RegisterExtendedPacketType(extensionPOSIXRename, func() sshfx.ExtendedData { + return new(POSIXRenameExtendedPacket) + }) +} + +// ExtensionPOSIXRename returns an ExtensionPair suitable to append into an sshfx.InitPacket or sshfx.VersionPacket. +func ExtensionPOSIXRename() *sshfx.ExtensionPair { + return &sshfx.ExtensionPair{ + Name: extensionPOSIXRename, + Data: "1", + } +} + +// POSIXRenameExtendedPacket defines the posix-rename@openssh.com extend packet. +type POSIXRenameExtendedPacket struct { + OldPath string + NewPath string +} + +// Type returns the SSH_FXP_EXTENDED packet type. +func (ep *POSIXRenameExtendedPacket) Type() sshfx.PacketType { + return sshfx.PacketTypeExtended +} + +// MarshalPacket returns ep as a two-part binary encoding of the full extended packet. +func (ep *POSIXRenameExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + p := &sshfx.ExtendedPacket{ + ExtendedRequest: extensionPOSIXRename, + + Data: ep, + } + return p.MarshalPacket(reqid, b) +} + +// MarshalInto encodes ep into the binary encoding of the hardlink@openssh.com extended packet-specific data. +func (ep *POSIXRenameExtendedPacket) MarshalInto(buf *sshfx.Buffer) { + buf.AppendString(ep.OldPath) + buf.AppendString(ep.NewPath) +} + +// MarshalBinary encodes ep into the binary encoding of the hardlink@openssh.com extended packet-specific data. +// +// NOTE: This _only_ encodes the packet-specific data, it does not encode the full extended packet. +func (ep *POSIXRenameExtendedPacket) MarshalBinary() ([]byte, error) { + // string(oldpath) + string(newpath) + size := 4 + len(ep.OldPath) + 4 + len(ep.NewPath) + + buf := sshfx.NewBuffer(make([]byte, 0, size)) + ep.MarshalInto(buf) + return buf.Bytes(), nil +} + +// UnmarshalFrom decodes the hardlink@openssh.com extended packet-specific data from buf. +func (ep *POSIXRenameExtendedPacket) UnmarshalFrom(buf *sshfx.Buffer) (err error) { + *ep = POSIXRenameExtendedPacket{ + OldPath: buf.ConsumeString(), + NewPath: buf.ConsumeString(), + } + + return buf.Err +} + +// UnmarshalBinary decodes the hardlink@openssh.com extended packet-specific data into ep. +func (ep *POSIXRenameExtendedPacket) UnmarshalBinary(data []byte) (err error) { + return ep.UnmarshalFrom(sshfx.NewBuffer(data)) +} diff --git a/internal/encoding/ssh/filexfer/openssh/posix-rename_test.go b/internal/encoding/ssh/filexfer/openssh/posix-rename_test.go new file mode 100644 index 00000000..58e4fffb --- /dev/null +++ b/internal/encoding/ssh/filexfer/openssh/posix-rename_test.go @@ -0,0 +1,69 @@ +package openssh + +import ( + "bytes" + "testing" + + sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer" +) + +var _ sshfx.PacketMarshaller = &POSIXRenameExtendedPacket{} + +func init() { + RegisterExtensionPOSIXRename() +} + +func TestPOSIXRenameExtendedPacket(t *testing.T) { + const ( + id = 42 + oldpath = "/foo" + newpath = "/bar" + ) + + ep := &POSIXRenameExtendedPacket{ + OldPath: oldpath, + NewPath: newpath, + } + + data, err := sshfx.ComposePacket(ep.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 49, + 200, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 24, 'p', 'o', 's', 'i', 'x', '-', 'r', 'e', 'n', 'a', 'm', 'e', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm', + 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o', + 0x00, 0x00, 0x00, 4, '/', 'b', 'a', 'r', + } + + if !bytes.Equal(data, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", data, want) + } + + var p sshfx.ExtendedPacket + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(sshfx.NewBuffer(data[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.ExtendedRequest != extensionPOSIXRename { + t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extensionPOSIXRename) + } + + ep, ok := p.Data.(*POSIXRenameExtendedPacket) + if !ok { + t.Fatalf("UnmarshaledPacketBody(): Data was type %T, but expected *POSIXRenameExtendedPacket", p.Data) + } + + if ep.OldPath != oldpath { + t.Errorf("UnmarshalPacketBody(): OldPath was %q, but expected %q", ep.OldPath, oldpath) + } + + if ep.NewPath != newpath { + t.Errorf("UnmarshalPacketBody(): NewPath was %q, but expected %q", ep.NewPath, newpath) + } +} diff --git a/internal/encoding/ssh/filexfer/openssh/statvfs.go b/internal/encoding/ssh/filexfer/openssh/statvfs.go new file mode 100644 index 00000000..51029ca0 --- /dev/null +++ b/internal/encoding/ssh/filexfer/openssh/statvfs.go @@ -0,0 +1,236 @@ +package openssh + +import ( + sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer" +) + +const extensionStatVFS = "statvfs@openssh.com" + +// RegisterExtensionStatVFS registers the "statvfs@openssh.com" extended packet with the encoding/ssh/filexfer package. +func RegisterExtensionStatVFS() { + sshfx.RegisterExtendedPacketType(extensionStatVFS, func() sshfx.ExtendedData { + return new(StatVFSExtendedPacket) + }) +} + +// ExtensionStatVFS returns an ExtensionPair suitable to append into an sshfx.InitPacket or sshfx.VersionPacket. +func ExtensionStatVFS() *sshfx.ExtensionPair { + return &sshfx.ExtensionPair{ + Name: extensionStatVFS, + Data: "2", + } +} + +// StatVFSExtendedPacket defines the statvfs@openssh.com extend packet. +type StatVFSExtendedPacket struct { + Path string +} + +// Type returns the SSH_FXP_EXTENDED packet type. +func (ep *StatVFSExtendedPacket) Type() sshfx.PacketType { + return sshfx.PacketTypeExtended +} + +// MarshalPacket returns ep as a two-part binary encoding of the full extended packet. +func (ep *StatVFSExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + p := &sshfx.ExtendedPacket{ + ExtendedRequest: extensionStatVFS, + + Data: ep, + } + return p.MarshalPacket(reqid, b) +} + +// MarshalInto encodes ep into the binary encoding of the statvfs@openssh.com extended packet-specific data. +func (ep *StatVFSExtendedPacket) MarshalInto(buf *sshfx.Buffer) { + buf.AppendString(ep.Path) +} + +// MarshalBinary encodes ep into the binary encoding of the statvfs@openssh.com extended packet-specific data. +// +// NOTE: This _only_ encodes the packet-specific data, it does not encode the full extended packet. +func (ep *StatVFSExtendedPacket) MarshalBinary() ([]byte, error) { + size := 4 + len(ep.Path) // string(path) + + buf := sshfx.NewBuffer(make([]byte, 0, size)) + + ep.MarshalInto(buf) + + return buf.Bytes(), nil +} + +// UnmarshalFrom decodes the statvfs@openssh.com extended packet-specific data into ep. +func (ep *StatVFSExtendedPacket) UnmarshalFrom(buf *sshfx.Buffer) (err error) { + *ep = StatVFSExtendedPacket{ + Path: buf.ConsumeString(), + } + + return buf.Err +} + +// UnmarshalBinary decodes the statvfs@openssh.com extended packet-specific data into ep. +func (ep *StatVFSExtendedPacket) UnmarshalBinary(data []byte) (err error) { + return ep.UnmarshalFrom(sshfx.NewBuffer(data)) +} + +const extensionFStatVFS = "fstatvfs@openssh.com" + +// RegisterExtensionFStatVFS registers the "fstatvfs@openssh.com" extended packet with the encoding/ssh/filexfer package. +func RegisterExtensionFStatVFS() { + sshfx.RegisterExtendedPacketType(extensionFStatVFS, func() sshfx.ExtendedData { + return new(FStatVFSExtendedPacket) + }) +} + +// ExtensionFStatVFS returns an ExtensionPair suitable to append into an sshfx.InitPacket or sshfx.VersionPacket. +func ExtensionFStatVFS() *sshfx.ExtensionPair { + return &sshfx.ExtensionPair{ + Name: extensionFStatVFS, + Data: "2", + } +} + +// FStatVFSExtendedPacket defines the fstatvfs@openssh.com extend packet. +type FStatVFSExtendedPacket struct { + Path string +} + +// Type returns the SSH_FXP_EXTENDED packet type. +func (ep *FStatVFSExtendedPacket) Type() sshfx.PacketType { + return sshfx.PacketTypeExtended +} + +// MarshalPacket returns ep as a two-part binary encoding of the full extended packet. +func (ep *FStatVFSExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + p := &sshfx.ExtendedPacket{ + ExtendedRequest: extensionFStatVFS, + + Data: ep, + } + return p.MarshalPacket(reqid, b) +} + +// MarshalInto encodes ep into the binary encoding of the statvfs@openssh.com extended packet-specific data. +func (ep *FStatVFSExtendedPacket) MarshalInto(buf *sshfx.Buffer) { + buf.AppendString(ep.Path) +} + +// MarshalBinary encodes ep into the binary encoding of the statvfs@openssh.com extended packet-specific data. +// +// NOTE: This _only_ encodes the packet-specific data, it does not encode the full extended packet. +func (ep *FStatVFSExtendedPacket) MarshalBinary() ([]byte, error) { + size := 4 + len(ep.Path) // string(path) + + buf := sshfx.NewBuffer(make([]byte, 0, size)) + + ep.MarshalInto(buf) + + return buf.Bytes(), nil +} + +// UnmarshalFrom decodes the statvfs@openssh.com extended packet-specific data into ep. +func (ep *FStatVFSExtendedPacket) UnmarshalFrom(buf *sshfx.Buffer) (err error) { + *ep = FStatVFSExtendedPacket{ + Path: buf.ConsumeString(), + } + + return buf.Err +} + +// UnmarshalBinary decodes the statvfs@openssh.com extended packet-specific data into ep. +func (ep *FStatVFSExtendedPacket) UnmarshalBinary(data []byte) (err error) { + return ep.UnmarshalFrom(sshfx.NewBuffer(data)) +} + +// The values for the MountFlags field. +// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL +const ( + MountFlagsReadOnly = 0x1 // SSH_FXE_STATVFS_ST_RDONLY + MountFlagsNoSUID = 0x2 // SSH_FXE_STATVFS_ST_NOSUID +) + +// StatVFSExtendedReplyPacket defines the extended reply packet for statvfs@openssh.com and fstatvfs@openssh.com requests. +type StatVFSExtendedReplyPacket struct { + BlockSize uint64 /* f_bsize: file system block size */ + FragmentSize uint64 /* f_frsize: fundamental fs block size / fagment size */ + Blocks uint64 /* f_blocks: number of blocks (unit f_frsize) */ + BlocksFree uint64 /* f_bfree: free blocks in filesystem */ + BlocksAvail uint64 /* f_bavail: free blocks for non-root */ + Files uint64 /* f_files: total file inodes */ + FilesFree uint64 /* f_ffree: free file inodes */ + FilesAvail uint64 /* f_favail: free file inodes for to non-root */ + FilesystemID uint64 /* f_fsid: file system id */ + MountFlags uint64 /* f_flag: bit mask of mount flag values */ + MaxNameLength uint64 /* f_namemax: maximum filename length */ +} + +// Type returns the SSH_FXP_EXTENDED_REPLY packet type. +func (ep *StatVFSExtendedReplyPacket) Type() sshfx.PacketType { + return sshfx.PacketTypeExtendedReply +} + +// MarshalPacket returns ep as a two-part binary encoding of the full extended reply packet. +func (ep *StatVFSExtendedReplyPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + p := &sshfx.ExtendedReplyPacket{ + Data: ep, + } + return p.MarshalPacket(reqid, b) +} + +// UnmarshalPacketBody returns ep as a two-part binary encoding of the full extended reply packet. +func (ep *StatVFSExtendedReplyPacket) UnmarshalPacketBody(buf *sshfx.Buffer) (err error) { + p := &sshfx.ExtendedReplyPacket{ + Data: ep, + } + return p.UnmarshalPacketBody(buf) +} + +// MarshalInto encodes ep into the binary encoding of the (f)statvfs@openssh.com extended reply packet-specific data. +func (ep *StatVFSExtendedReplyPacket) MarshalInto(buf *sshfx.Buffer) { + buf.AppendUint64(ep.BlockSize) + buf.AppendUint64(ep.FragmentSize) + buf.AppendUint64(ep.Blocks) + buf.AppendUint64(ep.BlocksFree) + buf.AppendUint64(ep.BlocksAvail) + buf.AppendUint64(ep.Files) + buf.AppendUint64(ep.FilesFree) + buf.AppendUint64(ep.FilesAvail) + buf.AppendUint64(ep.FilesystemID) + buf.AppendUint64(ep.MountFlags) + buf.AppendUint64(ep.MaxNameLength) +} + +// MarshalBinary encodes ep into the binary encoding of the (f)statvfs@openssh.com extended reply packet-specific data. +// +// NOTE: This _only_ encodes the packet-specific data, it does not encode the full extended reply packet. +func (ep *StatVFSExtendedReplyPacket) MarshalBinary() ([]byte, error) { + size := 11 * 8 // 11 × uint64(various) + + b := sshfx.NewBuffer(make([]byte, 0, size)) + ep.MarshalInto(b) + return b.Bytes(), nil +} + +// UnmarshalFrom decodes the fstatvfs@openssh.com extended reply packet-specific data into ep. +func (ep *StatVFSExtendedReplyPacket) UnmarshalFrom(buf *sshfx.Buffer) (err error) { + *ep = StatVFSExtendedReplyPacket{ + BlockSize: buf.ConsumeUint64(), + FragmentSize: buf.ConsumeUint64(), + Blocks: buf.ConsumeUint64(), + BlocksFree: buf.ConsumeUint64(), + BlocksAvail: buf.ConsumeUint64(), + Files: buf.ConsumeUint64(), + FilesFree: buf.ConsumeUint64(), + FilesAvail: buf.ConsumeUint64(), + FilesystemID: buf.ConsumeUint64(), + MountFlags: buf.ConsumeUint64(), + MaxNameLength: buf.ConsumeUint64(), + } + + return buf.Err +} + +// UnmarshalBinary decodes the fstatvfs@openssh.com extended reply packet-specific data into ep. +func (ep *StatVFSExtendedReplyPacket) UnmarshalBinary(data []byte) (err error) { + return ep.UnmarshalFrom(sshfx.NewBuffer(data)) +} diff --git a/internal/encoding/ssh/filexfer/openssh/statvfs_test.go b/internal/encoding/ssh/filexfer/openssh/statvfs_test.go new file mode 100644 index 00000000..014aa637 --- /dev/null +++ b/internal/encoding/ssh/filexfer/openssh/statvfs_test.go @@ -0,0 +1,239 @@ +package openssh + +import ( + "bytes" + "testing" + + sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer" +) + +var _ sshfx.PacketMarshaller = &StatVFSExtendedPacket{} + +func init() { + RegisterExtensionStatVFS() +} + +func TestStatVFSExtendedPacket(t *testing.T) { + const ( + id = 42 + path = "/foo" + ) + + ep := &StatVFSExtendedPacket{ + Path: path, + } + + data, err := sshfx.ComposePacket(ep.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 36, + 200, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 19, 's', 't', 'a', 't', 'v', 'f', 's', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm', + 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o', + } + + if !bytes.Equal(data, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", data, want) + } + + var p sshfx.ExtendedPacket + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(sshfx.NewBuffer(data[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.ExtendedRequest != extensionStatVFS { + t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extensionStatVFS) + } + + ep, ok := p.Data.(*StatVFSExtendedPacket) + if !ok { + t.Fatalf("UnmarshaledPacketBody(): Data was type %T, but expected *StatVFSExtendedPacket", p.Data) + } + + if ep.Path != path { + t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", ep.Path, path) + } +} + +var _ sshfx.PacketMarshaller = &FStatVFSExtendedPacket{} + +func init() { + RegisterExtensionFStatVFS() +} + +func TestFStatVFSExtendedPacket(t *testing.T) { + const ( + id = 42 + path = "/foo" + ) + + ep := &FStatVFSExtendedPacket{ + Path: path, + } + + data, err := sshfx.ComposePacket(ep.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 37, + 200, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 20, 'f', 's', 't', 'a', 't', 'v', 'f', 's', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm', + 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o', + } + + if !bytes.Equal(data, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", data, want) + } + + var p sshfx.ExtendedPacket + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(sshfx.NewBuffer(data[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.ExtendedRequest != extensionFStatVFS { + t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extensionFStatVFS) + } + + ep, ok := p.Data.(*FStatVFSExtendedPacket) + if !ok { + t.Fatalf("UnmarshaledPacketBody(): Data was type %T, but expected *FStatVFSExtendedPacket", p.Data) + } + + if ep.Path != path { + t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", ep.Path, path) + } +} + +var _ sshfx.Packet = &StatVFSExtendedReplyPacket{} + +func TestStatVFSExtendedReplyPacket(t *testing.T) { + const ( + id = 42 + path = "/foo" + ) + + const ( + BlockSize = uint64(iota + 13) + FragmentSize + Blocks + BlocksFree + BlocksAvail + Files + FilesFree + FilesAvail + FilesystemID + MountFlags + MaxNameLength + ) + + ep := &StatVFSExtendedReplyPacket{ + BlockSize: BlockSize, + FragmentSize: FragmentSize, + Blocks: Blocks, + BlocksFree: BlocksFree, + BlocksAvail: BlocksAvail, + Files: Files, + FilesFree: FilesFree, + FilesAvail: FilesAvail, + FilesystemID: FilesystemID, + MountFlags: MountFlags, + MaxNameLength: MaxNameLength, + } + + data, err := sshfx.ComposePacket(ep.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 93, + 201, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 13, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 14, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 15, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 16, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 17, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 18, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 19, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 21, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 22, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 23, + } + + if !bytes.Equal(data, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", data, want) + } + + *ep = StatVFSExtendedReplyPacket{} + + p := sshfx.ExtendedReplyPacket{ + Data: ep, + } + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(sshfx.NewBuffer(data[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + ep, ok := p.Data.(*StatVFSExtendedReplyPacket) + if !ok { + t.Fatalf("UnmarshaledPacketBody(): Data was type %T, but expected *StatVFSExtendedReplyPacket", p.Data) + } + + if ep.BlockSize != BlockSize { + t.Errorf("UnmarshalPacketBody(): BlockSize was %d, but expected %d", ep.BlockSize, BlockSize) + } + + if ep.FragmentSize != FragmentSize { + t.Errorf("UnmarshalPacketBody(): FragmentSize was %d, but expected %d", ep.FragmentSize, FragmentSize) + } + + if ep.Blocks != Blocks { + t.Errorf("UnmarshalPacketBody(): Blocks was %d, but expected %d", ep.Blocks, Blocks) + } + + if ep.BlocksFree != BlocksFree { + t.Errorf("UnmarshalPacketBody(): BlocksFree was %d, but expected %d", ep.BlocksFree, BlocksFree) + } + + if ep.BlocksAvail != BlocksAvail { + t.Errorf("UnmarshalPacketBody(): BlocksAvail was %d, but expected %d", ep.BlocksAvail, BlocksAvail) + } + + if ep.Files != Files { + t.Errorf("UnmarshalPacketBody(): Files was %d, but expected %d", ep.Files, Files) + } + + if ep.FilesFree != FilesFree { + t.Errorf("UnmarshalPacketBody(): FilesFree was %d, but expected %d", ep.FilesFree, FilesFree) + } + + if ep.FilesAvail != FilesAvail { + t.Errorf("UnmarshalPacketBody(): FilesAvail was %d, but expected %d", ep.FilesAvail, FilesAvail) + } + + if ep.FilesystemID != FilesystemID { + t.Errorf("UnmarshalPacketBody(): FilesystemID was %d, but expected %d", ep.FilesystemID, FilesystemID) + } + + if ep.MountFlags != MountFlags { + t.Errorf("UnmarshalPacketBody(): MountFlags was %d, but expected %d", ep.MountFlags, MountFlags) + } + + if ep.MaxNameLength != MaxNameLength { + t.Errorf("UnmarshalPacketBody(): MaxNameLength was %d, but expected %d", ep.MaxNameLength, MaxNameLength) + } +} diff --git a/internal/encoding/ssh/filexfer/packets.go b/internal/encoding/ssh/filexfer/packets.go new file mode 100644 index 00000000..fdf65d05 --- /dev/null +++ b/internal/encoding/ssh/filexfer/packets.go @@ -0,0 +1,273 @@ +package sshfx + +import ( + "errors" + "io" +) + +// smallBufferSize is an initial allocation minimal capacity. +const smallBufferSize = 64 + +// RawPacket implements the general packet format from draft-ietf-secsh-filexfer-02 +// +// RawPacket is intended for use in clients receiving responses, +// where a response will be expected to be of a limited number of types, +// and unmarshaling unknown/unexpected response packets is unnecessary. +// +// For servers expecting to receive arbitrary request packet types, +// use RequestPacket. +// +// Defined in https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt#section-3 +type RawPacket struct { + PacketType PacketType + RequestID uint32 + + Data Buffer +} + +// Type returns the Type field defining the SSH_FXP_xy type for this packet. +func (p *RawPacket) Type() PacketType { + return p.PacketType +} + +// Reset clears the pointers and reference-semantic variables of RawPacket, +// releasing underlying resources, and making them and the RawPacket suitable to be reused, +// so long as no other references have been kept. +func (p *RawPacket) Reset() { + p.Data = Buffer{} +} + +// MarshalPacket returns p as a two-part binary encoding of p. +// +// The internal p.RequestID is overridden by the reqid argument. +func (p *RawPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + buf = NewMarshalBuffer(0) + } + + buf.StartPacket(p.PacketType, reqid) + + return buf.Packet(p.Data.Bytes()) +} + +// MarshalBinary returns p as the binary encoding of p. +// +// This is a convenience implementation primarily intended for tests, +// because it is inefficient with allocations. +func (p *RawPacket) MarshalBinary() ([]byte, error) { + return ComposePacket(p.MarshalPacket(p.RequestID, nil)) +} + +// UnmarshalFrom decodes a RawPacket from the given Buffer into p. +// +// The Data field will alias the passed in Buffer, +// so the buffer passed in should not be reused before RawPacket.Reset(). +func (p *RawPacket) UnmarshalFrom(buf *Buffer) error { + *p = RawPacket{ + PacketType: PacketType(buf.ConsumeUint8()), + RequestID: buf.ConsumeUint32(), + } + + p.Data = *buf + + return buf.Err +} + +// UnmarshalBinary decodes a full raw packet out of the given data. +// It is assumed that the uint32(length) has already been consumed to receive the data. +// +// This is a convenience implementation primarily intended for tests, +// because this must clone the given data byte slice, +// as Data is not allowed to alias any part of the data byte slice. +func (p *RawPacket) UnmarshalBinary(data []byte) error { + clone := make([]byte, len(data)) + n := copy(clone, data) + return p.UnmarshalFrom(NewBuffer(clone[:n])) +} + +// readPacket reads a uint32 length-prefixed binary data packet from r. +// using the given byte slice as a backing array. +// +// If the packet length read from r is bigger than maxPacketLength, +// or greater than math.MaxInt32 on a 32-bit implementation, +// then a `ErrLongPacket` error will be returned. +// +// If the given byte slice is insufficient to hold the packet, +// then it will be extended to fill the packet size. +func readPacket(r io.Reader, b []byte, maxPacketLength uint32) ([]byte, error) { + if cap(b) < 4 { + // We will need allocate our own buffer just for reading the packet length. + + // However, we don’t really want to allocate an extremely narrow buffer (4-bytes), + // and cause unnecessary allocation churn from both length reads and small packet reads, + // so we use smallBufferSize from the bytes package as a reasonable guess. + + // But if callers really do want to force narrow throw-away allocation of every packet body, + // they can do so with a buffer of capacity 4. + b = make([]byte, smallBufferSize) + } + + if _, err := io.ReadFull(r, b[:4]); err != nil { + return nil, err + } + + length := unmarshalUint32(b) + if int(length) < 5 { + // Must have at least uint8(type) and uint32(request-id) + + if int(length) < 0 { + // Only possible when strconv.IntSize == 32, + // the packet length is longer than math.MaxInt32, + // and thus longer than any possible slice. + return nil, ErrLongPacket + } + + return nil, ErrShortPacket + } + if length > maxPacketLength { + return nil, ErrLongPacket + } + + if int(length) > cap(b) { + // We know int(length) must be positive, because of tests above. + b = make([]byte, length) + } + + n, err := io.ReadFull(r, b[:length]) + return b[:n], err +} + +// ReadFrom provides a simple functional packet reader, +// using the given byte slice as a backing array. +// +// To protect against potential denial of service attacks, +// if the read packet length is longer than maxPacketLength, +// then no packet data will be read, and ErrLongPacket will be returned. +// (On 32-bit int architectures, all packets >= 2^31 in length +// will return ErrLongPacket regardless of maxPacketLength.) +// +// If the read packet length is longer than cap(b), +// then a throw-away slice will allocated to meet the exact packet length. +// This can be used to limit the length of reused buffers, +// while still allowing reception of occasional large packets. +// +// The Data field may alias the passed in byte slice, +// so the byte slice passed in should not be reused before RawPacket.Reset(). +func (p *RawPacket) ReadFrom(r io.Reader, b []byte, maxPacketLength uint32) error { + b, err := readPacket(r, b, maxPacketLength) + if err != nil { + return err + } + + return p.UnmarshalFrom(NewBuffer(b)) +} + +// RequestPacket implements the general packet format from draft-ietf-secsh-filexfer-02 +// but also automatically decode/encodes valid request packets (2 < type < 100 || type == 200). +// +// RequestPacket is intended for use in servers receiving requests, +// where any arbitrary request may be received, and so decoding them automatically +// is useful. +// +// For clients expecting to receive specific response packet types, +// where automatic unmarshaling of the packet body does not make sense, +// use RawPacket. +// +// Defined in https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt#section-3 +type RequestPacket struct { + RequestID uint32 + + Request Packet +} + +// Type returns the SSH_FXP_xy value associated with the underlying packet. +func (p *RequestPacket) Type() PacketType { + return p.Request.Type() +} + +// Reset clears the pointers and reference-semantic variables in RequestPacket, +// releasing underlying resources, and making them and the RequestPacket suitable to be reused, +// so long as no other references have been kept. +func (p *RequestPacket) Reset() { + p.Request = nil +} + +// MarshalPacket returns p as a two-part binary encoding of p. +// +// The internal p.RequestID is overridden by the reqid argument. +func (p *RequestPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + if p.Request == nil { + return nil, nil, errors.New("empty request packet") + } + + return p.Request.MarshalPacket(reqid, b) +} + +// MarshalBinary returns p as the binary encoding of p. +// +// This is a convenience implementation primarily intended for tests, +// because it is inefficient with allocations. +func (p *RequestPacket) MarshalBinary() ([]byte, error) { + return ComposePacket(p.MarshalPacket(p.RequestID, nil)) +} + +// UnmarshalFrom decodes a RequestPacket from the given Buffer into p. +// +// The Request field may alias the passed in Buffer, (e.g. SSH_FXP_WRITE), +// so the buffer passed in should not be reused before RequestPacket.Reset(). +func (p *RequestPacket) UnmarshalFrom(buf *Buffer) error { + typ := PacketType(buf.ConsumeUint8()) + if buf.Err != nil { + return buf.Err + } + + req, err := newPacketFromType(typ) + if err != nil { + return err + } + + *p = RequestPacket{ + RequestID: buf.ConsumeUint32(), + Request: req, + } + + return p.Request.UnmarshalPacketBody(buf) +} + +// UnmarshalBinary decodes a full request packet out of the given data. +// It is assumed that the uint32(length) has already been consumed to receive the data. +// +// This is a convenience implementation primarily intended for tests, +// because this must clone the given data byte slice, +// as Request is not allowed to alias any part of the data byte slice. +func (p *RequestPacket) UnmarshalBinary(data []byte) error { + clone := make([]byte, len(data)) + n := copy(clone, data) + return p.UnmarshalFrom(NewBuffer(clone[:n])) +} + +// ReadFrom provides a simple functional packet reader, +// using the given byte slice as a backing array. +// +// To protect against potential denial of service attacks, +// if the read packet length is longer than maxPacketLength, +// then no packet data will be read, and ErrLongPacket will be returned. +// (On 32-bit int architectures, all packets >= 2^31 in length +// will return ErrLongPacket regardless of maxPacketLength.) +// +// If the read packet length is longer than cap(b), +// then a throw-away slice will allocated to meet the exact packet length. +// This can be used to limit the length of reused buffers, +// while still allowing reception of occasional large packets. +// +// The Request field may alias the passed in byte slice, +// so the byte slice passed in should not be reused before RawPacket.Reset(). +func (p *RequestPacket) ReadFrom(r io.Reader, b []byte, maxPacketLength uint32) error { + b, err := readPacket(r, b, maxPacketLength) + if err != nil { + return err + } + + return p.UnmarshalFrom(NewBuffer(b)) +} diff --git a/internal/encoding/ssh/filexfer/packets_test.go b/internal/encoding/ssh/filexfer/packets_test.go new file mode 100644 index 00000000..f6dac8e4 --- /dev/null +++ b/internal/encoding/ssh/filexfer/packets_test.go @@ -0,0 +1,132 @@ +package sshfx + +import ( + "bytes" + "testing" +) + +func TestRawPacket(t *testing.T) { + const ( + id = 42 + errMsg = "eof" + langTag = "en" + ) + + p := &RawPacket{ + PacketType: PacketTypeStatus, + RequestID: id, + Data: Buffer{ + b: []byte{ + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x03, 'e', 'o', 'f', + 0x00, 0x00, 0x00, 0x02, 'e', 'n', + }, + }, + } + + buf, err := p.MarshalBinary() + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 22, + 101, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 3, 'e', 'o', 'f', + 0x00, 0x00, 0x00, 2, 'e', 'n', + } + + if !bytes.Equal(buf, want) { + t.Errorf("RawPacket.MarshalBinary() = %X, but wanted %X", buf, want) + } + + *p = RawPacket{} + + if err := p.ReadFrom(bytes.NewReader(buf), nil, DefaultMaxPacketLength); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.PacketType != PacketTypeStatus { + t.Errorf("RawPacket.UnmarshalBinary(): Type was %v, but expected %v", p.PacketType, PacketTypeStat) + } + + if p.RequestID != uint32(id) { + t.Errorf("RawPacket.UnmarshalBinary(): RequestID was %d, but expected %d", p.RequestID, id) + } + + want = []byte{ + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 3, 'e', 'o', 'f', + 0x00, 0x00, 0x00, 2, 'e', 'n', + } + + if !bytes.Equal(p.Data.Bytes(), want) { + t.Fatalf("RawPacket.UnmarshalBinary(): Data was %X, but expected %X", p.Data, want) + } + + var resp StatusPacket + resp.UnmarshalPacketBody(&p.Data) + + if resp.StatusCode != StatusEOF { + t.Errorf("UnmarshalPacketBody(): StatusCode was %v, but expected %v", resp.StatusCode, StatusEOF) + } + + if resp.ErrorMessage != errMsg { + t.Errorf("UnmarshalPacketBody(): ErrorMessage was %q, but expected %q", resp.ErrorMessage, errMsg) + } + + if resp.LanguageTag != langTag { + t.Errorf("UnmarshalPacketBody(): LanguageTag was %q, but expected %q", resp.LanguageTag, langTag) + } +} + +func TestRequestPacket(t *testing.T) { + const ( + id = 42 + path = "foo" + ) + + p := &RequestPacket{ + RequestID: id, + Request: &StatPacket{ + Path: path, + }, + } + + buf, err := p.MarshalBinary() + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 12, + 17, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 3, 'f', 'o', 'o', + } + + if !bytes.Equal(buf, want) { + t.Errorf("RequestPacket.MarshalBinary() = %X, but wanted %X", buf, want) + } + + *p = RequestPacket{} + + if err := p.ReadFrom(bytes.NewReader(buf), nil, DefaultMaxPacketLength); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.RequestID != uint32(id) { + t.Errorf("RequestPacket.UnmarshalBinary(): RequestID was %d, but expected %d", p.RequestID, id) + } + + req, ok := p.Request.(*StatPacket) + if !ok { + t.Fatalf("unexpected Request type was %T, but expected %T", p.Request, req) + } + + if req.Path != path { + t.Errorf("RequestPacket.UnmarshalBinary(): Request.Path was %q, but expected %q", req.Path, path) + } +} diff --git a/internal/encoding/ssh/filexfer/path_packets.go b/internal/encoding/ssh/filexfer/path_packets.go new file mode 100644 index 00000000..0180326f --- /dev/null +++ b/internal/encoding/ssh/filexfer/path_packets.go @@ -0,0 +1,362 @@ +package sshfx + +// LStatPacket defines the SSH_FXP_LSTAT packet. +type LStatPacket struct { + Path string +} + +// Type returns the SSH_FXP_xy value associated with this packet type. +func (p *LStatPacket) Type() PacketType { + return PacketTypeLStat +} + +// MarshalPacket returns p as a two-part binary encoding of p. +func (p *LStatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Path) // string(path) + buf = NewMarshalBuffer(size) + } + + buf.StartPacket(PacketTypeLStat, reqid) + buf.AppendString(p.Path) + + return buf.Packet(payload) +} + +// UnmarshalPacketBody unmarshals the packet body from the given Buffer. +// It is assumed that the uint32(request-id) has already been consumed. +func (p *LStatPacket) UnmarshalPacketBody(buf *Buffer) (err error) { + *p = LStatPacket{ + Path: buf.ConsumeString(), + } + + return buf.Err +} + +// SetstatPacket defines the SSH_FXP_SETSTAT packet. +type SetstatPacket struct { + Path string + Attrs Attributes +} + +// Type returns the SSH_FXP_xy value associated with this packet type. +func (p *SetstatPacket) Type() PacketType { + return PacketTypeSetstat +} + +// MarshalPacket returns p as a two-part binary encoding of p. +func (p *SetstatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Path) + p.Attrs.Len() // string(path) + ATTRS(attrs) + buf = NewMarshalBuffer(size) + } + + buf.StartPacket(PacketTypeSetstat, reqid) + buf.AppendString(p.Path) + + p.Attrs.MarshalInto(buf) + + return buf.Packet(payload) +} + +// UnmarshalPacketBody unmarshals the packet body from the given Buffer. +// It is assumed that the uint32(request-id) has already been consumed. +func (p *SetstatPacket) UnmarshalPacketBody(buf *Buffer) (err error) { + *p = SetstatPacket{ + Path: buf.ConsumeString(), + } + + return p.Attrs.UnmarshalFrom(buf) +} + +// RemovePacket defines the SSH_FXP_REMOVE packet. +type RemovePacket struct { + Path string +} + +// Type returns the SSH_FXP_xy value associated with this packet type. +func (p *RemovePacket) Type() PacketType { + return PacketTypeRemove +} + +// MarshalPacket returns p as a two-part binary encoding of p. +func (p *RemovePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Path) // string(path) + buf = NewMarshalBuffer(size) + } + + buf.StartPacket(PacketTypeRemove, reqid) + buf.AppendString(p.Path) + + return buf.Packet(payload) +} + +// UnmarshalPacketBody unmarshals the packet body from the given Buffer. +// It is assumed that the uint32(request-id) has already been consumed. +func (p *RemovePacket) UnmarshalPacketBody(buf *Buffer) (err error) { + *p = RemovePacket{ + Path: buf.ConsumeString(), + } + + return buf.Err +} + +// MkdirPacket defines the SSH_FXP_MKDIR packet. +type MkdirPacket struct { + Path string + Attrs Attributes +} + +// Type returns the SSH_FXP_xy value associated with this packet type. +func (p *MkdirPacket) Type() PacketType { + return PacketTypeMkdir +} + +// MarshalPacket returns p as a two-part binary encoding of p. +func (p *MkdirPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Path) + p.Attrs.Len() // string(path) + ATTRS(attrs) + buf = NewMarshalBuffer(size) + } + + buf.StartPacket(PacketTypeMkdir, reqid) + buf.AppendString(p.Path) + + p.Attrs.MarshalInto(buf) + + return buf.Packet(payload) +} + +// UnmarshalPacketBody unmarshals the packet body from the given Buffer. +// It is assumed that the uint32(request-id) has already been consumed. +func (p *MkdirPacket) UnmarshalPacketBody(buf *Buffer) (err error) { + *p = MkdirPacket{ + Path: buf.ConsumeString(), + } + + return p.Attrs.UnmarshalFrom(buf) +} + +// RmdirPacket defines the SSH_FXP_RMDIR packet. +type RmdirPacket struct { + Path string +} + +// Type returns the SSH_FXP_xy value associated with this packet type. +func (p *RmdirPacket) Type() PacketType { + return PacketTypeRmdir +} + +// MarshalPacket returns p as a two-part binary encoding of p. +func (p *RmdirPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Path) // string(path) + buf = NewMarshalBuffer(size) + } + + buf.StartPacket(PacketTypeRmdir, reqid) + buf.AppendString(p.Path) + + return buf.Packet(payload) +} + +// UnmarshalPacketBody unmarshals the packet body from the given Buffer. +// It is assumed that the uint32(request-id) has already been consumed. +func (p *RmdirPacket) UnmarshalPacketBody(buf *Buffer) (err error) { + *p = RmdirPacket{ + Path: buf.ConsumeString(), + } + + return buf.Err +} + +// RealPathPacket defines the SSH_FXP_REALPATH packet. +type RealPathPacket struct { + Path string +} + +// Type returns the SSH_FXP_xy value associated with this packet type. +func (p *RealPathPacket) Type() PacketType { + return PacketTypeRealPath +} + +// MarshalPacket returns p as a two-part binary encoding of p. +func (p *RealPathPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Path) // string(path) + buf = NewMarshalBuffer(size) + } + + buf.StartPacket(PacketTypeRealPath, reqid) + buf.AppendString(p.Path) + + return buf.Packet(payload) +} + +// UnmarshalPacketBody unmarshals the packet body from the given Buffer. +// It is assumed that the uint32(request-id) has already been consumed. +func (p *RealPathPacket) UnmarshalPacketBody(buf *Buffer) (err error) { + *p = RealPathPacket{ + Path: buf.ConsumeString(), + } + + return buf.Err +} + +// StatPacket defines the SSH_FXP_STAT packet. +type StatPacket struct { + Path string +} + +// Type returns the SSH_FXP_xy value associated with this packet type. +func (p *StatPacket) Type() PacketType { + return PacketTypeStat +} + +// MarshalPacket returns p as a two-part binary encoding of p. +func (p *StatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Path) // string(path) + buf = NewMarshalBuffer(size) + } + + buf.StartPacket(PacketTypeStat, reqid) + buf.AppendString(p.Path) + + return buf.Packet(payload) +} + +// UnmarshalPacketBody unmarshals the packet body from the given Buffer. +// It is assumed that the uint32(request-id) has already been consumed. +func (p *StatPacket) UnmarshalPacketBody(buf *Buffer) (err error) { + *p = StatPacket{ + Path: buf.ConsumeString(), + } + + return buf.Err +} + +// RenamePacket defines the SSH_FXP_RENAME packet. +type RenamePacket struct { + OldPath string + NewPath string +} + +// Type returns the SSH_FXP_xy value associated with this packet type. +func (p *RenamePacket) Type() PacketType { + return PacketTypeRename +} + +// MarshalPacket returns p as a two-part binary encoding of p. +func (p *RenamePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + // string(oldpath) + string(newpath) + size := 4 + len(p.OldPath) + 4 + len(p.NewPath) + buf = NewMarshalBuffer(size) + } + + buf.StartPacket(PacketTypeRename, reqid) + buf.AppendString(p.OldPath) + buf.AppendString(p.NewPath) + + return buf.Packet(payload) +} + +// UnmarshalPacketBody unmarshals the packet body from the given Buffer. +// It is assumed that the uint32(request-id) has already been consumed. +func (p *RenamePacket) UnmarshalPacketBody(buf *Buffer) (err error) { + *p = RenamePacket{ + OldPath: buf.ConsumeString(), + NewPath: buf.ConsumeString(), + } + + return buf.Err +} + +// ReadLinkPacket defines the SSH_FXP_READLINK packet. +type ReadLinkPacket struct { + Path string +} + +// Type returns the SSH_FXP_xy value associated with this packet type. +func (p *ReadLinkPacket) Type() PacketType { + return PacketTypeReadLink +} + +// MarshalPacket returns p as a two-part binary encoding of p. +func (p *ReadLinkPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + size := 4 + len(p.Path) // string(path) + buf = NewMarshalBuffer(size) + } + + buf.StartPacket(PacketTypeReadLink, reqid) + buf.AppendString(p.Path) + + return buf.Packet(payload) +} + +// UnmarshalPacketBody unmarshals the packet body from the given Buffer. +// It is assumed that the uint32(request-id) has already been consumed. +func (p *ReadLinkPacket) UnmarshalPacketBody(buf *Buffer) (err error) { + *p = ReadLinkPacket{ + Path: buf.ConsumeString(), + } + + return buf.Err +} + +// SymlinkPacket defines the SSH_FXP_SYMLINK packet. +// +// The order of the arguments to the SSH_FXP_SYMLINK method was inadvertently reversed. +// Unfortunately, the reversal was not noticed until the server was widely deployed. +// Covered in Section 4.1 of https://github.com/openssh/openssh-portable/blob/master/PROTOCOL +type SymlinkPacket struct { + LinkPath string + TargetPath string +} + +// Type returns the SSH_FXP_xy value associated with this packet type. +func (p *SymlinkPacket) Type() PacketType { + return PacketTypeSymlink +} + +// MarshalPacket returns p as a two-part binary encoding of p. +func (p *SymlinkPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) { + buf := NewBuffer(b) + if buf.Cap() < 9 { + // string(targetpath) + string(linkpath) + size := 4 + len(p.TargetPath) + 4 + len(p.LinkPath) + buf = NewMarshalBuffer(size) + } + + buf.StartPacket(PacketTypeSymlink, reqid) + + // Arguments were inadvertently reversed. + buf.AppendString(p.TargetPath) + buf.AppendString(p.LinkPath) + + return buf.Packet(payload) +} + +// UnmarshalPacketBody unmarshals the packet body from the given Buffer. +// It is assumed that the uint32(request-id) has already been consumed. +func (p *SymlinkPacket) UnmarshalPacketBody(buf *Buffer) (err error) { + *p = SymlinkPacket{ + // Arguments were inadvertently reversed. + TargetPath: buf.ConsumeString(), + LinkPath: buf.ConsumeString(), + } + + return buf.Err +} diff --git a/internal/encoding/ssh/filexfer/path_packets_test.go b/internal/encoding/ssh/filexfer/path_packets_test.go new file mode 100644 index 00000000..f016f7e9 --- /dev/null +++ b/internal/encoding/ssh/filexfer/path_packets_test.go @@ -0,0 +1,450 @@ +package sshfx + +import ( + "bytes" + "testing" +) + +var _ Packet = &LStatPacket{} + +func TestLStatPacket(t *testing.T) { + const ( + id = 42 + path = "/foo" + ) + + p := &LStatPacket{ + Path: path, + } + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 13, + 7, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o', + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = LStatPacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.Path != path { + t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", p.Path, path) + } +} + +var _ Packet = &SetstatPacket{} + +func TestSetstatPacket(t *testing.T) { + const ( + id = 42 + path = "/foo" + perms FileMode = 0x87654321 + ) + + p := &SetstatPacket{ + Path: "/foo", + Attrs: Attributes{ + Flags: AttrPermissions, + Permissions: perms, + }, + } + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 21, + 9, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o', + 0x00, 0x00, 0x00, 0x04, + 0x87, 0x65, 0x43, 0x21, + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = SetstatPacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.Path != path { + t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", p.Path, path) + } + + if p.Attrs.Flags != AttrPermissions { + t.Errorf("UnmarshalPacketBody(): Attrs.Flags was %#x, but expected %#x", p.Attrs.Flags, AttrPermissions) + } + + if p.Attrs.Permissions != perms { + t.Errorf("UnmarshalPacketBody(): Attrs.Permissions was %#v, but expected %#v", p.Attrs.Permissions, perms) + } +} + +var _ Packet = &RemovePacket{} + +func TestRemovePacket(t *testing.T) { + const ( + id = 42 + path = "/foo" + ) + + p := &RemovePacket{ + Path: path, + } + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 13, + 13, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o', + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = RemovePacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.Path != path { + t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", p.Path, path) + } +} + +var _ Packet = &MkdirPacket{} + +func TestMkdirPacket(t *testing.T) { + const ( + id = 42 + path = "/foo" + perms FileMode = 0x87654321 + ) + + p := &MkdirPacket{ + Path: "/foo", + Attrs: Attributes{ + Flags: AttrPermissions, + Permissions: perms, + }, + } + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 21, + 14, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o', + 0x00, 0x00, 0x00, 0x04, + 0x87, 0x65, 0x43, 0x21, + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = MkdirPacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.Path != path { + t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", p.Path, path) + } + + if p.Attrs.Flags != AttrPermissions { + t.Errorf("UnmarshalPacketBody(): Attrs.Flags was %#x, but expected %#x", p.Attrs.Flags, AttrPermissions) + } + + if p.Attrs.Permissions != perms { + t.Errorf("UnmarshalPacketBody(): Attrs.Permissions was %#v, but expected %#v", p.Attrs.Permissions, perms) + } +} + +var _ Packet = &RmdirPacket{} + +func TestRmdirPacket(t *testing.T) { + const ( + id = 42 + path = "/foo" + ) + + p := &RmdirPacket{ + Path: path, + } + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 13, + 15, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o', + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = RmdirPacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.Path != path { + t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", p.Path, path) + } +} + +var _ Packet = &RealPathPacket{} + +func TestRealPathPacket(t *testing.T) { + const ( + id = 42 + path = "/foo" + ) + + p := &RealPathPacket{ + Path: path, + } + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 13, + 16, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o', + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = RealPathPacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.Path != path { + t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", p.Path, path) + } +} + +var _ Packet = &StatPacket{} + +func TestStatPacket(t *testing.T) { + const ( + id = 42 + path = "/foo" + ) + + p := &StatPacket{ + Path: path, + } + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 13, + 17, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o', + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = StatPacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.Path != path { + t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", p.Path, path) + } +} + +var _ Packet = &RenamePacket{} + +func TestRenamePacket(t *testing.T) { + const ( + id = 42 + oldpath = "/foo" + newpath = "/bar" + ) + + p := &RenamePacket{ + OldPath: oldpath, + NewPath: newpath, + } + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 21, + 18, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o', + 0x00, 0x00, 0x00, 4, '/', 'b', 'a', 'r', + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = RenamePacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.OldPath != oldpath { + t.Errorf("UnmarshalPacketBody(): OldPath was %q, but expected %q", p.OldPath, oldpath) + } + + if p.NewPath != newpath { + t.Errorf("UnmarshalPacketBody(): NewPath was %q, but expected %q", p.NewPath, newpath) + } +} + +var _ Packet = &ReadLinkPacket{} + +func TestReadLinkPacket(t *testing.T) { + const ( + id = 42 + path = "/foo" + ) + + p := &ReadLinkPacket{ + Path: path, + } + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 13, + 19, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o', + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = ReadLinkPacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.Path != path { + t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", p.Path, path) + } +} + +var _ Packet = &SymlinkPacket{} + +func TestSymlinkPacket(t *testing.T) { + const ( + id = 42 + linkpath = "/foo" + targetpath = "/bar" + ) + + p := &SymlinkPacket{ + LinkPath: linkpath, + TargetPath: targetpath, + } + + buf, err := ComposePacket(p.MarshalPacket(id, nil)) + if err != nil { + t.Fatal("unexpected error:", err) + } + + want := []byte{ + 0x00, 0x00, 0x00, 21, + 20, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 4, '/', 'b', 'a', 'r', // Arguments were inadvertently reversed. + 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o', + } + + if !bytes.Equal(buf, want) { + t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want) + } + + *p = SymlinkPacket{} + + // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed. + if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil { + t.Fatal("unexpected error:", err) + } + + if p.LinkPath != linkpath { + t.Errorf("UnmarshalPacketBody(): LinkPath was %q, but expected %q", p.LinkPath, linkpath) + } + + if p.TargetPath != targetpath { + t.Errorf("UnmarshalPacketBody(): TargetPath was %q, but expected %q", p.TargetPath, targetpath) + } +} diff --git a/internal/encoding/ssh/filexfer/permissions.go b/internal/encoding/ssh/filexfer/permissions.go new file mode 100644 index 00000000..0143ec0c --- /dev/null +++ b/internal/encoding/ssh/filexfer/permissions.go @@ -0,0 +1,114 @@ +package sshfx + +// FileMode represents a file’s mode and permission bits. +// The bits are defined according to POSIX standards, +// and may not apply to the OS being built for. +type FileMode uint32 + +// Permission flags, defined here to avoid potential inconsistencies in individual OS implementations. +const ( + ModePerm FileMode = 0o0777 // S_IRWXU | S_IRWXG | S_IRWXO + ModeUserRead FileMode = 0o0400 // S_IRUSR + ModeUserWrite FileMode = 0o0200 // S_IWUSR + ModeUserExec FileMode = 0o0100 // S_IXUSR + ModeGroupRead FileMode = 0o0040 // S_IRGRP + ModeGroupWrite FileMode = 0o0020 // S_IWGRP + ModeGroupExec FileMode = 0o0010 // S_IXGRP + ModeOtherRead FileMode = 0o0004 // S_IROTH + ModeOtherWrite FileMode = 0o0002 // S_IWOTH + ModeOtherExec FileMode = 0o0001 // S_IXOTH + + ModeSetUID FileMode = 0o4000 // S_ISUID + ModeSetGID FileMode = 0o2000 // S_ISGID + ModeSticky FileMode = 0o1000 // S_ISVTX + + ModeType FileMode = 0xF000 // S_IFMT + ModeNamedPipe FileMode = 0x1000 // S_IFIFO + ModeCharDevice FileMode = 0x2000 // S_IFCHR + ModeDir FileMode = 0x4000 // S_IFDIR + ModeDevice FileMode = 0x6000 // S_IFBLK + ModeRegular FileMode = 0x8000 // S_IFREG + ModeSymlink FileMode = 0xA000 // S_IFLNK + ModeSocket FileMode = 0xC000 // S_IFSOCK +) + +// IsDir reports whether m describes a directory. +// That is, it tests for m.Type() == ModeDir. +func (m FileMode) IsDir() bool { + return (m & ModeType) == ModeDir +} + +// IsRegular reports whether m describes a regular file. +// That is, it tests for m.Type() == ModeRegular +func (m FileMode) IsRegular() bool { + return (m & ModeType) == ModeRegular +} + +// Perm returns the POSIX permission bits in m (m & ModePerm). +func (m FileMode) Perm() FileMode { + return (m & ModePerm) +} + +// Type returns the type bits in m (m & ModeType). +func (m FileMode) Type() FileMode { + return (m & ModeType) +} + +// String returns a `-rwxrwxrwx` style string representing the `ls -l` POSIX permissions string. +func (m FileMode) String() string { + var buf [10]byte + + switch m.Type() { + case ModeRegular: + buf[0] = '-' + case ModeDir: + buf[0] = 'd' + case ModeSymlink: + buf[0] = 'l' + case ModeDevice: + buf[0] = 'b' + case ModeCharDevice: + buf[0] = 'c' + case ModeNamedPipe: + buf[0] = 'p' + case ModeSocket: + buf[0] = 's' + default: + buf[0] = '?' + } + + const rwx = "rwxrwxrwx" + for i, c := range rwx { + if m&(1<>24), byte(v>>16), byte(v>>8), byte(v)) } @@ -19,8 +37,68 @@ func marshalString(b []byte, v string) []byte { return append(marshalUint32(b, uint32(len(v))), v...) } +func marshalFileInfo(b []byte, fi os.FileInfo) []byte { + // attributes variable struct, and also variable per protocol version + // spec version 3 attributes: + // uint32 flags + // uint64 size present only if flag SSH_FILEXFER_ATTR_SIZE + // uint32 uid present only if flag SSH_FILEXFER_ATTR_UIDGID + // uint32 gid present only if flag SSH_FILEXFER_ATTR_UIDGID + // uint32 permissions present only if flag SSH_FILEXFER_ATTR_PERMISSIONS + // uint32 atime present only if flag SSH_FILEXFER_ACMODTIME + // uint32 mtime present only if flag SSH_FILEXFER_ACMODTIME + // uint32 extended_count present only if flag SSH_FILEXFER_ATTR_EXTENDED + // string extended_type + // string extended_data + // ... more extended data (extended_type - extended_data pairs), + // so that number of pairs equals extended_count + + flags, fileStat := fileStatFromInfo(fi) + + b = marshalUint32(b, flags) + + return marshalFileStat(b, flags, fileStat) +} + +func marshalFileStat(b []byte, flags uint32, fileStat *FileStat) []byte { + if flags&sshFileXferAttrSize != 0 { + b = marshalUint64(b, fileStat.Size) + } + if flags&sshFileXferAttrUIDGID != 0 { + b = marshalUint32(b, fileStat.UID) + b = marshalUint32(b, fileStat.GID) + } + if flags&sshFileXferAttrPermissions != 0 { + b = marshalUint32(b, fileStat.Mode) + } + if flags&sshFileXferAttrACmodTime != 0 { + b = marshalUint32(b, fileStat.Atime) + b = marshalUint32(b, fileStat.Mtime) + } + + if flags&sshFileXferAttrExtended != 0 { + b = marshalUint32(b, uint32(len(fileStat.Extended))) + + for _, attr := range fileStat.Extended { + b = marshalString(b, attr.ExtType) + b = marshalString(b, attr.ExtData) + } + } + + return b +} + +func marshalStatus(b []byte, err StatusError) []byte { + b = marshalUint32(b, err.Code) + b = marshalString(b, err.msg) + b = marshalString(b, err.lang) + return b +} + func marshal(b []byte, v interface{}) []byte { switch v := v.(type) { + case nil: + return b case uint8: return append(b, v) case uint32: @@ -29,16 +107,20 @@ func marshal(b []byte, v interface{}) []byte { return marshalUint64(b, v) case string: return marshalString(b, v) + case []byte: + return append(b, v...) + case os.FileInfo: + return marshalFileInfo(b, v) default: switch d := reflect.ValueOf(v); d.Kind() { case reflect.Struct: for i, n := 0, d.NumField(); i < n; i++ { - b = append(marshal(b, d.Field(i).Interface())) + b = marshal(b, d.Field(i).Interface()) } return b case reflect.Slice: for i, n := 0, d.Len(); i < n; i++ { - b = append(marshal(b, d.Index(i).Interface())) + b = marshal(b, d.Index(i).Interface()) } return b default: @@ -52,46 +134,250 @@ func unmarshalUint32(b []byte) (uint32, []byte) { return v, b[4:] } +func unmarshalUint32Safe(b []byte) (uint32, []byte, error) { + var v uint32 + if len(b) < 4 { + return 0, nil, errShortPacket + } + v, b = unmarshalUint32(b) + return v, b, nil +} + func unmarshalUint64(b []byte) (uint64, []byte) { h, b := unmarshalUint32(b) l, b := unmarshalUint32(b) return uint64(h)<<32 | uint64(l), b } +func unmarshalUint64Safe(b []byte) (uint64, []byte, error) { + var v uint64 + if len(b) < 8 { + return 0, nil, errShortPacket + } + v, b = unmarshalUint64(b) + return v, b, nil +} + func unmarshalString(b []byte) (string, []byte) { n, b := unmarshalUint32(b) return string(b[:n]), b[n:] } -// sendPacket marshals p according to RFC 4234. +func unmarshalStringSafe(b []byte) (string, []byte, error) { + n, b, err := unmarshalUint32Safe(b) + if err != nil { + return "", nil, err + } + if int64(n) > int64(len(b)) { + return "", nil, errShortPacket + } + return string(b[:n]), b[n:], nil +} -func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error { - bb, err := m.MarshalBinary() +func unmarshalAttrs(b []byte) (*FileStat, []byte, error) { + flags, b, err := unmarshalUint32Safe(b) if err != nil { - return fmt.Errorf("marshal2(%#v): binary marshaller failed", err) + return nil, b, err + } + return unmarshalFileStat(flags, b) +} + +func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte, error) { + var fs FileStat + var err error + + if flags&sshFileXferAttrSize == sshFileXferAttrSize { + fs.Size, b, err = unmarshalUint64Safe(b) + if err != nil { + return nil, b, err + } + } + if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID { + fs.UID, b, err = unmarshalUint32Safe(b) + if err != nil { + return nil, b, err + } + fs.GID, b, err = unmarshalUint32Safe(b) + if err != nil { + return nil, b, err + } + } + if flags&sshFileXferAttrPermissions == sshFileXferAttrPermissions { + fs.Mode, b, err = unmarshalUint32Safe(b) + if err != nil { + return nil, b, err + } + } + if flags&sshFileXferAttrACmodTime == sshFileXferAttrACmodTime { + fs.Atime, b, err = unmarshalUint32Safe(b) + if err != nil { + return nil, b, err + } + fs.Mtime, b, err = unmarshalUint32Safe(b) + if err != nil { + return nil, b, err + } } - l := uint32(len(bb)) - hdr := []byte{byte(l >> 24), byte(l >> 16), byte(l >> 8), byte(l)} - debug("send packet %T, len: %v", m, l) - _, err = w.Write(hdr) + if flags&sshFileXferAttrExtended == sshFileXferAttrExtended { + var count uint32 + count, b, err = unmarshalUint32Safe(b) + if err != nil { + return nil, b, err + } + + ext := make([]StatExtended, count) + for i := uint32(0); i < count; i++ { + var typ string + var data string + typ, b, err = unmarshalStringSafe(b) + if err != nil { + return nil, b, err + } + data, b, err = unmarshalStringSafe(b) + if err != nil { + return nil, b, err + } + ext[i] = StatExtended{ + ExtType: typ, + ExtData: data, + } + } + fs.Extended = ext + } + return &fs, b, nil +} + +func unmarshalStatus(id uint32, data []byte) error { + sid, data := unmarshalUint32(data) + if sid != id { + return &unexpectedIDErr{id, sid} + } + code, data := unmarshalUint32(data) + msg, data, _ := unmarshalStringSafe(data) + lang, _, _ := unmarshalStringSafe(data) + return &StatusError{ + Code: code, + msg: msg, + lang: lang, + } +} + +type packetMarshaler interface { + marshalPacket() (header, payload []byte, err error) +} + +func marshalPacket(m encoding.BinaryMarshaler) (header, payload []byte, err error) { + if m, ok := m.(packetMarshaler); ok { + return m.marshalPacket() + } + + header, err = m.MarshalBinary() + return +} + +// sendPacket marshals p according to RFC 4234. +func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error { + header, payload, err := marshalPacket(m) if err != nil { - return err + return fmt.Errorf("binary marshaller failed: %w", err) } - _, err = w.Write(bb) - return err + + length := len(header) + len(payload) - 4 // subtract the uint32(length) from the start + if debugDumpTxPacketBytes { + debug("send packet: %s %d bytes %x%x", fxp(header[4]), length, header[5:], payload) + } else if debugDumpTxPacket { + debug("send packet: %s %d bytes", fxp(header[4]), length) + } + + binary.BigEndian.PutUint32(header[:4], uint32(length)) + + if _, err := w.Write(header); err != nil { + return fmt.Errorf("failed to send packet: %w", err) + } + + if len(payload) > 0 { + if _, err := w.Write(payload); err != nil { + return fmt.Errorf("failed to send packet payload: %w", err) + } + } + + return nil } -func recvPacket(r io.Reader) (uint8, []byte, error) { - var b = []byte{0, 0, 0, 0} - if _, err := io.ReadFull(r, b); err != nil { - return 0, nil, err +func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (fxp, []byte, error) { + var b []byte + if alloc != nil { + b = alloc.GetPage(orderID) + } else { + b = make([]byte, 4) } - l, _ := unmarshalUint32(b) - b = make([]byte, l) - if _, err := io.ReadFull(r, b); err != nil { - return 0, nil, err + + if n, err := io.ReadFull(r, b[:4]); err != nil { + if err == io.EOF { + return 0, nil, err + } + + return 0, nil, fmt.Errorf("error reading packet length: %d of 4: %w", n, err) } - return b[0], b[1:], nil + + length, _ := unmarshalUint32(b) + if length > maxMsgLength { + debug("recv packet %d bytes too long", length) + return 0, nil, errLongPacket + } + if length == 0 { + debug("recv packet of 0 bytes too short") + return 0, nil, errShortPacket + } + + if alloc == nil { + b = make([]byte, length) + } + + n, err := io.ReadFull(r, b[:length]) + b = b[:n] + + if err != nil { + debug("recv packet error: %d of %d bytes: %x", n, length, b) + + // ReadFull only returns EOF if it has read no bytes. + // In this case, that means a partial packet, and thus unexpected. + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + + if n == 0 { + return 0, nil, fmt.Errorf("error reading packet body: %d of %d: %w", n, length, err) + } + + return 0, nil, fmt.Errorf("error reading packet body: %d of %d: (%s) %w", n, length, fxp(b[0]), err) + } + + typ, payload := fxp(b[0]), b[1:n] + + if debugDumpRxPacketBytes { + debug("recv packet: %s %d bytes %x", typ, length, payload) + } else if debugDumpRxPacket { + debug("recv packet: %s %d bytes", typ, length) + } + + return typ, payload, nil +} + +type extensionPair struct { + Name string + Data string +} + +func unmarshalExtensionPair(b []byte) (extensionPair, []byte, error) { + var ep extensionPair + var err error + ep.Name, b, err = unmarshalStringSafe(b) + if err != nil { + return ep, b, err + } + ep.Data, b, err = unmarshalStringSafe(b) + return ep, b, err } // Here starts the definition of packets along with their MarshalBinary @@ -101,231 +387,1036 @@ func recvPacket(r io.Reader) (uint8, []byte, error) { type sshFxInitPacket struct { Version uint32 - Extensions []struct { - Name, Data string + Extensions []extensionPair +} + +func (p *sshFxInitPacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 // uint32(length) + byte(type) + uint32(version) + for _, e := range p.Extensions { + l += 4 + len(e.Name) + 4 + len(e.Data) + } + + b := make([]byte, 4, l) + b = append(b, sshFxpInit) + b = marshalUint32(b, p.Version) + + for _, e := range p.Extensions { + b = marshalString(b, e.Name) + b = marshalString(b, e.Data) } + + return b, nil } -func (p sshFxInitPacket) MarshalBinary() ([]byte, error) { - l := 1 + 4 // byte + uint32 +func (p *sshFxInitPacket) UnmarshalBinary(b []byte) error { + var err error + if p.Version, b, err = unmarshalUint32Safe(b); err != nil { + return err + } + for len(b) > 0 { + var ep extensionPair + ep, b, err = unmarshalExtensionPair(b) + if err != nil { + return err + } + p.Extensions = append(p.Extensions, ep) + } + return nil +} + +type sshFxVersionPacket struct { + Version uint32 + Extensions []sshExtensionPair +} + +type sshExtensionPair struct { + Name, Data string +} + +func (p *sshFxVersionPacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 // uint32(length) + byte(type) + uint32(version) for _, e := range p.Extensions { l += 4 + len(e.Name) + 4 + len(e.Data) } - b := make([]byte, 0, l) - b = append(b, ssh_FXP_INIT) + b := make([]byte, 4, l) + b = append(b, sshFxpVersion) b = marshalUint32(b, p.Version) + for _, e := range p.Extensions { b = marshalString(b, e.Name) b = marshalString(b, e.Data) } + return b, nil } -func marshalIdString(packetType byte, id uint32, str string) ([]byte, error) { - l := 1 + 4 + // type(byte) + uint32 +func marshalIDStringPacket(packetType byte, id uint32, str string) ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 4 + len(str) - b := make([]byte, 0, l) + b := make([]byte, 4, l) b = append(b, packetType) b = marshalUint32(b, id) b = marshalString(b, str) + return b, nil } +func unmarshalIDString(b []byte, id *uint32, str *string) error { + var err error + *id, b, err = unmarshalUint32Safe(b) + if err != nil { + return err + } + *str, _, err = unmarshalStringSafe(b) + return err +} + type sshFxpReaddirPacket struct { - Id uint32 + ID uint32 Handle string } -func (p sshFxpReaddirPacket) MarshalBinary() ([]byte, error) { - return marshalIdString(ssh_FXP_READDIR, p.Id, p.Handle) +func (p *sshFxpReaddirPacket) id() uint32 { return p.ID } + +func (p *sshFxpReaddirPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpReaddir, p.ID, p.Handle) +} + +func (p *sshFxpReaddirPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Handle) } type sshFxpOpendirPacket struct { - Id uint32 + ID uint32 Path string } -func (p sshFxpOpendirPacket) MarshalBinary() ([]byte, error) { - return marshalIdString(ssh_FXP_OPENDIR, p.Id, p.Path) +func (p *sshFxpOpendirPacket) id() uint32 { return p.ID } + +func (p *sshFxpOpendirPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpOpendir, p.ID, p.Path) +} + +func (p *sshFxpOpendirPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Path) } type sshFxpLstatPacket struct { - Id uint32 + ID uint32 Path string } -func (p sshFxpLstatPacket) MarshalBinary() ([]byte, error) { - return marshalIdString(ssh_FXP_LSTAT, p.Id, p.Path) +func (p *sshFxpLstatPacket) id() uint32 { return p.ID } + +func (p *sshFxpLstatPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpLstat, p.ID, p.Path) +} + +func (p *sshFxpLstatPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Path) +} + +type sshFxpStatPacket struct { + ID uint32 + Path string +} + +func (p *sshFxpStatPacket) id() uint32 { return p.ID } + +func (p *sshFxpStatPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpStat, p.ID, p.Path) +} + +func (p *sshFxpStatPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Path) } type sshFxpFstatPacket struct { - Id uint32 + ID uint32 Handle string } -func (p sshFxpFstatPacket) MarshalBinary() ([]byte, error) { - return marshalIdString(ssh_FXP_FSTAT, p.Id, p.Handle) +func (p *sshFxpFstatPacket) id() uint32 { return p.ID } + +func (p *sshFxpFstatPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpFstat, p.ID, p.Handle) +} + +func (p *sshFxpFstatPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Handle) } type sshFxpClosePacket struct { - Id uint32 + ID uint32 Handle string } -func (p sshFxpClosePacket) MarshalBinary() ([]byte, error) { - return marshalIdString(ssh_FXP_CLOSE, p.Id, p.Handle) +func (p *sshFxpClosePacket) id() uint32 { return p.ID } + +func (p *sshFxpClosePacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpClose, p.ID, p.Handle) +} + +func (p *sshFxpClosePacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Handle) } type sshFxpRemovePacket struct { - Id uint32 + ID uint32 Filename string } -func (p sshFxpRemovePacket) MarshalBinary() ([]byte, error) { - return marshalIdString(ssh_FXP_REMOVE, p.Id, p.Filename) +func (p *sshFxpRemovePacket) id() uint32 { return p.ID } + +func (p *sshFxpRemovePacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpRemove, p.ID, p.Filename) +} + +func (p *sshFxpRemovePacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Filename) } type sshFxpRmdirPacket struct { - Id uint32 + ID uint32 Path string } -func (p sshFxpRmdirPacket) MarshalBinary() ([]byte, error) { - return marshalIdString(ssh_FXP_RMDIR, p.Id, p.Path) +func (p *sshFxpRmdirPacket) id() uint32 { return p.ID } + +func (p *sshFxpRmdirPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpRmdir, p.ID, p.Path) +} + +func (p *sshFxpRmdirPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Path) +} + +type sshFxpSymlinkPacket struct { + ID uint32 + + // The order of the arguments to the SSH_FXP_SYMLINK method was inadvertently reversed. + // Unfortunately, the reversal was not noticed until the server was widely deployed. + // Covered in Section 4.1 of https://github.com/openssh/openssh-portable/blob/master/PROTOCOL + + Targetpath string + Linkpath string +} + +func (p *sshFxpSymlinkPacket) id() uint32 { return p.ID } + +func (p *sshFxpSymlinkPacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(p.Targetpath) + + 4 + len(p.Linkpath) + + b := make([]byte, 4, l) + b = append(b, sshFxpSymlink) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Targetpath) + b = marshalString(b, p.Linkpath) + + return b, nil +} + +func (p *sshFxpSymlinkPacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Targetpath, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Linkpath, _, err = unmarshalStringSafe(b); err != nil { + return err + } + return nil +} + +type sshFxpHardlinkPacket struct { + ID uint32 + Oldpath string + Newpath string +} + +func (p *sshFxpHardlinkPacket) id() uint32 { return p.ID } + +func (p *sshFxpHardlinkPacket) MarshalBinary() ([]byte, error) { + const ext = "hardlink@openssh.com" + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(ext) + + 4 + len(p.Oldpath) + + 4 + len(p.Newpath) + + b := make([]byte, 4, l) + b = append(b, sshFxpExtended) + b = marshalUint32(b, p.ID) + b = marshalString(b, ext) + b = marshalString(b, p.Oldpath) + b = marshalString(b, p.Newpath) + + return b, nil } type sshFxpReadlinkPacket struct { - Id uint32 + ID uint32 Path string } -func (p sshFxpReadlinkPacket) MarshalBinary() ([]byte, error) { - return marshalIdString(ssh_FXP_READLINK, p.Id, p.Path) +func (p *sshFxpReadlinkPacket) id() uint32 { return p.ID } + +func (p *sshFxpReadlinkPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpReadlink, p.ID, p.Path) +} + +func (p *sshFxpReadlinkPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Path) +} + +type sshFxpRealpathPacket struct { + ID uint32 + Path string +} + +func (p *sshFxpRealpathPacket) id() uint32 { return p.ID } + +func (p *sshFxpRealpathPacket) MarshalBinary() ([]byte, error) { + return marshalIDStringPacket(sshFxpRealpath, p.ID, p.Path) +} + +func (p *sshFxpRealpathPacket) UnmarshalBinary(b []byte) error { + return unmarshalIDString(b, &p.ID, &p.Path) +} + +type sshFxpNameAttr struct { + Name string + LongName string + Attrs []interface{} +} + +func (p *sshFxpNameAttr) MarshalBinary() ([]byte, error) { + var b []byte + b = marshalString(b, p.Name) + b = marshalString(b, p.LongName) + for _, attr := range p.Attrs { + b = marshal(b, attr) + } + return b, nil +} + +type sshFxpNamePacket struct { + ID uint32 + NameAttrs []*sshFxpNameAttr +} + +func (p *sshFxpNamePacket) marshalPacket() ([]byte, []byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + + b := make([]byte, 4, l) + b = append(b, sshFxpName) + b = marshalUint32(b, p.ID) + b = marshalUint32(b, uint32(len(p.NameAttrs))) + + var payload []byte + for _, na := range p.NameAttrs { + ab, err := na.MarshalBinary() + if err != nil { + return nil, nil, err + } + + payload = append(payload, ab...) + } + + return b, payload, nil +} + +func (p *sshFxpNamePacket) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err } type sshFxpOpenPacket struct { - Id uint32 + ID uint32 Path string Pflags uint32 - Flags uint32 // ignored + Flags uint32 + Attrs interface{} } -func (p sshFxpOpenPacket) MarshalBinary() ([]byte, error) { - l := 1 + 4 + +func (p *sshFxpOpenPacket) id() uint32 { return p.ID } + +func (p *sshFxpOpenPacket) marshalPacket() ([]byte, []byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 4 + len(p.Path) + 4 + 4 - b := make([]byte, 0, l) - b = append(b, ssh_FXP_OPEN) - b = marshalUint32(b, p.Id) + b := make([]byte, 4, l) + b = append(b, sshFxpOpen) + b = marshalUint32(b, p.ID) b = marshalString(b, p.Path) b = marshalUint32(b, p.Pflags) b = marshalUint32(b, p.Flags) - return b, nil + + switch attrs := p.Attrs.(type) { + case []byte: + return b, attrs, nil // may as well short-ciruit this case. + case os.FileInfo: + _, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet. + return b, marshalFileStat(nil, p.Flags, fs), nil + case *FileStat: + return b, marshalFileStat(nil, p.Flags, attrs), nil + } + + return b, marshal(nil, p.Attrs), nil +} + +func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err +} + +func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Path, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Pflags, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil { + return err + } + p.Attrs = b + return nil +} + +func (p *sshFxpOpenPacket) unmarshalFileStat(flags uint32) (*FileStat, error) { + switch attrs := p.Attrs.(type) { + case *FileStat: + return attrs, nil + case []byte: + fs, _, err := unmarshalFileStat(flags, attrs) + return fs, err + default: + return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs) + } } type sshFxpReadPacket struct { - Id uint32 - Handle string - Offset uint64 + ID uint32 Len uint32 + Offset uint64 + Handle string } -func (p sshFxpReadPacket) MarshalBinary() ([]byte, error) { - l := 1 + 4 + // type(byte) + uint32 +func (p *sshFxpReadPacket) id() uint32 { return p.ID } + +func (p *sshFxpReadPacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 4 + len(p.Handle) + 8 + 4 // uint64 + uint32 - b := make([]byte, 0, l) - b = append(b, ssh_FXP_READ) - b = marshalUint32(b, p.Id) + b := make([]byte, 4, l) + b = append(b, sshFxpRead) + b = marshalUint32(b, p.ID) b = marshalString(b, p.Handle) b = marshalUint64(b, p.Offset) b = marshalUint32(b, p.Len) + return b, nil } +func (p *sshFxpReadPacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Handle, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Offset, b, err = unmarshalUint64Safe(b); err != nil { + return err + } else if p.Len, _, err = unmarshalUint32Safe(b); err != nil { + return err + } + return nil +} + +// We need allocate bigger slices with extra capacity to avoid a re-allocation in sshFxpDataPacket.MarshalBinary +// So, we need: uint32(length) + byte(type) + uint32(id) + uint32(data_length) +const dataHeaderLen = 4 + 1 + 4 + 4 + +func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32, maxTxPacket uint32) []byte { + dataLen := p.Len + if dataLen > maxTxPacket { + dataLen = maxTxPacket + } + + if alloc != nil { + // GetPage returns a slice with capacity = maxMsgLength this is enough to avoid new allocations in + // sshFxpDataPacket.MarshalBinary + return alloc.GetPage(orderID)[:dataLen] + } + + // allocate with extra space for the header + return make([]byte, dataLen, dataLen+dataHeaderLen) +} + type sshFxpRenamePacket struct { - Id uint32 + ID uint32 Oldpath string Newpath string } -func (p sshFxpRenamePacket) MarshalBinary() ([]byte, error) { - l := 1 + 4 + // type(byte) + uint32 +func (p *sshFxpRenamePacket) id() uint32 { return p.ID } + +func (p *sshFxpRenamePacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 4 + len(p.Oldpath) + 4 + len(p.Newpath) - b := make([]byte, 0, l) - b = append(b, ssh_FXP_RENAME) - b = marshalUint32(b, p.Id) + b := make([]byte, 4, l) + b = append(b, sshFxpRename) + b = marshalUint32(b, p.ID) b = marshalString(b, p.Oldpath) b = marshalString(b, p.Newpath) + + return b, nil +} + +func (p *sshFxpRenamePacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Oldpath, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Newpath, _, err = unmarshalStringSafe(b); err != nil { + return err + } + return nil +} + +type sshFxpPosixRenamePacket struct { + ID uint32 + Oldpath string + Newpath string +} + +func (p *sshFxpPosixRenamePacket) id() uint32 { return p.ID } + +func (p *sshFxpPosixRenamePacket) MarshalBinary() ([]byte, error) { + const ext = "posix-rename@openssh.com" + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(ext) + + 4 + len(p.Oldpath) + + 4 + len(p.Newpath) + + b := make([]byte, 4, l) + b = append(b, sshFxpExtended) + b = marshalUint32(b, p.ID) + b = marshalString(b, ext) + b = marshalString(b, p.Oldpath) + b = marshalString(b, p.Newpath) + return b, nil } type sshFxpWritePacket struct { - Id uint32 - Handle string - Offset uint64 + ID uint32 Length uint32 + Offset uint64 + Handle string Data []byte } -func (s sshFxpWritePacket) MarshalBinary() ([]byte, error) { - l := 1 + 4 + // type(byte) + uint32 - 4 + len(s.Handle) + - 8 + 4 + // uint64 + uint32 - len(s.Data) +func (p *sshFxpWritePacket) id() uint32 { return p.ID } - b := make([]byte, 0, l) - b = append(b, ssh_FXP_WRITE) - b = marshalUint32(b, s.Id) - b = marshalString(b, s.Handle) - b = marshalUint64(b, s.Offset) - b = marshalUint32(b, s.Length) - b = append(b, s.Data...) - return b, nil +func (p *sshFxpWritePacket) marshalPacket() ([]byte, []byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(p.Handle) + + 8 + // uint64 + 4 + + b := make([]byte, 4, l) + b = append(b, sshFxpWrite) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Handle) + b = marshalUint64(b, p.Offset) + b = marshalUint32(b, p.Length) + + return b, p.Data, nil +} + +func (p *sshFxpWritePacket) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err +} + +func (p *sshFxpWritePacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Handle, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Offset, b, err = unmarshalUint64Safe(b); err != nil { + return err + } else if p.Length, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if uint32(len(b)) < p.Length { + return errShortPacket + } + + p.Data = b[:p.Length] + return nil } type sshFxpMkdirPacket struct { - Id uint32 - Path string + ID uint32 Flags uint32 // ignored + Path string } -func (p sshFxpMkdirPacket) MarshalBinary() ([]byte, error) { - l := 1 + 4 + // type(byte) + uint32 +func (p *sshFxpMkdirPacket) id() uint32 { return p.ID } + +func (p *sshFxpMkdirPacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 4 + len(p.Path) + 4 // uint32 - b := make([]byte, 0, l) - b = append(b, ssh_FXP_MKDIR) - b = marshalUint32(b, p.Id) + b := make([]byte, 4, l) + b = append(b, sshFxpMkdir) + b = marshalUint32(b, p.ID) b = marshalString(b, p.Path) b = marshalUint32(b, p.Flags) + return b, nil } +func (p *sshFxpMkdirPacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Path, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Flags, _, err = unmarshalUint32Safe(b); err != nil { + return err + } + return nil +} + type sshFxpSetstatPacket struct { - Id uint32 - Path string + ID uint32 Flags uint32 + Path string Attrs interface{} } -func (p sshFxpSetstatPacket) MarshalBinary() ([]byte, error) { - l := 1 + 4 + // type(byte) + uint32 +type sshFxpFsetstatPacket struct { + ID uint32 + Flags uint32 + Handle string + Attrs interface{} +} + +func (p *sshFxpSetstatPacket) id() uint32 { return p.ID } +func (p *sshFxpFsetstatPacket) id() uint32 { return p.ID } + +func (p *sshFxpSetstatPacket) marshalPacket() ([]byte, []byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) 4 + len(p.Path) + - 4 // uint32 + uint64 + 4 // uint32 - b := make([]byte, 0, l) - b = append(b, ssh_FXP_SETSTAT) - b = marshalUint32(b, p.Id) + b := make([]byte, 4, l) + b = append(b, sshFxpSetstat) + b = marshalUint32(b, p.ID) b = marshalString(b, p.Path) b = marshalUint32(b, p.Flags) - b = marshal(b, p.Attrs) + + switch attrs := p.Attrs.(type) { + case []byte: + return b, attrs, nil // may as well short-ciruit this case. + case os.FileInfo: + _, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet. + return b, marshalFileStat(nil, p.Flags, fs), nil + case *FileStat: + return b, marshalFileStat(nil, p.Flags, attrs), nil + } + + return b, marshal(nil, p.Attrs), nil +} + +func (p *sshFxpSetstatPacket) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err +} + +func (p *sshFxpFsetstatPacket) marshalPacket() ([]byte, []byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(p.Handle) + + 4 // uint32 + + b := make([]byte, 4, l) + b = append(b, sshFxpFsetstat) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Handle) + b = marshalUint32(b, p.Flags) + + switch attrs := p.Attrs.(type) { + case []byte: + return b, attrs, nil // may as well short-ciruit this case. + case os.FileInfo: + _, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet. + return b, marshalFileStat(nil, p.Flags, fs), nil + case *FileStat: + return b, marshalFileStat(nil, p.Flags, attrs), nil + } + + return b, marshal(nil, p.Attrs), nil +} + +func (p *sshFxpFsetstatPacket) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err +} + +func (p *sshFxpSetstatPacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Path, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil { + return err + } + p.Attrs = b + return nil +} + +func (p *sshFxpSetstatPacket) unmarshalFileStat(flags uint32) (*FileStat, error) { + switch attrs := p.Attrs.(type) { + case *FileStat: + return attrs, nil + case []byte: + fs, _, err := unmarshalFileStat(flags, attrs) + return fs, err + default: + return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs) + } +} + +func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Handle, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil { + return err + } + p.Attrs = b + return nil +} + +func (p *sshFxpFsetstatPacket) unmarshalFileStat(flags uint32) (*FileStat, error) { + switch attrs := p.Attrs.(type) { + case *FileStat: + return attrs, nil + case []byte: + fs, _, err := unmarshalFileStat(flags, attrs) + return fs, err + default: + return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs) + } +} + +type sshFxpHandlePacket struct { + ID uint32 + Handle string +} + +func (p *sshFxpHandlePacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(p.Handle) + + b := make([]byte, 4, l) + b = append(b, sshFxpHandle) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Handle) + + return b, nil +} + +type sshFxpStatusPacket struct { + ID uint32 + StatusError +} + +func (p *sshFxpStatusPacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + + 4 + len(p.StatusError.msg) + + 4 + len(p.StatusError.lang) + + b := make([]byte, 4, l) + b = append(b, sshFxpStatus) + b = marshalUint32(b, p.ID) + b = marshalStatus(b, p.StatusError) + + return b, nil +} + +type sshFxpDataPacket struct { + ID uint32 + Length uint32 + Data []byte +} + +func (p *sshFxpDataPacket) marshalPacket() ([]byte, []byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + + b := make([]byte, 4, l) + b = append(b, sshFxpData) + b = marshalUint32(b, p.ID) + b = marshalUint32(b, p.Length) + + return b, p.Data, nil +} + +// MarshalBinary encodes the receiver into a binary form and returns the result. +// To avoid a new allocation the Data slice must have a capacity >= Length + 9 +// +// This is hand-coded rather than just append(header, payload...), +// in order to try and reuse the r.Data backing store in the packet. +func (p *sshFxpDataPacket) MarshalBinary() ([]byte, error) { + b := append(p.Data, make([]byte, dataHeaderLen)...) + copy(b[dataHeaderLen:], p.Data[:p.Length]) + // b[0:4] will be overwritten with the length in sendPacket + b[4] = sshFxpData + binary.BigEndian.PutUint32(b[5:9], p.ID) + binary.BigEndian.PutUint32(b[9:13], p.Length) + return b, nil +} + +func (p *sshFxpDataPacket) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.Length, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if uint32(len(b)) < p.Length { + return errShortPacket + } + + p.Data = b[:p.Length] + return nil +} + +type sshFxpStatvfsPacket struct { + ID uint32 + Path string +} + +func (p *sshFxpStatvfsPacket) id() uint32 { return p.ID } + +func (p *sshFxpStatvfsPacket) MarshalBinary() ([]byte, error) { + const ext = "statvfs@openssh.com" + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(ext) + + 4 + len(p.Path) + + b := make([]byte, 4, l) + b = append(b, sshFxpExtended) + b = marshalUint32(b, p.ID) + b = marshalString(b, ext) + b = marshalString(b, p.Path) + + return b, nil +} + +// A StatVFS contains statistics about a filesystem. +type StatVFS struct { + ID uint32 + Bsize uint64 /* file system block size */ + Frsize uint64 /* fundamental fs block size */ + Blocks uint64 /* number of blocks (unit f_frsize) */ + Bfree uint64 /* free blocks in file system */ + Bavail uint64 /* free blocks for non-root */ + Files uint64 /* total file inodes */ + Ffree uint64 /* free file inodes */ + Favail uint64 /* free file inodes for to non-root */ + Fsid uint64 /* file system id */ + Flag uint64 /* bit mask of f_flag values */ + Namemax uint64 /* maximum filename length */ +} + +// TotalSpace calculates the amount of total space in a filesystem. +func (p *StatVFS) TotalSpace() uint64 { + return p.Frsize * p.Blocks +} + +// FreeSpace calculates the amount of free space in a filesystem. +func (p *StatVFS) FreeSpace() uint64 { + return p.Frsize * p.Bfree +} + +// marshalPacket converts to ssh_FXP_EXTENDED_REPLY packet binary format +func (p *StatVFS) marshalPacket() ([]byte, []byte, error) { + header := []byte{0, 0, 0, 0, sshFxpExtendedReply} + + var buf bytes.Buffer + err := binary.Write(&buf, binary.BigEndian, p) + + return header, buf.Bytes(), err +} + +// MarshalBinary encodes the StatVFS as an SSH_FXP_EXTENDED_REPLY packet. +func (p *StatVFS) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err +} + +type sshFxpFsyncPacket struct { + ID uint32 + Handle string +} + +func (p *sshFxpFsyncPacket) id() uint32 { return p.ID } + +func (p *sshFxpFsyncPacket) MarshalBinary() ([]byte, error) { + const ext = "fsync@openssh.com" + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(ext) + + 4 + len(p.Handle) + + b := make([]byte, 4, l) + b = append(b, sshFxpExtended) + b = marshalUint32(b, p.ID) + b = marshalString(b, ext) + b = marshalString(b, p.Handle) + return b, nil } + +type sshFxpExtendedPacket struct { + ID uint32 + ExtendedRequest string + SpecificPacket interface { + serverRespondablePacket + readonly() bool + } +} + +func (p *sshFxpExtendedPacket) id() uint32 { return p.ID } +func (p *sshFxpExtendedPacket) readonly() bool { + if p.SpecificPacket == nil { + return true + } + return p.SpecificPacket.readonly() +} + +func (p *sshFxpExtendedPacket) respond(svr *Server) responsePacket { + if p.SpecificPacket == nil { + return statusFromError(p.ID, nil) + } + return p.SpecificPacket.respond(svr) +} + +func (p *sshFxpExtendedPacket) UnmarshalBinary(b []byte) error { + var err error + bOrig := b + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.ExtendedRequest, _, err = unmarshalStringSafe(b); err != nil { + return err + } + + // specific unmarshalling + switch p.ExtendedRequest { + case "statvfs@openssh.com": + p.SpecificPacket = &sshFxpExtendedPacketStatVFS{} + case "posix-rename@openssh.com": + p.SpecificPacket = &sshFxpExtendedPacketPosixRename{} + case "hardlink@openssh.com": + p.SpecificPacket = &sshFxpExtendedPacketHardlink{} + default: + return fmt.Errorf("packet type %v: %w", p.SpecificPacket, errUnknownExtendedPacket) + } + + return p.SpecificPacket.UnmarshalBinary(bOrig) +} + +type sshFxpExtendedPacketStatVFS struct { + ID uint32 + ExtendedRequest string + Path string +} + +func (p *sshFxpExtendedPacketStatVFS) id() uint32 { return p.ID } +func (p *sshFxpExtendedPacketStatVFS) readonly() bool { return true } +func (p *sshFxpExtendedPacketStatVFS) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.ExtendedRequest, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Path, _, err = unmarshalStringSafe(b); err != nil { + return err + } + return nil +} + +type sshFxpExtendedPacketPosixRename struct { + ID uint32 + ExtendedRequest string + Oldpath string + Newpath string +} + +func (p *sshFxpExtendedPacketPosixRename) id() uint32 { return p.ID } +func (p *sshFxpExtendedPacketPosixRename) readonly() bool { return false } +func (p *sshFxpExtendedPacketPosixRename) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.ExtendedRequest, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Oldpath, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Newpath, _, err = unmarshalStringSafe(b); err != nil { + return err + } + return nil +} + +func (p *sshFxpExtendedPacketPosixRename) respond(s *Server) responsePacket { + err := os.Rename(s.toLocalPath(p.Oldpath), s.toLocalPath(p.Newpath)) + return statusFromError(p.ID, err) +} + +type sshFxpExtendedPacketHardlink struct { + ID uint32 + ExtendedRequest string + Oldpath string + Newpath string +} + +// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL +func (p *sshFxpExtendedPacketHardlink) id() uint32 { return p.ID } +func (p *sshFxpExtendedPacketHardlink) readonly() bool { return true } +func (p *sshFxpExtendedPacketHardlink) UnmarshalBinary(b []byte) error { + var err error + if p.ID, b, err = unmarshalUint32Safe(b); err != nil { + return err + } else if p.ExtendedRequest, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Oldpath, b, err = unmarshalStringSafe(b); err != nil { + return err + } else if p.Newpath, _, err = unmarshalStringSafe(b); err != nil { + return err + } + return nil +} + +func (p *sshFxpExtendedPacketHardlink) respond(s *Server) responsePacket { + err := os.Link(s.toLocalPath(p.Oldpath), s.toLocalPath(p.Newpath)) + return statusFromError(p.ID, err) +} diff --git a/packet_test.go b/packet_test.go index 80a1ebf0..59a1b21c 100644 --- a/packet_test.go +++ b/packet_test.go @@ -3,120 +3,175 @@ package sftp import ( "bytes" "encoding" + "errors" + "io/ioutil" "os" + "reflect" "testing" ) -var marshalUint32Tests = []struct { - v uint32 - want []byte -}{ - {1, []byte{0, 0, 0, 1}}, - {256, []byte{0, 0, 1, 0}}, - {^uint32(0), []byte{255, 255, 255, 255}}, -} - func TestMarshalUint32(t *testing.T) { - for _, tt := range marshalUint32Tests { + var tests = []struct { + v uint32 + want []byte + }{ + {0, []byte{0, 0, 0, 0}}, + {42, []byte{0, 0, 0, 42}}, + {42 << 8, []byte{0, 0, 42, 0}}, + {42 << 16, []byte{0, 42, 0, 0}}, + {42 << 24, []byte{42, 0, 0, 0}}, + {^uint32(0), []byte{255, 255, 255, 255}}, + } + + for _, tt := range tests { got := marshalUint32(nil, tt.v) if !bytes.Equal(tt.want, got) { - t.Errorf("marshalUint32(%d): want %v, got %v", tt.v, tt.want, got) + t.Errorf("marshalUint32(%d) = %#v, want %#v", tt.v, got, tt.want) } } } -var marshalUint64Tests = []struct { - v uint64 - want []byte -}{ - {1, []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}}, - {256, []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0}}, - {^uint64(0), []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, - {1 << 32, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0}}, -} - func TestMarshalUint64(t *testing.T) { - for _, tt := range marshalUint64Tests { + var tests = []struct { + v uint64 + want []byte + }{ + {0, []byte{0, 0, 0, 0, 0, 0, 0, 0}}, + {42, []byte{0, 0, 0, 0, 0, 0, 0, 42}}, + {42 << 8, []byte{0, 0, 0, 0, 0, 0, 42, 0}}, + {42 << 16, []byte{0, 0, 0, 0, 0, 42, 0, 0}}, + {42 << 24, []byte{0, 0, 0, 0, 42, 0, 0, 0}}, + {42 << 32, []byte{0, 0, 0, 42, 0, 0, 0, 0}}, + {42 << 40, []byte{0, 0, 42, 0, 0, 0, 0, 0}}, + {42 << 48, []byte{0, 42, 0, 0, 0, 0, 0, 0}}, + {42 << 56, []byte{42, 0, 0, 0, 0, 0, 0, 0}}, + {^uint64(0), []byte{255, 255, 255, 255, 255, 255, 255, 255}}, + } + + for _, tt := range tests { got := marshalUint64(nil, tt.v) if !bytes.Equal(tt.want, got) { - t.Errorf("marshalUint64(%d): want %#v, got %#v", tt.v, tt.want, got) + t.Errorf("marshalUint64(%d) = %#v, want %#v", tt.v, got, tt.want) } } } -var marshalStringTests = []struct { - v string - want []byte -}{ - {"", []byte{0, 0, 0, 0}}, - {"/foo", []byte{0x0, 0x0, 0x0, 0x4, 0x2f, 0x66, 0x6f, 0x6f}}, -} - func TestMarshalString(t *testing.T) { - for _, tt := range marshalStringTests { + var tests = []struct { + v string + want []byte + }{ + {"", []byte{0, 0, 0, 0}}, + {"/", []byte{0x0, 0x0, 0x0, 0x01, '/'}}, + {"/foo", []byte{0x0, 0x0, 0x0, 0x4, '/', 'f', 'o', 'o'}}, + {"\x00bar", []byte{0x0, 0x0, 0x0, 0x4, 0, 'b', 'a', 'r'}}, + {"b\x00ar", []byte{0x0, 0x0, 0x0, 0x4, 'b', 0, 'a', 'r'}}, + {"ba\x00r", []byte{0x0, 0x0, 0x0, 0x4, 'b', 'a', 0, 'r'}}, + {"bar\x00", []byte{0x0, 0x0, 0x0, 0x4, 'b', 'a', 'r', 0}}, + } + + for _, tt := range tests { got := marshalString(nil, tt.v) if !bytes.Equal(tt.want, got) { - t.Errorf("marshalString(%q): want %#v, got %#v", tt.v, tt.want, got) + t.Errorf("marshalString(%q) = %#v, want %#v", tt.v, got, tt.want) } } } -var marshalTests = []struct { - v interface{} - want []byte -}{ - {uint8(1), []byte{1}}, - {byte(1), []byte{1}}, - {uint32(1), []byte{0, 0, 0, 1}}, - {uint64(1), []byte{0, 0, 0, 0, 0, 0, 0, 1}}, - {"foo", []byte{0x0, 0x0, 0x0, 0x3, 0x66, 0x6f, 0x6f}}, - {[]uint32{1, 2, 3, 4}, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x4}}, -} - func TestMarshal(t *testing.T) { - for _, tt := range marshalTests { + type Struct struct { + X, Y, Z uint32 + } + + var tests = []struct { + v interface{} + want []byte + }{ + {uint8(42), []byte{42}}, + {uint32(42 << 8), []byte{0, 0, 42, 0}}, + {uint64(42 << 32), []byte{0, 0, 0, 42, 0, 0, 0, 0}}, + {"foo", []byte{0x0, 0x0, 0x0, 0x3, 'f', 'o', 'o'}}, + {Struct{1, 2, 3}, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3}}, + {[]uint32{1, 2, 3}, []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3}}, + } + + for _, tt := range tests { got := marshal(nil, tt.v) if !bytes.Equal(tt.want, got) { - t.Errorf("marshal(%v): want %#v, got %#v", tt.v, tt.want, got) + t.Errorf("marshal(%#v) = %#v, want %#v", tt.v, got, tt.want) } } } -var unmarshalUint32Tests = []struct { - b []byte - want uint32 - rest []byte -}{ - {[]byte{0, 0, 0, 0}, 0, nil}, - {[]byte{0, 0, 1, 0}, 256, nil}, - {[]byte{255, 0, 0, 255}, 4278190335, nil}, -} - func TestUnmarshalUint32(t *testing.T) { - for _, tt := range unmarshalUint32Tests { - got, rest := unmarshalUint32(tt.b) - if got != tt.want || !bytes.Equal(rest, tt.rest) { - t.Errorf("unmarshalUint32(%v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest) - } + testBuffer := []byte{ + 0, 0, 0, 0, + 0, 0, 0, 42, + 0, 0, 42, 0, + 0, 42, 0, 0, + 42, 0, 0, 0, + 255, 0, 0, 254, } -} -var unmarshalUint64Tests = []struct { - b []byte - want uint64 - rest []byte -}{ - {[]byte{0, 0, 0, 0, 0, 0, 0, 0}, 0, nil}, - {[]byte{0, 0, 0, 0, 0, 0, 1, 0}, 256, nil}, - {[]byte{255, 0, 0, 0, 0, 0, 0, 255}, 18374686479671623935, nil}, + var wants = []uint32{ + 0, + 42, + 42 << 8, + 42 << 16, + 42 << 24, + 255<<24 | 254, + } + + var i int + for len(testBuffer) > 0 { + got, rest := unmarshalUint32(testBuffer) + + if got != wants[i] { + t.Fatalf("unmarshalUint32(%#v) = %d, want %d", testBuffer[:4], got, wants[i]) + } + + i++ + testBuffer = rest + } } func TestUnmarshalUint64(t *testing.T) { - for _, tt := range unmarshalUint64Tests { - got, rest := unmarshalUint64(tt.b) - if got != tt.want || !bytes.Equal(rest, tt.rest) { - t.Errorf("unmarshalUint64(%v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest) + testBuffer := []byte{ + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 42, + 0, 0, 0, 0, 0, 0, 42, 0, + 0, 0, 0, 0, 0, 42, 0, 0, + 0, 0, 0, 0, 42, 0, 0, 0, + 0, 0, 0, 42, 0, 0, 0, 0, + 0, 0, 42, 0, 0, 0, 0, 0, + 0, 42, 0, 0, 0, 0, 0, 0, + 42, 0, 0, 0, 0, 0, 0, 0, + 255, 0, 0, 0, 0, 0, 0, 254, + } + + var wants = []uint64{ + 0, + 42, + 42 << 8, + 42 << 16, + 42 << 24, + 42 << 32, + 42 << 40, + 42 << 48, + 42 << 56, + 255<<56 | 254, + } + + var i int + for len(testBuffer) > 0 { + got, rest := unmarshalUint64(testBuffer) + + if got != wants[i] { + t.Fatalf("unmarshalUint64(%#v) = %d, want %d", testBuffer[:8], got, wants[i]) } + + i++ + testBuffer = rest } } @@ -130,132 +185,476 @@ var unmarshalStringTests = []struct { } func TestUnmarshalString(t *testing.T) { - for _, tt := range unmarshalStringTests { - got, rest := unmarshalString(tt.b) - if got != tt.want || !bytes.Equal(rest, tt.rest) { - t.Errorf("unmarshalUint64(%v): want %q, %#v, got %q, %#v", tt.b, tt.want, tt.rest, got, rest) + testBuffer := []byte{ + 0, 0, 0, 0, + 0, 0, 0, 1, '/', + 0, 0, 0, 4, '/', 'f', 'o', 'o', + 0, 0, 0, 4, 0, 'b', 'a', 'r', + 0, 0, 0, 4, 'b', 0, 'a', 'r', + 0, 0, 0, 4, 'b', 'a', 0, 'r', + 0, 0, 0, 4, 'b', 'a', 'r', 0, + } + + var wants = []string{ + "", + "/", + "/foo", + "\x00bar", + "b\x00ar", + "ba\x00r", + "bar\x00", + } + + var i int + for len(testBuffer) > 0 { + got, rest := unmarshalString(testBuffer) + + if got != wants[i] { + t.Fatalf("unmarshalUint64(%#v...) = %q, want %q", testBuffer[:4], got, wants[i]) } + + i++ + testBuffer = rest } } -var sendPacketTests = []struct { - p encoding.BinaryMarshaler - want []byte -}{ - {sshFxInitPacket{ - Version: 3, - Extensions: []struct{ Name, Data string }{ - {"posix-rename@openssh.com", "1"}, +func TestUnmarshalAttrs(t *testing.T) { + var tests = []struct { + b []byte + want *FileStat + }{ + { + b: []byte{0x00, 0x00, 0x00, 0x00}, + want: &FileStat{}, }, - }, []byte{0x0, 0x0, 0x0, 0x26, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x18, 0x70, 0x6f, 0x73, 0x69, 0x78, 0x2d, 0x72, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x40, 0x6f, 0x70, 0x65, 0x6e, 0x73, 0x73, 0x68, 0x2e, 0x63, 0x6f, 0x6d, 0x0, 0x0, 0x0, 0x1, 0x31}}, - - {sshFxpOpenPacket{ - Id: 1, - Path: "/foo", - Pflags: flags(os.O_RDONLY), - }, []byte{0x0, 0x0, 0x0, 0x15, 0x3, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x4, 0x2f, 0x66, 0x6f, 0x6f, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0}}, - - {sshFxpWritePacket{ - Id: 124, - Handle: "foo", - Offset: 13, - Length: uint32(len([]byte("bar"))), - Data: []byte("bar"), - }, []byte{0x0, 0x0, 0x0, 0x1b, 0x6, 0x0, 0x0, 0x0, 0x7c, 0x0, 0x0, 0x0, 0x3, 0x66, 0x6f, 0x6f, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xd, 0x0, 0x0, 0x0, 0x3, 0x62, 0x61, 0x72}}, - - {sshFxpSetstatPacket{ - Id: 31, - Path: "/bar", - Flags: flags(os.O_WRONLY), - Attrs: struct { - Uid uint32 - Gid uint32 - }{1000, 100}, - }, []byte{0x0, 0x0, 0x0, 0x19, 0x9, 0x0, 0x0, 0x0, 0x1f, 0x0, 0x0, 0x0, 0x4, 0x2f, 0x62, 0x61, 0x72, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x3, 0xe8, 0x0, 0x0, 0x0, 0x64}}, -} + { + b: []byte{ + 0x00, 0x00, 0x00, byte(sshFileXferAttrSize), + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20, + }, + want: &FileStat{ + Size: 20, + }, + }, + { + b: []byte{ + 0x00, 0x00, 0x00, byte(sshFileXferAttrSize | sshFileXferAttrPermissions), + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20, + 0x00, 0x00, 0x01, 0xA4, + }, + want: &FileStat{ + Size: 20, + Mode: 0644, + }, + }, + { + b: []byte{ + 0x00, 0x00, 0x00, byte(sshFileXferAttrSize | sshFileXferAttrPermissions | sshFileXferAttrUIDGID), + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20, + 0x00, 0x00, 0x03, 0xE8, + 0x00, 0x00, 0x03, 0xE9, + 0x00, 0x00, 0x01, 0xA4, + }, + want: &FileStat{ + Size: 20, + Mode: 0644, + UID: 1000, + GID: 1001, + }, + }, + { + b: []byte{ + 0x00, 0x00, 0x00, byte(sshFileXferAttrSize | sshFileXferAttrPermissions | sshFileXferAttrUIDGID | sshFileXferAttrACmodTime), + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20, + 0x00, 0x00, 0x03, 0xE8, + 0x00, 0x00, 0x03, 0xE9, + 0x00, 0x00, 0x01, 0xA4, + 0x00, 0x00, 0x00, 42, + 0x00, 0x00, 0x00, 13, + }, + want: &FileStat{ + Size: 20, + Mode: 0644, + UID: 1000, + GID: 1001, + Atime: 42, + Mtime: 13, + }, + }, + } -func TestSendPacket(t *testing.T) { - for _, tt := range sendPacketTests { - var w bytes.Buffer - sendPacket(&w, tt.p) - if got := w.Bytes(); !bytes.Equal(tt.want, got) { - t.Errorf("sendPacket(%v): want %#v, got %#v", tt.p, tt.want, got) + for _, tt := range tests { + got, _, err := unmarshalAttrs(tt.b) + if err != nil { + t.Fatal("unexpected error:", err) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("unmarshalAttrs(% X):\n- got: %#v\n- want: %#v", tt.b, got, tt.want) } } } -func sp(p encoding.BinaryMarshaler) []byte { - var w bytes.Buffer - sendPacket(&w, p) - return w.Bytes() +func TestUnmarshalStatus(t *testing.T) { + var requestID uint32 = 1 + + id := marshalUint32(nil, requestID) + idCode := marshalUint32(id, sshFxFailure) + idCodeMsg := marshalString(idCode, "err msg") + idCodeMsgLang := marshalString(idCodeMsg, "lang tag") + + var tests = []struct { + desc string + reqID uint32 + status []byte + want error + }{ + { + desc: "well-formed status", + status: idCodeMsgLang, + want: &StatusError{ + Code: sshFxFailure, + msg: "err msg", + lang: "lang tag", + }, + }, + { + desc: "missing language tag", + status: idCodeMsg, + want: &StatusError{ + Code: sshFxFailure, + msg: "err msg", + }, + }, + { + desc: "missing error message and language tag", + status: idCode, + want: &StatusError{ + Code: sshFxFailure, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + got := unmarshalStatus(1, tt.status) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("unmarshalStatus(1, % X):\n- got: %#v\n- want: %#v", tt.status, got, tt.want) + } + }) + } + + got := unmarshalStatus(2, idCodeMsgLang) + want := &unexpectedIDErr{ + want: 2, + got: 1, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("unmarshalStatus(2, % X):\n- got: %#v\n- want: %#v", idCodeMsgLang, got, want) + } } -var recvPacketTests = []struct { - b []byte - want uint8 - rest []byte -}{ - {sp(sshFxInitPacket{ - Version: 3, - Extensions: []struct{ Name, Data string }{ - {"posix-rename@openssh.com", "1"}, +func TestSendPacket(t *testing.T) { + var tests = []struct { + packet encoding.BinaryMarshaler + want []byte + }{ + { + packet: &sshFxInitPacket{ + Version: 3, + Extensions: []extensionPair{ + {"posix-rename@openssh.com", "1"}, + }, + }, + want: []byte{ + 0x0, 0x0, 0x0, 0x26, + 0x1, + 0x0, 0x0, 0x0, 0x3, + 0x0, 0x0, 0x0, 0x18, + 'p', 'o', 's', 'i', 'x', '-', 'r', 'e', 'n', 'a', 'm', 'e', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm', + 0x0, 0x0, 0x0, 0x1, + '1', + }, + }, + { + packet: &sshFxpOpenPacket{ + ID: 1, + Path: "/foo", + Pflags: toPflags(os.O_RDONLY), + }, + want: []byte{ + 0x0, 0x0, 0x0, 0x15, + 0x3, + 0x0, 0x0, 0x0, 0x1, + 0x0, 0x0, 0x0, 0x4, '/', 'f', 'o', 'o', + 0x0, 0x0, 0x0, 0x1, + 0x0, 0x0, 0x0, 0x0, + }, + }, + { + packet: &sshFxpOpenPacket{ + ID: 3, + Path: "/foo", + Pflags: toPflags(os.O_WRONLY | os.O_CREATE | os.O_TRUNC), + Flags: sshFileXferAttrPermissions, + Attrs: &FileStat{ + Mode: 0o755, + }, + }, + want: []byte{ + 0x0, 0x0, 0x0, 0x19, + 0x3, + 0x0, 0x0, 0x0, 0x3, + 0x0, 0x0, 0x0, 0x4, '/', 'f', 'o', 'o', + 0x0, 0x0, 0x0, 0x1a, + 0x0, 0x0, 0x0, 0x4, + 0x0, 0x0, 0x1, 0xed, + }, + }, + { + packet: &sshFxpWritePacket{ + ID: 124, + Handle: "foo", + Offset: 13, + Length: uint32(len("bar")), + Data: []byte("bar"), + }, + want: []byte{ + 0x0, 0x0, 0x0, 0x1b, + 0x6, + 0x0, 0x0, 0x0, 0x7c, + 0x0, 0x0, 0x0, 0x3, 'f', 'o', 'o', + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xd, + 0x0, 0x0, 0x0, 0x3, 'b', 'a', 'r', + }, }, - }), ssh_FXP_INIT, []byte{0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x18, 0x70, 0x6f, 0x73, 0x69, 0x78, 0x2d, 0x72, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x40, 0x6f, 0x70, 0x65, 0x6e, 0x73, 0x73, 0x68, 0x2e, 0x63, 0x6f, 0x6d, 0x0, 0x0, 0x0, 0x1, 0x31}}, + { + packet: &sshFxpSetstatPacket{ + ID: 31, + Path: "/bar", + Flags: sshFileXferAttrUIDGID, + Attrs: &FileStat{ + UID: 1000, + GID: 100, + }, + }, + want: []byte{ + 0x0, 0x0, 0x0, 0x19, + 0x9, + 0x0, 0x0, 0x0, 0x1f, + 0x0, 0x0, 0x0, 0x4, '/', 'b', 'a', 'r', + 0x0, 0x0, 0x0, 0x2, + 0x0, 0x0, 0x3, 0xe8, + 0x0, 0x0, 0x0, 0x64, + }, + }, + } + + for _, tt := range tests { + b := new(bytes.Buffer) + sendPacket(b, tt.packet) + if got := b.Bytes(); !bytes.Equal(tt.want, got) { + t.Errorf("sendPacket(%v): got %x want %x", tt.packet, tt.want, got) + } + } +} + +func sp(data encoding.BinaryMarshaler) []byte { + b := new(bytes.Buffer) + sendPacket(b, data) + return b.Bytes() } func TestRecvPacket(t *testing.T) { + var recvPacketTests = []struct { + b []byte + + want fxp + body []byte + wantErr error + }{ + { + b: sp(&sshFxInitPacket{ + Version: 3, + Extensions: []extensionPair{ + {"posix-rename@openssh.com", "1"}, + }, + }), + want: sshFxpInit, + body: []byte{ + 0x0, 0x0, 0x0, 0x3, + 0x0, 0x0, 0x0, 0x18, + 'p', 'o', 's', 'i', 'x', '-', 'r', 'e', 'n', 'a', 'm', 'e', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm', + 0x0, 0x0, 0x0, 0x01, + '1', + }, + }, + { + b: []byte{ + 0x0, 0x0, 0x0, 0x0, + }, + wantErr: errShortPacket, + }, + { + b: []byte{ + 0xff, 0xff, 0xff, 0xff, + }, + wantErr: errLongPacket, + }, + } + for _, tt := range recvPacketTests { r := bytes.NewReader(tt.b) - got, rest, _ := recvPacket(r) - if got != tt.want || !bytes.Equal(rest, tt.rest) { - t.Errorf("recvPacket(%#v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest) + + got, body, err := recvPacket(r, nil, 0) + if tt.wantErr == nil { + if err != nil { + t.Fatalf("recvPacket(%#v): unexpected error: %v", tt.b, err) + } + } else { + if !errors.Is(err, tt.wantErr) { + t.Fatalf("recvPacket(%#v) = %v, want %v", tt.b, err, tt.wantErr) + } + } + + if got != tt.want { + t.Errorf("recvPacket(%#v) = %#v, want %#v", tt.b, got, tt.want) + } + + if !bytes.Equal(body, tt.body) { + t.Errorf("recvPacket(%#v) = %#v, want %#v", tt.b, body, tt.body) } } } -func BenchmarkMarshalInit(b *testing.B) { - for i := 0; i < b.N; i++ { - sp(sshFxInitPacket{ - Version: 3, - Extensions: []struct{ Name, Data string }{ - {"posix-rename@openssh.com", "1"}, - }, - }) +func TestSSHFxpOpenPacketreadonly(t *testing.T) { + var tests = []struct { + pflags uint32 + ok bool + }{ + { + pflags: sshFxfRead, + ok: true, + }, + { + pflags: sshFxfWrite, + ok: false, + }, + { + pflags: sshFxfRead | sshFxfWrite, + ok: false, + }, + } + + for _, tt := range tests { + p := &sshFxpOpenPacket{ + Pflags: tt.pflags, + } + + if want, got := tt.ok, p.readonly(); want != got { + t.Errorf("unexpected value for p.readonly(): want: %v, got: %v", + want, got) + } } } -func BenchmarkMarshalOpen(b *testing.B) { +func TestSSHFxpOpenPackethasPflags(t *testing.T) { + var tests = []struct { + desc string + haveFlags uint32 + testFlags []uint32 + ok bool + }{ + { + desc: "have read, test against write", + haveFlags: sshFxfRead, + testFlags: []uint32{sshFxfWrite}, + ok: false, + }, + { + desc: "have write, test against read", + haveFlags: sshFxfWrite, + testFlags: []uint32{sshFxfRead}, + ok: false, + }, + { + desc: "have read+write, test against read", + haveFlags: sshFxfRead | sshFxfWrite, + testFlags: []uint32{sshFxfRead}, + ok: true, + }, + { + desc: "have read+write, test against write", + haveFlags: sshFxfRead | sshFxfWrite, + testFlags: []uint32{sshFxfWrite}, + ok: true, + }, + { + desc: "have read+write, test against read+write", + haveFlags: sshFxfRead | sshFxfWrite, + testFlags: []uint32{sshFxfRead, sshFxfWrite}, + ok: true, + }, + } + + for _, tt := range tests { + t.Log(tt.desc) + + p := &sshFxpOpenPacket{ + Pflags: tt.haveFlags, + } + + if want, got := tt.ok, p.hasPflags(tt.testFlags...); want != got { + t.Errorf("unexpected value for p.hasPflags(%#v): want: %v, got: %v", + tt.testFlags, want, got) + } + } +} + +func benchMarshal(b *testing.B, packet encoding.BinaryMarshaler) { + b.ResetTimer() + for i := 0; i < b.N; i++ { - sp(sshFxpOpenPacket{ - Id: 1, - Path: "/home/test/some/random/path", - Pflags: flags(os.O_RDONLY), - }) + sendPacket(ioutil.Discard, packet) } } +func BenchmarkMarshalInit(b *testing.B) { + benchMarshal(b, &sshFxInitPacket{ + Version: 3, + Extensions: []extensionPair{ + {"posix-rename@openssh.com", "1"}, + }, + }) +} + +func BenchmarkMarshalOpen(b *testing.B) { + benchMarshal(b, &sshFxpOpenPacket{ + ID: 1, + Path: "/home/test/some/random/path", + Pflags: toPflags(os.O_RDONLY), + }) +} + func BenchmarkMarshalWriteWorstCase(b *testing.B) { data := make([]byte, 32*1024) - for i := 0; i < b.N; i++ { - sp(sshFxpWritePacket{ - Id: 1, - Handle: "someopaquehandle", - Offset: 0, - Length: uint32(len(data)), - Data: data, - }) - } + + benchMarshal(b, &sshFxpWritePacket{ + ID: 1, + Handle: "someopaquehandle", + Offset: 0, + Length: uint32(len(data)), + Data: data, + }) } func BenchmarkMarshalWrite1k(b *testing.B) { - data := make([]byte, 1024) - for i := 0; i < b.N; i++ { - sp(sshFxpWritePacket{ - Id: 1, - Handle: "someopaquehandle", - Offset: 0, - Length: uint32(len(data)), - Data: data, - }) - } + data := make([]byte, 1025) + + benchMarshal(b, &sshFxpWritePacket{ + ID: 1, + Handle: "someopaquehandle", + Offset: 0, + Length: uint32(len(data)), + Data: data, + }) } diff --git a/pool.go b/pool.go new file mode 100644 index 00000000..36126290 --- /dev/null +++ b/pool.go @@ -0,0 +1,79 @@ +package sftp + +// bufPool provides a pool of byte-slices to be reused in various parts of the package. +// It is safe to use concurrently through a pointer. +type bufPool struct { + ch chan []byte + blen int +} + +func newBufPool(depth, bufLen int) *bufPool { + return &bufPool{ + ch: make(chan []byte, depth), + blen: bufLen, + } +} + +func (p *bufPool) Get() []byte { + if p.blen <= 0 { + panic("bufPool: new buffer creation length must be greater than zero") + } + + for { + select { + case b := <-p.ch: + if cap(b) < p.blen { + // just in case: throw away any buffer with insufficient capacity. + continue + } + + return b[:p.blen] + + default: + return make([]byte, p.blen) + } + } +} + +func (p *bufPool) Put(b []byte) { + if p == nil { + // functional default: no reuse. + return + } + + if cap(b) < p.blen || cap(b) > p.blen*2 { + // DO NOT reuse buffers with insufficient capacity. + // This could cause panics when resizing to p.blen. + + // DO NOT reuse buffers with excessive capacity. + // This could cause memory leaks. + return + } + + select { + case p.ch <- b: + default: + } +} + +type resChanPool chan chan result + +func newResChanPool(depth int) resChanPool { + return make(chan chan result, depth) +} + +func (p resChanPool) Get() chan result { + select { + case ch := <-p: + return ch + default: + return make(chan result, 1) + } +} + +func (p resChanPool) Put(ch chan result) { + select { + case p <- ch: + default: + } +} diff --git a/release.go b/release.go index b695528f..9ecedc44 100644 --- a/release.go +++ b/release.go @@ -1,3 +1,4 @@ +//go:build !debug // +build !debug package sftp diff --git a/request-attrs.go b/request-attrs.go new file mode 100644 index 00000000..476c5651 --- /dev/null +++ b/request-attrs.go @@ -0,0 +1,57 @@ +package sftp + +// Methods on the Request object to make working with the Flags bitmasks and +// Attr(ibutes) byte blob easier. Use Pflags() when working with an Open/Write +// request and AttrFlags() and Attributes() when working with SetStat requests. + +// FileOpenFlags defines Open and Write Flags. Correlate directly with with os.OpenFile flags +// (https://golang.org/pkg/os/#pkg-constants). +type FileOpenFlags struct { + Read, Write, Append, Creat, Trunc, Excl bool +} + +func newFileOpenFlags(flags uint32) FileOpenFlags { + return FileOpenFlags{ + Read: flags&sshFxfRead != 0, + Write: flags&sshFxfWrite != 0, + Append: flags&sshFxfAppend != 0, + Creat: flags&sshFxfCreat != 0, + Trunc: flags&sshFxfTrunc != 0, + Excl: flags&sshFxfExcl != 0, + } +} + +// Pflags converts the bitmap/uint32 from SFTP Open packet pflag values, +// into a FileOpenFlags struct with booleans set for flags set in bitmap. +func (r *Request) Pflags() FileOpenFlags { + return newFileOpenFlags(r.Flags) +} + +// FileAttrFlags that indicate whether SFTP file attributes were passed. When a flag is +// true the corresponding attribute should be available from the FileStat +// object returned by Attributes method. Used with SetStat. +type FileAttrFlags struct { + Size, UidGid, Permissions, Acmodtime bool +} + +func newFileAttrFlags(flags uint32) FileAttrFlags { + return FileAttrFlags{ + Size: (flags & sshFileXferAttrSize) != 0, + UidGid: (flags & sshFileXferAttrUIDGID) != 0, + Permissions: (flags & sshFileXferAttrPermissions) != 0, + Acmodtime: (flags & sshFileXferAttrACmodTime) != 0, + } +} + +// AttrFlags returns a FileAttrFlags boolean struct based on the +// bitmap/uint32 file attribute flags from the SFTP packaet. +func (r *Request) AttrFlags() FileAttrFlags { + return newFileAttrFlags(r.Flags) +} + +// Attributes parses file attributes byte blob and return them in a +// FileStat object. +func (r *Request) Attributes() *FileStat { + fs, _, _ := unmarshalFileStat(r.Flags, r.Attrs) + return fs +} diff --git a/request-attrs_test.go b/request-attrs_test.go new file mode 100644 index 00000000..b1b559b8 --- /dev/null +++ b/request-attrs_test.go @@ -0,0 +1,70 @@ +package sftp + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRequestPflags(t *testing.T) { + pflags := newFileOpenFlags(sshFxfRead | sshFxfWrite | sshFxfAppend) + assert.True(t, pflags.Read) + assert.True(t, pflags.Write) + assert.True(t, pflags.Append) + assert.False(t, pflags.Creat) + assert.False(t, pflags.Trunc) + assert.False(t, pflags.Excl) +} + +func TestRequestAflags(t *testing.T) { + aflags := newFileAttrFlags( + sshFileXferAttrSize | sshFileXferAttrUIDGID) + assert.True(t, aflags.Size) + assert.True(t, aflags.UidGid) + assert.False(t, aflags.Acmodtime) + assert.False(t, aflags.Permissions) +} + +func TestRequestAttributes(t *testing.T) { + // UID/GID + fa := FileStat{UID: 1, GID: 2} + fl := uint32(sshFileXferAttrUIDGID) + at := []byte{} + at = marshalUint32(at, 1) + at = marshalUint32(at, 2) + testFs, _, err := unmarshalFileStat(fl, at) + require.NoError(t, err) + assert.Equal(t, fa, *testFs) + // Size and Mode + fa = FileStat{Mode: 0700, Size: 99} + fl = uint32(sshFileXferAttrSize | sshFileXferAttrPermissions) + at = []byte{} + at = marshalUint64(at, 99) + at = marshalUint32(at, 0700) + testFs, _, err = unmarshalFileStat(fl, at) + require.NoError(t, err) + assert.Equal(t, fa, *testFs) + // FileMode + assert.True(t, testFs.FileMode().IsRegular()) + assert.False(t, testFs.FileMode().IsDir()) + assert.Equal(t, testFs.FileMode().Perm(), os.FileMode(0700).Perm()) +} + +func TestRequestAttributesEmpty(t *testing.T) { + fs, b, err := unmarshalFileStat(sshFileXferAttrAll, []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // size + 0x00, 0x00, 0x00, 0x00, // mode + 0x00, 0x00, 0x00, 0x00, // mtime + 0x00, 0x00, 0x00, 0x00, // atime + 0x00, 0x00, 0x00, 0x00, // uid + 0x00, 0x00, 0x00, 0x00, // gid + 0x00, 0x00, 0x00, 0x00, // extended_count + }) + require.NoError(t, err) + assert.Equal(t, &FileStat{ + Extended: []StatExtended{}, + }, fs) + assert.Empty(t, b) +} diff --git a/request-errors.go b/request-errors.go new file mode 100644 index 00000000..6505b5c7 --- /dev/null +++ b/request-errors.go @@ -0,0 +1,54 @@ +package sftp + +type fxerr uint32 + +// Error types that match the SFTP's SSH_FXP_STATUS codes. Gives you more +// direct control of the errors being sent vs. letting the library work them +// out from the standard os/io errors. +const ( + ErrSSHFxOk = fxerr(sshFxOk) + ErrSSHFxEOF = fxerr(sshFxEOF) + ErrSSHFxNoSuchFile = fxerr(sshFxNoSuchFile) + ErrSSHFxPermissionDenied = fxerr(sshFxPermissionDenied) + ErrSSHFxFailure = fxerr(sshFxFailure) + ErrSSHFxBadMessage = fxerr(sshFxBadMessage) + ErrSSHFxNoConnection = fxerr(sshFxNoConnection) + ErrSSHFxConnectionLost = fxerr(sshFxConnectionLost) + ErrSSHFxOpUnsupported = fxerr(sshFxOPUnsupported) +) + +// Deprecated error types, these are aliases for the new ones, please use the new ones directly +const ( + ErrSshFxOk = ErrSSHFxOk + ErrSshFxEof = ErrSSHFxEOF + ErrSshFxNoSuchFile = ErrSSHFxNoSuchFile + ErrSshFxPermissionDenied = ErrSSHFxPermissionDenied + ErrSshFxFailure = ErrSSHFxFailure + ErrSshFxBadMessage = ErrSSHFxBadMessage + ErrSshFxNoConnection = ErrSSHFxNoConnection + ErrSshFxConnectionLost = ErrSSHFxConnectionLost + ErrSshFxOpUnsupported = ErrSSHFxOpUnsupported +) + +func (e fxerr) Error() string { + switch e { + case ErrSSHFxOk: + return "OK" + case ErrSSHFxEOF: + return "EOF" + case ErrSSHFxNoSuchFile: + return "no such file" + case ErrSSHFxPermissionDenied: + return "permission denied" + case ErrSSHFxBadMessage: + return "bad message" + case ErrSSHFxNoConnection: + return "no connection" + case ErrSSHFxConnectionLost: + return "connection lost" + case ErrSSHFxOpUnsupported: + return "operation unsupported" + default: + return "failure" + } +} diff --git a/request-example.go b/request-example.go new file mode 100644 index 00000000..f003e711 --- /dev/null +++ b/request-example.go @@ -0,0 +1,647 @@ +package sftp + +// This serves as an example of how to implement the request server handler as +// well as a dummy backend for testing. It implements an in-memory backend that +// works as a very simple filesystem with simple flat key-value lookup system. + +import ( + "errors" + "io" + "os" + "path" + "sort" + "strings" + "sync" + "syscall" + "time" +) + +const maxSymlinkFollows = 5 + +var errTooManySymlinks = errors.New("too many symbolic links") + +// InMemHandler returns a Handlers object with the test handlers. +func InMemHandler() Handlers { + root := &root{ + rootFile: &memFile{name: "/", modtime: time.Now(), isdir: true}, + files: make(map[string]*memFile), + } + return Handlers{root, root, root, root} +} + +// Example Handlers +func (fs *root) Fileread(r *Request) (io.ReaderAt, error) { + flags := r.Pflags() + if !flags.Read { + // sanity check + return nil, os.ErrInvalid + } + + return fs.OpenFile(r) +} + +func (fs *root) Filewrite(r *Request) (io.WriterAt, error) { + flags := r.Pflags() + if !flags.Write { + // sanity check + return nil, os.ErrInvalid + } + + return fs.OpenFile(r) +} + +func (fs *root) OpenFile(r *Request) (WriterAtReaderAt, error) { + if fs.mockErr != nil { + return nil, fs.mockErr + } + _ = r.WithContext(r.Context()) // initialize context for deadlock testing + + fs.mu.Lock() + defer fs.mu.Unlock() + + return fs.openfile(r.Filepath, r.Flags) +} + +func (fs *root) putfile(pathname string, file *memFile) error { + pathname, err := fs.canonName(pathname) + if err != nil { + return err + } + + if !strings.HasPrefix(pathname, "/") { + return os.ErrInvalid + } + + if _, err := fs.lfetch(pathname); err != os.ErrNotExist { + return os.ErrExist + } + + file.name = pathname + fs.files[pathname] = file + + return nil +} + +func (fs *root) openfile(pathname string, flags uint32) (*memFile, error) { + pflags := newFileOpenFlags(flags) + + file, err := fs.fetch(pathname) + if err == os.ErrNotExist { + if !pflags.Creat { + return nil, os.ErrNotExist + } + + var count int + // You can create files through dangling symlinks. + link, err := fs.lfetch(pathname) + for err == nil && link.symlink != "" { + if pflags.Excl { + // unless you also passed in O_EXCL + return nil, os.ErrInvalid + } + + if count++; count > maxSymlinkFollows { + return nil, errTooManySymlinks + } + + pathname = link.symlink + link, err = fs.lfetch(pathname) + } + + file := &memFile{ + modtime: time.Now(), + } + + if err := fs.putfile(pathname, file); err != nil { + return nil, err + } + + return file, nil + } + + if err != nil { + return nil, err + } + + if pflags.Creat && pflags.Excl { + return nil, os.ErrExist + } + + if file.IsDir() { + return nil, os.ErrInvalid + } + + if pflags.Trunc { + if err := file.Truncate(0); err != nil { + return nil, err + } + } + + return file, nil +} + +func (fs *root) Filecmd(r *Request) error { + if fs.mockErr != nil { + return fs.mockErr + } + _ = r.WithContext(r.Context()) // initialize context for deadlock testing + + fs.mu.Lock() + defer fs.mu.Unlock() + + switch r.Method { + case "Setstat": + file, err := fs.openfile(r.Filepath, sshFxfWrite) + if err != nil { + return err + } + + if r.AttrFlags().Size { + return file.Truncate(int64(r.Attributes().Size)) + } + + return nil + + case "Rename": + // SFTP-v2: "It is an error if there already exists a file with the name specified by newpath." + // This varies from the POSIX specification, which allows limited replacement of target files. + if fs.exists(r.Target) { + return os.ErrExist + } + + return fs.rename(r.Filepath, r.Target) + + case "Rmdir": + return fs.rmdir(r.Filepath) + + case "Remove": + // IEEE 1003.1 remove explicitly can unlink files and remove empty directories. + // We use instead here the semantics of unlink, which is allowed to be restricted against directories. + return fs.unlink(r.Filepath) + + case "Mkdir": + return fs.mkdir(r.Filepath) + + case "Link": + return fs.link(r.Filepath, r.Target) + + case "Symlink": + // NOTE: r.Filepath is the target, and r.Target is the linkpath. + return fs.symlink(r.Filepath, r.Target) + } + + return errors.New("unsupported") +} + +func (fs *root) rename(oldpath, newpath string) error { + file, err := fs.lfetch(oldpath) + if err != nil { + return err + } + + newpath, err = fs.canonName(newpath) + if err != nil { + return err + } + + if !strings.HasPrefix(newpath, "/") { + return os.ErrInvalid + } + + target, err := fs.lfetch(newpath) + if err != os.ErrNotExist { + if target == file { + // IEEE 1003.1: if oldpath and newpath are the same directory entry, + // then return no error, and perform no further action. + return nil + } + + switch { + case file.IsDir(): + // IEEE 1003.1: if oldpath is a directory, and newpath exists, + // then newpath must be a directory, and empty. + // It is to be removed prior to rename. + if err := fs.rmdir(newpath); err != nil { + return err + } + + case target.IsDir(): + // IEEE 1003.1: if oldpath is not a directory, and newpath exists, + // then newpath may not be a directory. + return syscall.EISDIR + } + } + + fs.files[newpath] = file + + if file.IsDir() { + dirprefix := file.name + "/" + + for name, file := range fs.files { + if strings.HasPrefix(name, dirprefix) { + newname := path.Join(newpath, strings.TrimPrefix(name, dirprefix)) + + fs.files[newname] = file + file.name = newname + delete(fs.files, name) + } + } + } + + file.name = newpath + delete(fs.files, oldpath) + + return nil +} + +func (fs *root) PosixRename(r *Request) error { + if fs.mockErr != nil { + return fs.mockErr + } + _ = r.WithContext(r.Context()) // initialize context for deadlock testing + + fs.mu.Lock() + defer fs.mu.Unlock() + + return fs.rename(r.Filepath, r.Target) +} + +func (fs *root) StatVFS(r *Request) (*StatVFS, error) { + if fs.mockErr != nil { + return nil, fs.mockErr + } + + return getStatVFSForPath(r.Filepath) +} + +func (fs *root) mkdir(pathname string) error { + dir := &memFile{ + modtime: time.Now(), + isdir: true, + } + + return fs.putfile(pathname, dir) +} + +func (fs *root) rmdir(pathname string) error { + // IEEE 1003.1: If pathname is a symlink, then rmdir should fail with ENOTDIR. + dir, err := fs.lfetch(pathname) + if err != nil { + return err + } + + if !dir.IsDir() { + return syscall.ENOTDIR + } + + // use the dir‘s internal name not the pathname we passed in. + // the dir.name is always the canonical name of a directory. + pathname = dir.name + + for name := range fs.files { + if path.Dir(name) == pathname { + return errors.New("directory not empty") + } + } + + delete(fs.files, pathname) + + return nil +} + +func (fs *root) link(oldpath, newpath string) error { + file, err := fs.lfetch(oldpath) + if err != nil { + return err + } + + if file.IsDir() { + return errors.New("hard link not allowed for directory") + } + + return fs.putfile(newpath, file) +} + +// symlink() creates a symbolic link named `linkpath` which contains the string `target`. +// NOTE! This would be called with `symlink(req.Filepath, req.Target)` due to different semantics. +func (fs *root) symlink(target, linkpath string) error { + link := &memFile{ + modtime: time.Now(), + symlink: target, + } + + return fs.putfile(linkpath, link) +} + +func (fs *root) unlink(pathname string) error { + // does not follow symlinks! + file, err := fs.lfetch(pathname) + if err != nil { + return err + } + + if file.IsDir() { + // IEEE 1003.1: implementations may opt out of allowing the unlinking of directories. + // SFTP-v2: SSH_FXP_REMOVE may not remove directories. + return os.ErrInvalid + } + + // DO NOT use the file’s internal name. + // because of hard-links files cannot have a single canonical name. + delete(fs.files, pathname) + + return nil +} + +type listerat []os.FileInfo + +// Modeled after strings.Reader's ReadAt() implementation +func (f listerat) ListAt(ls []os.FileInfo, offset int64) (int, error) { + var n int + if offset >= int64(len(f)) { + return 0, io.EOF + } + n = copy(ls, f[offset:]) + if n < len(ls) { + return n, io.EOF + } + return n, nil +} + +func (fs *root) Filelist(r *Request) (ListerAt, error) { + if fs.mockErr != nil { + return nil, fs.mockErr + } + _ = r.WithContext(r.Context()) // initialize context for deadlock testing + + fs.mu.Lock() + defer fs.mu.Unlock() + + switch r.Method { + case "List": + files, err := fs.readdir(r.Filepath) + if err != nil { + return nil, err + } + return listerat(files), nil + + case "Stat": + file, err := fs.fetch(r.Filepath) + if err != nil { + return nil, err + } + return listerat{file}, nil + } + + return nil, errors.New("unsupported") +} + +func (fs *root) readdir(pathname string) ([]os.FileInfo, error) { + dir, err := fs.fetch(pathname) + if err != nil { + return nil, err + } + + if !dir.IsDir() { + return nil, syscall.ENOTDIR + } + + var files []os.FileInfo + + for name, file := range fs.files { + if path.Dir(name) == dir.name { + files = append(files, file) + } + } + + sort.Slice(files, func(i, j int) bool { return files[i].Name() < files[j].Name() }) + + return files, nil +} + +func (fs *root) Readlink(pathname string) (string, error) { + file, err := fs.lfetch(pathname) + if err != nil { + return "", err + } + + if file.symlink == "" { + return "", os.ErrInvalid + } + + return file.symlink, nil +} + +// implements LstatFileLister interface +func (fs *root) Lstat(r *Request) (ListerAt, error) { + if fs.mockErr != nil { + return nil, fs.mockErr + } + _ = r.WithContext(r.Context()) // initialize context for deadlock testing + + fs.mu.Lock() + defer fs.mu.Unlock() + + file, err := fs.lfetch(r.Filepath) + if err != nil { + return nil, err + } + return listerat{file}, nil +} + +// In memory file-system-y thing that the Handlers live on +type root struct { + rootFile *memFile + mockErr error + + mu sync.Mutex + files map[string]*memFile +} + +// Set a mocked error that the next handler call will return. +// Set to nil to reset for no error. +func (fs *root) returnErr(err error) { + fs.mockErr = err +} + +func (fs *root) lfetch(path string) (*memFile, error) { + if path == "/" { + return fs.rootFile, nil + } + + file, ok := fs.files[path] + if file == nil { + if ok { + delete(fs.files, path) + } + + return nil, os.ErrNotExist + } + + return file, nil +} + +// canonName returns the “canonical” name of a file, that is: +// if the directory of the pathname is a symlink, it follows that symlink to the valid directory name. +// this is relatively easy, since `dir.name` will be the only valid canonical path for a directory. +func (fs *root) canonName(pathname string) (string, error) { + dirname, filename := path.Dir(pathname), path.Base(pathname) + + dir, err := fs.fetch(dirname) + if err != nil { + return "", err + } + + if !dir.IsDir() { + return "", syscall.ENOTDIR + } + + return path.Join(dir.name, filename), nil +} + +func (fs *root) exists(path string) bool { + path, err := fs.canonName(path) + if err != nil { + return false + } + + _, err = fs.lfetch(path) + + return err != os.ErrNotExist +} + +func (fs *root) fetch(pathname string) (*memFile, error) { + file, err := fs.lfetch(pathname) + if err != nil { + return nil, err + } + + var count int + for file.symlink != "" { + if count++; count > maxSymlinkFollows { + return nil, errTooManySymlinks + } + + linkTarget := file.symlink + if !path.IsAbs(linkTarget) { + linkTarget = path.Join(path.Dir(file.name), linkTarget) + } + + file, err = fs.lfetch(linkTarget) + if err != nil { + return nil, err + } + } + + return file, nil +} + +// Implements os.FileInfo, io.ReaderAt and io.WriterAt interfaces. +// These are the 3 interfaces necessary for the Handlers. +// Implements the optional interface TransferError. +type memFile struct { + name string + modtime time.Time + symlink string + isdir bool + + mu sync.RWMutex + content []byte + err error +} + +// These are helper functions, they must be called while holding the memFile.mu mutex +func (f *memFile) size() int64 { return int64(len(f.content)) } +func (f *memFile) grow(n int64) { f.content = append(f.content, make([]byte, n)...) } + +// Have memFile fulfill os.FileInfo interface +func (f *memFile) Name() string { return path.Base(f.name) } +func (f *memFile) Size() int64 { + f.mu.Lock() + defer f.mu.Unlock() + + return f.size() +} +func (f *memFile) Mode() os.FileMode { + if f.isdir { + return os.FileMode(0755) | os.ModeDir + } + if f.symlink != "" { + return os.FileMode(0777) | os.ModeSymlink + } + return os.FileMode(0644) +} +func (f *memFile) ModTime() time.Time { return f.modtime } +func (f *memFile) IsDir() bool { return f.isdir } +func (f *memFile) Sys() interface{} { + return fakeFileInfoSys() +} + +func (f *memFile) ReadAt(b []byte, off int64) (int, error) { + f.mu.Lock() + defer f.mu.Unlock() + + if f.err != nil { + return 0, f.err + } + + if off < 0 { + return 0, errors.New("memFile.ReadAt: negative offset") + } + + if off >= f.size() { + return 0, io.EOF + } + + n := copy(b, f.content[off:]) + if n < len(b) { + return n, io.EOF + } + + return n, nil +} + +func (f *memFile) WriteAt(b []byte, off int64) (int, error) { + // fmt.Println(string(p), off) + // mimic write delays, should be optional + time.Sleep(time.Microsecond * time.Duration(len(b))) + + f.mu.Lock() + defer f.mu.Unlock() + + if f.err != nil { + return 0, f.err + } + + grow := int64(len(b)) + off - f.size() + if grow > 0 { + f.grow(grow) + } + + return copy(f.content[off:], b), nil +} + +func (f *memFile) Truncate(size int64) error { + f.mu.Lock() + defer f.mu.Unlock() + + if f.err != nil { + return f.err + } + + grow := size - f.size() + if grow <= 0 { + f.content = f.content[:size] + } else { + f.grow(grow) + } + + return nil +} + +func (f *memFile) TransferError(err error) { + f.mu.Lock() + defer f.mu.Unlock() + + f.err = err +} diff --git a/request-interfaces.go b/request-interfaces.go new file mode 100644 index 00000000..13e7577e --- /dev/null +++ b/request-interfaces.go @@ -0,0 +1,159 @@ +package sftp + +import ( + "io" + "os" +) + +// WriterAtReaderAt defines the interface to return when a file is to +// be opened for reading and writing +type WriterAtReaderAt interface { + io.WriterAt + io.ReaderAt +} + +// Interfaces are differentiated based on required returned values. +// All input arguments are to be pulled from Request (the only arg). + +// The Handler interfaces all take the Request object as its only argument. +// All the data you should need to handle the call are in the Request object. +// The request.Method attribute is initially the most important one as it +// determines which Handler gets called. + +// FileReader should return an io.ReaderAt for the filepath +// Note in cases of an error, the error text will be sent to the client. +// Called for Methods: Get +type FileReader interface { + Fileread(*Request) (io.ReaderAt, error) +} + +// FileWriter should return an io.WriterAt for the filepath. +// +// The request server code will call Close() on the returned io.WriterAt +// object if an io.Closer type assertion succeeds. +// Note in cases of an error, the error text will be sent to the client. +// Note when receiving an Append flag it is important to not open files using +// O_APPEND if you plan to use WriteAt, as they conflict. +// Called for Methods: Put, Open +type FileWriter interface { + Filewrite(*Request) (io.WriterAt, error) +} + +// OpenFileWriter is a FileWriter that implements the generic OpenFile method. +// You need to implement this optional interface if you want to be able +// to read and write from/to the same handle. +// Called for Methods: Open +type OpenFileWriter interface { + FileWriter + OpenFile(*Request) (WriterAtReaderAt, error) +} + +// FileCmder should return an error +// Note in cases of an error, the error text will be sent to the client. +// Called for Methods: Setstat, Rename, Rmdir, Mkdir, Link, Symlink, Remove +type FileCmder interface { + Filecmd(*Request) error +} + +// PosixRenameFileCmder is a FileCmder that implements the PosixRename method. +// If this interface is implemented PosixRename requests will call it +// otherwise they will be handled in the same way as Rename +type PosixRenameFileCmder interface { + FileCmder + PosixRename(*Request) error +} + +// StatVFSFileCmder is a FileCmder that implements the StatVFS method. +// You need to implement this interface if you want to handle statvfs requests. +// Please also be sure that the statvfs@openssh.com extension is enabled +type StatVFSFileCmder interface { + FileCmder + StatVFS(*Request) (*StatVFS, error) +} + +// FileLister should return an object that fulfils the ListerAt interface +// Note in cases of an error, the error text will be sent to the client. +// Called for Methods: List, Stat, Readlink +// +// Since Filelist returns an os.FileInfo, this can make it non-ideal for implementing Readlink. +// This is because the Name receiver method defined by that interface defines that it should only return the base name. +// However, Readlink is required to be capable of returning essentially any arbitrary valid path relative or absolute. +// In order to implement this more expressive requirement, implement [ReadlinkFileLister] which will then be used instead. +type FileLister interface { + Filelist(*Request) (ListerAt, error) +} + +// LstatFileLister is a FileLister that implements the Lstat method. +// If this interface is implemented Lstat requests will call it +// otherwise they will be handled in the same way as Stat +type LstatFileLister interface { + FileLister + Lstat(*Request) (ListerAt, error) +} + +// RealPathFileLister is a FileLister that implements the Realpath method. +// The built-in RealPath implementation does not resolve symbolic links. +// By implementing this interface you can customize the returned path +// and, for example, resolve symbolinc links if needed for your use case. +// You have to return an absolute POSIX path. +// +// Up to v1.13.5 the signature for the RealPath method was: +// +// # RealPath(string) string +// +// we have added a legacyRealPathFileLister that implements the old method +// to ensure that your code does not break. +// You should use the new method signature to avoid future issues +type RealPathFileLister interface { + FileLister + RealPath(string) (string, error) +} + +// ReadlinkFileLister is a FileLister that implements the Readlink method. +// By implementing the Readlink method, it is possible to return any arbitrary valid path relative or absolute. +// This allows giving a better response than via the default FileLister (which is limited to os.FileInfo, whose Name method should only return the base name of a file) +type ReadlinkFileLister interface { + FileLister + Readlink(string) (string, error) +} + +// This interface is here for backward compatibility only +type legacyRealPathFileLister interface { + FileLister + RealPath(string) string +} + +// NameLookupFileLister is a FileLister that implmeents the LookupUsername and LookupGroupName methods. +// If this interface is implemented, then longname ls formatting will use these to convert usernames and groupnames. +type NameLookupFileLister interface { + FileLister + LookupUserName(string) string + LookupGroupName(string) string +} + +// ListerAt does for file lists what io.ReaderAt does for files, i.e. a []os.FileInfo buffer is passed to the ListAt function +// and the entries that are populated in the buffer will be passed to the client. +// +// ListAt should return the number of entries copied and an io.EOF error if at end of list. +// This is testable by comparing how many you copied to how many could be copied (eg. n < len(ls) below). +// The copy() builtin is best for the copying. +// +// Uid and gid information will on unix systems be retrieved from [os.FileInfo.Sys] +// if this function returns a [syscall.Stat_t] when called on a populated entry. +// Alternatively, if the entry implements [FileInfoUidGid], it will be used for uid and gid information. +// +// If a populated entry implements [FileInfoExtendedData], extended attributes will also be returned to the client. +// +// The request server code will call Close() on ListerAt if an io.Closer type assertion succeeds. +// +// Note in cases of an error, the error text will be sent to the client. +type ListerAt interface { + ListAt([]os.FileInfo, int64) (int, error) +} + +// TransferError is an optional interface that readerAt and writerAt +// can implement to be notified about the error causing Serve() to exit +// with the request still open +type TransferError interface { + TransferError(err error) +} diff --git a/request-plan9.go b/request-plan9.go new file mode 100644 index 00000000..38f91bcd --- /dev/null +++ b/request-plan9.go @@ -0,0 +1,16 @@ +//go:build plan9 +// +build plan9 + +package sftp + +import ( + "syscall" +) + +func fakeFileInfoSys() interface{} { + return &syscall.Dir{} +} + +func testOsSys(sys interface{}) error { + return nil +} diff --git a/request-readme.md b/request-readme.md new file mode 100644 index 00000000..f8b81f3a --- /dev/null +++ b/request-readme.md @@ -0,0 +1,53 @@ +# Request Based SFTP API + +The request based API allows for custom backends in a way similar to the http +package. In order to create a backend you need to implement 4 handler +interfaces; one for reading, one for writing, one for misc commands and one for +listing files. Each has 1 required method and in each case those methods take +the Request as the only parameter and they each return something different. +These 4 interfaces are enough to handle all the SFTP traffic in a simplified +manner. + +The Request structure has 5 public fields which you will deal with. + +- Method (string) - string name of incoming call +- Filepath (string) - POSIX path of file to act on +- Flags (uint32) - 32bit bitmask value of file open/create flags +- Attrs ([]byte) - byte string of file attribute data +- Target (string) - target path for renames and sym-links + +Below are the methods and a brief description of what they need to do. + +### Fileread(*Request) (io.Reader, error) + +Handler for "Get" method and returns an io.Reader for the file which the server +then sends to the client. + +### Filewrite(*Request) (io.Writer, error) + +Handler for "Put" method and returns an io.Writer for the file which the server +then writes the uploaded file to. The file opening "pflags" are currently +preserved in the Request.Flags field as a 32bit bitmask value. See the [SFTP +spec](https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt#section-6.3) for +details. + +### Filecmd(*Request) error + +Handles "SetStat", "Rename", "Rmdir", "Mkdir" and "Symlink" methods. Makes the +appropriate changes and returns nil for success or an filesystem like error +(eg. os.ErrNotExist). The attributes are currently propagated in their raw form +([]byte) and will need to be unmarshalled to be useful. See the respond method +on sshFxpSetstatPacket for example of you might want to do this. + +### Fileinfo(*Request) ([]os.FileInfo, error) + +Handles "List", "Stat", "Readlink" methods. Gathers/creates FileInfo structs +with the data on the files and returns in a list (list of 1 for Stat and +Readlink). + + +## TODO + +- Add support for API users to see trace/debugging info of what is going on +inside SFTP server. +- Unmarshal the file attributes into a structure on the Request object. diff --git a/request-server.go b/request-server.go new file mode 100644 index 00000000..08c24d7f --- /dev/null +++ b/request-server.go @@ -0,0 +1,355 @@ +package sftp + +import ( + "context" + "errors" + "io" + "path" + "path/filepath" + "strconv" + "sync" +) + +const defaultMaxTxPacket uint32 = 1 << 15 + +// Handlers contains the 4 SFTP server request handlers. +type Handlers struct { + FileGet FileReader + FilePut FileWriter + FileCmd FileCmder + FileList FileLister +} + +// RequestServer abstracts the sftp protocol with an http request-like protocol +type RequestServer struct { + Handlers Handlers + + *serverConn + pktMgr *packetManager + + startDirectory string + maxTxPacket uint32 + + mu sync.RWMutex + handleCount int + openRequests map[string]*Request +} + +// A RequestServerOption is a function which applies configuration to a RequestServer. +type RequestServerOption func(*RequestServer) + +// WithRSAllocator enable the allocator. +// After processing a packet we keep in memory the allocated slices +// and we reuse them for new packets. +// The allocator is experimental +func WithRSAllocator() RequestServerOption { + return func(rs *RequestServer) { + alloc := newAllocator() + rs.pktMgr.alloc = alloc + rs.conn.alloc = alloc + } +} + +// WithStartDirectory sets a start directory to use as base for relative paths. +// If unset the default is "/" +func WithStartDirectory(startDirectory string) RequestServerOption { + return func(rs *RequestServer) { + rs.startDirectory = cleanPath(startDirectory) + } +} + +// WithRSMaxTxPacket sets the maximum size of the payload returned to the client, +// measured in bytes. The default value is 32768 bytes, and this option +// can only be used to increase it. Setting this option to a larger value +// should be safe, because the client decides the size of the requested payload. +// +// The default maximum packet size is 32768 bytes. +func WithRSMaxTxPacket(size uint32) RequestServerOption { + return func(rs *RequestServer) { + if size < defaultMaxTxPacket { + return + } + + rs.maxTxPacket = size + } +} + +// NewRequestServer creates/allocates/returns new RequestServer. +// Normally there will be one server per user-session. +func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer { + svrConn := &serverConn{ + conn: conn{ + Reader: rwc, + WriteCloser: rwc, + }, + } + rs := &RequestServer{ + Handlers: h, + + serverConn: svrConn, + pktMgr: newPktMgr(svrConn), + + startDirectory: "/", + maxTxPacket: defaultMaxTxPacket, + + openRequests: make(map[string]*Request), + } + + for _, o := range options { + o(rs) + } + return rs +} + +// New Open packet/Request +func (rs *RequestServer) nextRequest(r *Request) string { + rs.mu.Lock() + defer rs.mu.Unlock() + + rs.handleCount++ + + r.handle = strconv.Itoa(rs.handleCount) + rs.openRequests[r.handle] = r + + return r.handle +} + +// Returns Request from openRequests, bool is false if it is missing. +// +// The Requests in openRequests work essentially as open file descriptors that +// you can do different things with. What you are doing with it are denoted by +// the first packet of that type (read/write/etc). +func (rs *RequestServer) getRequest(handle string) (*Request, bool) { + rs.mu.RLock() + defer rs.mu.RUnlock() + + r, ok := rs.openRequests[handle] + return r, ok +} + +// Close the Request and clear from openRequests map +func (rs *RequestServer) closeRequest(handle string) error { + rs.mu.Lock() + defer rs.mu.Unlock() + + if r, ok := rs.openRequests[handle]; ok { + delete(rs.openRequests, handle) + return r.close() + } + + return EBADF +} + +// Close the read/write/closer to trigger exiting the main server loop +func (rs *RequestServer) Close() error { return rs.conn.Close() } + +func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error { + defer close(pktChan) // shuts down sftpServerWorkers + + var err error + var pkt requestPacket + var pktType fxp + var pktBytes []byte + + for { + pktType, pktBytes, err = rs.serverConn.recvPacket(rs.pktMgr.getNextOrderID()) + if err != nil { + // we don't care about releasing allocated pages here, the server will quit and the allocator freed + return err + } + + pkt, err = makePacket(rxPacket{pktType, pktBytes}) + if err != nil { + switch { + case errors.Is(err, errUnknownExtendedPacket): + // do nothing + default: + debug("makePacket err: %v", err) + rs.conn.Close() // shuts down recvPacket + return err + } + } + + pktChan <- rs.pktMgr.newOrderedRequest(pkt) + } +} + +// Serve requests for user session +func (rs *RequestServer) Serve() error { + defer func() { + if rs.pktMgr.alloc != nil { + rs.pktMgr.alloc.Free() + } + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var wg sync.WaitGroup + runWorker := func(ch chan orderedRequest) { + wg.Add(1) + go func() { + defer wg.Done() + if err := rs.packetWorker(ctx, ch); err != nil { + rs.conn.Close() // shuts down recvPacket + } + }() + } + pktChan := rs.pktMgr.workerChan(runWorker) + + err := rs.serveLoop(pktChan) + + wg.Wait() // wait for all workers to exit + + rs.mu.Lock() + defer rs.mu.Unlock() + + // make sure all open requests are properly closed + // (eg. possible on dropped connections, client crashes, etc.) + for handle, req := range rs.openRequests { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + req.transferError(err) + + delete(rs.openRequests, handle) + req.close() + } + + return err +} + +func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedRequest) error { + for pkt := range pktChan { + orderID := pkt.orderID() + if epkt, ok := pkt.requestPacket.(*sshFxpExtendedPacket); ok { + if epkt.SpecificPacket != nil { + pkt.requestPacket = epkt.SpecificPacket + } + } + + var rpkt responsePacket + switch pkt := pkt.requestPacket.(type) { + case *sshFxInitPacket: + rpkt = &sshFxVersionPacket{Version: sftpProtocolVersion, Extensions: sftpExtensions} + case *sshFxpClosePacket: + handle := pkt.getHandle() + rpkt = statusFromError(pkt.ID, rs.closeRequest(handle)) + case *sshFxpRealpathPacket: + var realPath string + var err error + + switch pather := rs.Handlers.FileList.(type) { + case RealPathFileLister: + realPath, err = pather.RealPath(pkt.getPath()) + case legacyRealPathFileLister: + realPath = pather.RealPath(pkt.getPath()) + default: + realPath = cleanPathWithBase(rs.startDirectory, pkt.getPath()) + } + if err != nil { + rpkt = statusFromError(pkt.ID, err) + } else { + rpkt = cleanPacketPath(pkt, realPath) + } + case *sshFxpOpendirPacket: + request := requestFromPacket(ctx, pkt, rs.startDirectory) + handle := rs.nextRequest(request) + rpkt = request.opendir(rs.Handlers, pkt) + if _, ok := rpkt.(*sshFxpHandlePacket); !ok { + // if we return an error we have to remove the handle from the active ones + rs.closeRequest(handle) + } + case *sshFxpOpenPacket: + request := requestFromPacket(ctx, pkt, rs.startDirectory) + handle := rs.nextRequest(request) + rpkt = request.open(rs.Handlers, pkt) + if _, ok := rpkt.(*sshFxpHandlePacket); !ok { + // if we return an error we have to remove the handle from the active ones + rs.closeRequest(handle) + } + case *sshFxpFstatPacket: + handle := pkt.getHandle() + request, ok := rs.getRequest(handle) + if !ok { + rpkt = statusFromError(pkt.ID, EBADF) + } else { + request = &Request{ + Method: "Stat", + Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath), + } + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket) + } + case *sshFxpFsetstatPacket: + handle := pkt.getHandle() + request, ok := rs.getRequest(handle) + if !ok { + rpkt = statusFromError(pkt.ID, EBADF) + } else { + request = &Request{ + Method: "Setstat", + Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath), + } + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket) + } + case *sshFxpExtendedPacketPosixRename: + request := &Request{ + Method: "PosixRename", + Filepath: cleanPathWithBase(rs.startDirectory, pkt.Oldpath), + Target: cleanPathWithBase(rs.startDirectory, pkt.Newpath), + } + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket) + case *sshFxpExtendedPacketStatVFS: + request := &Request{ + Method: "StatVFS", + Filepath: cleanPathWithBase(rs.startDirectory, pkt.Path), + } + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket) + case hasHandle: + handle := pkt.getHandle() + request, ok := rs.getRequest(handle) + if !ok { + rpkt = statusFromError(pkt.id(), EBADF) + } else { + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket) + } + case hasPath: + request := requestFromPacket(ctx, pkt, rs.startDirectory) + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket) + request.close() + default: + rpkt = statusFromError(pkt.id(), ErrSSHFxOpUnsupported) + } + + rs.pktMgr.readyPacket( + rs.pktMgr.newOrderedResponse(rpkt, orderID)) + } + return nil +} + +// clean and return name packet for file +func cleanPacketPath(pkt *sshFxpRealpathPacket, realPath string) responsePacket { + return &sshFxpNamePacket{ + ID: pkt.id(), + NameAttrs: []*sshFxpNameAttr{ + { + Name: realPath, + LongName: realPath, + Attrs: emptyFileStat, + }, + }, + } +} + +// Makes sure we have a clean POSIX (/) absolute path to work with +func cleanPath(p string) string { + return cleanPathWithBase("/", p) +} + +func cleanPathWithBase(base, p string) string { + p = filepath.ToSlash(filepath.Clean(p)) + if !path.IsAbs(p) { + return path.Join(base, p) + } + return p +} diff --git a/request-server_test.go b/request-server_test.go new file mode 100644 index 00000000..93011a57 --- /dev/null +++ b/request-server_test.go @@ -0,0 +1,1047 @@ +package sftp + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "net" + "os" + "path" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var _ = fmt.Print + +type csPair struct { + cli *Client + svr *RequestServer + svrResult chan error +} + +// these must be closed in order, else client.Close will hang +func (cs csPair) Close() { + cs.svr.Close() + cs.cli.Close() + os.Remove(sock) +} + +func (cs csPair) testHandler() *root { + return cs.svr.Handlers.FileGet.(*root) +} + +const sock = "/tmp/rstest.sock" + +func clientRequestServerPairWithHandlers(t *testing.T, handlers Handlers, options ...RequestServerOption) *csPair { + skipIfWindows(t) + skipIfPlan9(t) + + ready := make(chan struct{}) + canReturn := make(chan struct{}) + os.Remove(sock) // either this or signal handling + pair := &csPair{ + svrResult: make(chan error, 1), + } + + var server *RequestServer + go func() { + l, err := net.Listen("unix", sock) + if err != nil { + // neither assert nor t.Fatal reliably exit before Accept errors + panic(err) + } + + close(ready) + + fd, err := l.Accept() + require.NoError(t, err) + + if *testAllocator { + options = append(options, WithRSAllocator()) + } + + server = NewRequestServer(fd, handlers, options...) + close(canReturn) + + err = server.Serve() + pair.svrResult <- err + }() + + <-ready + defer os.Remove(sock) + + c, err := net.Dial("unix", sock) + require.NoError(t, err) + + client, err := NewClientPipe(c, c) + if err != nil { + t.Fatalf("unexpected error: %+v", err) + } + + <-canReturn + pair.svr = server + pair.cli = client + return pair +} + +func clientRequestServerPair(t *testing.T, options ...RequestServerOption) *csPair { + return clientRequestServerPairWithHandlers(t, InMemHandler(), options...) +} + +func checkRequestServerAllocator(t *testing.T, p *csPair) { + if p.svr.pktMgr.alloc == nil { + return + } + checkAllocatorBeforeServerClose(t, p.svr.pktMgr.alloc) + p.Close() + checkAllocatorAfterServerClose(t, p.svr.pktMgr.alloc) +} + +// after adding logging, maybe check log to make sure packet handling +// was split over more than one worker +func TestRequestSplitWrite(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + w, err := p.cli.Create("/foo") + require.NoError(t, err) + p.cli.maxPacket = 3 // force it to send in small chunks + contents := "one two three four five six seven eight nine ten" + w.Write([]byte(contents)) + w.Close() + r := p.testHandler() + f, err := r.fetch("http://github.com/foo") + require.NoError(t, err) + assert.Equal(t, contents, string(f.content)) + checkRequestServerAllocator(t, p) +} + +func TestRequestCache(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + foo := NewRequest("", "foo") + foo.ctx, foo.cancelCtx = context.WithCancel(context.Background()) + bar := NewRequest("", "bar") + fh := p.svr.nextRequest(foo) + bh := p.svr.nextRequest(bar) + assert.Len(t, p.svr.openRequests, 2) + _foo, ok := p.svr.getRequest(fh) + assert.Equal(t, foo.Method, _foo.Method) + assert.Equal(t, foo.Filepath, _foo.Filepath) + assert.Equal(t, foo.Target, _foo.Target) + assert.Equal(t, foo.Flags, _foo.Flags) + assert.Equal(t, foo.Attrs, _foo.Attrs) + assert.Equal(t, foo.state, _foo.state) + assert.NotNil(t, _foo.ctx) + assert.Equal(t, _foo.Context().Err(), nil, "context is still valid") + assert.True(t, ok) + _, ok = p.svr.getRequest("zed") + assert.False(t, ok) + p.svr.closeRequest(fh) + assert.Equal(t, _foo.Context().Err(), context.Canceled, "context is now canceled") + p.svr.closeRequest(bh) + assert.Len(t, p.svr.openRequests, 0) + checkRequestServerAllocator(t, p) +} + +func TestRequestCacheState(t *testing.T) { + // test operation that uses open/close + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + require.NoError(t, err) + assert.Len(t, p.svr.openRequests, 0) + // test operation that doesn't open/close + err = p.cli.Remove("/foo") + assert.NoError(t, err) + assert.Len(t, p.svr.openRequests, 0) + checkRequestServerAllocator(t, p) +} + +func putTestFile(cli *Client, path, content string) (int, error) { + w, err := cli.Create(path) + if err != nil { + return 0, err + } + defer w.Close() + + return w.Write([]byte(content)) +} + +func getTestFile(cli *Client, path string) ([]byte, error) { + r, err := cli.Open(path) + if err != nil { + return nil, err + } + defer r.Close() + + return ioutil.ReadAll(r) +} + +func TestRequestWrite(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + n, err := putTestFile(p.cli, "/foo", "hello") + require.NoError(t, err) + assert.Equal(t, 5, n) + r := p.testHandler() + f, err := r.fetch("http://github.com/foo") + require.NoError(t, err) + assert.False(t, f.isdir) + assert.Equal(t, f.content, []byte("hello")) + checkRequestServerAllocator(t, p) +} + +func TestRequestWriteEmpty(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + n, err := putTestFile(p.cli, "/foo", "") + require.NoError(t, err) + assert.Equal(t, 0, n) + r := p.testHandler() + f, err := r.fetch("http://github.com/foo") + require.NoError(t, err) + assert.False(t, f.isdir) + assert.Len(t, f.content, 0) + // lets test with an error + r.returnErr(os.ErrInvalid) + n, err = putTestFile(p.cli, "/bar", "") + require.Error(t, err) + r.returnErr(nil) + assert.Equal(t, 0, n) + checkRequestServerAllocator(t, p) +} + +func TestRequestFilename(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + require.NoError(t, err) + r := p.testHandler() + f, err := r.fetch("http://github.com/foo") + require.NoError(t, err) + assert.Equal(t, f.Name(), "foo") + _, err = r.fetch("http://github.com/bar") + assert.Error(t, err) + checkRequestServerAllocator(t, p) +} + +func TestRequestJustRead(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + require.NoError(t, err) + rf, err := p.cli.Open("/foo") + require.NoError(t, err) + defer rf.Close() + contents := make([]byte, 5) + n, err := rf.Read(contents) + if err != nil && err != io.EOF { + t.Fatalf("err: %v", err) + } + assert.Equal(t, 5, n) + assert.Equal(t, "hello", string(contents[0:5])) + checkRequestServerAllocator(t, p) +} + +func TestRequestOpenFail(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + rf, err := p.cli.Open("/foo") + assert.Exactly(t, os.ErrNotExist, err) + assert.Nil(t, rf) + // if we return an error the sftp client will not close the handle + // ensure that we close it ourself + assert.Len(t, p.svr.openRequests, 0) + checkRequestServerAllocator(t, p) +} + +func TestRequestCreate(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + fh, err := p.cli.Create("foo") + require.NoError(t, err) + err = fh.Close() + assert.NoError(t, err) + checkRequestServerAllocator(t, p) +} + +func TestRequestReadAndWrite(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + + file, err := p.cli.OpenFile("/foo", os.O_RDWR|os.O_CREATE) + require.NoError(t, err) + defer file.Close() + + n, err := file.Write([]byte("hello")) + require.NoError(t, err) + assert.Equal(t, 5, n) + + buf := make([]byte, 4) + n, err = file.ReadAt(buf, 1) + require.NoError(t, err) + assert.Equal(t, 4, n) + assert.Equal(t, []byte{'e', 'l', 'l', 'o'}, buf) + + checkRequestServerAllocator(t, p) +} + +func TestOpenFileExclusive(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + + // first open should work + file, err := p.cli.OpenFile("/foo", os.O_RDWR|os.O_CREATE|os.O_EXCL) + require.NoError(t, err) + file.Close() + + // second open should return error + _, err = p.cli.OpenFile("/foo", os.O_RDWR|os.O_CREATE|os.O_EXCL) + assert.Error(t, err) + + checkRequestServerAllocator(t, p) +} + +func TestOpenFileExclusiveNoSymlinkFollowing(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + + // make a directory + err := p.cli.Mkdir("/foo") + require.NoError(t, err) + + // make a symlink to that directory + err = p.cli.Symlink("/foo", "/foo2") + require.NoError(t, err) + + // with O_EXCL, we can follow directory symlinks + file, err := p.cli.OpenFile("/foo2/bar", os.O_RDWR|os.O_CREATE|os.O_EXCL) + require.NoError(t, err) + err = file.Close() + require.NoError(t, err) + + // we should have created the file above; and this create should fail. + _, err = p.cli.OpenFile("/foo/bar", os.O_RDWR|os.O_CREATE|os.O_EXCL) + require.Error(t, err) + + // create a dangling symlink + err = p.cli.Symlink("/notexist", "/bar") + require.NoError(t, err) + + // opening a dangling symlink with O_CREATE and O_EXCL should fail, regardless of target not existing. + _, err = p.cli.OpenFile("/bar", os.O_RDWR|os.O_CREATE|os.O_EXCL) + require.Error(t, err) + + checkRequestServerAllocator(t, p) +} + +func TestRequestMkdir(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + err := p.cli.Mkdir("/foo") + require.NoError(t, err) + r := p.testHandler() + f, err := r.fetch("http://github.com/foo") + require.NoError(t, err) + assert.True(t, f.IsDir()) + checkRequestServerAllocator(t, p) +} + +func TestRequestRemove(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + require.NoError(t, err) + r := p.testHandler() + _, err = r.fetch("http://github.com/foo") + assert.NoError(t, err) + err = p.cli.Remove("/foo") + assert.NoError(t, err) + _, err = r.fetch("http://github.com/foo") + assert.Equal(t, err, os.ErrNotExist) + checkRequestServerAllocator(t, p) +} + +func TestRequestRename(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + + _, err := putTestFile(p.cli, "/foo", "hello") + require.NoError(t, err) + content, err := getTestFile(p.cli, "/foo") + require.NoError(t, err) + require.Equal(t, []byte("hello"), content) + + err = p.cli.Rename("/foo", "/bar") + require.NoError(t, err) + + // file contents are now at /bar + content, err = getTestFile(p.cli, "/bar") + require.NoError(t, err) + require.Equal(t, []byte("hello"), content) + + // /foo no longer exists + _, err = getTestFile(p.cli, "/foo") + require.Error(t, err) + + _, err = putTestFile(p.cli, "/baz", "goodbye") + require.NoError(t, err) + content, err = getTestFile(p.cli, "/baz") + require.NoError(t, err) + require.Equal(t, []byte("goodbye"), content) + + // SFTP-v2: SSH_FXP_RENAME may not overwrite existing files. + err = p.cli.Rename("/bar", "/baz") + require.Error(t, err) + + // /bar and /baz are unchanged + content, err = getTestFile(p.cli, "/bar") + require.NoError(t, err) + require.Equal(t, []byte("hello"), content) + content, err = getTestFile(p.cli, "/baz") + require.NoError(t, err) + require.Equal(t, []byte("goodbye"), content) + + // posix-rename@openssh.com extension allows overwriting existing files. + err = p.cli.PosixRename("/bar", "/baz") + require.NoError(t, err) + + // /baz now has the contents of /bar + content, err = getTestFile(p.cli, "/baz") + require.NoError(t, err) + require.Equal(t, []byte("hello"), content) + + // /bar no longer exists + _, err = getTestFile(p.cli, "/bar") + require.Error(t, err) + + checkRequestServerAllocator(t, p) +} + +func TestRequestRenameFail(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + require.NoError(t, err) + _, err = putTestFile(p.cli, "/bar", "goodbye") + require.NoError(t, err) + err = p.cli.Rename("/foo", "/bar") + assert.IsType(t, &StatusError{}, err) + checkRequestServerAllocator(t, p) +} + +func TestRequestStat(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + require.NoError(t, err) + fi, err := p.cli.Stat("/foo") + require.NoError(t, err) + assert.Equal(t, "foo", fi.Name()) + assert.Equal(t, int64(5), fi.Size()) + assert.Equal(t, os.FileMode(0644), fi.Mode()) + assert.NoError(t, testOsSys(fi.Sys())) + checkRequestServerAllocator(t, p) +} + +// NOTE: Setstat is a noop in the request server tests, but we want to test +// that is does nothing without crapping out. +func TestRequestSetstat(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + require.NoError(t, err) + mode := os.FileMode(0644) + err = p.cli.Chmod("/foo", mode) + require.NoError(t, err) + fi, err := p.cli.Stat("/foo") + require.NoError(t, err) + assert.Equal(t, "foo", fi.Name()) + assert.Equal(t, int64(5), fi.Size()) + assert.Equal(t, os.FileMode(0644), fi.Mode()) + assert.NoError(t, testOsSys(fi.Sys())) + checkRequestServerAllocator(t, p) +} + +func TestRequestFstat(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + require.NoError(t, err) + fp, err := p.cli.Open("/foo") + require.NoError(t, err) + fi, err := fp.Stat() + require.NoError(t, err) + assert.Equal(t, "foo", fi.Name()) + assert.Equal(t, int64(5), fi.Size()) + assert.Equal(t, os.FileMode(0644), fi.Mode()) + assert.NoError(t, testOsSys(fi.Sys())) + checkRequestServerAllocator(t, p) +} + +func TestRequestFsetstat(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + require.NoError(t, err) + fp, err := p.cli.OpenFile("/foo", os.O_WRONLY) + require.NoError(t, err) + err = fp.Truncate(2) + require.NoError(t, err) + fi, err := fp.Stat() + require.NoError(t, err) + assert.Equal(t, fi.Name(), "foo") + assert.Equal(t, fi.Size(), int64(2)) + err = fp.Truncate(5) + require.NoError(t, err) + fi, err = fp.Stat() + require.NoError(t, err) + assert.Equal(t, fi.Name(), "foo") + assert.Equal(t, fi.Size(), int64(5)) + err = fp.Close() + assert.NoError(t, err) + rf, err := p.cli.Open("/foo") + assert.NoError(t, err) + defer rf.Close() + contents := make([]byte, 20) + n, err := rf.Read(contents) + assert.EqualError(t, err, io.EOF.Error()) + assert.Equal(t, 5, n) + assert.Equal(t, []byte{'h', 'e', 0, 0, 0}, contents[0:n]) + checkRequestServerAllocator(t, p) +} + +func TestRequestStatFail(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + fi, err := p.cli.Stat("/foo") + assert.Nil(t, fi) + assert.True(t, os.IsNotExist(err)) + checkRequestServerAllocator(t, p) +} + +func TestRequestLstat(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + require.NoError(t, err) + err = p.cli.Symlink("/foo", "/bar") + require.NoError(t, err) + fi, err := p.cli.Lstat("/bar") + require.NoError(t, err) + assert.True(t, fi.Mode()&os.ModeSymlink == os.ModeSymlink) + checkRequestServerAllocator(t, p) +} + +func TestRequestLink(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + + _, err := putTestFile(p.cli, "/foo", "hello") + require.NoError(t, err) + + err = p.cli.Link("/foo", "/bar") + require.NoError(t, err) + + content, err := getTestFile(p.cli, "/bar") + assert.NoError(t, err) + assert.Equal(t, []byte("hello"), content) + + checkRequestServerAllocator(t, p) +} + +func TestRequestLinkFail(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + err := p.cli.Link("/foo", "/bar") + t.Log(err) + assert.True(t, os.IsNotExist(err)) + checkRequestServerAllocator(t, p) +} + +func TestRequestSymlink(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + + const CONTENT_FOO = "hello" + const CONTENT_DIR_FILE_TXT = "file" + const CONTENT_SUB_FILE_TXT = "file-in-sub" + + // prepare all files + _, err := putTestFile(p.cli, "/foo", CONTENT_FOO) + require.NoError(t, err) + err = p.cli.Mkdir("/dir") + require.NoError(t, err) + err = p.cli.Mkdir("/dir/sub") + require.NoError(t, err) + _, err = putTestFile(p.cli, "/dir/file.txt", CONTENT_DIR_FILE_TXT) + require.NoError(t, err) + _, err = putTestFile(p.cli, "/dir/sub/file-in-sub.txt", CONTENT_SUB_FILE_TXT) + require.NoError(t, err) + + type symlink struct { + name string // this is the filename of the symbolic link + target string // this is the file or directory the link points to + + //for testing + expectsNotExist bool + expectedFileContent string + } + + symlinks := []symlink{ + {name: "/bar", target: "/foo", expectedFileContent: CONTENT_FOO}, + {name: "/baz", target: "/bar", expectedFileContent: CONTENT_FOO}, + {name: "/link-to-non-existent-file", target: "non-existent-file", expectsNotExist: true}, + {name: "/dir/rel-link.txt", target: "file.txt", expectedFileContent: CONTENT_DIR_FILE_TXT}, + {name: "/dir/abs-link.txt", target: "/dir/file.txt", expectedFileContent: CONTENT_DIR_FILE_TXT}, + {name: "/dir/rel-subdir-link.txt", target: "sub/file-in-sub.txt", expectedFileContent: CONTENT_SUB_FILE_TXT}, + {name: "/dir/abs-subdir-link.txt", target: "/dir/sub/file-in-sub.txt", expectedFileContent: CONTENT_SUB_FILE_TXT}, + {name: "/dir/sub/parentdir-link.txt", target: "../file.txt", expectedFileContent: CONTENT_DIR_FILE_TXT}, + } + + for _, s := range symlinks { + err := p.cli.Symlink(s.target, s.name) + require.NoError(t, err, "Creating symlink %q with target %q failed", s.name, s.target) + + rl, err := p.cli.ReadLink(s.name) + require.NoError(t, err, "ReadLink(%q) failed", s.name) + require.Equal(t, s.target, rl, "Unexpected result when reading symlink %q", s.name) + } + + // test fetching via symlink + r := p.testHandler() + + for _, s := range symlinks { + fi, err := r.lfetch(s.name) + require.NoError(t, err, "lfetch(%q) failed", s.name) + require.True(t, fi.Mode()&os.ModeSymlink == os.ModeSymlink, "Expected %q to be a symlink but it is not.", s.name) + + content, err := getTestFile(p.cli, s.name) + if s.expectsNotExist { + require.True(t, os.IsNotExist(err), "Reading symlink %q expected os.ErrNotExist", s.name) + } else { + require.NoError(t, err, "getTestFile(%q) failed", s.name) + require.Equal(t, []byte(s.expectedFileContent), content, "Reading symlink %q returned unexpected content", s.name) + } + } + + checkRequestServerAllocator(t, p) +} + +func TestRequestSymlinkLoop(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + + err := p.cli.Symlink("/foo", "/bar") + require.NoError(t, err) + err = p.cli.Symlink("/bar", "/baz") + require.NoError(t, err) + err = p.cli.Symlink("/baz", "/foo") + require.NoError(t, err) + + // test should fail if we reach this point + timer := time.NewTimer(1 * time.Second) + defer timer.Stop() + + var content []byte + + done := make(chan struct{}) + go func() { + defer close(done) + + content, err = getTestFile(p.cli, "/bar") + }() + + select { + case <-timer.C: + t.Fatal("symlink loop following timed out") + return // just to let the compiler be absolutely sure + + case <-done: + } + + assert.Error(t, err) + assert.Len(t, content, 0) + + checkRequestServerAllocator(t, p) +} + +func TestRequestSymlinkDanglingFiles(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + + // dangling links are ok. We will use "/foo" later. + err := p.cli.Symlink("/foo", "/bar") + require.NoError(t, err) + + // creating a symlink in a non-existent directory should fail. + err = p.cli.Symlink("/dangle", "/foo/bar") + require.Error(t, err) + + // creating a symlink under a dangling symlink should fail. + err = p.cli.Symlink("/dangle", "/bar/bar") + require.Error(t, err) + + // opening a dangling link without O_CREATE should fail with os.IsNotExist == true + _, err = p.cli.OpenFile("/bar", os.O_RDONLY) + require.True(t, os.IsNotExist(err)) + + // overwriting a symlink is not allowed. + err = p.cli.Symlink("/dangle", "/bar") + require.Error(t, err) + + // double symlink + err = p.cli.Symlink("/bar", "/baz") + require.NoError(t, err) + + // opening a dangling link with O_CREATE should work. + _, err = putTestFile(p.cli, "/baz", "hello") + require.NoError(t, err) + + // dangling link creation should create the target file itself. + content, err := getTestFile(p.cli, "/foo") + require.NoError(t, err) + assert.Equal(t, []byte("hello"), content) + + // creating a symlink under a non-directory file should fail. + err = p.cli.Symlink("/dangle", "/foo/bar") + assert.Error(t, err) + + checkRequestServerAllocator(t, p) +} + +func TestRequestSymlinkDanglingDirectories(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + + // dangling links are ok. We will use "/foo" later. + err := p.cli.Symlink("/foo", "/bar") + require.NoError(t, err) + + // reading from a dangling symlink should fail. + _, err = p.cli.ReadDir("/bar") + require.True(t, os.IsNotExist(err)) + + // making a directory on a dangling symlink SHOULD NOT work. + err = p.cli.Mkdir("/bar") + require.Error(t, err) + + // ok, now make directory, so we can test make files through the symlink. + err = p.cli.Mkdir("/foo") + require.NoError(t, err) + + // should be able to make a file in that symlinked directory. + _, err = putTestFile(p.cli, "/bar/baz", "hello") + require.NoError(t, err) + + // dangling directory creation should create the target directory itself. + content, err := getTestFile(p.cli, "/foo/baz") + assert.NoError(t, err) + assert.Equal(t, []byte("hello"), content) + + checkRequestServerAllocator(t, p) +} + +func TestRequestReadlink(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + _, err := putTestFile(p.cli, "/foo", "hello") + require.NoError(t, err) + err = p.cli.Symlink("/foo", "/bar") + require.NoError(t, err) + + rl, err := p.cli.ReadLink("/bar") + assert.NoError(t, err) + assert.Equal(t, "/foo", rl) + + _, err = p.cli.ReadLink("/foo") + assert.Error(t, err, "Readlink on non-symlink should fail") + + _, err = p.cli.ReadLink("/does-not-exist") + assert.Error(t, err, "Readlink on non-existent file should fail") + + checkRequestServerAllocator(t, p) +} + +func TestRequestReaddir(t *testing.T) { + p := clientRequestServerPair(t) + MaxFilelist = 22 // make not divisible by our test amount (100) + defer p.Close() + for i := 0; i < 100; i++ { + fname := fmt.Sprintf("/foo_%02d", i) + _, err := putTestFile(p.cli, fname, fname) + if err != nil { + t.Fatal("expected no error, got:", err) + } + } + _, err := p.cli.ReadDir("/foo_01") + if runtime.GOOS == "zos" { + assert.Equal(t, &StatusError{Code: sshFxFailure, + msg: " /foo_01: EDC5135I Not a directory."}, err) + } else { + assert.Equal(t, &StatusError{Code: sshFxFailure, + msg: " /foo_01: not a directory"}, err) + } + _, err = p.cli.ReadDir("/does_not_exist") + assert.Equal(t, os.ErrNotExist, err) + di, err := p.cli.ReadDir("/") + require.NoError(t, err) + require.Len(t, di, 100) + names := []string{di[18].Name(), di[81].Name()} + assert.Equal(t, []string{"foo_18", "foo_81"}, names) + assert.Len(t, p.svr.openRequests, 0) + checkRequestServerAllocator(t, p) +} + +type testListerAtCloser struct { + isClosed bool +} + +func (l *testListerAtCloser) ListAt([]os.FileInfo, int64) (int, error) { + return 0, io.EOF +} + +func (l *testListerAtCloser) Close() error { + l.isClosed = true + return nil +} + +func TestRequestServerListerAtCloser(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + + handle, err := p.cli.opendir(context.Background(), "/") + require.NoError(t, err) + require.Len(t, p.svr.openRequests, 1) + req, ok := p.svr.getRequest(handle) + require.True(t, ok) + listerAt := &testListerAtCloser{} + req.setListerAt(listerAt) + assert.NotNil(t, req.state.getListerAt()) + err = p.cli.close(handle) + assert.NoError(t, err) + require.Len(t, p.svr.openRequests, 0) + assert.True(t, listerAt.isClosed) +} + +func TestRequestStatVFS(t *testing.T) { + if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { + t.Skip("StatVFS is implemented on linux and darwin") + } + + p := clientRequestServerPair(t) + defer p.Close() + + _, ok := p.cli.HasExtension("statvfs@openssh.com") + require.True(t, ok, "request server doesn't list statvfs extension") + vfs, err := p.cli.StatVFS("/") + require.NoError(t, err) + expected, err := getStatVFSForPath("/") + require.NoError(t, err) + require.NotEqual(t, 0, expected.ID) + // check some stats + require.Equal(t, expected.Bavail, vfs.Bavail) + require.Equal(t, expected.Bfree, vfs.Bfree) + require.Equal(t, expected.Blocks, vfs.Blocks) + + checkRequestServerAllocator(t, p) +} + +func TestRequestStatVFSError(t *testing.T) { + if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { + t.Skip("StatVFS is implemented on linux and darwin") + } + + p := clientRequestServerPair(t) + defer p.Close() + + _, err := p.cli.StatVFS("a missing path") + require.Error(t, err) + require.True(t, os.IsNotExist(err)) + + checkRequestServerAllocator(t, p) +} + +func TestRequestStartDirOption(t *testing.T) { + startDir := "/start/dir" + p := clientRequestServerPair(t, WithStartDirectory(startDir)) + defer p.Close() + + // create the start directory + err := p.cli.MkdirAll(startDir) + require.NoError(t, err) + // the working directory must be the defined start directory + wd, err := p.cli.Getwd() + require.NoError(t, err) + require.Equal(t, startDir, wd) + // upload a file using a relative path, it must be uploaded to the start directory + fileName := "file.txt" + _, err = putTestFile(p.cli, fileName, "") + require.NoError(t, err) + // we must be able to stat the file using both a relative and an absolute path + for _, filePath := range []string{fileName, path.Join(startDir, fileName)} { + fi, err := p.cli.Stat(filePath) + require.NoError(t, err) + assert.Equal(t, fileName, fi.Name()) + } + // list dir contents using a relative path + entries, err := p.cli.ReadDir(".") + assert.NoError(t, err) + assert.Len(t, entries, 1) + // delete the file using a relative path + err = p.cli.Remove(fileName) + assert.NoError(t, err) +} + +func TestCleanDisconnect(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + + err := p.cli.conn.Close() + require.NoError(t, err) + // server must return io.EOF after a clean client close + // with no pending open requests + err = <-p.svrResult + require.EqualError(t, err, io.EOF.Error()) + checkRequestServerAllocator(t, p) +} + +func TestUncleanDisconnect(t *testing.T) { + p := clientRequestServerPair(t) + defer p.Close() + + foo := NewRequest("", "foo") + p.svr.nextRequest(foo) + err := p.cli.conn.Close() + require.NoError(t, err) + // the foo request above is still open after the client disconnects + // so the server will convert io.EOF to io.ErrUnexpectedEOF + err = <-p.svrResult + require.EqualError(t, err, io.ErrUnexpectedEOF.Error()) + checkRequestServerAllocator(t, p) +} + +func TestRealPath(t *testing.T) { + startDir := "/startdir" + // the default InMemHandler does not implement the RealPathFileLister interface + // so we are using the builtin implementation here + p := clientRequestServerPair(t, WithStartDirectory(startDir)) + defer p.Close() + + realPath, err := p.cli.RealPath(".") + require.NoError(t, err) + assert.Equal(t, startDir, realPath) + realPath, err = p.cli.RealPath("/") + require.NoError(t, err) + assert.Equal(t, "/", realPath) + realPath, err = p.cli.RealPath("..") + require.NoError(t, err) + assert.Equal(t, "/", realPath) + realPath, err = p.cli.RealPath("../../..") + require.NoError(t, err) + assert.Equal(t, "/", realPath) + // test a relative path + realPath, err = p.cli.RealPath("relpath") + require.NoError(t, err) + assert.Equal(t, path.Join(startDir, "relpath"), realPath) +} + +// In memory file-system which implements RealPathFileLister +type rootWithRealPather struct { + root +} + +// implements RealpathFileLister interface +func (fs *rootWithRealPather) RealPath(p string) (string, error) { + if fs.mockErr != nil { + return "", fs.mockErr + } + return cleanPath(p), nil +} + +func TestRealPathFileLister(t *testing.T) { + root := &rootWithRealPather{ + root: root{ + rootFile: &memFile{name: "/", modtime: time.Now(), isdir: true}, + files: make(map[string]*memFile), + }, + } + handlers := Handlers{root, root, root, root} + p := clientRequestServerPairWithHandlers(t, handlers) + defer p.Close() + + realPath, err := p.cli.RealPath(".") + require.NoError(t, err) + assert.Equal(t, "/", realPath) + realPath, err = p.cli.RealPath("relpath") + require.NoError(t, err) + assert.Equal(t, "/relpath", realPath) + // test an error + root.returnErr(ErrSSHFxPermissionDenied) + _, err = p.cli.RealPath("/") + require.ErrorIs(t, err, os.ErrPermission) +} + +// In memory file-system which implements legacyRealPathFileLister +type rootWithLegacyRealPather struct { + root +} + +// implements RealpathFileLister interface +func (fs *rootWithLegacyRealPather) RealPath(p string) string { + return cleanPath(p) +} + +func TestLegacyRealPathFileLister(t *testing.T) { + root := &rootWithLegacyRealPather{ + root: root{ + rootFile: &memFile{name: "/", modtime: time.Now(), isdir: true}, + files: make(map[string]*memFile), + }, + } + handlers := Handlers{root, root, root, root} + p := clientRequestServerPairWithHandlers(t, handlers) + defer p.Close() + + realPath, err := p.cli.RealPath(".") + require.NoError(t, err) + assert.Equal(t, "/", realPath) + realPath, err = p.cli.RealPath("..") + require.NoError(t, err) + assert.Equal(t, "/", realPath) + realPath, err = p.cli.RealPath("relpath") + require.NoError(t, err) + assert.Equal(t, "/relpath", realPath) +} + +func TestCleanPath(t *testing.T) { + assert.Equal(t, "/", cleanPath("/")) + assert.Equal(t, "/", cleanPath(".")) + assert.Equal(t, "/", cleanPath("")) + assert.Equal(t, "/", cleanPath("/.")) + assert.Equal(t, "/", cleanPath("/a/..")) + assert.Equal(t, "/a/c", cleanPath("/a/b/../c")) + assert.Equal(t, "/a/c", cleanPath("/a/b/../c/")) + assert.Equal(t, "/a", cleanPath("/a/b/..")) + assert.Equal(t, "/a/b/c", cleanPath("/a/b/c")) + assert.Equal(t, "/", cleanPath("//")) + assert.Equal(t, "/a", cleanPath("/a/")) + assert.Equal(t, "/a", cleanPath("a/")) + assert.Equal(t, "/a/b/c", cleanPath("/a//b//c/")) + + // filepath.ToSlash does not touch \ as char on unix systems + // so os.PathSeparator is used for windows compatible tests + bslash := string(os.PathSeparator) + assert.Equal(t, "/", cleanPath(bslash)) + assert.Equal(t, "/", cleanPath(bslash+bslash)) + assert.Equal(t, "/a", cleanPath(bslash+"a"+bslash)) + assert.Equal(t, "/a", cleanPath("a"+bslash)) + assert.Equal(t, "/a/b/c", + cleanPath(bslash+"a"+bslash+bslash+"b"+bslash+bslash+"c"+bslash)) + assert.Equal(t, "/C:/a", cleanPath("C:"+bslash+"a")) +} diff --git a/request-unix.go b/request-unix.go new file mode 100644 index 00000000..e3e037d6 --- /dev/null +++ b/request-unix.go @@ -0,0 +1,24 @@ +//go:build !windows && !plan9 +// +build !windows,!plan9 + +package sftp + +import ( + "errors" + "syscall" +) + +func fakeFileInfoSys() interface{} { + return &syscall.Stat_t{Uid: 65534, Gid: 65534} +} + +func testOsSys(sys interface{}) error { + fstat := sys.(*FileStat) + if fstat.UID != uint32(65534) { + return errors.New("Uid failed to match") + } + if fstat.GID != uint32(65534) { + return errors.New("Gid failed to match") + } + return nil +} diff --git a/request.go b/request.go new file mode 100644 index 00000000..e7c47a9c --- /dev/null +++ b/request.go @@ -0,0 +1,670 @@ +package sftp + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "strings" + "sync" + "syscall" +) + +// MaxFilelist is the max number of files to return in a readdir batch. +var MaxFilelist int64 = 100 + +// state encapsulates the reader/writer/readdir from handlers. +type state struct { + mu sync.RWMutex + + writerAt io.WriterAt + readerAt io.ReaderAt + writerAtReaderAt WriterAtReaderAt + listerAt ListerAt + lsoffset int64 +} + +// copy returns a shallow copy the state. +// This is broken out to specific fields, +// because we have to copy around the mutex in state. +func (s *state) copy() state { + s.mu.RLock() + defer s.mu.RUnlock() + + return state{ + writerAt: s.writerAt, + readerAt: s.readerAt, + writerAtReaderAt: s.writerAtReaderAt, + listerAt: s.listerAt, + lsoffset: s.lsoffset, + } +} + +func (s *state) setReaderAt(rd io.ReaderAt) { + s.mu.Lock() + defer s.mu.Unlock() + + s.readerAt = rd +} + +func (s *state) getReaderAt() io.ReaderAt { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.readerAt +} + +func (s *state) setWriterAt(rd io.WriterAt) { + s.mu.Lock() + defer s.mu.Unlock() + + s.writerAt = rd +} + +func (s *state) getWriterAt() io.WriterAt { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.writerAt +} + +func (s *state) setWriterAtReaderAt(rw WriterAtReaderAt) { + s.mu.Lock() + defer s.mu.Unlock() + + s.writerAtReaderAt = rw +} + +func (s *state) getWriterAtReaderAt() WriterAtReaderAt { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.writerAtReaderAt +} + +func (s *state) getAllReaderWriters() (io.ReaderAt, io.WriterAt, WriterAtReaderAt) { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.readerAt, s.writerAt, s.writerAtReaderAt +} + +// Returns current offset for file list +func (s *state) lsNext() int64 { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.lsoffset +} + +// Increases next offset +func (s *state) lsInc(offset int64) { + s.mu.Lock() + defer s.mu.Unlock() + + s.lsoffset += offset +} + +// manage file read/write state +func (s *state) setListerAt(la ListerAt) { + s.mu.Lock() + defer s.mu.Unlock() + + s.listerAt = la +} + +func (s *state) getListerAt() ListerAt { + s.mu.RLock() + defer s.mu.RUnlock() + + return s.listerAt +} + +func (s *state) closeListerAt() error { + s.mu.Lock() + defer s.mu.Unlock() + + var err error + + if s.listerAt != nil { + if c, ok := s.listerAt.(io.Closer); ok { + err = c.Close() + } + s.listerAt = nil + } + + return err +} + +// Request contains the data and state for the incoming service request. +type Request struct { + // Get, Put, Setstat, Stat, Rename, Remove + // Rmdir, Mkdir, List, Readlink, Link, Symlink + Method string + Filepath string + Flags uint32 + Attrs []byte // convert to sub-struct + Target string // for renames and sym-links + handle string + + // reader/writer/readdir from handlers + state + + // context lasts duration of request + ctx context.Context + cancelCtx context.CancelFunc +} + +// NewRequest creates a new Request object. +func NewRequest(method, path string) *Request { + return &Request{ + Method: method, + Filepath: cleanPath(path), + } +} + +// copy returns a shallow copy of existing request. +// This is broken out to specific fields, +// because we have to copy around the mutex in state. +func (r *Request) copy() *Request { + return &Request{ + Method: r.Method, + Filepath: r.Filepath, + Flags: r.Flags, + Attrs: r.Attrs, + Target: r.Target, + handle: r.handle, + + state: r.state.copy(), + + ctx: r.ctx, + cancelCtx: r.cancelCtx, + } +} + +// New Request initialized based on packet data +func requestFromPacket(ctx context.Context, pkt hasPath, baseDir string) *Request { + request := &Request{ + Method: requestMethod(pkt), + Filepath: cleanPathWithBase(baseDir, pkt.getPath()), + } + request.ctx, request.cancelCtx = context.WithCancel(ctx) + + switch p := pkt.(type) { + case *sshFxpOpenPacket: + request.Flags = p.Pflags + request.Attrs = p.Attrs.([]byte) + case *sshFxpSetstatPacket: + request.Flags = p.Flags + request.Attrs = p.Attrs.([]byte) + case *sshFxpRenamePacket: + request.Target = cleanPathWithBase(baseDir, p.Newpath) + case *sshFxpSymlinkPacket: + // NOTE: given a POSIX compliant signature: symlink(target, linkpath string) + // this makes Request.Target the linkpath, and Request.Filepath the target. + request.Target = cleanPathWithBase(baseDir, p.Linkpath) + request.Filepath = p.Targetpath + case *sshFxpExtendedPacketHardlink: + request.Target = cleanPathWithBase(baseDir, p.Newpath) + } + return request +} + +// Context returns the request's context. To change the context, +// use WithContext. +// +// The returned context is always non-nil; it defaults to the +// background context. +// +// For incoming server requests, the context is canceled when the +// request is complete or the client's connection closes. +func (r *Request) Context() context.Context { + if r.ctx != nil { + return r.ctx + } + return context.Background() +} + +// WithContext returns a copy of r with its context changed to ctx. +// The provided ctx must be non-nil. +func (r *Request) WithContext(ctx context.Context) *Request { + if ctx == nil { + panic("nil context") + } + r2 := r.copy() + r2.ctx = ctx + r2.cancelCtx = nil + return r2 +} + +// Close reader/writer if possible +func (r *Request) close() error { + defer func() { + if r.cancelCtx != nil { + r.cancelCtx() + } + }() + + err := r.state.closeListerAt() + + rd, wr, rw := r.getAllReaderWriters() + + // Close errors on a Writer are far more likely to be the important one. + // As they can be information that there was a loss of data. + if c, ok := wr.(io.Closer); ok { + if err2 := c.Close(); err == nil { + // update error if it is still nil + err = err2 + } + } + + if c, ok := rw.(io.Closer); ok { + if err2 := c.Close(); err == nil { + // update error if it is still nil + err = err2 + + r.setWriterAtReaderAt(nil) + } + } + + if c, ok := rd.(io.Closer); ok { + if err2 := c.Close(); err == nil { + // update error if it is still nil + err = err2 + } + } + + return err +} + +// Notify transfer error if any +func (r *Request) transferError(err error) { + if err == nil { + return + } + + rd, wr, rw := r.getAllReaderWriters() + + if t, ok := wr.(TransferError); ok { + t.TransferError(err) + } + + if t, ok := rw.(TransferError); ok { + t.TransferError(err) + } + + if t, ok := rd.(TransferError); ok { + t.TransferError(err) + } +} + +// called from worker to handle packet/request +func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket { + switch r.Method { + case "Get": + return fileget(handlers.FileGet, r, pkt, alloc, orderID, maxTxPacket) + case "Put": + return fileput(handlers.FilePut, r, pkt, alloc, orderID, maxTxPacket) + case "Open": + return fileputget(handlers.FilePut, r, pkt, alloc, orderID, maxTxPacket) + case "Setstat", "Rename", "Rmdir", "Mkdir", "Link", "Symlink", "Remove", "PosixRename", "StatVFS": + return filecmd(handlers.FileCmd, r, pkt) + case "List": + return filelist(handlers.FileList, r, pkt) + case "Stat", "Lstat": + return filestat(handlers.FileList, r, pkt) + case "Readlink": + if readlinkFileLister, ok := handlers.FileList.(ReadlinkFileLister); ok { + return readlink(readlinkFileLister, r, pkt) + } + return filestat(handlers.FileList, r, pkt) + default: + return statusFromError(pkt.id(), fmt.Errorf("unexpected method: %s", r.Method)) + } +} + +// Additional initialization for Open packets +func (r *Request) open(h Handlers, pkt requestPacket) responsePacket { + flags := r.Pflags() + + id := pkt.id() + + switch { + case flags.Write, flags.Append, flags.Creat, flags.Trunc: + if flags.Read { + if openFileWriter, ok := h.FilePut.(OpenFileWriter); ok { + r.Method = "Open" + rw, err := openFileWriter.OpenFile(r) + if err != nil { + return statusFromError(id, err) + } + + r.setWriterAtReaderAt(rw) + + return &sshFxpHandlePacket{ + ID: id, + Handle: r.handle, + } + } + } + + r.Method = "Put" + wr, err := h.FilePut.Filewrite(r) + if err != nil { + return statusFromError(id, err) + } + + r.setWriterAt(wr) + + case flags.Read: + r.Method = "Get" + rd, err := h.FileGet.Fileread(r) + if err != nil { + return statusFromError(id, err) + } + + r.setReaderAt(rd) + + default: + return statusFromError(id, errors.New("bad file flags")) + } + + return &sshFxpHandlePacket{ + ID: id, + Handle: r.handle, + } +} + +func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket { + r.Method = "List" + la, err := h.FileList.Filelist(r) + if err != nil { + return statusFromError(pkt.id(), wrapPathError(r.Filepath, err)) + } + + r.setListerAt(la) + + return &sshFxpHandlePacket{ + ID: pkt.id(), + Handle: r.handle, + } +} + +// wrap FileReader handler +func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket { + rd := r.getReaderAt() + if rd == nil { + return statusFromError(pkt.id(), errors.New("unexpected read packet")) + } + + data, offset, _ := packetData(pkt, alloc, orderID, maxTxPacket) + + n, err := rd.ReadAt(data, offset) + // only return EOF error if no data left to read + if err != nil && (err != io.EOF || n == 0) { + return statusFromError(pkt.id(), err) + } + + return &sshFxpDataPacket{ + ID: pkt.id(), + Length: uint32(n), + Data: data[:n], + } +} + +// wrap FileWriter handler +func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket { + wr := r.getWriterAt() + if wr == nil { + return statusFromError(pkt.id(), errors.New("unexpected write packet")) + } + + data, offset, _ := packetData(pkt, alloc, orderID, maxTxPacket) + + _, err := wr.WriteAt(data, offset) + return statusFromError(pkt.id(), err) +} + +// wrap OpenFileWriter handler +func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket { + rw := r.getWriterAtReaderAt() + if rw == nil { + return statusFromError(pkt.id(), errors.New("unexpected write and read packet")) + } + + switch p := pkt.(type) { + case *sshFxpReadPacket: + data, offset := p.getDataSlice(alloc, orderID, maxTxPacket), int64(p.Offset) + + n, err := rw.ReadAt(data, offset) + // only return EOF error if no data left to read + if err != nil && (err != io.EOF || n == 0) { + return statusFromError(pkt.id(), err) + } + + return &sshFxpDataPacket{ + ID: pkt.id(), + Length: uint32(n), + Data: data[:n], + } + + case *sshFxpWritePacket: + data, offset := p.Data, int64(p.Offset) + + _, err := rw.WriteAt(data, offset) + return statusFromError(pkt.id(), err) + + default: + return statusFromError(pkt.id(), errors.New("unexpected packet type for read or write")) + } +} + +// file data for additional read/write packets +func packetData(p requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) (data []byte, offset int64, length uint32) { + switch p := p.(type) { + case *sshFxpReadPacket: + return p.getDataSlice(alloc, orderID, maxTxPacket), int64(p.Offset), p.Len + case *sshFxpWritePacket: + return p.Data, int64(p.Offset), p.Length + } + return +} + +// wrap FileCmder handler +func filecmd(h FileCmder, r *Request, pkt requestPacket) responsePacket { + switch p := pkt.(type) { + case *sshFxpFsetstatPacket: + r.Flags = p.Flags + r.Attrs = p.Attrs.([]byte) + } + + switch r.Method { + case "PosixRename": + if posixRenamer, ok := h.(PosixRenameFileCmder); ok { + err := posixRenamer.PosixRename(r) + return statusFromError(pkt.id(), err) + } + + // PosixRenameFileCmder not implemented handle this request as a Rename + r.Method = "Rename" + err := h.Filecmd(r) + return statusFromError(pkt.id(), err) + + case "StatVFS": + if statVFSCmdr, ok := h.(StatVFSFileCmder); ok { + stat, err := statVFSCmdr.StatVFS(r) + if err != nil { + return statusFromError(pkt.id(), err) + } + stat.ID = pkt.id() + return stat + } + + return statusFromError(pkt.id(), ErrSSHFxOpUnsupported) + } + + err := h.Filecmd(r) + return statusFromError(pkt.id(), err) +} + +// wrap FileLister handler +func filelist(h FileLister, r *Request, pkt requestPacket) responsePacket { + lister := r.getListerAt() + if lister == nil { + return statusFromError(pkt.id(), errors.New("unexpected dir packet")) + } + + offset := r.lsNext() + finfo := make([]os.FileInfo, MaxFilelist) + n, err := lister.ListAt(finfo, offset) + r.lsInc(int64(n)) + // ignore EOF as we only return it when there are no results + finfo = finfo[:n] // avoid need for nil tests below + + switch r.Method { + case "List": + if err != nil && (err != io.EOF || n == 0) { + return statusFromError(pkt.id(), err) + } + + nameAttrs := make([]*sshFxpNameAttr, 0, len(finfo)) + + // If the type conversion fails, we get untyped `nil`, + // which is handled by not looking up any names. + idLookup, _ := h.(NameLookupFileLister) + + for _, fi := range finfo { + nameAttrs = append(nameAttrs, &sshFxpNameAttr{ + Name: fi.Name(), + LongName: runLs(idLookup, fi), + Attrs: []interface{}{fi}, + }) + } + + return &sshFxpNamePacket{ + ID: pkt.id(), + NameAttrs: nameAttrs, + } + + default: + err = fmt.Errorf("unexpected method: %s", r.Method) + return statusFromError(pkt.id(), err) + } +} + +func filestat(h FileLister, r *Request, pkt requestPacket) responsePacket { + var lister ListerAt + var err error + + if r.Method == "Lstat" { + if lstatFileLister, ok := h.(LstatFileLister); ok { + lister, err = lstatFileLister.Lstat(r) + } else { + // LstatFileLister not implemented handle this request as a Stat + r.Method = "Stat" + lister, err = h.Filelist(r) + } + } else { + lister, err = h.Filelist(r) + } + if err != nil { + return statusFromError(pkt.id(), err) + } + finfo := make([]os.FileInfo, 1) + n, err := lister.ListAt(finfo, 0) + finfo = finfo[:n] // avoid need for nil tests below + + switch r.Method { + case "Stat", "Lstat": + if err != nil && err != io.EOF { + return statusFromError(pkt.id(), err) + } + if n == 0 { + err = &os.PathError{ + Op: strings.ToLower(r.Method), + Path: r.Filepath, + Err: syscall.ENOENT, + } + return statusFromError(pkt.id(), err) + } + return &sshFxpStatResponse{ + ID: pkt.id(), + info: finfo[0], + } + case "Readlink": + if err != nil && err != io.EOF { + return statusFromError(pkt.id(), err) + } + if n == 0 { + err = &os.PathError{ + Op: "readlink", + Path: r.Filepath, + Err: syscall.ENOENT, + } + return statusFromError(pkt.id(), err) + } + filename := finfo[0].Name() + return &sshFxpNamePacket{ + ID: pkt.id(), + NameAttrs: []*sshFxpNameAttr{ + { + Name: filename, + LongName: filename, + Attrs: emptyFileStat, + }, + }, + } + default: + err = fmt.Errorf("unexpected method: %s", r.Method) + return statusFromError(pkt.id(), err) + } +} + +func readlink(readlinkFileLister ReadlinkFileLister, r *Request, pkt requestPacket) responsePacket { + resolved, err := readlinkFileLister.Readlink(r.Filepath) + if err != nil { + return statusFromError(pkt.id(), err) + } + return &sshFxpNamePacket{ + ID: pkt.id(), + NameAttrs: []*sshFxpNameAttr{ + { + Name: resolved, + LongName: resolved, + Attrs: emptyFileStat, + }, + }, + } +} + +// init attributes of request object from packet data +func requestMethod(p requestPacket) (method string) { + switch p.(type) { + case *sshFxpReadPacket, *sshFxpWritePacket, *sshFxpOpenPacket: + // set in open() above + case *sshFxpOpendirPacket, *sshFxpReaddirPacket: + // set in opendir() above + case *sshFxpSetstatPacket, *sshFxpFsetstatPacket: + method = "Setstat" + case *sshFxpRenamePacket: + method = "Rename" + case *sshFxpSymlinkPacket: + method = "Symlink" + case *sshFxpRemovePacket: + method = "Remove" + case *sshFxpStatPacket, *sshFxpFstatPacket: + method = "Stat" + case *sshFxpLstatPacket: + method = "Lstat" + case *sshFxpRmdirPacket: + method = "Rmdir" + case *sshFxpReadlinkPacket: + method = "Readlink" + case *sshFxpMkdirPacket: + method = "Mkdir" + case *sshFxpExtendedPacketHardlink: + method = "Link" + } + return method +} diff --git a/request_test.go b/request_test.go new file mode 100644 index 00000000..807833aa --- /dev/null +++ b/request_test.go @@ -0,0 +1,249 @@ +package sftp + +import ( + "bytes" + "errors" + "io" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +type testHandler struct { + filecontents []byte // dummy contents + output io.WriterAt // dummy file out + err error // dummy error, should be file related +} + +func (t *testHandler) Fileread(r *Request) (io.ReaderAt, error) { + if t.err != nil { + return nil, t.err + } + _ = r.WithContext(r.Context()) // initialize context for deadlock testing + return bytes.NewReader(t.filecontents), nil +} + +func (t *testHandler) Filewrite(r *Request) (io.WriterAt, error) { + if t.err != nil { + return nil, t.err + } + _ = r.WithContext(r.Context()) // initialize context for deadlock testing + return io.WriterAt(t.output), nil +} + +func (t *testHandler) Filecmd(r *Request) error { + _ = r.WithContext(r.Context()) // initialize context for deadlock testing + return t.err +} + +func (t *testHandler) Filelist(r *Request) (ListerAt, error) { + if t.err != nil { + return nil, t.err + } + _ = r.WithContext(r.Context()) // initialize context for deadlock testing + f, err := os.Open(r.Filepath) + if err != nil { + return nil, err + } + fi, err := f.Stat() + if err != nil { + return nil, err + } + return listerat([]os.FileInfo{fi}), nil +} + +// make sure len(fakefile) == len(filecontents) +type fakefile [10]byte + +var filecontents = []byte("file-data.") + +// XXX need new for creating test requests that supports Open-ing +func testRequest(method string) *Request { + var flags uint32 + switch method { + case "Get": + flags = flags | sshFxfRead + case "Put": + flags = flags | sshFxfWrite + } + request := &Request{ + Filepath: "./request_test.go", + Method: method, + Attrs: []byte("foo"), + Flags: flags, + Target: "foo", + } + return request +} + +func (ff *fakefile) WriteAt(p []byte, off int64) (int, error) { + n := copy(ff[off:], p) + return n, nil +} + +func (ff fakefile) string() string { + b := make([]byte, len(ff)) + copy(b, ff[:]) + return string(b) +} + +func newTestHandlers() Handlers { + handler := &testHandler{ + filecontents: filecontents, + output: &fakefile{}, + err: nil, + } + return Handlers{ + FileGet: handler, + FilePut: handler, + FileCmd: handler, + FileList: handler, + } +} + +func (h Handlers) getOutString() string { + handler := h.FilePut.(*testHandler) + return handler.output.(*fakefile).string() +} + +var errTest = errors.New("test error") + +func (h *Handlers) returnError(err error) { + handler := h.FilePut.(*testHandler) + handler.err = err +} + +func getStatusMsg(p interface{}) string { + pkt := p.(*sshFxpStatusPacket) + return pkt.StatusError.msg +} +func checkOkStatus(t *testing.T, p interface{}) { + pkt := p.(*sshFxpStatusPacket) + assert.Equal(t, pkt.StatusError.Code, uint32(sshFxOk), + "sshFxpStatusPacket not OK\n", pkt.StatusError.msg) +} + +// fake/test packet +type fakePacket struct { + myid uint32 + handle string +} + +func (f fakePacket) id() uint32 { + return f.myid +} + +func (f fakePacket) getHandle() string { + return f.handle +} +func (fakePacket) UnmarshalBinary(d []byte) error { return nil } + +// XXX can't just set method to Get, need to use Open to setup Get/Put +func TestRequestGet(t *testing.T) { + handlers := newTestHandlers() + request := testRequest("Get") + pkt := fakePacket{myid: 1} + request.open(handlers, pkt) + // req.length is 5, so we test reads in 5 byte chunks + for i, txt := range []string{"file-", "data."} { + pkt := &sshFxpReadPacket{ID: uint32(i), Handle: "a", + Offset: uint64(i * 5), Len: 5} + rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) + dpkt := rpkt.(*sshFxpDataPacket) + assert.Equal(t, dpkt.id(), uint32(i)) + assert.Equal(t, string(dpkt.Data), txt) + } +} + +func TestRequestCustomError(t *testing.T) { + handlers := newTestHandlers() + request := testRequest("Stat") + pkt := fakePacket{myid: 1} + cmdErr := errors.New("stat not supported") + handlers.returnError(cmdErr) + rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) + assert.Equal(t, rpkt, statusFromError(pkt.myid, cmdErr)) +} + +// XXX can't just set method to Get, need to use Open to setup Get/Put +func TestRequestPut(t *testing.T) { + handlers := newTestHandlers() + request := testRequest("Put") + request.state.writerAt, _ = handlers.FilePut.Filewrite(request) + pkt := &sshFxpWritePacket{ID: 0, Handle: "a", Offset: 0, Length: 5, + Data: []byte("file-")} + rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) + checkOkStatus(t, rpkt) + pkt = &sshFxpWritePacket{ID: 1, Handle: "a", Offset: 5, Length: 5, + Data: []byte("data.")} + rpkt = request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) + checkOkStatus(t, rpkt) + assert.Equal(t, "file-data.", handlers.getOutString()) +} + +func TestRequestCmdr(t *testing.T) { + handlers := newTestHandlers() + request := testRequest("Mkdir") + pkt := fakePacket{myid: 1} + rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) + checkOkStatus(t, rpkt) + + handlers.returnError(errTest) + rpkt = request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) + assert.Equal(t, rpkt, statusFromError(pkt.myid, errTest)) +} + +func TestRequestInfoStat(t *testing.T) { + handlers := newTestHandlers() + request := testRequest("Stat") + pkt := fakePacket{myid: 1} + rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) + spkt, ok := rpkt.(*sshFxpStatResponse) + assert.True(t, ok) + assert.Equal(t, spkt.info.Name(), "request_test.go") +} + +func TestRequestInfoList(t *testing.T) { + handlers := newTestHandlers() + request := testRequest("List") + request.handle = "1" + pkt := fakePacket{myid: 1} + rpkt := request.opendir(handlers, pkt) + hpkt, ok := rpkt.(*sshFxpHandlePacket) + if assert.True(t, ok) { + assert.Equal(t, hpkt.Handle, "1") + } + pkt = fakePacket{myid: 2} + request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) +} +func TestRequestInfoReadlink(t *testing.T) { + handlers := newTestHandlers() + request := testRequest("Readlink") + pkt := fakePacket{myid: 1} + rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) + npkt, ok := rpkt.(*sshFxpNamePacket) + if assert.True(t, ok) { + assert.IsType(t, &sshFxpNameAttr{}, npkt.NameAttrs[0]) + assert.Equal(t, npkt.NameAttrs[0].Name, "request_test.go") + } +} + +func TestOpendirHandleReuse(t *testing.T) { + handlers := newTestHandlers() + request := testRequest("Stat") + request.handle = "1" + pkt := fakePacket{myid: 1} + rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) + assert.IsType(t, &sshFxpStatResponse{}, rpkt) + + request.Method = "List" + pkt = fakePacket{myid: 2} + rpkt = request.opendir(handlers, pkt) + if assert.IsType(t, &sshFxpHandlePacket{}, rpkt) { + hpkt := rpkt.(*sshFxpHandlePacket) + assert.Equal(t, hpkt.Handle, "1") + } + rpkt = request.call(handlers, pkt, nil, 0, defaultMaxTxPacket) + assert.IsType(t, &sshFxpNamePacket{}, rpkt) +} diff --git a/request_windows.go b/request_windows.go new file mode 100644 index 00000000..bd1d6864 --- /dev/null +++ b/request_windows.go @@ -0,0 +1,13 @@ +package sftp + +import ( + "syscall" +) + +func fakeFileInfoSys() interface{} { + return syscall.Win32FileAttributeData{} +} + +func testOsSys(sys interface{}) error { + return nil +} diff --git a/server.go b/server.go new file mode 100644 index 00000000..7735c42c --- /dev/null +++ b/server.go @@ -0,0 +1,657 @@ +package sftp + +// sftp server counterpart + +import ( + "encoding" + "errors" + "fmt" + "io" + "io/fs" + "io/ioutil" + "os" + "path/filepath" + "strconv" + "sync" + "syscall" + "time" +) + +const ( + // SftpServerWorkerCount defines the number of workers for the SFTP server + SftpServerWorkerCount = 8 +) + +type file interface { + Stat() (os.FileInfo, error) + ReadAt(b []byte, off int64) (int, error) + WriteAt(b []byte, off int64) (int, error) + Readdir(int) ([]os.FileInfo, error) + Name() string + Truncate(int64) error + Chmod(mode fs.FileMode) error + Chown(uid, gid int) error + Close() error +} + +// Server is an SSH File Transfer Protocol (sftp) server. +// This is intended to provide the sftp subsystem to an ssh server daemon. +// This implementation currently supports most of sftp server protocol version 3, +// as specified at https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt +type Server struct { + *serverConn + debugStream io.Writer + readOnly bool + pktMgr *packetManager + openFiles map[string]file + openFilesLock sync.RWMutex + handleCount int + workDir string + winRoot bool + maxTxPacket uint32 +} + +func (svr *Server) nextHandle(f file) string { + svr.openFilesLock.Lock() + defer svr.openFilesLock.Unlock() + svr.handleCount++ + handle := strconv.Itoa(svr.handleCount) + svr.openFiles[handle] = f + return handle +} + +func (svr *Server) closeHandle(handle string) error { + svr.openFilesLock.Lock() + defer svr.openFilesLock.Unlock() + if f, ok := svr.openFiles[handle]; ok { + delete(svr.openFiles, handle) + return f.Close() + } + + return EBADF +} + +func (svr *Server) getHandle(handle string) (file, bool) { + svr.openFilesLock.RLock() + defer svr.openFilesLock.RUnlock() + f, ok := svr.openFiles[handle] + return f, ok +} + +type serverRespondablePacket interface { + encoding.BinaryUnmarshaler + id() uint32 + respond(svr *Server) responsePacket +} + +// NewServer creates a new Server instance around the provided streams, serving +// content from the root of the filesystem. Optionally, ServerOption +// functions may be specified to further configure the Server. +// +// A subsequent call to Serve() is required to begin serving files over SFTP. +func NewServer(rwc io.ReadWriteCloser, options ...ServerOption) (*Server, error) { + svrConn := &serverConn{ + conn: conn{ + Reader: rwc, + WriteCloser: rwc, + }, + } + s := &Server{ + serverConn: svrConn, + debugStream: ioutil.Discard, + pktMgr: newPktMgr(svrConn), + openFiles: make(map[string]file), + maxTxPacket: defaultMaxTxPacket, + } + + for _, o := range options { + if err := o(s); err != nil { + return nil, err + } + } + + return s, nil +} + +// A ServerOption is a function which applies configuration to a Server. +type ServerOption func(*Server) error + +// WithDebug enables Server debugging output to the supplied io.Writer. +func WithDebug(w io.Writer) ServerOption { + return func(s *Server) error { + s.debugStream = w + return nil + } +} + +// ReadOnly configures a Server to serve files in read-only mode. +func ReadOnly() ServerOption { + return func(s *Server) error { + s.readOnly = true + return nil + } +} + +// WindowsRootEnumeratesDrives configures a Server to serve a virtual '/' for windows that lists all drives +func WindowsRootEnumeratesDrives() ServerOption { + return func(s *Server) error { + s.winRoot = true + return nil + } +} + +// WithAllocator enable the allocator. +// After processing a packet we keep in memory the allocated slices +// and we reuse them for new packets. +// The allocator is experimental +func WithAllocator() ServerOption { + return func(s *Server) error { + alloc := newAllocator() + s.pktMgr.alloc = alloc + s.conn.alloc = alloc + return nil + } +} + +// WithServerWorkingDirectory sets a working directory to use as base +// for relative paths. +// If unset the default is current working directory (os.Getwd). +func WithServerWorkingDirectory(workDir string) ServerOption { + return func(s *Server) error { + s.workDir = cleanPath(workDir) + return nil + } +} + +// WithMaxTxPacket sets the maximum size of the payload returned to the client, +// measured in bytes. The default value is 32768 bytes, and this option +// can only be used to increase it. Setting this option to a larger value +// should be safe, because the client decides the size of the requested payload. +// +// The default maximum packet size is 32768 bytes. +func WithMaxTxPacket(size uint32) ServerOption { + return func(s *Server) error { + if size < defaultMaxTxPacket { + return errors.New("size must be greater than or equal to 32768") + } + + s.maxTxPacket = size + + return nil + } +} + +type rxPacket struct { + pktType fxp + pktBytes []byte +} + +// Up to N parallel servers +func (svr *Server) sftpServerWorker(pktChan chan orderedRequest) error { + for pkt := range pktChan { + // readonly checks + readonly := true + switch pkt := pkt.requestPacket.(type) { + case notReadOnly: + readonly = false + case *sshFxpOpenPacket: + readonly = pkt.readonly() + case *sshFxpExtendedPacket: + readonly = pkt.readonly() + } + + // If server is operating read-only and a write operation is requested, + // return permission denied + if !readonly && svr.readOnly { + svr.pktMgr.readyPacket( + svr.pktMgr.newOrderedResponse(statusFromError(pkt.id(), syscall.EPERM), pkt.orderID()), + ) + continue + } + + if err := handlePacket(svr, pkt); err != nil { + return err + } + } + return nil +} + +func handlePacket(s *Server, p orderedRequest) error { + var rpkt responsePacket + orderID := p.orderID() + switch p := p.requestPacket.(type) { + case *sshFxInitPacket: + rpkt = &sshFxVersionPacket{ + Version: sftpProtocolVersion, + Extensions: sftpExtensions, + } + case *sshFxpStatPacket: + // stat the requested file + info, err := os.Stat(s.toLocalPath(p.Path)) + rpkt = &sshFxpStatResponse{ + ID: p.ID, + info: info, + } + if err != nil { + rpkt = statusFromError(p.ID, err) + } + case *sshFxpLstatPacket: + // stat the requested file + info, err := s.lstat(s.toLocalPath(p.Path)) + rpkt = &sshFxpStatResponse{ + ID: p.ID, + info: info, + } + if err != nil { + rpkt = statusFromError(p.ID, err) + } + case *sshFxpFstatPacket: + f, ok := s.getHandle(p.Handle) + var err error = EBADF + var info os.FileInfo + if ok { + info, err = f.Stat() + rpkt = &sshFxpStatResponse{ + ID: p.ID, + info: info, + } + } + if err != nil { + rpkt = statusFromError(p.ID, err) + } + case *sshFxpMkdirPacket: + // TODO FIXME: ignore flags field + err := os.Mkdir(s.toLocalPath(p.Path), 0o755) + rpkt = statusFromError(p.ID, err) + case *sshFxpRmdirPacket: + err := os.Remove(s.toLocalPath(p.Path)) + rpkt = statusFromError(p.ID, err) + case *sshFxpRemovePacket: + err := os.Remove(s.toLocalPath(p.Filename)) + rpkt = statusFromError(p.ID, err) + case *sshFxpRenamePacket: + err := os.Rename(s.toLocalPath(p.Oldpath), s.toLocalPath(p.Newpath)) + rpkt = statusFromError(p.ID, err) + case *sshFxpSymlinkPacket: + err := os.Symlink(s.toLocalPath(p.Targetpath), s.toLocalPath(p.Linkpath)) + rpkt = statusFromError(p.ID, err) + case *sshFxpClosePacket: + rpkt = statusFromError(p.ID, s.closeHandle(p.Handle)) + case *sshFxpReadlinkPacket: + f, err := os.Readlink(s.toLocalPath(p.Path)) + rpkt = &sshFxpNamePacket{ + ID: p.ID, + NameAttrs: []*sshFxpNameAttr{ + { + Name: f, + LongName: f, + Attrs: emptyFileStat, + }, + }, + } + if err != nil { + rpkt = statusFromError(p.ID, err) + } + case *sshFxpRealpathPacket: + f, err := filepath.Abs(s.toLocalPath(p.Path)) + f = cleanPath(f) + rpkt = &sshFxpNamePacket{ + ID: p.ID, + NameAttrs: []*sshFxpNameAttr{ + { + Name: f, + LongName: f, + Attrs: emptyFileStat, + }, + }, + } + if err != nil { + rpkt = statusFromError(p.ID, err) + } + case *sshFxpOpendirPacket: + lp := s.toLocalPath(p.Path) + + if stat, err := s.stat(lp); err != nil { + rpkt = statusFromError(p.ID, err) + } else if !stat.IsDir() { + rpkt = statusFromError(p.ID, &os.PathError{ + Path: lp, Err: syscall.ENOTDIR, + }) + } else { + rpkt = (&sshFxpOpenPacket{ + ID: p.ID, + Path: p.Path, + Pflags: sshFxfRead, + }).respond(s) + } + case *sshFxpReadPacket: + var err error = EBADF + f, ok := s.getHandle(p.Handle) + if ok { + err = nil + data := p.getDataSlice(s.pktMgr.alloc, orderID, s.maxTxPacket) + n, _err := f.ReadAt(data, int64(p.Offset)) + if _err != nil && (_err != io.EOF || n == 0) { + err = _err + } + rpkt = &sshFxpDataPacket{ + ID: p.ID, + Length: uint32(n), + Data: data[:n], + // do not use data[:n:n] here to clamp the capacity, we allocated extra capacity above to avoid reallocations + } + } + if err != nil { + rpkt = statusFromError(p.ID, err) + } + + case *sshFxpWritePacket: + f, ok := s.getHandle(p.Handle) + var err error = EBADF + if ok { + _, err = f.WriteAt(p.Data, int64(p.Offset)) + } + rpkt = statusFromError(p.ID, err) + case *sshFxpExtendedPacket: + if p.SpecificPacket == nil { + rpkt = statusFromError(p.ID, ErrSSHFxOpUnsupported) + } else { + rpkt = p.respond(s) + } + case serverRespondablePacket: + rpkt = p.respond(s) + default: + return fmt.Errorf("unexpected packet type %T", p) + } + + s.pktMgr.readyPacket(s.pktMgr.newOrderedResponse(rpkt, orderID)) + return nil +} + +// Serve serves SFTP connections until the streams stop or the SFTP subsystem +// is stopped. It returns nil if the server exits cleanly. +func (svr *Server) Serve() error { + defer func() { + if svr.pktMgr.alloc != nil { + svr.pktMgr.alloc.Free() + } + }() + var wg sync.WaitGroup + runWorker := func(ch chan orderedRequest) { + wg.Add(1) + go func() { + defer wg.Done() + if err := svr.sftpServerWorker(ch); err != nil { + svr.conn.Close() // shuts down recvPacket + } + }() + } + pktChan := svr.pktMgr.workerChan(runWorker) + + var err error + var pkt requestPacket + var pktType fxp + var pktBytes []byte + for { + pktType, pktBytes, err = svr.serverConn.recvPacket(svr.pktMgr.getNextOrderID()) + if err != nil { + // Check whether the connection terminated cleanly in-between packets. + if err == io.EOF { + err = nil + } + // we don't care about releasing allocated pages here, the server will quit and the allocator freed + break + } + + pkt, err = makePacket(rxPacket{pktType, pktBytes}) + if err != nil { + switch { + case errors.Is(err, errUnknownExtendedPacket): + //if err := svr.serverConn.sendError(pkt, ErrSshFxOpUnsupported); err != nil { + // debug("failed to send err packet: %v", err) + // svr.conn.Close() // shuts down recvPacket + // break + //} + default: + debug("makePacket err: %v", err) + svr.conn.Close() // shuts down recvPacket + break + } + } + + pktChan <- svr.pktMgr.newOrderedRequest(pkt) + } + + close(pktChan) // shuts down sftpServerWorkers + wg.Wait() // wait for all workers to exit + + // close any still-open files + for handle, file := range svr.openFiles { + fmt.Fprintf(svr.debugStream, "sftp server file with handle %q left open: %v\n", handle, file.Name()) + file.Close() + } + return err // error from recvPacket +} + +type ider interface { + id() uint32 +} + +// The init packet has no ID, so we just return a zero-value ID +func (p *sshFxInitPacket) id() uint32 { return 0 } + +type sshFxpStatResponse struct { + ID uint32 + info os.FileInfo +} + +func (p *sshFxpStatResponse) marshalPacket() ([]byte, []byte, error) { + l := 4 + 1 + 4 // uint32(length) + byte(type) + uint32(id) + + b := make([]byte, 4, l) + b = append(b, sshFxpAttrs) + b = marshalUint32(b, p.ID) + + var payload []byte + payload = marshalFileInfo(payload, p.info) + + return b, payload, nil +} + +func (p *sshFxpStatResponse) MarshalBinary() ([]byte, error) { + header, payload, err := p.marshalPacket() + return append(header, payload...), err +} + +var emptyFileStat = []interface{}{uint32(0)} + +func (p *sshFxpOpenPacket) readonly() bool { + return !p.hasPflags(sshFxfWrite) +} + +func (p *sshFxpOpenPacket) hasPflags(flags ...uint32) bool { + for _, f := range flags { + if p.Pflags&f == 0 { + return false + } + } + return true +} + +func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket { + var osFlags int + if p.hasPflags(sshFxfRead, sshFxfWrite) { + osFlags |= os.O_RDWR + } else if p.hasPflags(sshFxfWrite) { + osFlags |= os.O_WRONLY + } else if p.hasPflags(sshFxfRead) { + osFlags |= os.O_RDONLY + } else { + // how are they opening? + return statusFromError(p.ID, syscall.EINVAL) + } + + // Don't use O_APPEND flag as it conflicts with WriteAt. + // The sshFxfAppend flag is a no-op here as the client sends the offsets. + + if p.hasPflags(sshFxfCreat) { + osFlags |= os.O_CREATE + } + if p.hasPflags(sshFxfTrunc) { + osFlags |= os.O_TRUNC + } + if p.hasPflags(sshFxfExcl) { + osFlags |= os.O_EXCL + } + + mode := os.FileMode(0o644) + // Like OpenSSH, we only handle permissions here, and only when the file is being created. + // Otherwise, the permissions are ignored. + if p.Flags&sshFileXferAttrPermissions != 0 { + fs, err := p.unmarshalFileStat(p.Flags) + if err != nil { + return statusFromError(p.ID, err) + } + mode = fs.FileMode() & os.ModePerm + } + + f, err := svr.openfile(svr.toLocalPath(p.Path), osFlags, mode) + if err != nil { + return statusFromError(p.ID, err) + } + + handle := svr.nextHandle(f) + return &sshFxpHandlePacket{ID: p.ID, Handle: handle} +} + +func (p *sshFxpReaddirPacket) respond(svr *Server) responsePacket { + f, ok := svr.getHandle(p.Handle) + if !ok { + return statusFromError(p.ID, EBADF) + } + + dirents, err := f.Readdir(128) + if err != nil { + return statusFromError(p.ID, err) + } + + idLookup := osIDLookup{} + + ret := &sshFxpNamePacket{ID: p.ID} + for _, dirent := range dirents { + ret.NameAttrs = append(ret.NameAttrs, &sshFxpNameAttr{ + Name: dirent.Name(), + LongName: runLs(idLookup, dirent), + Attrs: []interface{}{dirent}, + }) + } + return ret +} + +func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket { + path := svr.toLocalPath(p.Path) + + debug("setstat name %q", path) + + fs, err := p.unmarshalFileStat(p.Flags) + + if err == nil && (p.Flags&sshFileXferAttrSize) != 0 { + err = os.Truncate(path, int64(fs.Size)) + } + if err == nil && (p.Flags&sshFileXferAttrPermissions) != 0 { + err = os.Chmod(path, fs.FileMode()) + } + if err == nil && (p.Flags&sshFileXferAttrUIDGID) != 0 { + err = os.Chown(path, int(fs.UID), int(fs.GID)) + } + if err == nil && (p.Flags&sshFileXferAttrACmodTime) != 0 { + err = os.Chtimes(path, fs.AccessTime(), fs.ModTime()) + } + + return statusFromError(p.ID, err) +} + +func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket { + f, ok := svr.getHandle(p.Handle) + if !ok { + return statusFromError(p.ID, EBADF) + } + + path := f.Name() + + debug("fsetstat name %q", path) + + fs, err := p.unmarshalFileStat(p.Flags) + + if err == nil && (p.Flags&sshFileXferAttrSize) != 0 { + err = f.Truncate(int64(fs.Size)) + } + if err == nil && (p.Flags&sshFileXferAttrPermissions) != 0 { + err = f.Chmod(fs.FileMode()) + } + if err == nil && (p.Flags&sshFileXferAttrUIDGID) != 0 { + err = f.Chown(int(fs.UID), int(fs.GID)) + } + if err == nil && (p.Flags&sshFileXferAttrACmodTime) != 0 { + type chtimer interface { + Chtimes(atime, mtime time.Time) error + } + + switch f := interface{}(f).(type) { + case chtimer: + // future-compatible, for when/if *os.File supports Chtimes. + err = f.Chtimes(fs.AccessTime(), fs.ModTime()) + default: + err = os.Chtimes(path, fs.AccessTime(), fs.ModTime()) + } + } + + return statusFromError(p.ID, err) +} + +func statusFromError(id uint32, err error) *sshFxpStatusPacket { + ret := &sshFxpStatusPacket{ + ID: id, + StatusError: StatusError{ + // sshFXOk = 0 + // sshFXEOF = 1 + // sshFXNoSuchFile = 2 ENOENT + // sshFXPermissionDenied = 3 + // sshFXFailure = 4 + // sshFXBadMessage = 5 + // sshFXNoConnection = 6 + // sshFXConnectionLost = 7 + // sshFXOPUnsupported = 8 + Code: sshFxOk, + }, + } + if err == nil { + return ret + } + + debug("statusFromError: error is %T %#v", err, err) + ret.StatusError.Code = sshFxFailure + ret.StatusError.msg = err.Error() + + if os.IsNotExist(err) { + ret.StatusError.Code = sshFxNoSuchFile + return ret + } + if code, ok := translateSyscallError(err); ok { + ret.StatusError.Code = code + return ret + } + + if errors.Is(err, io.EOF) { + ret.StatusError.Code = sshFxEOF + return ret + } + + var e fxerr + if errors.As(err, &e) { + ret.StatusError.Code = uint32(e) + return ret + } + + return ret +} diff --git a/server_integration_test.go b/server_integration_test.go new file mode 100644 index 00000000..398ea865 --- /dev/null +++ b/server_integration_test.go @@ -0,0 +1,939 @@ +package sftp + +// sftp server integration tests +// enable with -integration +// example invokation (darwin): gofmt -w `find . -name \*.go` && (cd server_standalone/ ; go build -tags debug) && go test -tags debug github.com/pkg/sftp -integration -v -sftp /usr/libexec/sftp-server -run ServerCompareSubsystems + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + crand "crypto/rand" + "crypto/x509" + "encoding/hex" + "encoding/pem" + "flag" + "fmt" + "io/ioutil" + "math/rand" + "net" + "os" + "os/exec" + "path/filepath" + "regexp" + "runtime" + "strconv" + "strings" + "testing" + "time" + + "github.com/kr/fs" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ssh" +) + +func TestMain(m *testing.M) { + sftpClientLocation, _ := exec.LookPath("sftp") + testSftpClientBin = flag.String("sftp_client", sftpClientLocation, "location of the sftp client binary") + + lookSFTPServer := []string{ + "/usr/libexec/sftp-server", + "/usr/lib/openssh/sftp-server", + "/usr/lib/ssh/sftp-server", + "C:\\Program Files\\Git\\usr\\lib\\ssh\\sftp-server.exe", + } + sftpServer, _ := exec.LookPath("sftp-server") + if len(sftpServer) == 0 { + for _, location := range lookSFTPServer { + if _, err := os.Stat(location); err == nil { + sftpServer = location + break + } + } + } + testSftp = flag.String("sftp", sftpServer, "location of the sftp server binary") + flag.Parse() + + os.Exit(m.Run()) +} + +func skipIfWindows(t testing.TB) { + if runtime.GOOS == "windows" { + t.Skip("skipping test on windows") + } +} + +func skipIfPlan9(t testing.TB) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test on plan9") + } +} + +var testServerImpl = flag.Bool("testserver", false, "perform integration tests against sftp package server instance") +var testIntegration = flag.Bool("integration", false, "perform integration tests against sftp server process") +var testAllocator = flag.Bool("allocator", false, "perform tests using the allocator") +var testSftp *string + +var testSftpClientBin *string +var sshServerDebugStream = ioutil.Discard +var sftpServerDebugStream = ioutil.Discard +var sftpClientDebugStream = ioutil.Discard + +const ( + GolangSFTP = true + OpenSSHSFTP = false +) + +var ( + hostPrivateKeySigner ssh.Signer + privKey = []byte(` +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEArhp7SqFnXVZAgWREL9Ogs+miy4IU/m0vmdkoK6M97G9NX/Pj +wf8I/3/ynxmcArbt8Rc4JgkjT2uxx/NqR0yN42N1PjO5Czu0dms1PSqcKIJdeUBV +7gdrKSm9Co4d2vwfQp5mg47eG4w63pz7Drk9+VIyi9YiYH4bve7WnGDswn4ycvYZ +slV5kKnjlfCdPig+g5P7yQYud0cDWVwyA0+kxvL6H3Ip+Fu8rLDZn4/P1WlFAIuc +PAf4uEKDGGmC2URowi5eesYR7f6GN/HnBs2776laNlAVXZUmYTUfOGagwLsEkx8x +XdNqntfbs2MOOoK+myJrNtcB9pCrM0H6um19uQIDAQABAoIBABkWr9WdVKvalgkP +TdQmhu3mKRNyd1wCl+1voZ5IM9Ayac/98UAvZDiNU4Uhx52MhtVLJ0gz4Oa8+i16 +IkKMAZZW6ro/8dZwkBzQbieWUFJ2Fso2PyvB3etcnGU8/Yhk9IxBDzy+BbuqhYE2 +1ebVQtz+v1HvVZzaD11bYYm/Xd7Y28QREVfFen30Q/v3dv7dOteDE/RgDS8Czz7w +jMW32Q8JL5grz7zPkMK39BLXsTcSYcaasT2ParROhGJZDmbgd3l33zKCVc1zcj9B +SA47QljGd09Tys958WWHgtj2o7bp9v1Ufs4LnyKgzrB80WX1ovaSQKvd5THTLchO +kLIhUAECgYEA2doGXy9wMBmTn/hjiVvggR1aKiBwUpnB87Hn5xCMgoECVhFZlT6l +WmZe7R2klbtG1aYlw+y+uzHhoVDAJW9AUSV8qoDUwbRXvBVlp+In5wIqJ+VjfivK +zgIfzomL5NvDz37cvPmzqIeySTowEfbQyq7CUQSoDtE9H97E2wWZhDkCgYEAzJdJ +k+NSFoTkHhfD3L0xCDHpRV3gvaOeew8524fVtVUq53X8m91ng4AX1r74dCUYwwiF +gqTtSSJfx2iH1xKnNq28M9uKg7wOrCKrRqNPnYUO3LehZEC7rwUr26z4iJDHjjoB +uBcS7nw0LJ+0Zeg1IF+aIdZGV3MrAKnrzWPixYECgYBsffX6ZWebrMEmQ89eUtFF +u9ZxcGI/4K8ErC7vlgBD5ffB4TYZ627xzFWuBLs4jmHCeNIJ9tct5rOVYN+wRO1k +/CRPzYUnSqb+1jEgILL6istvvv+DkE+ZtNkeRMXUndWwel94BWsBnUKe0UmrSJ3G +sq23J3iCmJW2T3z+DpXbkQKBgQCK+LUVDNPE0i42NsRnm+fDfkvLP7Kafpr3Umdl +tMY474o+QYn+wg0/aPJIf9463rwMNyyhirBX/k57IIktUdFdtfPicd2MEGETElWv +nN1GzYxD50Rs2f/jKisZhEwqT9YNyV9DkgDdGGdEbJNYqbv0qpwDIg8T9foe8E1p +bdErgQKBgAt290I3L316cdxIQTkJh1DlScN/unFffITwu127WMr28Jt3mq3cZpuM +Aecey/eEKCj+Rlas5NDYKsB18QIuAw+qqWyq0LAKLiAvP1965Rkc4PLScl3MgJtO +QYa37FK0p8NcDeUuF86zXBVutwS5nJLchHhKfd590ks57OROtm29 +-----END RSA PRIVATE KEY----- +`) +) + +func init() { + var err error + hostPrivateKeySigner, err = ssh.ParsePrivateKey(privKey) + if err != nil { + panic(err) + } +} + +func keyAuth(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + permissions := &ssh.Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + } + return permissions, nil +} + +func pwAuth(conn ssh.ConnMetadata, pw []byte) (*ssh.Permissions, error) { + permissions := &ssh.Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + } + return permissions, nil +} + +func basicServerConfig() *ssh.ServerConfig { + config := ssh.ServerConfig{ + Config: ssh.Config{ + MACs: []string{"hmac-sha1"}, + }, + PasswordCallback: pwAuth, + PublicKeyCallback: keyAuth, + } + config.AddHostKey(hostPrivateKeySigner) + return &config +} + +type sshServer struct { + useSubsystem bool + conn net.Conn + config *ssh.ServerConfig + sshConn *ssh.ServerConn + newChans <-chan ssh.NewChannel + newReqs <-chan *ssh.Request +} + +func sshServerFromConn(conn net.Conn, useSubsystem bool, config *ssh.ServerConfig) (*sshServer, error) { + // From a standard TCP connection to an encrypted SSH connection + sshConn, newChans, newReqs, err := ssh.NewServerConn(conn, config) + if err != nil { + return nil, err + } + + svr := &sshServer{useSubsystem, conn, config, sshConn, newChans, newReqs} + svr.listenChannels() + return svr, nil +} + +func (svr *sshServer) Wait() error { + return svr.sshConn.Wait() +} + +func (svr *sshServer) Close() error { + return svr.sshConn.Close() +} + +func (svr *sshServer) listenChannels() { + go func() { + for chanReq := range svr.newChans { + go svr.handleChanReq(chanReq) + } + }() + go func() { + for req := range svr.newReqs { + go svr.handleReq(req) + } + }() +} + +func (svr *sshServer) handleReq(req *ssh.Request) { + switch req.Type { + default: + rejectRequest(req) + } +} + +type sshChannelServer struct { + svr *sshServer + chanReq ssh.NewChannel + ch ssh.Channel + newReqs <-chan *ssh.Request +} + +type sshSessionChannelServer struct { + *sshChannelServer + env []string +} + +func (svr *sshServer) handleChanReq(chanReq ssh.NewChannel) { + fmt.Fprintf(sshServerDebugStream, "channel request: %v, extra: '%v'\n", chanReq.ChannelType(), hex.EncodeToString(chanReq.ExtraData())) + switch chanReq.ChannelType() { + case "session": + if ch, reqs, err := chanReq.Accept(); err != nil { + fmt.Fprintf(sshServerDebugStream, "fail to accept channel request: %v\n", err) + chanReq.Reject(ssh.ResourceShortage, "channel accept failure") + } else { + chsvr := &sshSessionChannelServer{ + sshChannelServer: &sshChannelServer{svr, chanReq, ch, reqs}, + env: append([]string{}, os.Environ()...), + } + chsvr.handle() + } + default: + chanReq.Reject(ssh.UnknownChannelType, "channel type is not a session") + } +} + +func (chsvr *sshSessionChannelServer) handle() { + // should maybe do something here... + go chsvr.handleReqs() +} + +func (chsvr *sshSessionChannelServer) handleReqs() { + for req := range chsvr.newReqs { + chsvr.handleReq(req) + } + fmt.Fprintf(sshServerDebugStream, "ssh server session channel complete\n") +} + +func (chsvr *sshSessionChannelServer) handleReq(req *ssh.Request) { + switch req.Type { + case "env": + chsvr.handleEnv(req) + case "subsystem": + chsvr.handleSubsystem(req) + default: + rejectRequest(req) + } +} + +func rejectRequest(req *ssh.Request) error { + fmt.Fprintf(sshServerDebugStream, "ssh rejecting request, type: %s\n", req.Type) + err := req.Reply(false, []byte{}) + if err != nil { + fmt.Fprintf(sshServerDebugStream, "ssh request reply had error: %v\n", err) + } + return err +} + +func rejectRequestUnmarshalError(req *ssh.Request, s interface{}, err error) error { + fmt.Fprintf(sshServerDebugStream, "ssh request unmarshaling error, type '%T': %v\n", s, err) + rejectRequest(req) + return err +} + +// env request form: +type sshEnvRequest struct { + Envvar string + Value string +} + +func (chsvr *sshSessionChannelServer) handleEnv(req *ssh.Request) error { + envReq := &sshEnvRequest{} + if err := ssh.Unmarshal(req.Payload, envReq); err != nil { + return rejectRequestUnmarshalError(req, envReq, err) + } + req.Reply(true, nil) + + found := false + for i, envstr := range chsvr.env { + if strings.HasPrefix(envstr, envReq.Envvar+"=") { + found = true + chsvr.env[i] = envReq.Envvar + "=" + envReq.Value + } + } + if !found { + chsvr.env = append(chsvr.env, envReq.Envvar+"="+envReq.Value) + } + + return nil +} + +// Payload: int: command size, string: command +type sshSubsystemRequest struct { + Name string +} + +type sshSubsystemExitStatus struct { + Status uint32 +} + +func (chsvr *sshSessionChannelServer) handleSubsystem(req *ssh.Request) error { + defer func() { + err1 := chsvr.ch.CloseWrite() + err2 := chsvr.ch.Close() + fmt.Fprintf(sshServerDebugStream, "ssh server subsystem request complete, err: %v %v\n", err1, err2) + }() + + subsystemReq := &sshSubsystemRequest{} + if err := ssh.Unmarshal(req.Payload, subsystemReq); err != nil { + return rejectRequestUnmarshalError(req, subsystemReq, err) + } + + // reply to the ssh client + + // no idea if this is actually correct spec-wise. + // just enough for an sftp server to start. + if subsystemReq.Name != "sftp" { + return req.Reply(false, nil) + } + + req.Reply(true, nil) + + if !chsvr.svr.useSubsystem { + // use the openssh sftp server backend; this is to test the ssh code, not the sftp code, + // or is used for comparison between our sftp subsystem and the openssh sftp subsystem + cmd := exec.Command(*testSftp, "-e", "-l", "DEBUG") // log to stderr + cmd.Stdin = chsvr.ch + cmd.Stdout = chsvr.ch + cmd.Stderr = sftpServerDebugStream + if err := cmd.Start(); err != nil { + return err + } + return cmd.Wait() + } + + sftpServer, err := NewServer( + chsvr.ch, + WithDebug(sftpServerDebugStream), + ) + if err != nil { + return err + } + + // wait for the session to close + runErr := sftpServer.Serve() + exitStatus := uint32(1) + if runErr == nil { + exitStatus = uint32(0) + } + + _, exitStatusErr := chsvr.ch.SendRequest("exit-status", false, ssh.Marshal(sshSubsystemExitStatus{exitStatus})) + return exitStatusErr +} + +// starts an ssh server to test. returns: host string and port +func testServer(t *testing.T, useSubsystem bool, readonly bool) (func(), string, int) { + t.Helper() + + if !*testIntegration { + t.Skip("skipping integration test") + } + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + host, portStr, err := net.SplitHostPort(listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + port, err := strconv.Atoi(portStr) + if err != nil { + t.Fatal(err) + } + + shutdown := make(chan struct{}) + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + select { + case <-shutdown: + default: + t.Error("ssh server socket closed:", err) + } + return + } + + go func() { + defer conn.Close() + + sshSvr, err := sshServerFromConn(conn, useSubsystem, basicServerConfig()) + if err != nil { + t.Error(err) + return + } + + _ = sshSvr.Wait() + }() + } + }() + + return func() { close(shutdown); listener.Close() }, host, port +} + +func makeDummyKey() (string, error) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), crand.Reader) + if err != nil { + return "", fmt.Errorf("cannot generate key: %w", err) + } + der, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return "", fmt.Errorf("cannot marshal key: %w", err) + } + block := &pem.Block{Type: "EC PRIVATE KEY", Bytes: der} + f, err := ioutil.TempFile("", "sftp-test-key-") + if err != nil { + return "", fmt.Errorf("cannot create temp file: %w", err) + } + defer func() { + if f != nil { + _ = f.Close() + _ = os.Remove(f.Name()) + } + }() + if err := pem.Encode(f, block); err != nil { + return "", fmt.Errorf("cannot write key: %w", err) + } + if err := f.Close(); err != nil { + return "", fmt.Errorf("error closing key file: %w", err) + } + path := f.Name() + f = nil + return path, nil +} + +type execError struct { + path string + stderr string + err error +} + +func (e *execError) Error() string { + return fmt.Sprintf("%s: %v: %s", e.path, e.err, e.stderr) +} + +func (e *execError) Unwrap() error { + return e.err +} + +func (e *execError) Cause() error { + return e.err +} + +func runSftpClient(t *testing.T, script string, path string, host string, port int) (string, error) { + // if sftp client binary is unavailable, skip test + if _, err := os.Stat(*testSftpClientBin); err != nil { + t.Skip("sftp client binary unavailable") + } + + // make a dummy key so we don't rely on ssh-agent + dummyKey, err := makeDummyKey() + if err != nil { + return "", err + } + defer os.Remove(dummyKey) + + cmd := exec.Command( + *testSftpClientBin, + // "-vvvv", + "-b", "-", + "-o", "StrictHostKeyChecking=no", + "-o", "LogLevel=ERROR", + "-o", "UserKnownHostsFile /dev/null", + // do not trigger ssh-agent prompting + "-o", "IdentityFile="+dummyKey, + "-o", "IdentitiesOnly=yes", + "-P", fmt.Sprintf("%d", port), fmt.Sprintf("%s:%s", host, path), + ) + + cmd.Stdin = strings.NewReader(script) + + stdout := new(bytes.Buffer) + cmd.Stdout = stdout + + stderr := new(bytes.Buffer) + cmd.Stderr = stderr + + if err := cmd.Start(); err != nil { + return "", err + } + + if err := cmd.Wait(); err != nil { + return stdout.String(), &execError{ + path: cmd.Path, + stderr: stderr.String(), + err: err, + } + } + + return stdout.String(), nil +} + +// assert.Eventually seems to have a data rate on macOS with go 1.14 so replace it with this simpler function +func waitForCondition(t *testing.T, condition func() bool) { + start := time.Now() + tick := 10 * time.Millisecond + waitFor := 100 * time.Millisecond + for !condition() { + time.Sleep(tick) + if time.Since(start) > waitFor { + break + } + } + assert.True(t, condition()) +} + +func checkAllocatorBeforeServerClose(t *testing.T, alloc *allocator) { + if alloc != nil { + // before closing the server we are, generally, waiting for new packets in recvPacket and we have a page allocated. + // Sometime the sendPacket returns some milliseconds after the client receives the response, and so we have 2 + // allocated pages here, so wait some milliseconds. To avoid crashes we must be sure to not release the pages + // too soon. + waitForCondition(t, func() bool { return alloc.countUsedPages() <= 1 }) + } +} + +func checkAllocatorAfterServerClose(t *testing.T, alloc *allocator) { + if alloc != nil { + // wait for the server cleanup + waitForCondition(t, func() bool { return alloc.countUsedPages() == 0 }) + waitForCondition(t, func() bool { return alloc.countAvailablePages() == 0 }) + } +} + +func TestServerCompareSubsystems(t *testing.T) { + if runtime.GOOS == "windows" { + // TODO (puellanivis): not sure how to fix this, the OpenSSH SFTP implementation closes immediately. + t.Skip() + } + + shutdownGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY) + defer shutdownGo() + + shutdownOp, hostOp, portOp := testServer(t, OpenSSHSFTP, READONLY) + defer shutdownOp() + + script := ` +ls / +ls -l / +ls /dev/ +ls -l /dev/ +ls -l /etc/ +ls -l /bin/ +ls -l /usr/bin/ +` + outputGo, err := runSftpClient(t, script, "/", hostGo, portGo) + if err != nil { + t.Fatal(err) + } + + outputOp, err := runSftpClient(t, script, "/", hostOp, portOp) + if err != nil { + t.Fatal(err) + } + + newlineRegex := regexp.MustCompile(`\r*\n`) + spaceRegex := regexp.MustCompile(`\s+`) + outputGoLines := newlineRegex.Split(outputGo, -1) + outputOpLines := newlineRegex.Split(outputOp, -1) + + if len(outputGoLines) != len(outputOpLines) { + t.Fatalf("output line count differs, go = %d, openssh = %d", len(outputGoLines), len(outputOpLines)) + } + + for i, goLine := range outputGoLines { + opLine := outputOpLines[i] + bad := false + if goLine != opLine { + goWords := spaceRegex.Split(goLine, -1) + opWords := spaceRegex.Split(opLine, -1) + // some fields are allowed to be different.. + // during testing as processes are created/destroyed. + for j, goWord := range goWords { + if j >= len(opWords) { + bad = true + break + } + opWord := opWords[j] + if goWord != opWord { + switch j { + case 1, 2, 3, 7: + // words[1] as the link count for directories like proc is unstable + // words[2] and [3] as these are users & groups + // words[7] as timestamps on dirs can vary for things like /tmp + case 8: + // words[8] can either have full path or just the filename + bad = !strings.HasSuffix(opWord, "/"+goWord) + default: + bad = true + } + } + } + } + + if bad { + t.Errorf("outputs differ\n go: %q\nopenssh: %q\n", goLine, opLine) + } + } +} + +var rng = rand.New(rand.NewSource(time.Now().Unix())) + +func randData(length int) []byte { + data := make([]byte, length) + for i := 0; i < length; i++ { + data[i] = byte(rng.Uint32()) + } + return data +} + +func randName() string { + return "sftp." + hex.EncodeToString(randData(16)) +} + +func TestServerMkdirRmdir(t *testing.T) { + shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY) + defer shutdown() + + tmpDir := "/tmp/" + randName() + defer os.RemoveAll(tmpDir) + + // mkdir remote + if _, err := runSftpClient(t, "mkdir "+tmpDir, "/", hostGo, portGo); err != nil { + t.Fatal(err) + } + + // directory should now exist + if _, err := os.Stat(tmpDir); err != nil { + t.Fatal(err) + } + + // now remove the directory + if _, err := runSftpClient(t, "rmdir "+tmpDir, "/", hostGo, portGo); err != nil { + t.Fatal(err) + } + + if _, err := os.Stat(tmpDir); err == nil { + t.Fatal("should have error after deleting the directory") + } +} + +func TestServerLink(t *testing.T) { + skipIfWindows(t) // No hard links on windows. + shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY) + defer shutdown() + + tmpFileLocalData := randData(999) + + linkdest := "/tmp/" + randName() + defer os.RemoveAll(linkdest) + if err := ioutil.WriteFile(linkdest, tmpFileLocalData, 0644); err != nil { + t.Fatal(err) + } + + link := "/tmp/" + randName() + defer os.RemoveAll(link) + + // now create a hard link within the new directory + if output, err := runSftpClient(t, fmt.Sprintf("ln %s %s", linkdest, link), "/", hostGo, portGo); err != nil { + t.Fatalf("failed: %v %v", err, string(output)) + } + + // file should now exist and be the same size as linkdest + if stat, err := os.Lstat(link); err != nil { + t.Fatal(err) + } else if int(stat.Size()) != len(tmpFileLocalData) { + t.Fatalf("wrong size: %v", len(tmpFileLocalData)) + } +} + +func TestServerSymlink(t *testing.T) { + skipIfWindows(t) // No symlinks on windows. + shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY) + defer shutdown() + + link := "/tmp/" + randName() + defer os.RemoveAll(link) + + // now create a symbolic link within the new directory + if output, err := runSftpClient(t, "symlink /bin/sh "+link, "/", hostGo, portGo); err != nil { + t.Fatalf("failed: %v %v", err, string(output)) + } + + // symlink should now exist + if stat, err := os.Lstat(link); err != nil { + t.Fatal(err) + } else if (stat.Mode() & os.ModeSymlink) != os.ModeSymlink { + t.Fatalf("is not a symlink: %v", stat.Mode()) + } +} + +func TestServerPut(t *testing.T) { + shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY) + defer shutdown() + + tmpFileLocal := "/tmp/" + randName() + tmpFileRemote := "/tmp/" + randName() + defer os.RemoveAll(tmpFileLocal) + defer os.RemoveAll(tmpFileRemote) + + t.Logf("put: local %v remote %v", tmpFileLocal, tmpFileRemote) + + // create a file with random contents. This will be the local file pushed to the server + tmpFileLocalData := randData(10 * 1024 * 1024) + if err := ioutil.WriteFile(tmpFileLocal, tmpFileLocalData, 0644); err != nil { + t.Fatal(err) + } + + // sftp the file to the server + if output, err := runSftpClient(t, "put "+tmpFileLocal+" "+tmpFileRemote, "/", hostGo, portGo); err != nil { + t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output) + } + + // tmpFile2 should now exist, with the same contents + if tmpFileRemoteData, err := ioutil.ReadFile(tmpFileRemote); err != nil { + t.Fatal(err) + } else if string(tmpFileLocalData) != string(tmpFileRemoteData) { + t.Fatal("contents of file incorrect after put") + } +} + +func TestServerResume(t *testing.T) { + shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY) + defer shutdown() + + tmpFileLocal := "/tmp/" + randName() + tmpFileRemote := "/tmp/" + randName() + defer os.RemoveAll(tmpFileLocal) + defer os.RemoveAll(tmpFileRemote) + + t.Logf("put: local %v remote %v", tmpFileLocal, tmpFileRemote) + + // create a local file with random contents to be pushed to the server + tmpFileLocalData := randData(2 * 1024 * 1024) + // only write half the data to simulate a split upload + half := 1024 * 1024 + err := ioutil.WriteFile(tmpFileLocal, tmpFileLocalData[:half], 0644) + if err != nil { + t.Fatal(err) + } + + // sftp the first half of the file to the server + output, err := runSftpClient(t, "put "+tmpFileLocal+" "+tmpFileRemote, + "/", hostGo, portGo) + if err != nil { + t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output) + } + + // write the full file out + err = ioutil.WriteFile(tmpFileLocal, tmpFileLocalData, 0644) + if err != nil { + t.Fatal(err) + } + // re-sftp the full file with the append flag set + output, err = runSftpClient(t, "put -a "+tmpFileLocal+" "+tmpFileRemote, + "/", hostGo, portGo) + if err != nil { + t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output) + } + + // tmpFileRemote should now exist, with the same contents + if tmpFileRemoteData, err := ioutil.ReadFile(tmpFileRemote); err != nil { + t.Fatal(err) + } else if string(tmpFileLocalData) != string(tmpFileRemoteData) { + t.Fatal("contents of file incorrect after put") + } +} + +func TestServerGet(t *testing.T) { + shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY) + defer shutdown() + + tmpFileLocal := "/tmp/" + randName() + tmpFileRemote := "/tmp/" + randName() + defer os.RemoveAll(tmpFileLocal) + defer os.RemoveAll(tmpFileRemote) + + t.Logf("get: local %v remote %v", tmpFileLocal, tmpFileRemote) + + // create a file with random contents. This will be the remote file pulled from the server + tmpFileRemoteData := randData(10 * 1024 * 1024) + if err := ioutil.WriteFile(tmpFileRemote, tmpFileRemoteData, 0644); err != nil { + t.Fatal(err) + } + + // sftp the file to the server + if output, err := runSftpClient(t, "get "+tmpFileRemote+" "+tmpFileLocal, "/", hostGo, portGo); err != nil { + t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output) + } + + // tmpFile2 should now exist, with the same contents + if tmpFileLocalData, err := ioutil.ReadFile(tmpFileLocal); err != nil { + t.Fatal(err) + } else if string(tmpFileLocalData) != string(tmpFileRemoteData) { + t.Fatal("contents of file incorrect after put") + } +} + +func compareDirectoriesRecursive(t *testing.T, aroot, broot string) { + walker := fs.Walk(aroot) + for walker.Step() { + if err := walker.Err(); err != nil { + t.Fatal(err) + } + // find paths + aPath := walker.Path() + aRel, err := filepath.Rel(aroot, aPath) + if err != nil { + t.Fatalf("could not find relative path for %v: %v", aPath, err) + } + bPath := filepath.Join(broot, aRel) + + if aRel == "." { + continue + } + + //t.Logf("comparing: %v a: %v b %v", aRel, aPath, bPath) + + // if a is a link, the sftp recursive copy won't have copied it. ignore + aLink, err := os.Lstat(aPath) + if err != nil { + t.Fatalf("could not lstat %v: %v", aPath, err) + } + if aLink.Mode()&os.ModeSymlink != 0 { + continue + } + + // stat the files + aFile, err := os.Stat(aPath) + if err != nil { + t.Fatalf("could not stat %v: %v", aPath, err) + } + bFile, err := os.Stat(bPath) + if err != nil { + t.Fatalf("could not stat %v: %v", bPath, err) + } + + // compare stats, with some leniency for the timestamp + if aFile.Mode() != bFile.Mode() { + t.Fatalf("modes different for %v: %v vs %v", aRel, aFile.Mode(), bFile.Mode()) + } + if !aFile.IsDir() { + if aFile.Size() != bFile.Size() { + t.Fatalf("sizes different for %v: %v vs %v", aRel, aFile.Size(), bFile.Size()) + } + } + timeDiff := aFile.ModTime().Sub(bFile.ModTime()) + if timeDiff > time.Second || timeDiff < -time.Second { + t.Fatalf("mtimes different for %v: %v vs %v", aRel, aFile.ModTime(), bFile.ModTime()) + } + + // compare contents + if !aFile.IsDir() { + if aContents, err := ioutil.ReadFile(aPath); err != nil { + t.Fatal(err) + } else if bContents, err := ioutil.ReadFile(bPath); err != nil { + t.Fatal(err) + } else if string(aContents) != string(bContents) { + t.Fatalf("contents different for %v", aRel) + } + } + } +} + +func TestServerPutRecursive(t *testing.T) { + shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY) + defer shutdown() + + dirLocal, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + tmpDirRemote := "/tmp/" + randName() + defer os.RemoveAll(tmpDirRemote) + + t.Logf("put recursive: local %v remote %v", dirLocal, tmpDirRemote) + + // On windows, the client copies the contents of the directory, not the directory itself. + winFix := "" + if runtime.GOOS == "windows" { + winFix = "/" + filepath.Base(dirLocal) + } //*/ + + // push this directory (source code etc) recursively to the server + if output, err := runSftpClient(t, "mkdir "+tmpDirRemote+"\r\nput -R -p "+dirLocal+" "+tmpDirRemote+winFix, "/", hostGo, portGo); err != nil { + t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output) + } + + compareDirectoriesRecursive(t, dirLocal, filepath.Join(tmpDirRemote, filepath.Base(dirLocal))) +} + +func TestServerGetRecursive(t *testing.T) { + shutdown, hostGo, portGo := testServer(t, GolangSFTP, READONLY) + defer shutdown() + + dirRemote, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + tmpDirLocal := "/tmp/" + randName() + defer os.RemoveAll(tmpDirLocal) + + t.Logf("get recursive: local %v remote %v", tmpDirLocal, dirRemote) + + // On windows, the client copies the contents of the directory, not the directory itself. + winFix := "" + if runtime.GOOS == "windows" { + winFix = "/" + filepath.Base(dirRemote) + } + + // pull this directory (source code etc) recursively from the server + if output, err := runSftpClient(t, "lmkdir "+tmpDirLocal+"\r\nget -R -p "+dirRemote+" "+tmpDirLocal+winFix, "/", hostGo, portGo); err != nil { + t.Fatalf("runSftpClient failed: %v, output\n%v\n", err, output) + } + + compareDirectoriesRecursive(t, dirRemote, filepath.Join(tmpDirLocal, filepath.Base(dirRemote))) +} diff --git a/server_nowindows_test.go b/server_nowindows_test.go new file mode 100644 index 00000000..d38dbba3 --- /dev/null +++ b/server_nowindows_test.go @@ -0,0 +1,87 @@ +//go:build !windows +// +build !windows + +package sftp + +import ( + "testing" +) + +func TestServer_toLocalPath(t *testing.T) { + tests := []struct { + name string + withWorkDir string + p string + want string + }{ + { + name: "empty path with no workdir", + p: "", + want: "", + }, + { + name: "relative path with no workdir", + p: "file", + want: "file", + }, + { + name: "absolute path with no workdir", + p: "/file", + want: "/file", + }, + { + name: "workdir and empty path", + withWorkDir: "/home/user", + p: "", + want: "/home/user", + }, + { + name: "workdir and relative path", + withWorkDir: "/home/user", + p: "file", + want: "/home/user/file", + }, + { + name: "workdir and relative path with .", + withWorkDir: "/home/user", + p: ".", + want: "/home/user", + }, + { + name: "workdir and relative path with . and file", + withWorkDir: "/home/user", + p: "./file", + want: "/home/user/file", + }, + { + name: "workdir and absolute path", + withWorkDir: "/home/user", + p: "/file", + want: "/file", + }, + { + name: "workdir and non-unixy path prefixes workdir", + withWorkDir: "/home/user", + p: "C:\\file", + // This may look like a bug but it is the result of passing + // invalid input (a non-unixy path) to the server. + want: "/home/user/C:\\file", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // We don't need to initialize the Server further to test + // toLocalPath behavior. + s := &Server{} + if tt.withWorkDir != "" { + if err := WithServerWorkingDirectory(tt.withWorkDir)(s); err != nil { + t.Fatal(err) + } + } + + if got := s.toLocalPath(tt.p); got != tt.want { + t.Errorf("Server.toLocalPath() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/server_plan9.go b/server_plan9.go new file mode 100644 index 00000000..4e8ed067 --- /dev/null +++ b/server_plan9.go @@ -0,0 +1,27 @@ +package sftp + +import ( + "path" + "path/filepath" +) + +func (s *Server) toLocalPath(p string) string { + if s.workDir != "" && !path.IsAbs(p) { + p = path.Join(s.workDir, p) + } + + lp := filepath.FromSlash(p) + + if path.IsAbs(p) { + tmp := lp[1:] + + if filepath.IsAbs(tmp) { + // If the FromSlash without any starting slashes is absolute, + // then we have a filepath encoded with a prefix '/'. + // e.g. "/#s/boot" to "#s/boot" + return tmp + } + } + + return lp +} diff --git a/server_posix.go b/server_posix.go new file mode 100644 index 00000000..c07d70a0 --- /dev/null +++ b/server_posix.go @@ -0,0 +1,21 @@ +//go:build !windows +// +build !windows + +package sftp + +import ( + "io/fs" + "os" +) + +func (s *Server) openfile(path string, flag int, mode fs.FileMode) (file, error) { + return os.OpenFile(path, flag, mode) +} + +func (s *Server) lstat(name string) (os.FileInfo, error) { + return os.Lstat(name) +} + +func (s *Server) stat(name string) (os.FileInfo, error) { + return os.Stat(name) +} diff --git a/server_standalone/main.go b/server_standalone/main.go new file mode 100644 index 00000000..0b8e102a --- /dev/null +++ b/server_standalone/main.go @@ -0,0 +1,52 @@ +package main + +// small wrapper around sftp server that allows it to be used as a separate process subsystem call by the ssh server. +// in practice this will statically link; however this allows unit testing from the sftp client. + +import ( + "flag" + "fmt" + "io" + "io/ioutil" + "os" + + "github.com/pkg/sftp" +) + +func main() { + var ( + readOnly bool + debugStderr bool + debugLevel string + options []sftp.ServerOption + ) + + flag.BoolVar(&readOnly, "R", false, "read-only server") + flag.BoolVar(&debugStderr, "e", false, "debug to stderr") + flag.StringVar(&debugLevel, "l", "none", "debug level (ignored)") + flag.Parse() + + debugStream := ioutil.Discard + if debugStderr { + debugStream = os.Stderr + } + options = append(options, sftp.WithDebug(debugStream)) + + if readOnly { + options = append(options, sftp.ReadOnly()) + } + + svr, _ := sftp.NewServer( + struct { + io.Reader + io.WriteCloser + }{os.Stdin, + os.Stdout, + }, + options..., + ) + if err := svr.Serve(); err != nil { + fmt.Fprintf(debugStream, "sftp server completed with error: %v", err) + os.Exit(1) + } +} diff --git a/server_statvfs_darwin.go b/server_statvfs_darwin.go new file mode 100644 index 00000000..8c01dac5 --- /dev/null +++ b/server_statvfs_darwin.go @@ -0,0 +1,21 @@ +package sftp + +import ( + "syscall" +) + +func statvfsFromStatfst(stat *syscall.Statfs_t) (*StatVFS, error) { + return &StatVFS{ + Bsize: uint64(stat.Bsize), + Frsize: uint64(stat.Bsize), // fragment size is a linux thing; use block size here + Blocks: stat.Blocks, + Bfree: stat.Bfree, + Bavail: stat.Bavail, + Files: stat.Files, + Ffree: stat.Ffree, + Favail: stat.Ffree, // not sure how to calculate Favail + Fsid: uint64(uint64(stat.Fsid.Val[1])<<32 | uint64(stat.Fsid.Val[0])), // endianness? + Flag: uint64(stat.Flags), // assuming POSIX? + Namemax: 1024, // man 2 statfs shows: #define MAXPATHLEN 1024 + }, nil +} diff --git a/server_statvfs_impl.go b/server_statvfs_impl.go new file mode 100644 index 00000000..a5470798 --- /dev/null +++ b/server_statvfs_impl.go @@ -0,0 +1,30 @@ +//go:build darwin || linux +// +build darwin linux + +// fill in statvfs structure with OS specific values +// Statfs_t is different per-kernel, and only exists on some unixes (not Solaris for instance) + +package sftp + +import ( + "syscall" +) + +func (p *sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket { + retPkt, err := getStatVFSForPath(p.Path) + if err != nil { + return statusFromError(p.ID, err) + } + retPkt.ID = p.ID + + return retPkt +} + +func getStatVFSForPath(name string) (*StatVFS, error) { + var stat syscall.Statfs_t + if err := syscall.Statfs(name, &stat); err != nil { + return nil, err + } + + return statvfsFromStatfst(&stat) +} diff --git a/server_statvfs_linux.go b/server_statvfs_linux.go new file mode 100644 index 00000000..615c4157 --- /dev/null +++ b/server_statvfs_linux.go @@ -0,0 +1,23 @@ +//go:build linux +// +build linux + +package sftp + +import ( + "syscall" +) + +func statvfsFromStatfst(stat *syscall.Statfs_t) (*StatVFS, error) { + return &StatVFS{ + Bsize: uint64(stat.Bsize), + Frsize: uint64(stat.Frsize), + Blocks: stat.Blocks, + Bfree: stat.Bfree, + Bavail: stat.Bavail, + Files: stat.Files, + Ffree: stat.Ffree, + Favail: stat.Ffree, // not sure how to calculate Favail + Flag: uint64(stat.Flags), // assuming POSIX? + Namemax: uint64(stat.Namelen), + }, nil +} diff --git a/server_statvfs_plan9.go b/server_statvfs_plan9.go new file mode 100644 index 00000000..e71a27d3 --- /dev/null +++ b/server_statvfs_plan9.go @@ -0,0 +1,13 @@ +package sftp + +import ( + "syscall" +) + +func (p *sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket { + return statusFromError(p.ID, syscall.EPLAN9) +} + +func getStatVFSForPath(name string) (*StatVFS, error) { + return nil, syscall.EPLAN9 +} diff --git a/server_statvfs_stubs.go b/server_statvfs_stubs.go new file mode 100644 index 00000000..dd4705bb --- /dev/null +++ b/server_statvfs_stubs.go @@ -0,0 +1,16 @@ +//go:build !darwin && !linux && !plan9 +// +build !darwin,!linux,!plan9 + +package sftp + +import ( + "syscall" +) + +func (p *sshFxpExtendedPacketStatVFS) respond(svr *Server) responsePacket { + return statusFromError(p.ID, syscall.ENOTSUP) +} + +func getStatVFSForPath(name string) (*StatVFS, error) { + return nil, syscall.ENOTSUP +} diff --git a/server_test.go b/server_test.go new file mode 100644 index 00000000..5bac81eb --- /dev/null +++ b/server_test.go @@ -0,0 +1,335 @@ +package sftp + +import ( + "bytes" + "context" + "errors" + "io" + "os" + "path" + "runtime" + "sync" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func clientServerPair(t *testing.T) (*Client, *Server) { + cr, sw := io.Pipe() + sr, cw := io.Pipe() + var options []ServerOption + if *testAllocator { + options = append(options, WithAllocator()) + } + server, err := NewServer(struct { + io.Reader + io.WriteCloser + }{sr, sw}, options...) + if err != nil { + t.Fatal(err) + } + go server.Serve() + client, err := NewClientPipe(cr, cw) + if err != nil { + t.Fatalf("%+v\n", err) + } + return client, server +} + +type sshFxpTestBadExtendedPacket struct { + ID uint32 + Extension string + Data string +} + +func (p sshFxpTestBadExtendedPacket) id() uint32 { return p.ID } + +func (p sshFxpTestBadExtendedPacket) MarshalBinary() ([]byte, error) { + l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) + 4 + len(p.Extension) + + 4 + len(p.Data) + + b := make([]byte, 4, l) + b = append(b, sshFxpExtended) + b = marshalUint32(b, p.ID) + b = marshalString(b, p.Extension) + b = marshalString(b, p.Data) + + return b, nil +} + +func checkServerAllocator(t *testing.T, server *Server) { + if server.pktMgr.alloc == nil { + return + } + checkAllocatorBeforeServerClose(t, server.pktMgr.alloc) + server.Close() + checkAllocatorAfterServerClose(t, server.pktMgr.alloc) +} + +// test that errors are sent back when we request an invalid extended packet operation +// this validates the following rfc draft is followed https://tools.ietf.org/html/draft-ietf-secsh-filexfer-extensions-00 +func TestInvalidExtendedPacket(t *testing.T) { + client, server := clientServerPair(t) + defer client.Close() + defer server.Close() + + badPacket := sshFxpTestBadExtendedPacket{client.nextID(), "thisDoesn'tExist", "foobar"} + typ, data, err := client.clientConn.sendPacket(context.Background(), nil, badPacket) + if err != nil { + t.Fatalf("unexpected error from sendPacket: %s", err) + } + if typ != sshFxpStatus { + t.Fatalf("received non-FPX_STATUS packet: %v", typ) + } + + err = unmarshalStatus(badPacket.id(), data) + statusErr, ok := err.(*StatusError) + if !ok { + t.Fatal("failed to convert error from unmarshalStatus to *StatusError") + } + if statusErr.Code != sshFxOPUnsupported { + t.Errorf("statusErr.Code => %d, wanted %d", statusErr.Code, sshFxOPUnsupported) + } + checkServerAllocator(t, server) +} + +// test that server handles concurrent requests correctly +func TestConcurrentRequests(t *testing.T) { + skipIfWindows(t) + var filename string + switch runtime.GOOS { + case "plan9": + filename = "/lib/ndb/local" + case "zos": + filename = "/etc/.shrc" + default: + filename = "/etc/passwd" + } + client, server := clientServerPair(t) + defer client.Close() + defer server.Close() + + concurrency := 2 + var wg sync.WaitGroup + wg.Add(concurrency) + + for i := 0; i < concurrency; i++ { + go func() { + defer wg.Done() + + for j := 0; j < 1024; j++ { + f, err := client.Open(filename) + if err != nil { + t.Errorf("failed to open file: %v", err) + continue + } + if err := f.Close(); err != nil { + t.Errorf("failed t close file: %v", err) + } + } + }() + } + wg.Wait() + checkServerAllocator(t, server) +} + +// Test error conversion +func TestStatusFromError(t *testing.T) { + type test struct { + err error + pkt *sshFxpStatusPacket + } + tpkt := func(id, code uint32) *sshFxpStatusPacket { + return &sshFxpStatusPacket{ + ID: id, + StatusError: StatusError{Code: code}, + } + } + testCases := []test{ + {syscall.ENOENT, tpkt(1, sshFxNoSuchFile)}, + {&os.PathError{Err: syscall.ENOENT}, + tpkt(2, sshFxNoSuchFile)}, + {&os.PathError{Err: errors.New("foo")}, tpkt(3, sshFxFailure)}, + {ErrSSHFxEOF, tpkt(4, sshFxEOF)}, + {ErrSSHFxOpUnsupported, tpkt(5, sshFxOPUnsupported)}, + {io.EOF, tpkt(6, sshFxEOF)}, + {os.ErrNotExist, tpkt(7, sshFxNoSuchFile)}, + } + for _, tc := range testCases { + tc.pkt.StatusError.msg = tc.err.Error() + assert.Equal(t, tc.pkt, statusFromError(tc.pkt.ID, tc.err)) + } +} + +// This was written to test a race b/w open immediately followed by a stat. +// Previous to this the Open would trigger the use of a worker pool, then the +// stat packet would come in an hit the pool and return faster than the open +// (returning a file-not-found error). +// The below by itself wouldn't trigger the race however, I needed to add a +// small sleep in the openpacket code to trigger the issue. I wanted to add a +// way to inject that in the code but right now there is no good place for it. +// I'm thinking after I convert the server into a request-server backend I +// might be able to do something with the runWorker method passed into the +// packet manager. But with the 2 implementations fo the server it just doesn't +// fit well right now. +func TestOpenStatRace(t *testing.T) { + client, server := clientServerPair(t) + defer client.Close() + defer server.Close() + + // openpacket finishes to fast to trigger race in tests + // need to add a small sleep on server to openpackets somehow + tmppath := path.Join(os.TempDir(), "stat_race") + pflags := toPflags(os.O_RDWR | os.O_CREATE | os.O_TRUNC) + ch := make(chan result, 3) + id1 := client.nextID() + id2 := client.nextID() + client.dispatchRequest(ch, &sshFxpOpenPacket{ + ID: id1, + Path: tmppath, + Pflags: pflags, + }) + client.dispatchRequest(ch, &sshFxpLstatPacket{ + ID: id2, + Path: tmppath, + }) + testreply := func(id uint32) { + r := <-ch + require.NoError(t, r.err) + switch r.typ { + case sshFxpAttrs, sshFxpHandle: // ignore + case sshFxpStatus: + err := normaliseError(unmarshalStatus(id, r.data)) + assert.NoError(t, err, "race hit, stat before open") + default: + t.Fatal("unexpected type:", r.typ) + } + } + testreply(id1) + testreply(id2) + os.Remove(tmppath) + checkServerAllocator(t, server) +} + +func TestOpenWithPermissions(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + skipIfWindows(t) + + client, server := clientServerPair(t) + defer client.Close() + defer server.Close() + + tmppath := path.Join(os.TempDir(), "open_permissions") + defer os.Remove(tmppath) + + pflags := toPflags(os.O_RDWR | os.O_CREATE | os.O_TRUNC) + + id1 := client.nextID() + id2 := client.nextID() + + typ, data, err := client.sendPacket(ctx, nil, &sshFxpOpenPacket{ + ID: id1, + Path: tmppath, + Pflags: pflags, + Flags: sshFileXferAttrPermissions, + Attrs: &FileStat{ + Mode: 0o745, + }, + }) + if err != nil { + t.Fatal("unexpected error:", err) + } + switch typ { + case sshFxpHandle: + // do nothing, we can just leave the handle open. + case sshFxpStatus: + t.Fatal("unexpected status:", normaliseError(unmarshalStatus(id1, data))) + default: + t.Fatal("unpexpected packet type:", unimplementedPacketErr(typ)) + } + + stat, err := os.Stat(tmppath) + if err != nil { + t.Fatal("unexpected error:", err) + } + + if stat.Mode()&os.ModePerm != 0o745 { + t.Errorf("stat.Mode() = %v was expecting 0o745", stat.Mode()) + } + + // Existing files should not have their permissions changed. + typ, data, err = client.sendPacket(ctx, nil, &sshFxpOpenPacket{ + ID: id2, + Path: tmppath, + Pflags: pflags, + Flags: sshFileXferAttrPermissions, + Attrs: &FileStat{ + Mode: 0o755, + }, + }) + if err != nil { + t.Fatal("unexpected error:", err) + } + switch typ { + case sshFxpHandle: + // do nothing, we can just leave the handle open. + case sshFxpStatus: + t.Fatal("unexpected status:", normaliseError(unmarshalStatus(id2, data))) + default: + t.Fatal("unpexpected packet type:", unimplementedPacketErr(typ)) + } + + if stat.Mode()&os.ModePerm != 0o745 { + t.Errorf("stat.Mode() = %v, was expecting unchanged 0o745", stat.Mode()) + } + + checkServerAllocator(t, server) +} + +// Ensure that proper error codes are returned for non existent files, such +// that they are mapped back to a 'not exists' error on the client side. +func TestStatNonExistent(t *testing.T) { + client, server := clientServerPair(t) + defer client.Close() + defer server.Close() + + for _, file := range []string{"/doesnotexist", "/doesnotexist/a/b"} { + _, err := client.Stat(file) + if !os.IsNotExist(err) { + t.Errorf("expected 'does not exist' err for file %q. got: %v", file, err) + } + } +} + +func TestServerWithBrokenClient(t *testing.T) { + validInit := sp(&sshFxInitPacket{Version: 3}) + brokenOpen := sp(&sshFxpOpenPacket{Path: "foo"}) + brokenOpen = brokenOpen[:len(brokenOpen)-2] + + for _, clientInput := range [][]byte{ + // Packet length zero (never valid). This used to crash the server. + {0, 0, 0, 0}, + append(validInit, 0, 0, 0, 0), + + // Client hangs up mid-packet. + append(validInit, brokenOpen...), + } { + srv, err := NewServer(struct { + io.Reader + io.WriteCloser + }{ + bytes.NewReader(clientInput), + &sink{}, + }) + require.NoError(t, err) + + err = srv.Serve() + assert.Error(t, err) + srv.Close() + } +} diff --git a/server_unix.go b/server_unix.go new file mode 100644 index 00000000..495b397c --- /dev/null +++ b/server_unix.go @@ -0,0 +1,16 @@ +//go:build !windows && !plan9 +// +build !windows,!plan9 + +package sftp + +import ( + "path" +) + +func (s *Server) toLocalPath(p string) string { + if s.workDir != "" && !path.IsAbs(p) { + p = path.Join(s.workDir, p) + } + + return p +} diff --git a/server_windows.go b/server_windows.go new file mode 100644 index 00000000..e940dba1 --- /dev/null +++ b/server_windows.go @@ -0,0 +1,193 @@ +package sftp + +import ( + "fmt" + "io" + "io/fs" + "os" + "path" + "path/filepath" + "time" + + "golang.org/x/sys/windows" +) + +func (s *Server) toLocalPath(p string) string { + if s.workDir != "" && !path.IsAbs(p) { + p = path.Join(s.workDir, p) + } + + lp := filepath.FromSlash(p) + + if path.IsAbs(p) { // starts with '/' + if len(p) == 1 && s.winRoot { + return `\\.\` // for openfile + } + + tmp := lp + for len(tmp) > 0 && tmp[0] == '\\' { + tmp = tmp[1:] + } + + if filepath.IsAbs(tmp) { + // If the FromSlash without any starting slashes is absolute, + // then we have a filepath encoded with a prefix '/'. + // e.g. "/C:/Windows" to "C:\\Windows" + return tmp + } + + tmp += "\\" + + if filepath.IsAbs(tmp) { + // If the FromSlash without any starting slashes but with extra end slash is absolute, + // then we have a filepath encoded with a prefix '/' and a dropped '/' at the end. + // e.g. "/C:" to "C:\\" + return tmp + } + + if s.winRoot { + // Make it so that "/Windows" is not found, and "/c:/Windows" has to be used + return `\\.\` + tmp + } + } + + return lp +} + +func bitsToDrives(bitmap uint32) []string { + var drive rune = 'a' + var drives []string + + for bitmap != 0 && drive <= 'z' { + if bitmap&1 == 1 { + drives = append(drives, string(drive)+":") + } + drive++ + bitmap >>= 1 + } + + return drives +} + +func getDrives() ([]string, error) { + mask, err := windows.GetLogicalDrives() + if err != nil { + return nil, fmt.Errorf("GetLogicalDrives: %w", err) + } + return bitsToDrives(mask), nil +} + +type driveInfo struct { + fs.FileInfo + name string +} + +func (i *driveInfo) Name() string { + return i.name // since the Name() returned from a os.Stat("C:\\") is "\\" +} + +type winRoot struct { + drives []string +} + +func newWinRoot() (*winRoot, error) { + drives, err := getDrives() + if err != nil { + return nil, err + } + return &winRoot{ + drives: drives, + }, nil +} + +func (f *winRoot) Readdir(n int) ([]os.FileInfo, error) { + drives := f.drives + if n > 0 && len(drives) > n { + drives = drives[:n] + } + f.drives = f.drives[len(drives):] + if len(drives) == 0 { + return nil, io.EOF + } + + var infos []os.FileInfo + for _, drive := range drives { + fi, err := os.Stat(drive + `\`) + if err != nil { + return nil, err + } + + di := &driveInfo{ + FileInfo: fi, + name: drive, + } + infos = append(infos, di) + } + + return infos, nil +} + +func (f *winRoot) Stat() (os.FileInfo, error) { + return rootFileInfo, nil +} +func (f *winRoot) ReadAt(b []byte, off int64) (int, error) { + return 0, os.ErrPermission +} +func (f *winRoot) WriteAt(b []byte, off int64) (int, error) { + return 0, os.ErrPermission +} +func (f *winRoot) Name() string { + return "/" +} +func (f *winRoot) Truncate(int64) error { + return os.ErrPermission +} +func (f *winRoot) Chmod(mode fs.FileMode) error { + return os.ErrPermission +} +func (f *winRoot) Chown(uid, gid int) error { + return os.ErrPermission +} +func (f *winRoot) Close() error { + f.drives = nil + return nil +} + +func (s *Server) openfile(path string, flag int, mode fs.FileMode) (file, error) { + if path == `\\.\` && s.winRoot { + return newWinRoot() + } + return os.OpenFile(path, flag, mode) +} + +type winRootFileInfo struct { + name string + modTime time.Time +} + +func (w *winRootFileInfo) Name() string { return w.name } +func (w *winRootFileInfo) Size() int64 { return 0 } +func (w *winRootFileInfo) Mode() fs.FileMode { return fs.ModeDir | 0555 } // read+execute for all +func (w *winRootFileInfo) ModTime() time.Time { return w.modTime } +func (w *winRootFileInfo) IsDir() bool { return true } +func (w *winRootFileInfo) Sys() interface{} { return nil } + +// Create a new root FileInfo +var rootFileInfo = &winRootFileInfo{ + name: "/", + modTime: time.Now(), +} + +func (s *Server) lstat(name string) (os.FileInfo, error) { + if name == `\\.\` && s.winRoot { + return rootFileInfo, nil + } + return os.Lstat(name) +} + +func (s *Server) stat(name string) (os.FileInfo, error) { + if name == `\\.\` && s.winRoot { + return rootFileInfo, nil + } + return os.Stat(name) +} diff --git a/server_windows_test.go b/server_windows_test.go new file mode 100644 index 00000000..ca9ed027 --- /dev/null +++ b/server_windows_test.go @@ -0,0 +1,84 @@ +package sftp + +import ( + "testing" +) + +func TestServer_toLocalPath(t *testing.T) { + tests := []struct { + name string + withWorkDir string + p string + want string + }{ + { + name: "empty path with no workdir", + p: "", + want: "", + }, + { + name: "relative path with no workdir", + p: "file", + want: "file", + }, + { + name: "absolute path with no workdir", + p: "/file", + want: "\\file", + }, + { + name: "workdir and empty path", + withWorkDir: "C:\\Users\\User", + p: "", + want: "C:\\Users\\User", + }, + { + name: "workdir and relative path", + withWorkDir: "C:\\Users\\User", + p: "file", + want: "C:\\Users\\User\\file", + }, + { + name: "workdir and relative path with .", + withWorkDir: "C:\\Users\\User", + p: ".", + want: "C:\\Users\\User", + }, + { + name: "workdir and relative path with . and file", + withWorkDir: "C:\\Users\\User", + p: "./file", + want: "C:\\Users\\User\\file", + }, + { + name: "workdir and absolute path", + withWorkDir: "C:\\Users\\User", + p: "/C:/file", + want: "C:\\file", + }, + { + name: "workdir and non-unixy path prefixes workdir", + withWorkDir: "C:\\Users\\User", + p: "C:\\file", + // This may look like a bug but it is the result of passing + // invalid input (a non-unixy path) to the server. + want: "C:\\Users\\User\\C:\\file", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // We don't need to initialize the Server further to test + // toLocalPath behavior. + s := &Server{} + if tt.withWorkDir != "" { + if err := WithServerWorkingDirectory(tt.withWorkDir)(s); err != nil { + t.Fatal(err) + } + } + + if got := s.toLocalPath(tt.p); got != tt.want { + t.Errorf("Server.toLocalPath() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/sftp.go b/sftp.go index 934684c2..1e698bb2 100644 --- a/sftp.go +++ b/sftp.go @@ -7,113 +7,149 @@ import ( ) const ( - ssh_FXP_INIT = 1 - ssh_FXP_VERSION = 2 - ssh_FXP_OPEN = 3 - ssh_FXP_CLOSE = 4 - ssh_FXP_READ = 5 - ssh_FXP_WRITE = 6 - ssh_FXP_LSTAT = 7 - ssh_FXP_FSTAT = 8 - ssh_FXP_SETSTAT = 9 - ssh_FXP_FSETSTAT = 10 - ssh_FXP_OPENDIR = 11 - ssh_FXP_READDIR = 12 - ssh_FXP_REMOVE = 13 - ssh_FXP_MKDIR = 14 - ssh_FXP_RMDIR = 15 - ssh_FXP_REALPATH = 16 - ssh_FXP_STAT = 17 - ssh_FXP_RENAME = 18 - ssh_FXP_READLINK = 19 - ssh_FXP_SYMLINK = 20 - ssh_FXP_STATUS = 101 - ssh_FXP_HANDLE = 102 - ssh_FXP_DATA = 103 - ssh_FXP_NAME = 104 - ssh_FXP_ATTRS = 105 - ssh_FXP_EXTENDED = 200 - ssh_FXP_EXTENDED_REPLY = 201 + sshFxpInit = 1 + sshFxpVersion = 2 + sshFxpOpen = 3 + sshFxpClose = 4 + sshFxpRead = 5 + sshFxpWrite = 6 + sshFxpLstat = 7 + sshFxpFstat = 8 + sshFxpSetstat = 9 + sshFxpFsetstat = 10 + sshFxpOpendir = 11 + sshFxpReaddir = 12 + sshFxpRemove = 13 + sshFxpMkdir = 14 + sshFxpRmdir = 15 + sshFxpRealpath = 16 + sshFxpStat = 17 + sshFxpRename = 18 + sshFxpReadlink = 19 + sshFxpSymlink = 20 + sshFxpStatus = 101 + sshFxpHandle = 102 + sshFxpData = 103 + sshFxpName = 104 + sshFxpAttrs = 105 + sshFxpExtended = 200 + sshFxpExtendedReply = 201 ) const ( - ssh_FX_OK = 0 - ssh_FX_EOF = 1 - ssh_FX_NO_SUCH_FILE = 2 - ssh_FX_PERMISSION_DENIED = 3 - ssh_FX_FAILURE = 4 - ssh_FX_BAD_MESSAGE = 5 - ssh_FX_NO_CONNECTION = 6 - ssh_FX_CONNECTION_LOST = 7 - ssh_FX_OP_UNSUPPORTED = 8 + sshFxOk = 0 + sshFxEOF = 1 + sshFxNoSuchFile = 2 + sshFxPermissionDenied = 3 + sshFxFailure = 4 + sshFxBadMessage = 5 + sshFxNoConnection = 6 + sshFxConnectionLost = 7 + sshFxOPUnsupported = 8 + + // see draft-ietf-secsh-filexfer-13 + // https://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1 + sshFxInvalidHandle = 9 + sshFxNoSuchPath = 10 + sshFxFileAlreadyExists = 11 + sshFxWriteProtect = 12 + sshFxNoMedia = 13 + sshFxNoSpaceOnFilesystem = 14 + sshFxQuotaExceeded = 15 + sshFxUnknownPrincipal = 16 + sshFxLockConflict = 17 + sshFxDirNotEmpty = 18 + sshFxNotADirectory = 19 + sshFxInvalidFilename = 20 + sshFxLinkLoop = 21 + sshFxCannotDelete = 22 + sshFxInvalidParameter = 23 + sshFxFileIsADirectory = 24 + sshFxByteRangeLockConflict = 25 + sshFxByteRangeLockRefused = 26 + sshFxDeletePending = 27 + sshFxFileCorrupt = 28 + sshFxOwnerInvalid = 29 + sshFxGroupInvalid = 30 + sshFxNoMatchingByteRangeLock = 31 ) const ( - ssh_FXF_READ = 0x00000001 - ssh_FXF_WRITE = 0x00000002 - ssh_FXF_APPEND = 0x00000004 - ssh_FXF_CREAT = 0x00000008 - ssh_FXF_TRUNC = 0x00000010 - ssh_FXF_EXCL = 0x00000020 + sshFxfRead = 0x00000001 + sshFxfWrite = 0x00000002 + sshFxfAppend = 0x00000004 + sshFxfCreat = 0x00000008 + sshFxfTrunc = 0x00000010 + sshFxfExcl = 0x00000020 +) + +var ( + // supportedSFTPExtensions defines the supported extensions + supportedSFTPExtensions = []sshExtensionPair{ + {"hardlink@openssh.com", "1"}, + {"posix-rename@openssh.com", "1"}, + {"statvfs@openssh.com", "2"}, + } + sftpExtensions = supportedSFTPExtensions ) type fxp uint8 func (f fxp) String() string { switch f { - case ssh_FXP_INIT: + case sshFxpInit: return "SSH_FXP_INIT" - case ssh_FXP_VERSION: + case sshFxpVersion: return "SSH_FXP_VERSION" - case ssh_FXP_OPEN: + case sshFxpOpen: return "SSH_FXP_OPEN" - case ssh_FXP_CLOSE: + case sshFxpClose: return "SSH_FXP_CLOSE" - case ssh_FXP_READ: + case sshFxpRead: return "SSH_FXP_READ" - case ssh_FXP_WRITE: + case sshFxpWrite: return "SSH_FXP_WRITE" - case ssh_FXP_LSTAT: + case sshFxpLstat: return "SSH_FXP_LSTAT" - case ssh_FXP_FSTAT: + case sshFxpFstat: return "SSH_FXP_FSTAT" - case ssh_FXP_SETSTAT: + case sshFxpSetstat: return "SSH_FXP_SETSTAT" - case ssh_FXP_FSETSTAT: + case sshFxpFsetstat: return "SSH_FXP_FSETSTAT" - case ssh_FXP_OPENDIR: + case sshFxpOpendir: return "SSH_FXP_OPENDIR" - case ssh_FXP_READDIR: + case sshFxpReaddir: return "SSH_FXP_READDIR" - case ssh_FXP_REMOVE: + case sshFxpRemove: return "SSH_FXP_REMOVE" - case ssh_FXP_MKDIR: + case sshFxpMkdir: return "SSH_FXP_MKDIR" - case ssh_FXP_RMDIR: + case sshFxpRmdir: return "SSH_FXP_RMDIR" - case ssh_FXP_REALPATH: + case sshFxpRealpath: return "SSH_FXP_REALPATH" - case ssh_FXP_STAT: + case sshFxpStat: return "SSH_FXP_STAT" - case ssh_FXP_RENAME: + case sshFxpRename: return "SSH_FXP_RENAME" - case ssh_FXP_READLINK: + case sshFxpReadlink: return "SSH_FXP_READLINK" - case ssh_FXP_SYMLINK: + case sshFxpSymlink: return "SSH_FXP_SYMLINK" - case ssh_FXP_STATUS: + case sshFxpStatus: return "SSH_FXP_STATUS" - case ssh_FXP_HANDLE: + case sshFxpHandle: return "SSH_FXP_HANDLE" - case ssh_FXP_DATA: + case sshFxpData: return "SSH_FXP_DATA" - case ssh_FXP_NAME: + case sshFxpName: return "SSH_FXP_NAME" - case ssh_FXP_ATTRS: + case sshFxpAttrs: return "SSH_FXP_ATTRS" - case ssh_FXP_EXTENDED: + case sshFxpExtended: return "SSH_FXP_EXTENDED" - case ssh_FXP_EXTENDED_REPLY: + case sshFxpExtendedReply: return "SSH_FXP_EXTENDED_REPLY" default: return "unknown" @@ -124,23 +160,23 @@ type fx uint8 func (f fx) String() string { switch f { - case ssh_FX_OK: + case sshFxOk: return "SSH_FX_OK" - case ssh_FX_EOF: + case sshFxEOF: return "SSH_FX_EOF" - case ssh_FX_NO_SUCH_FILE: + case sshFxNoSuchFile: return "SSH_FX_NO_SUCH_FILE" - case ssh_FX_PERMISSION_DENIED: + case sshFxPermissionDenied: return "SSH_FX_PERMISSION_DENIED" - case ssh_FX_FAILURE: + case sshFxFailure: return "SSH_FX_FAILURE" - case ssh_FX_BAD_MESSAGE: + case sshFxBadMessage: return "SSH_FX_BAD_MESSAGE" - case ssh_FX_NO_CONNECTION: + case sshFxNoConnection: return "SSH_FX_NO_CONNECTION" - case ssh_FX_CONNECTION_LOST: + case sshFxConnectionLost: return "SSH_FX_CONNECTION_LOST" - case ssh_FX_OP_UNSUPPORTED: + case sshFxOPUnsupported: return "SSH_FX_OP_UNSUPPORTED" default: return "unknown" @@ -148,29 +184,29 @@ func (f fx) String() string { } type unexpectedPacketErr struct { - want, got uint8 + want, got fxp } func (u *unexpectedPacketErr) Error() string { - return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", fxp(u.want), fxp(u.got)) + return fmt.Sprintf("sftp: unexpected packet: want %v, got %v", u.want, u.got) } -func unimplementedPacketErr(u uint8) error { - return fmt.Errorf("sftp: unimplemented packet type: got %v", fxp(u)) +func unimplementedPacketErr(u fxp) error { + return fmt.Errorf("sftp: unimplemented packet type: got %v", u) } -type unexpectedIdErr struct{ want, got uint32 } +type unexpectedIDErr struct{ want, got uint32 } -func (u *unexpectedIdErr) Error() string { - return fmt.Sprintf("sftp: unexpected id: want %v, got %v", u.want, u.got) +func (u *unexpectedIDErr) Error() string { + return fmt.Sprintf("sftp: unexpected id: want %d, got %d", u.want, u.got) } func unimplementedSeekWhence(whence int) error { - return fmt.Errorf("sftp: unimplemented seek whence %v", whence) + return fmt.Errorf("sftp: unimplemented seek whence %d", whence) } func unexpectedCount(want, got uint32) error { - return fmt.Errorf("sftp: unexpected count: want %v, got %v", want, got) + return fmt.Errorf("sftp: unexpected count: want %d, got %d", want, got) } type unexpectedVersionErr struct{ want, got uint32 } @@ -179,9 +215,44 @@ func (u *unexpectedVersionErr) Error() string { return fmt.Sprintf("sftp: unexpected server version: want %v, got %v", u.want, u.got) } +// A StatusError is returned when an SFTP operation fails, and provides +// additional information about the failure. type StatusError struct { Code uint32 msg, lang string } -func (s *StatusError) Error() string { return fmt.Sprintf("sftp: %q (%v)", s.msg, fx(s.Code)) } +func (s *StatusError) Error() string { + return fmt.Sprintf("sftp: %q (%v)", s.msg, fx(s.Code)) +} + +// FxCode returns the error code typed to match against the exported codes +func (s *StatusError) FxCode() fxerr { + return fxerr(s.Code) +} + +func getSupportedExtensionByName(extensionName string) (sshExtensionPair, error) { + for _, supportedExtension := range supportedSFTPExtensions { + if supportedExtension.Name == extensionName { + return supportedExtension, nil + } + } + return sshExtensionPair{}, fmt.Errorf("unsupported extension: %s", extensionName) +} + +// SetSFTPExtensions allows to customize the supported server extensions. +// See the variable supportedSFTPExtensions for supported extensions. +// This method accepts a slice of sshExtensionPair names for example 'hardlink@openssh.com'. +// If an invalid extension is given an error will be returned and nothing will be changed +func SetSFTPExtensions(extensions ...string) error { + tempExtensions := []sshExtensionPair{} + for _, extension := range extensions { + sftpExtension, err := getSupportedExtensionByName(extension) + if err != nil { + return err + } + tempExtensions = append(tempExtensions, sftpExtension) + } + sftpExtensions = tempExtensions + return nil +} diff --git a/sftp_test.go b/sftp_test.go new file mode 100644 index 00000000..18eed5e7 --- /dev/null +++ b/sftp_test.go @@ -0,0 +1,78 @@ +package sftp + +import ( + "errors" + "fmt" + "io" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestErrFxCode(t *testing.T) { + table := []struct { + err error + fx fxerr + }{ + {err: errors.New("random error"), fx: ErrSSHFxFailure}, + {err: EBADF, fx: ErrSSHFxFailure}, + {err: syscall.ENOENT, fx: ErrSSHFxNoSuchFile}, + {err: syscall.EPERM, fx: ErrSSHFxPermissionDenied}, + {err: io.EOF, fx: ErrSSHFxEOF}, + {err: fmt.Errorf("wrapped permission denied error: %w", ErrSSHFxPermissionDenied), fx: ErrSSHFxPermissionDenied}, + {err: fmt.Errorf("wrapped op unsupported error: %w", ErrSSHFxOpUnsupported), fx: ErrSSHFxOpUnsupported}, + } + for _, tt := range table { + statusErr := statusFromError(1, tt.err).StatusError + assert.Equal(t, statusErr.FxCode(), tt.fx) + } +} + +func TestSupportedExtensions(t *testing.T) { + for _, supportedExtension := range supportedSFTPExtensions { + _, err := getSupportedExtensionByName(supportedExtension.Name) + assert.NoError(t, err) + } + _, err := getSupportedExtensionByName("invalid@example.com") + assert.Error(t, err) +} + +func TestExtensions(t *testing.T) { + var supportedExtensions []string + for _, supportedExtension := range supportedSFTPExtensions { + supportedExtensions = append(supportedExtensions, supportedExtension.Name) + } + + testSFTPExtensions := []string{"hardlink@openssh.com"} + expectedSFTPExtensions := []sshExtensionPair{ + {"hardlink@openssh.com", "1"}, + } + err := SetSFTPExtensions(testSFTPExtensions...) + assert.NoError(t, err) + assert.Equal(t, expectedSFTPExtensions, sftpExtensions) + + invalidSFTPExtensions := []string{"invalid@example.com"} + err = SetSFTPExtensions(invalidSFTPExtensions...) + assert.Error(t, err) + assert.Equal(t, expectedSFTPExtensions, sftpExtensions) + + emptySFTPExtensions := []string{} + expectedSFTPExtensions = []sshExtensionPair{} + err = SetSFTPExtensions(emptySFTPExtensions...) + assert.NoError(t, err) + assert.Equal(t, expectedSFTPExtensions, sftpExtensions) + + // if we only have an invalid extension nothing will be modified. + invalidSFTPExtensions = []string{ + "hardlink@openssh.com", + "invalid@example.com", + } + err = SetSFTPExtensions(invalidSFTPExtensions...) + assert.Error(t, err) + assert.Equal(t, expectedSFTPExtensions, sftpExtensions) + + err = SetSFTPExtensions(supportedExtensions...) + assert.NoError(t, err) + assert.Equal(t, supportedSFTPExtensions, sftpExtensions) +} diff --git a/stat.go b/stat.go new file mode 100644 index 00000000..2bb2c137 --- /dev/null +++ b/stat.go @@ -0,0 +1,94 @@ +package sftp + +import ( + "os" + + sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer" +) + +// isRegular returns true if the mode describes a regular file. +func isRegular(mode uint32) bool { + return sshfx.FileMode(mode)&sshfx.ModeType == sshfx.ModeRegular +} + +// toFileMode converts sftp filemode bits to the os.FileMode specification +func toFileMode(mode uint32) os.FileMode { + var fm = os.FileMode(mode & 0777) + + switch sshfx.FileMode(mode) & sshfx.ModeType { + case sshfx.ModeDevice: + fm |= os.ModeDevice + case sshfx.ModeCharDevice: + fm |= os.ModeDevice | os.ModeCharDevice + case sshfx.ModeDir: + fm |= os.ModeDir + case sshfx.ModeNamedPipe: + fm |= os.ModeNamedPipe + case sshfx.ModeSymlink: + fm |= os.ModeSymlink + case sshfx.ModeRegular: + // nothing to do + case sshfx.ModeSocket: + fm |= os.ModeSocket + } + + if sshfx.FileMode(mode)&sshfx.ModeSetUID != 0 { + fm |= os.ModeSetuid + } + if sshfx.FileMode(mode)&sshfx.ModeSetGID != 0 { + fm |= os.ModeSetgid + } + if sshfx.FileMode(mode)&sshfx.ModeSticky != 0 { + fm |= os.ModeSticky + } + + return fm +} + +// fromFileMode converts from the os.FileMode specification to sftp filemode bits +func fromFileMode(mode os.FileMode) uint32 { + ret := sshfx.FileMode(mode & os.ModePerm) + + switch mode & os.ModeType { + case os.ModeDevice | os.ModeCharDevice: + ret |= sshfx.ModeCharDevice + case os.ModeDevice: + ret |= sshfx.ModeDevice + case os.ModeDir: + ret |= sshfx.ModeDir + case os.ModeNamedPipe: + ret |= sshfx.ModeNamedPipe + case os.ModeSymlink: + ret |= sshfx.ModeSymlink + case 0: + ret |= sshfx.ModeRegular + case os.ModeSocket: + ret |= sshfx.ModeSocket + } + + if mode&os.ModeSetuid != 0 { + ret |= sshfx.ModeSetUID + } + if mode&os.ModeSetgid != 0 { + ret |= sshfx.ModeSetGID + } + if mode&os.ModeSticky != 0 { + ret |= sshfx.ModeSticky + } + + return uint32(ret) +} + +const ( + s_ISUID = uint32(sshfx.ModeSetUID) + s_ISGID = uint32(sshfx.ModeSetGID) + s_ISVTX = uint32(sshfx.ModeSticky) +) + +// S_IFMT is a legacy export, and was brought in to support GOOS environments whose sysconfig.S_IFMT may be different from the value used internally by SFTP standards. +// There should be no reason why you need to import it, or use it, but unexporting it could cause code to break in a way that cannot be readily fixed. +// As such, we continue to export this value as the value used in the SFTP standard. +// +// Deprecated: Remove use of this value, and avoid any future use as well. +// There is no alternative provided, you should never need to access this value. +const S_IFMT = uint32(sshfx.ModeType) diff --git a/wercker.yml b/wercker.yml deleted file mode 100644 index 41d2c528..00000000 --- a/wercker.yml +++ /dev/null @@ -1 +0,0 @@ -box: wercker/golang