Update dependencies
This commit is contained in:
87
vendor/github.com/tailscale/certstore/.appveyor.yml
generated
vendored
Normal file
87
vendor/github.com/tailscale/certstore/.appveyor.yml
generated
vendored
Normal file
@@ -0,0 +1,87 @@
|
||||
skip_branch_with_pr: true
|
||||
os: "Visual Studio 2015"
|
||||
|
||||
build: off
|
||||
|
||||
environment:
|
||||
matrix:
|
||||
- # Default Go, x64
|
||||
ARCH: x64
|
||||
MSYS2_ARCH: x86_64
|
||||
MSYS2_DIR: msys64
|
||||
MSYSTEM: MINGW64
|
||||
GOPATH: c:\gopath
|
||||
GOROOT: c:\go
|
||||
GOARCH: amd64
|
||||
EXTLD: x86_64-w64-mingw32-gcc
|
||||
- # Default Go, x86
|
||||
ARCH: x86
|
||||
MSYS2_ARCH: i686
|
||||
MSYS2_DIR: msys64
|
||||
MSYSTEM: MINGW32
|
||||
GOPATH: c:\gopath
|
||||
GOROOT: c:\go-x86
|
||||
GOARCH: 386
|
||||
EXTLD: i686-w64-mingw32-gcc
|
||||
- # Go 1.12, x64
|
||||
ARCH: x64
|
||||
MSYS2_ARCH: x86_64
|
||||
MSYS2_DIR: msys64
|
||||
MSYSTEM: MINGW64
|
||||
GOPATH: c:\gopath
|
||||
GOROOT: c:\go112
|
||||
GOARCH: amd64
|
||||
EXTLD: x86_64-w64-mingw32-gcc
|
||||
- # Go 1.11, x64
|
||||
ARCH: x64
|
||||
MSYS2_ARCH: x86_64
|
||||
MSYS2_DIR: msys64
|
||||
MSYSTEM: MINGW64
|
||||
GOPATH: c:\gopath
|
||||
GOROOT: c:\go111
|
||||
GOARCH: amd64
|
||||
EXTLD: x86_64-w64-mingw32-gcc
|
||||
- # Go 1.10, x64
|
||||
ARCH: x64
|
||||
MSYS2_ARCH: x86_64
|
||||
MSYS2_DIR: msys64
|
||||
MSYSTEM: MINGW64
|
||||
GOPATH: c:\gopath
|
||||
GOROOT: c:\go110
|
||||
GOARCH: amd64
|
||||
EXTLD: x86_64-w64-mingw32-gcc
|
||||
- # Go 1.9, x64
|
||||
ARCH: x64
|
||||
MSYS2_ARCH: x86_64
|
||||
MSYS2_DIR: msys64
|
||||
MSYSTEM: MINGW64
|
||||
GOPATH: c:\gopath
|
||||
GOROOT: c:\go19
|
||||
GOARCH: amd64
|
||||
EXTLD: x86_64-w64-mingw32-gcc
|
||||
- # Go 1.9, x64
|
||||
ARCH: x64
|
||||
MSYS2_ARCH: x86_64
|
||||
MSYS2_DIR: msys64
|
||||
MSYSTEM: MINGW64
|
||||
GOPATH: c:\gopath
|
||||
GOROOT: c:\go18
|
||||
GOARCH: amd64
|
||||
EXTLD: x86_64-w64-mingw32-gcc
|
||||
|
||||
clone_folder: C:\gopath\src\github.com\github\certstore
|
||||
|
||||
before_test:
|
||||
# Ensure CGO is enabled
|
||||
- set CGO_ENABLED=1
|
||||
# Go paths
|
||||
- set PATH=%GOROOT%\bin;C:\%GOPATH%\bin;%PATH%
|
||||
# MSYS paths
|
||||
- set PATH=C:\%MSYS2_DIR%\%MSYSTEM%\bin;C:\%MSYS2_DIR%\usr\bin;%PATH%
|
||||
# Install build deps
|
||||
- bash -lc "for n in `seq 1 3`; do pacman --noconfirm -S mingw-w64-%MSYS2_ARCH%-libtool && break || sleep 15; done"
|
||||
# Install Go deps
|
||||
- go get -t -v ./...
|
||||
|
||||
test_script:
|
||||
- go test -v ./...
|
||||
21
vendor/github.com/tailscale/certstore/LICENSE.md
generated
vendored
Normal file
21
vendor/github.com/tailscale/certstore/LICENSE.md
generated
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2017 Ben Toews.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
98
vendor/github.com/tailscale/certstore/README.md
generated
vendored
Normal file
98
vendor/github.com/tailscale/certstore/README.md
generated
vendored
Normal file
@@ -0,0 +1,98 @@
|
||||
# certstore
|
||||
|
||||
Certstore is a Go library for accessing user identities stored in platform certificate stores. On Windows and macOS, certstore can enumerate user identities and sign messages with their private keys.
|
||||
|
||||
## Fork from original module
|
||||
|
||||
This is a fork from the [cyolosecurity fork](https://github.com/cyolosecurity/certstore/)
|
||||
of the the original [certstore module](https://github.com/github/certstore/). The cyolosecurity
|
||||
fork adds some functionality that we require (RSA-PSS support and ability to use the machine
|
||||
certificate store instead of the current user certificate store). However, the cyolosecurity
|
||||
fork did not update the module name, so we are unable to import it directly, thus leading to
|
||||
the Tailscale fork.
|
||||
|
||||
As of this writing, the [RSA-PSS PR](https://github.com/github/certstore/pull/18) has been
|
||||
under review for several months and the machine certificate store change has not yet been sent
|
||||
to the github maintainer for review. Ideally these changes will make their way back into the
|
||||
original module and we can
|
||||
[discontinue this fork](https://github.com/tailscale/tailscale/issues/2005) at some point in
|
||||
the future.
|
||||
|
||||
## Example
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
|
||||
"github.com/github/certstore"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sig, err := signWithMyIdentity("Ben Toews", "hello, world!")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Println(hex.EncodeToString(sig))
|
||||
}
|
||||
|
||||
func signWithMyIdentity(cn, msg string) ([]byte, error) {
|
||||
// Open the certificate store for use. This must be Close()'ed once you're
|
||||
// finished with the store and any identities it contains.
|
||||
store, err := certstore.Open()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
// Get an Identity slice, containing every identity in the store. Each of
|
||||
// these must be Close()'ed when you're done with them.
|
||||
idents, err := store.Identities()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Iterate through the identities, looking for the one we want.
|
||||
var me certstore.Identity
|
||||
for _, ident := range idents {
|
||||
defer ident.Close()
|
||||
|
||||
crt, errr := ident.Certificate()
|
||||
if errr != nil {
|
||||
return nil, errr
|
||||
}
|
||||
|
||||
if crt.Subject.CommonName == "Ben Toews" {
|
||||
me = ident
|
||||
}
|
||||
}
|
||||
|
||||
if me == nil {
|
||||
return nil, errors.New("Couldn't find my identity")
|
||||
}
|
||||
|
||||
// Get a crypto.Signer for the identity.
|
||||
signer, err := me.Signer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Digest and sign our message.
|
||||
digest := sha256.Sum256([]byte(msg))
|
||||
signature, err := signer.Sign(rand.Reader, digest[:], crypto.SHA256)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
```
|
||||
71
vendor/github.com/tailscale/certstore/certstore.go
generated
vendored
Normal file
71
vendor/github.com/tailscale/certstore/certstore.go
generated
vendored
Normal file
@@ -0,0 +1,71 @@
|
||||
package certstore
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrUnsupportedHash is returned by Signer.Sign() when the provided hash
|
||||
// algorithm isn't supported.
|
||||
ErrUnsupportedHash = errors.New("unsupported hash algorithm")
|
||||
)
|
||||
|
||||
// StoreLocation defines the store location to look certificates in.
|
||||
type StoreLocation int
|
||||
|
||||
const (
|
||||
// User is the user scoped certificate store. "CURRENT_USER" on Windows or
|
||||
// "login" on MacOS
|
||||
User StoreLocation = iota
|
||||
// System is the system scoped certificate store. "LOCAL_MACHINE" on Windows
|
||||
// or "System" on MacOS
|
||||
System
|
||||
)
|
||||
|
||||
// StorePermission defines the store permission to open the store with
|
||||
type StorePermission int
|
||||
|
||||
const (
|
||||
// ReadOnly is the store permission allows reading and using certificates
|
||||
ReadOnly StorePermission = iota
|
||||
// ReadWrite is the store permission that allows importing certificates
|
||||
ReadWrite
|
||||
)
|
||||
|
||||
// Open opens the system's certificate store.
|
||||
func Open(location StoreLocation, permissions ...StorePermission) (Store, error) {
|
||||
return openStore(location, permissions...)
|
||||
}
|
||||
|
||||
// Store represents the system's certificate store.
|
||||
type Store interface {
|
||||
// Identities gets a list of identities from the store.
|
||||
Identities() ([]Identity, error)
|
||||
|
||||
// Import imports a PKCS#12 (PFX) blob containing a certificate and private
|
||||
// key.
|
||||
Import(data []byte, password string) error
|
||||
|
||||
// Close closes the store.
|
||||
Close()
|
||||
}
|
||||
|
||||
// Identity is a X.509 certificate and its corresponding private key.
|
||||
type Identity interface {
|
||||
// Certificate gets the identity's certificate.
|
||||
Certificate() (*x509.Certificate, error)
|
||||
|
||||
// CertificateChain attempts to get the identity's full certificate chain.
|
||||
CertificateChain() ([]*x509.Certificate, error)
|
||||
|
||||
// Signer gets a crypto.Signer that uses the identity's private key.
|
||||
Signer() (crypto.Signer, error)
|
||||
|
||||
// Delete deletes this identity from the system.
|
||||
Delete() error
|
||||
|
||||
// Close any manually managed memory held by the Identity.
|
||||
Close()
|
||||
}
|
||||
497
vendor/github.com/tailscale/certstore/certstore_darwin.go
generated
vendored
Normal file
497
vendor/github.com/tailscale/certstore/certstore_darwin.go
generated
vendored
Normal file
@@ -0,0 +1,497 @@
|
||||
package certstore
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -x objective-c
|
||||
#cgo LDFLAGS: -framework CoreFoundation -framework Security
|
||||
#include <CoreFoundation/CoreFoundation.h>
|
||||
#include <Security/Security.h>
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// work around https://golang.org/doc/go1.10#cgo
|
||||
// in go>=1.10 CFTypeRefs are translated to uintptrs instead of pointers.
|
||||
var (
|
||||
nilCFDictionaryRef C.CFDictionaryRef
|
||||
nilSecCertificateRef C.SecCertificateRef
|
||||
nilCFArrayRef C.CFArrayRef
|
||||
nilCFDataRef C.CFDataRef
|
||||
nilCFErrorRef C.CFErrorRef
|
||||
nilCFStringRef C.CFStringRef
|
||||
nilSecIdentityRef C.SecIdentityRef
|
||||
nilSecKeyRef C.SecKeyRef
|
||||
nilCFAllocatorRef C.CFAllocatorRef
|
||||
)
|
||||
|
||||
// macStore is a bogus type. We have to explicitly open/close the store on
|
||||
// windows, so we provide those methods here too.
|
||||
type macStore int
|
||||
|
||||
// openStore is a function for opening a macStore.
|
||||
func openStore(_ StoreLocation, _ ...StorePermission) (macStore, error) {
|
||||
return macStore(0), nil
|
||||
}
|
||||
|
||||
// Identities implements the Store interface.
|
||||
func (s macStore) Identities() ([]Identity, error) {
|
||||
query := mapToCFDictionary(map[C.CFTypeRef]C.CFTypeRef{
|
||||
C.CFTypeRef(C.kSecClass): C.CFTypeRef(C.kSecClassIdentity),
|
||||
C.CFTypeRef(C.kSecReturnRef): C.CFTypeRef(C.kCFBooleanTrue),
|
||||
C.CFTypeRef(C.kSecMatchLimit): C.CFTypeRef(C.kSecMatchLimitAll),
|
||||
})
|
||||
if query == nilCFDictionaryRef {
|
||||
return nil, errors.New("error creating CFDictionary")
|
||||
}
|
||||
defer C.CFRelease(C.CFTypeRef(query))
|
||||
|
||||
var absResult C.CFTypeRef
|
||||
if err := osStatusError(C.SecItemCopyMatching(query, &absResult)); err != nil {
|
||||
if err == errSecItemNotFound {
|
||||
return []Identity{}, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
defer C.CFRelease(C.CFTypeRef(absResult))
|
||||
|
||||
// don't need to release aryResult since the abstract result is released above.
|
||||
aryResult := C.CFArrayRef(absResult)
|
||||
|
||||
// identRefs aren't owned by us initially. newMacIdentity retains them.
|
||||
n := C.CFArrayGetCount(aryResult)
|
||||
identRefs := make([]C.CFTypeRef, n)
|
||||
C.CFArrayGetValues(aryResult, C.CFRange{0, n}, (*unsafe.Pointer)(unsafe.Pointer(&identRefs[0])))
|
||||
|
||||
idents := make([]Identity, 0, n)
|
||||
for _, identRef := range identRefs {
|
||||
idents = append(idents, newMacIdentity(C.SecIdentityRef(identRef)))
|
||||
}
|
||||
|
||||
return idents, nil
|
||||
}
|
||||
|
||||
// Import implements the Store interface.
|
||||
func (s macStore) Import(data []byte, password string) error {
|
||||
cdata, err := bytesToCFData(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer C.CFRelease(C.CFTypeRef(cdata))
|
||||
|
||||
cpass := stringToCFString(password)
|
||||
defer C.CFRelease(C.CFTypeRef(cpass))
|
||||
|
||||
cops := mapToCFDictionary(map[C.CFTypeRef]C.CFTypeRef{
|
||||
C.CFTypeRef(C.kSecImportExportPassphrase): C.CFTypeRef(cpass),
|
||||
})
|
||||
if cops == nilCFDictionaryRef {
|
||||
return errors.New("error creating CFDictionary")
|
||||
}
|
||||
defer C.CFRelease(C.CFTypeRef(cops))
|
||||
|
||||
var cret C.CFArrayRef
|
||||
if err := osStatusError(C.SecPKCS12Import(cdata, cops, &cret)); err != nil {
|
||||
return err
|
||||
}
|
||||
defer C.CFRelease(C.CFTypeRef(cret))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close implements the Store interface.
|
||||
func (s macStore) Close() {}
|
||||
|
||||
// macIdentity implements the Identity interface.
|
||||
type macIdentity struct {
|
||||
ref C.SecIdentityRef
|
||||
kref C.SecKeyRef
|
||||
cref C.SecCertificateRef
|
||||
crt *x509.Certificate
|
||||
chain []*x509.Certificate
|
||||
}
|
||||
|
||||
func newMacIdentity(ref C.SecIdentityRef) *macIdentity {
|
||||
C.CFRetain(C.CFTypeRef(ref))
|
||||
return &macIdentity{ref: ref}
|
||||
}
|
||||
|
||||
// Certificate implements the Identity interface.
|
||||
func (i *macIdentity) Certificate() (*x509.Certificate, error) {
|
||||
certRef, err := i.getCertRef()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
crt, err := exportCertRef(certRef)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
i.crt = crt
|
||||
|
||||
return i.crt, nil
|
||||
}
|
||||
|
||||
// CertificateChain implements the Identity interface.
|
||||
func (i *macIdentity) CertificateChain() ([]*x509.Certificate, error) {
|
||||
if i.chain != nil {
|
||||
return i.chain, nil
|
||||
}
|
||||
|
||||
certRef, err := i.getCertRef()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
policy := C.SecPolicyCreateSSL(0, nilCFStringRef)
|
||||
|
||||
var trustRef C.SecTrustRef
|
||||
if err := osStatusError(C.SecTrustCreateWithCertificates(C.CFTypeRef(certRef), C.CFTypeRef(policy), &trustRef)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer C.CFRelease(C.CFTypeRef(trustRef))
|
||||
|
||||
var status C.SecTrustResultType
|
||||
if err := osStatusError(C.SecTrustEvaluate(trustRef, &status)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
nchain = C.SecTrustGetCertificateCount(trustRef)
|
||||
chain = make([]*x509.Certificate, 0, int(nchain))
|
||||
)
|
||||
|
||||
for i := C.CFIndex(0); i < nchain; i++ {
|
||||
// TODO: do we need to release these?
|
||||
chainCertref := C.SecTrustGetCertificateAtIndex(trustRef, i)
|
||||
if chainCertref == nilSecCertificateRef {
|
||||
return nil, errors.New("nil certificate in chain")
|
||||
}
|
||||
|
||||
chainCert, err := exportCertRef(chainCertref)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chain = append(chain, chainCert)
|
||||
}
|
||||
|
||||
i.chain = chain
|
||||
|
||||
return chain, nil
|
||||
}
|
||||
|
||||
// Signer implements the Identity interface.
|
||||
func (i *macIdentity) Signer() (crypto.Signer, error) {
|
||||
// pre-load the certificate so Public() is less likely to return nil
|
||||
// unexpectedly.
|
||||
if _, err := i.Certificate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return i, nil
|
||||
}
|
||||
|
||||
// Delete implements the Identity interface.
|
||||
func (i *macIdentity) Delete() error {
|
||||
itemList := []C.SecIdentityRef{i.ref}
|
||||
itemListPtr := (*unsafe.Pointer)(unsafe.Pointer(&itemList[0]))
|
||||
citemList := C.CFArrayCreate(nilCFAllocatorRef, itemListPtr, 1, nil)
|
||||
if citemList == nilCFArrayRef {
|
||||
return errors.New("error creating CFArray")
|
||||
}
|
||||
defer C.CFRelease(C.CFTypeRef(citemList))
|
||||
|
||||
query := mapToCFDictionary(map[C.CFTypeRef]C.CFTypeRef{
|
||||
C.CFTypeRef(C.kSecClass): C.CFTypeRef(C.kSecClassIdentity),
|
||||
C.CFTypeRef(C.kSecMatchItemList): C.CFTypeRef(citemList),
|
||||
})
|
||||
if query == nilCFDictionaryRef {
|
||||
return errors.New("error creating CFDictionary")
|
||||
}
|
||||
defer C.CFRelease(C.CFTypeRef(query))
|
||||
|
||||
if err := osStatusError(C.SecItemDelete(query)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close implements the Identity interface.
|
||||
func (i *macIdentity) Close() {
|
||||
if i.ref != nilSecIdentityRef {
|
||||
C.CFRelease(C.CFTypeRef(i.ref))
|
||||
i.ref = nilSecIdentityRef
|
||||
}
|
||||
|
||||
if i.kref != nilSecKeyRef {
|
||||
C.CFRelease(C.CFTypeRef(i.kref))
|
||||
i.kref = nilSecKeyRef
|
||||
}
|
||||
|
||||
if i.cref != nilSecCertificateRef {
|
||||
C.CFRelease(C.CFTypeRef(i.cref))
|
||||
i.cref = nilSecCertificateRef
|
||||
}
|
||||
}
|
||||
|
||||
// Public implements the crypto.Signer interface.
|
||||
func (i *macIdentity) Public() crypto.PublicKey {
|
||||
cert, err := i.Certificate()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return cert.PublicKey
|
||||
}
|
||||
|
||||
// Sign implements the crypto.Signer interface.
|
||||
func (i *macIdentity) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
|
||||
hash := opts.HashFunc()
|
||||
|
||||
if len(digest) != hash.Size() {
|
||||
return nil, errors.New("bad digest for hash")
|
||||
}
|
||||
|
||||
kref, err := i.getKeyRef()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cdigest, err := bytesToCFData(digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer C.CFRelease(C.CFTypeRef(cdigest))
|
||||
|
||||
algo, err := i.getAlgo(opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// sign the digest
|
||||
var cerr C.CFErrorRef
|
||||
csig := C.SecKeyCreateSignature(kref, algo, cdigest, &cerr)
|
||||
|
||||
if err := cfErrorError(cerr); err != nil {
|
||||
defer C.CFRelease(C.CFTypeRef(cerr))
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if csig == nilCFDataRef {
|
||||
return nil, errors.New("nil signature from SecKeyCreateSignature")
|
||||
}
|
||||
|
||||
defer C.CFRelease(C.CFTypeRef(csig))
|
||||
|
||||
sig := cfDataToBytes(csig)
|
||||
|
||||
return sig, nil
|
||||
}
|
||||
|
||||
// getAlgo decides which algorithm to use with this key type for the given hash.
|
||||
func (i *macIdentity) getAlgo(opts crypto.SignerOpts) (algo C.SecKeyAlgorithm, err error) {
|
||||
hash := opts.HashFunc()
|
||||
var crt *x509.Certificate
|
||||
if crt, err = i.Certificate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch crt.PublicKey.(type) {
|
||||
case *ecdsa.PublicKey:
|
||||
switch hash {
|
||||
case crypto.SHA256:
|
||||
algo = C.kSecKeyAlgorithmECDSASignatureDigestX962SHA256
|
||||
case crypto.SHA384:
|
||||
algo = C.kSecKeyAlgorithmECDSASignatureDigestX962SHA384
|
||||
case crypto.SHA512:
|
||||
algo = C.kSecKeyAlgorithmECDSASignatureDigestX962SHA512
|
||||
default:
|
||||
err = ErrUnsupportedHash
|
||||
}
|
||||
case *rsa.PublicKey:
|
||||
if _, ok := opts.(*rsa.PSSOptions); ok {
|
||||
switch hash {
|
||||
case crypto.SHA256:
|
||||
algo = C.kSecKeyAlgorithmRSASignatureDigestPSSSHA256
|
||||
case crypto.SHA384:
|
||||
algo = C.kSecKeyAlgorithmRSASignatureDigestPSSSHA384
|
||||
case crypto.SHA512:
|
||||
algo = C.kSecKeyAlgorithmRSASignatureDigestPSSSHA512
|
||||
default:
|
||||
err = ErrUnsupportedHash
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
switch hash {
|
||||
case crypto.SHA256:
|
||||
algo = C.kSecKeyAlgorithmRSASignatureDigestPKCS1v15SHA256
|
||||
case crypto.SHA384:
|
||||
algo = C.kSecKeyAlgorithmRSASignatureDigestPKCS1v15SHA384
|
||||
case crypto.SHA512:
|
||||
algo = C.kSecKeyAlgorithmRSASignatureDigestPKCS1v15SHA512
|
||||
default:
|
||||
err = ErrUnsupportedHash
|
||||
}
|
||||
default:
|
||||
err = errors.New("unsupported key type")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// getKeyRef gets the SecKeyRef for this identity's pricate key.
|
||||
func (i *macIdentity) getKeyRef() (C.SecKeyRef, error) {
|
||||
if i.kref != nilSecKeyRef {
|
||||
return i.kref, nil
|
||||
}
|
||||
|
||||
var keyRef C.SecKeyRef
|
||||
if err := osStatusError(C.SecIdentityCopyPrivateKey(i.ref, &keyRef)); err != nil {
|
||||
return nilSecKeyRef, err
|
||||
}
|
||||
|
||||
i.kref = keyRef
|
||||
|
||||
return i.kref, nil
|
||||
}
|
||||
|
||||
// getCertRef gets the SecCertificateRef for this identity's certificate.
|
||||
func (i *macIdentity) getCertRef() (C.SecCertificateRef, error) {
|
||||
if i.cref != nilSecCertificateRef {
|
||||
return i.cref, nil
|
||||
}
|
||||
|
||||
var certRef C.SecCertificateRef
|
||||
if err := osStatusError(C.SecIdentityCopyCertificate(i.ref, &certRef)); err != nil {
|
||||
return nilSecCertificateRef, err
|
||||
}
|
||||
|
||||
i.cref = certRef
|
||||
|
||||
return i.cref, nil
|
||||
}
|
||||
|
||||
// exportCertRef gets a *x509.Certificate for the given SecCertificateRef.
|
||||
func exportCertRef(certRef C.SecCertificateRef) (*x509.Certificate, error) {
|
||||
derRef := C.SecCertificateCopyData(certRef)
|
||||
if derRef == nilCFDataRef {
|
||||
return nil, errors.New("error getting certificate from identity")
|
||||
}
|
||||
defer C.CFRelease(C.CFTypeRef(derRef))
|
||||
|
||||
der := cfDataToBytes(derRef)
|
||||
crt, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return crt, nil
|
||||
}
|
||||
|
||||
// stringToCFString converts a Go string to a CFStringRef.
|
||||
func stringToCFString(gostr string) C.CFStringRef {
|
||||
cstr := C.CString(gostr)
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
|
||||
return C.CFStringCreateWithCString(nilCFAllocatorRef, cstr, C.kCFStringEncodingUTF8)
|
||||
}
|
||||
|
||||
// mapToCFDictionary converts a Go map[C.CFTypeRef]C.CFTypeRef to a
|
||||
// CFDictionaryRef.
|
||||
func mapToCFDictionary(gomap map[C.CFTypeRef]C.CFTypeRef) C.CFDictionaryRef {
|
||||
var (
|
||||
n = len(gomap)
|
||||
keys = make([]unsafe.Pointer, 0, n)
|
||||
values = make([]unsafe.Pointer, 0, n)
|
||||
)
|
||||
|
||||
for k, v := range gomap {
|
||||
keys = append(keys, unsafe.Pointer(k))
|
||||
values = append(values, unsafe.Pointer(v))
|
||||
}
|
||||
|
||||
return C.CFDictionaryCreate(nilCFAllocatorRef, &keys[0], &values[0], C.CFIndex(n), nil, nil)
|
||||
}
|
||||
|
||||
// cfDataToBytes converts a CFDataRef to a Go byte slice.
|
||||
func cfDataToBytes(cfdata C.CFDataRef) []byte {
|
||||
nBytes := C.CFDataGetLength(cfdata)
|
||||
bytesPtr := C.CFDataGetBytePtr(cfdata)
|
||||
return C.GoBytes(unsafe.Pointer(bytesPtr), C.int(nBytes))
|
||||
}
|
||||
|
||||
// bytesToCFData converts a Go byte slice to a CFDataRef.
|
||||
func bytesToCFData(gobytes []byte) (C.CFDataRef, error) {
|
||||
var (
|
||||
cptr = (*C.UInt8)(nil)
|
||||
clen = C.CFIndex(len(gobytes))
|
||||
)
|
||||
|
||||
if len(gobytes) > 0 {
|
||||
cptr = (*C.UInt8)(&gobytes[0])
|
||||
}
|
||||
|
||||
cdata := C.CFDataCreate(nilCFAllocatorRef, cptr, clen)
|
||||
if cdata == nilCFDataRef {
|
||||
return nilCFDataRef, errors.New("error creatin cfdata")
|
||||
}
|
||||
|
||||
return cdata, nil
|
||||
}
|
||||
|
||||
// osStatus wraps a C.OSStatus
|
||||
type osStatus C.OSStatus
|
||||
|
||||
const (
|
||||
errSecItemNotFound = osStatus(C.errSecItemNotFound)
|
||||
)
|
||||
|
||||
// osStatusError returns an error for an OSStatus unless it is errSecSuccess.
|
||||
func osStatusError(s C.OSStatus) error {
|
||||
if s == C.errSecSuccess {
|
||||
return nil
|
||||
}
|
||||
|
||||
return osStatus(s)
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (s osStatus) Error() string {
|
||||
return fmt.Sprintf("OSStatus %d", s)
|
||||
}
|
||||
|
||||
// cfErrorError returns an error for a CFErrorRef unless it is nil.
|
||||
func cfErrorError(cerr C.CFErrorRef) error {
|
||||
if cerr == nilCFErrorRef {
|
||||
return nil
|
||||
}
|
||||
|
||||
code := int(C.CFErrorGetCode(cerr))
|
||||
|
||||
if cdescription := C.CFErrorCopyDescription(cerr); cdescription != nilCFStringRef {
|
||||
defer C.CFRelease(C.CFTypeRef(cdescription))
|
||||
|
||||
if cstr := C.CFStringGetCStringPtr(cdescription, C.kCFStringEncodingUTF8); cstr != nil {
|
||||
str := C.GoString(cstr)
|
||||
|
||||
return fmt.Errorf("CFError %d (%s)", code, str)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return fmt.Errorf("CFError %d", code)
|
||||
}
|
||||
14
vendor/github.com/tailscale/certstore/certstore_linux.go
generated
vendored
Normal file
14
vendor/github.com/tailscale/certstore/certstore_linux.go
generated
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
package certstore
|
||||
|
||||
import "errors"
|
||||
|
||||
// This will hopefully give a compiler error that will hint at the fact that
|
||||
// this package isn't designed to work on Linux.
|
||||
func init() {
|
||||
CERTSTORE_DOESNT_WORK_ON_LINIX
|
||||
}
|
||||
|
||||
// Implement this function, just to silence other compiler errors.
|
||||
func openStore(location StoreLocation, permissions ...StorePermission) (Store, error) {
|
||||
return nil, errors.New("certstore only works on macOS and Windows")
|
||||
}
|
||||
639
vendor/github.com/tailscale/certstore/certstore_windows.go
generated
vendored
Normal file
639
vendor/github.com/tailscale/certstore/certstore_windows.go
generated
vendored
Normal file
@@ -0,0 +1,639 @@
|
||||
package certstore
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/cpu"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const myStoreName = "MY"
|
||||
|
||||
var (
|
||||
// winAPIFlag specifies the flags that should be passed to
|
||||
// CryptAcquireCertificatePrivateKey. This impacts whether the CryptoAPI or CNG
|
||||
// API will be used.
|
||||
//
|
||||
// Possible values are:
|
||||
//
|
||||
// 0 — Only use CryptoAPI.
|
||||
// windows.CRYPT_ACQUIRE_ALLOW_NCRYPT_KEY_FLAG — Prefer CryptoAPI.
|
||||
// windows.CRYPT_ACQUIRE_PREFER_NCRYPT_KEY_FLAG — Prefer CNG.
|
||||
// windows.CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG — Only use CNG.
|
||||
winAPIFlag uint32 = windows.CRYPT_ACQUIRE_PREFER_NCRYPT_KEY_FLAG
|
||||
)
|
||||
|
||||
// winStore is a wrapper around a certStoreHandle.
|
||||
type winStore struct {
|
||||
store certStoreHandle
|
||||
}
|
||||
|
||||
// openStore opens the current user's personal cert store.
|
||||
func openStore(location StoreLocation, permissions ...StorePermission) (*winStore, error) {
|
||||
storeName, err := windows.UTF16PtrFromString(myStoreName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var flags uint32
|
||||
switch location {
|
||||
case User:
|
||||
flags |= windows.CERT_SYSTEM_STORE_CURRENT_USER
|
||||
case System:
|
||||
flags |= windows.CERT_SYSTEM_STORE_LOCAL_MACHINE
|
||||
}
|
||||
|
||||
for _, p := range permissions {
|
||||
switch p {
|
||||
case ReadOnly:
|
||||
flags |= windows.CERT_STORE_READONLY_FLAG
|
||||
}
|
||||
}
|
||||
|
||||
store, err := windows.CertOpenStore(windows.CERT_STORE_PROV_SYSTEM_W, 0, 0, flags, uintptr(unsafe.Pointer(storeName)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open system cert store: %w", err)
|
||||
}
|
||||
|
||||
return &winStore{certStoreHandle(store)}, nil
|
||||
}
|
||||
|
||||
// Identities implements the Store interface.
|
||||
func (s *winStore) Identities() ([]Identity, error) {
|
||||
var (
|
||||
err error
|
||||
idents []Identity
|
||||
|
||||
// CertFindChainInStore parameters
|
||||
encoding = uint32(windows.X509_ASN_ENCODING)
|
||||
flags = uint32(windows.CERT_CHAIN_FIND_BY_ISSUER_CACHE_ONLY_FLAG | windows.CERT_CHAIN_FIND_BY_ISSUER_CACHE_ONLY_URL_FLAG)
|
||||
findType = uint32(windows.CERT_CHAIN_FIND_BY_ISSUER)
|
||||
params = &windows.CertChainFindByIssuerPara{Size: uint32(unsafe.Sizeof(windows.CertChainFindByIssuerPara{}))}
|
||||
chainCtx *windows.CertChainContext
|
||||
)
|
||||
|
||||
for {
|
||||
chainCtx, err = windows.CertFindChainInStore(windows.Handle(s.store), encoding, flags, findType, unsafe.Pointer(params), chainCtx)
|
||||
if errors.Is(err, syscall.Errno(windows.CRYPT_E_NOT_FOUND)) {
|
||||
return idents, nil
|
||||
}
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to iterate certs in store: %w", err)
|
||||
break
|
||||
}
|
||||
if chainCtx.ChainCount < 1 {
|
||||
err = errors.New("bad chain")
|
||||
break
|
||||
}
|
||||
|
||||
// not sure why this isn't 1 << 29
|
||||
const maxPointerArray = 1 << 28
|
||||
|
||||
// rgpChain is actually an array, but we only care about the first one.
|
||||
simpleChain := *chainCtx.Chains
|
||||
if simpleChain.NumElements < 1 || simpleChain.NumElements > maxPointerArray {
|
||||
err = errors.New("bad chain")
|
||||
break
|
||||
}
|
||||
|
||||
chainElts := unsafe.Slice(simpleChain.Elements, simpleChain.NumElements)
|
||||
|
||||
// Build chain of certificates from each elt's certificate context.
|
||||
chain := make([]*windows.CertContext, len(chainElts))
|
||||
for j := range chainElts {
|
||||
chain[j] = chainElts[j].CertContext
|
||||
}
|
||||
|
||||
idents = append(idents, newWinIdentity(chain))
|
||||
}
|
||||
|
||||
for _, ident := range idents {
|
||||
ident.Close()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Import implements the Store interface.
|
||||
func (s *winStore) Import(data []byte, password string) error {
|
||||
cpw, err := windows.UTF16PtrFromString(password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("converting password: %w", err)
|
||||
}
|
||||
|
||||
pfx := &windows.CryptDataBlob{
|
||||
Size: uint32(len(data)),
|
||||
Data: unsafe.SliceData(data),
|
||||
}
|
||||
|
||||
var flags uint32
|
||||
flags = windows.CRYPT_USER_KEYSET
|
||||
|
||||
// import into preferred KSP
|
||||
if winAPIFlag&windows.CRYPT_ACQUIRE_PREFER_NCRYPT_KEY_FLAG > 0 {
|
||||
flags |= windows.PKCS12_PREFER_CNG_KSP
|
||||
} else if winAPIFlag&windows.CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG > 0 {
|
||||
flags |= windows.PKCS12_ALWAYS_CNG_KSP
|
||||
}
|
||||
|
||||
store, err := windows.PFXImportCertStore(pfx, cpw, flags)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to import PFX cert store: %w", err)
|
||||
}
|
||||
defer windows.CertCloseStore(store, windows.CERT_CLOSE_STORE_FORCE_FLAG)
|
||||
|
||||
var (
|
||||
ctx *windows.CertContext
|
||||
encoding = uint32(windows.X509_ASN_ENCODING | windows.PKCS_7_ASN_ENCODING)
|
||||
)
|
||||
|
||||
for {
|
||||
// iterate through certs in temporary store
|
||||
ctx, err = windows.CertFindCertificateInStore(store, encoding, 0, windows.CERT_FIND_ANY, nil, ctx)
|
||||
if errors.Is(err, syscall.Errno(windows.CRYPT_E_NOT_FOUND)) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to iterate certs in store: %w", err)
|
||||
}
|
||||
|
||||
// Copy the cert to the system store.
|
||||
if err := windows.CertAddCertificateContextToStore(windows.Handle(s.store), ctx, windows.CERT_STORE_ADD_REPLACE_EXISTING, nil); err != nil {
|
||||
return fmt.Errorf("failed to add imported certificate to %s store: %w", myStoreName, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close implements the Store interface.
|
||||
func (s *winStore) Close() {
|
||||
windows.CertCloseStore(windows.Handle(s.store), 0)
|
||||
s.store = 0
|
||||
}
|
||||
|
||||
// winIdentity implements the Identity interface.
|
||||
type winIdentity struct {
|
||||
chain []*windows.CertContext
|
||||
signer *winPrivateKey
|
||||
}
|
||||
|
||||
func newWinIdentity(chain []*windows.CertContext) *winIdentity {
|
||||
for _, ctx := range chain {
|
||||
windows.CertDuplicateCertificateContext(ctx)
|
||||
}
|
||||
|
||||
return &winIdentity{chain: chain}
|
||||
}
|
||||
|
||||
// Certificate implements the Identity interface.
|
||||
func (i *winIdentity) Certificate() (*x509.Certificate, error) {
|
||||
return exportCertCtx(i.chain[0])
|
||||
}
|
||||
|
||||
// CertificateChain implements the Identity interface.
|
||||
func (i *winIdentity) CertificateChain() ([]*x509.Certificate, error) {
|
||||
var (
|
||||
certs = make([]*x509.Certificate, len(i.chain))
|
||||
err error
|
||||
)
|
||||
|
||||
for j := range i.chain {
|
||||
if certs[j], err = exportCertCtx(i.chain[j]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return certs, nil
|
||||
}
|
||||
|
||||
// Signer implements the Identity interface.
|
||||
func (i *winIdentity) Signer() (crypto.Signer, error) {
|
||||
return i.getPrivateKey()
|
||||
}
|
||||
|
||||
// getPrivateKey gets this identity's private *winPrivateKey.
|
||||
func (i *winIdentity) getPrivateKey() (*winPrivateKey, error) {
|
||||
if i.signer != nil {
|
||||
return i.signer, nil
|
||||
}
|
||||
|
||||
cert, err := i.Certificate()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get identity certificate: %w", err)
|
||||
}
|
||||
|
||||
signer, err := newWinPrivateKey(i.chain[0], cert.PublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load identity private key: %w", err)
|
||||
}
|
||||
|
||||
i.signer = signer
|
||||
|
||||
return i.signer, nil
|
||||
}
|
||||
|
||||
// Delete implements the Identity interface.
|
||||
func (i *winIdentity) Delete() error {
|
||||
// duplicate cert context, since CertDeleteCertificateFromStore will free it.
|
||||
deleteCtx := windows.CertDuplicateCertificateContext(i.chain[0])
|
||||
|
||||
// try deleting cert
|
||||
if err := windows.CertDeleteCertificateFromStore(deleteCtx); err != nil {
|
||||
return fmt.Errorf("failed to delete certificate from store: %w", err)
|
||||
}
|
||||
|
||||
// try deleting private key
|
||||
wpk, err := i.getPrivateKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get identity private key: %w", err)
|
||||
}
|
||||
|
||||
if err := wpk.Delete(); err != nil {
|
||||
return fmt.Errorf("failed to delete identity private key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close implements the Identity interface.
|
||||
func (i *winIdentity) Close() {
|
||||
if i.signer != nil {
|
||||
i.signer.Close()
|
||||
i.signer = nil
|
||||
}
|
||||
|
||||
for _, ctx := range i.chain {
|
||||
_ = windows.CertFreeCertificateContext(ctx)
|
||||
i.chain = nil
|
||||
}
|
||||
}
|
||||
|
||||
// winPrivateKey is a wrapper around a HCRYPTPROV_OR_NCRYPT_KEY_HANDLE.
|
||||
type winPrivateKey struct {
|
||||
publicKey crypto.PublicKey
|
||||
|
||||
// CryptoAPI fields
|
||||
capiProv cryptProviderHandle
|
||||
|
||||
// CNG fields
|
||||
cngHandle nCryptKeyHandle
|
||||
keySpec uint32
|
||||
}
|
||||
|
||||
// newWinPrivateKey gets a *winPrivateKey for the given certificate.
|
||||
func newWinPrivateKey(certCtx *windows.CertContext, publicKey crypto.PublicKey) (*winPrivateKey, error) {
|
||||
var (
|
||||
provOrKey windows.Handle // HCRYPTPROV_OR_NCRYPT_KEY_HANDLE
|
||||
keySpec uint32
|
||||
mustFree bool
|
||||
)
|
||||
|
||||
if publicKey == nil {
|
||||
return nil, errors.New("nil public key")
|
||||
}
|
||||
|
||||
// Get a handle for the found private key.
|
||||
if err := windows.CryptAcquireCertificatePrivateKey(certCtx, winAPIFlag, nil, &provOrKey, &keySpec, &mustFree); err != nil {
|
||||
return nil, fmt.Errorf("failed to get private key for certificate: %w", err)
|
||||
}
|
||||
|
||||
if !mustFree {
|
||||
// This shouldn't happen since we're not asking for cached keys.
|
||||
return nil, errors.New("CryptAcquireCertificatePrivateKey set mustFree")
|
||||
}
|
||||
|
||||
if keySpec == windows.CERT_NCRYPT_KEY_SPEC {
|
||||
return &winPrivateKey{
|
||||
publicKey: publicKey,
|
||||
cngHandle: nCryptKeyHandle(provOrKey),
|
||||
}, nil
|
||||
} else {
|
||||
return &winPrivateKey{
|
||||
publicKey: publicKey,
|
||||
capiProv: cryptProviderHandle(provOrKey),
|
||||
keySpec: keySpec,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Public implements the crypto.Signer interface.
|
||||
func (wpk *winPrivateKey) Public() crypto.PublicKey {
|
||||
return wpk.publicKey
|
||||
}
|
||||
|
||||
// Sign implements the crypto.Signer interface.
|
||||
func (wpk *winPrivateKey) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
|
||||
if wpk.capiProv != 0 {
|
||||
return wpk.capiSignHash(opts, digest)
|
||||
} else if wpk.cngHandle != 0 {
|
||||
return wpk.cngSignHash(opts, digest)
|
||||
} else {
|
||||
return nil, errors.New("bad private key")
|
||||
}
|
||||
}
|
||||
|
||||
type bCryptPKCS1PaddingInfo struct {
|
||||
// Algorithm is the name of the hashing algorithm to use, in UTF16. Use one
|
||||
// of the BCRYPT_*_ALGORITHM constants.
|
||||
Algorithm *uint16
|
||||
}
|
||||
|
||||
type bCryptPSSPaddingInfo struct {
|
||||
// Algorithm is the name of the hashing algorithm to use, in UTF16. Use one
|
||||
// of the BCRYPT_*_ALGORITHM constants.
|
||||
Algorithm *uint16
|
||||
// SaltLength is the size, in bytes, of the random salt to use for the
|
||||
// padding.
|
||||
SaltLength uint32
|
||||
}
|
||||
|
||||
// cngSignHash signs a digest using the CNG APIs.
|
||||
func (wpk *winPrivateKey) cngSignHash(opts crypto.SignerOpts, digest []byte) ([]byte, error) {
|
||||
hash := opts.HashFunc()
|
||||
if len(digest) != hash.Size() {
|
||||
return nil, errors.New("cngSignHash: bad digest for hash")
|
||||
}
|
||||
|
||||
var (
|
||||
// input
|
||||
padPtr unsafe.Pointer
|
||||
digestPtr = unsafe.SliceData(digest)
|
||||
digestLen = uint32(len(digest))
|
||||
flags uint32
|
||||
|
||||
// output
|
||||
sigLen uint32
|
||||
)
|
||||
|
||||
// setup pkcs1v1.5 padding for RSA
|
||||
if _, isRSA := wpk.publicKey.(*rsa.PublicKey); isRSA {
|
||||
var alg string
|
||||
switch hash {
|
||||
case crypto.SHA256:
|
||||
alg = BCRYPT_SHA256_ALGORITHM
|
||||
case crypto.SHA384:
|
||||
alg = BCRYPT_SHA384_ALGORITHM
|
||||
case crypto.SHA512:
|
||||
alg = BCRYPT_SHA512_ALGORITHM
|
||||
default:
|
||||
return nil, fmt.Errorf("cngSignHash: converting from Go algorithm: %w", ErrUnsupportedHash)
|
||||
}
|
||||
|
||||
algName, err := windows.UTF16PtrFromString(alg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if pssOpts, ok := opts.(*rsa.PSSOptions); ok {
|
||||
saltLen := pssOpts.SaltLength
|
||||
if saltLen == rsa.PSSSaltLengthEqualsHash {
|
||||
saltLen = len(digest)
|
||||
} else if saltLen == rsa.PSSSaltLengthAuto {
|
||||
saltLen = hash.Size()
|
||||
}
|
||||
padPtr = unsafe.Pointer(&bCryptPSSPaddingInfo{
|
||||
Algorithm: algName,
|
||||
SaltLength: uint32(saltLen),
|
||||
})
|
||||
flags |= BCRYPT_PAD_PSS
|
||||
} else {
|
||||
padPtr = unsafe.Pointer(&bCryptPKCS1PaddingInfo{
|
||||
Algorithm: algName,
|
||||
})
|
||||
flags |= BCRYPT_PAD_PKCS1
|
||||
}
|
||||
}
|
||||
|
||||
// get signature length
|
||||
if err := nCryptSignHash(wpk.cngHandle, padPtr, digestPtr, digestLen, nil, 0, &sigLen, flags); err != nil {
|
||||
return nil, fmt.Errorf("cngSignHash: failed to get signature length: %w", err)
|
||||
}
|
||||
|
||||
// get signature
|
||||
sig := make([]byte, sigLen)
|
||||
if err := nCryptSignHash(wpk.cngHandle, padPtr, digestPtr, digestLen, unsafe.SliceData(sig), sigLen, &sigLen, flags); err != nil {
|
||||
return nil, fmt.Errorf("cngSignHash: failed to sign digest: %w", err)
|
||||
}
|
||||
|
||||
// CNG returns a raw ECDSA signature, but we wan't ASN.1 DER encoding.
|
||||
if _, isEC := wpk.publicKey.(*ecdsa.PublicKey); isEC {
|
||||
if len(sig)%2 != 0 {
|
||||
return nil, errors.New("cngSignHash: bad ecdsa signature from CNG")
|
||||
}
|
||||
|
||||
type ecdsaSignature struct {
|
||||
R, S *big.Int
|
||||
}
|
||||
|
||||
r := new(big.Int).SetBytes(sig[:len(sig)/2])
|
||||
s := new(big.Int).SetBytes(sig[len(sig)/2:])
|
||||
|
||||
encoded, err := asn1.Marshal(ecdsaSignature{r, s})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cngSignHash: failed to ASN.1 encode EC signature: %w", err)
|
||||
}
|
||||
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
return sig, nil
|
||||
}
|
||||
|
||||
// capiSignHash signs a digest using the CryptoAPI APIs.
|
||||
func (wpk *winPrivateKey) capiSignHash(opts crypto.SignerOpts, digest []byte) ([]byte, error) {
|
||||
if _, ok := opts.(*rsa.PSSOptions); ok {
|
||||
return nil, fmt.Errorf("capiSignHash: CAPI does not support PSS padding, %w", ErrUnsupportedHash)
|
||||
}
|
||||
|
||||
hash := opts.HashFunc()
|
||||
if len(digest) != hash.Size() {
|
||||
return nil, errors.New("capiSignHash: bad digest for hash")
|
||||
}
|
||||
|
||||
// Figure out which CryptoAPI hash algorithm we're using.
|
||||
var hash_alg cryptAlgorithm
|
||||
|
||||
switch hash {
|
||||
case crypto.SHA256:
|
||||
hash_alg = CALG_SHA_256
|
||||
case crypto.SHA384:
|
||||
hash_alg = CALG_SHA_384
|
||||
case crypto.SHA512:
|
||||
hash_alg = CALG_SHA_512
|
||||
default:
|
||||
return nil, fmt.Errorf("capiSignHash: converting from Go algorithm: %w", ErrUnsupportedHash)
|
||||
}
|
||||
|
||||
// Instantiate a CryptoAPI hash object.
|
||||
var chash cryptHashHandle
|
||||
|
||||
if err := cryptCreateHash(wpk.capiProv, hash_alg, 0, 0, &chash); err != nil {
|
||||
if errors.Is(err, syscall.Errno(windows.NTE_BAD_ALGID)) {
|
||||
err = ErrUnsupportedHash
|
||||
}
|
||||
return nil, fmt.Errorf("capiSignHash: cryptCreateHash: %w", err)
|
||||
}
|
||||
defer cryptDestroyHash(chash)
|
||||
|
||||
// Make sure the hash size matches.
|
||||
var (
|
||||
hashSize uint32
|
||||
hashSizeLen = uint32(unsafe.Sizeof(hashSize))
|
||||
)
|
||||
|
||||
if err := cryptGetHashParam(chash, HP_HASHSIZE, unsafe.Pointer(&hashSize), &hashSizeLen, 0); err != nil {
|
||||
return nil, fmt.Errorf("capiSignHash: failed to get hash size: %w", err)
|
||||
}
|
||||
|
||||
if hash.Size() != int(hashSize) {
|
||||
return nil, errors.New("capiSignHash: invalid CryptoAPI hash")
|
||||
}
|
||||
|
||||
// Put our digest into the hash object.
|
||||
if err := cryptSetHashParam(chash, HP_HASHVAL, unsafe.Pointer(unsafe.SliceData(digest)), 0); err != nil {
|
||||
return nil, fmt.Errorf("capiSignHash: failed to set hash digest: %w", err)
|
||||
}
|
||||
|
||||
// Get signature length.
|
||||
var sigLen uint32
|
||||
|
||||
if err := cryptSignHash(chash, wpk.keySpec, nil, 0, nil, &sigLen); err != nil {
|
||||
return nil, fmt.Errorf("capiSignHash: failed to get signature length: %w", err)
|
||||
}
|
||||
|
||||
// Get signature
|
||||
sig := make([]byte, int(sigLen))
|
||||
|
||||
if err := cryptSignHash(chash, wpk.keySpec, nil, 0, unsafe.SliceData(sig), &sigLen); err != nil {
|
||||
return nil, fmt.Errorf("capiSignHash: failed to sign digest: %w", err)
|
||||
}
|
||||
|
||||
// Signature is little endian, but we want big endian. Reverse it.
|
||||
for i := len(sig)/2 - 1; i >= 0; i-- {
|
||||
opp := len(sig) - 1 - i
|
||||
sig[i], sig[opp] = sig[opp], sig[i]
|
||||
}
|
||||
|
||||
return sig, nil
|
||||
}
|
||||
|
||||
func (wpk *winPrivateKey) Delete() error {
|
||||
if wpk.cngHandle != 0 {
|
||||
// Delete CNG key
|
||||
if err := nCryptDeleteKey(wpk.cngHandle, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if wpk.capiProv != 0 {
|
||||
// Delete CryptoAPI key
|
||||
var (
|
||||
containerName string
|
||||
providerName string
|
||||
providerType uint32
|
||||
err error
|
||||
)
|
||||
|
||||
containerName, err = wpk.getProviderStringParam(PP_CONTAINER)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get PP_CONTAINER: %w", err)
|
||||
}
|
||||
|
||||
providerName, err = wpk.getProviderStringParam(PP_NAME)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get PP_NAME: %w", err)
|
||||
}
|
||||
|
||||
providerType, err = wpk.getProviderUint32Param(PP_PROVTYPE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get PP_PROVTYPE: %w", err)
|
||||
}
|
||||
|
||||
// use CRYPT_SILENT too?
|
||||
var prov windows.Handle
|
||||
if err := windows.CryptAcquireContext(&prov, windows.StringToUTF16Ptr(containerName), windows.StringToUTF16Ptr(providerName),
|
||||
providerType, windows.CRYPT_DELETEKEYSET); err != nil {
|
||||
return fmt.Errorf("failed to delete key set (container=%q, name=%q, provtype=%d): %w",
|
||||
containerName, providerName, providerType, err)
|
||||
}
|
||||
wpk.capiProv = 0
|
||||
} else {
|
||||
return errors.New("bad private key")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getProviderParam gets a parameter about a provider.
|
||||
func (wpk *winPrivateKey) getProviderParam(param uint32) ([]byte, error) {
|
||||
var dataLen uint32
|
||||
if err := cryptGetProvParam(wpk.capiProv, param, nil, &dataLen, 0); err != nil {
|
||||
return nil, fmt.Errorf("failed to get provider parameter size: %w", err)
|
||||
}
|
||||
|
||||
data := make([]byte, dataLen)
|
||||
if err := cryptGetProvParam(wpk.capiProv, param, unsafe.SliceData(data), &dataLen, 0); err != nil {
|
||||
return nil, fmt.Errorf("failed to get provider parameter: %w", err)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// getProviderStringParam gets a parameter about a provider. The parameter is a zero-terminated char string.
|
||||
func (wpk *winPrivateKey) getProviderStringParam(param uint32) (string, error) {
|
||||
val, err := wpk.getProviderParam(param)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return windows.ByteSliceToString(val), nil
|
||||
}
|
||||
|
||||
// getProviderUint32Param gets a parameter about a provider. The parameter is a uint32.
|
||||
func (wpk *winPrivateKey) getProviderUint32Param(param uint32) (uint32, error) {
|
||||
val, err := wpk.getProviderParam(param)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(val) != 4 {
|
||||
return 0, errors.New("uint32 data length is not 32 bits")
|
||||
}
|
||||
if cpu.IsBigEndian {
|
||||
return binary.BigEndian.Uint32(val), nil
|
||||
}
|
||||
return binary.LittleEndian.Uint32(val), nil
|
||||
}
|
||||
|
||||
// Close closes this winPrivateKey.
|
||||
func (wpk *winPrivateKey) Close() {
|
||||
if wpk.cngHandle != 0 {
|
||||
_ = nCryptFreeObject(windows.Handle(wpk.cngHandle))
|
||||
wpk.cngHandle = 0
|
||||
}
|
||||
|
||||
if wpk.capiProv != 0 {
|
||||
_ = windows.CryptReleaseContext(windows.Handle(wpk.capiProv), 0)
|
||||
wpk.capiProv = 0
|
||||
}
|
||||
}
|
||||
|
||||
// exportCertCtx exports a *windows.CertContext as an *x509.Certificate.
|
||||
func exportCertCtx(ctx *windows.CertContext) (*x509.Certificate, error) {
|
||||
der := unsafe.Slice(ctx.EncodedCert, ctx.Length)
|
||||
|
||||
cert, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("certificate parsing failed: %w", err)
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
79
vendor/github.com/tailscale/certstore/crypt_strings_windows.go
generated
vendored
Normal file
79
vendor/github.com/tailscale/certstore/crypt_strings_windows.go
generated
vendored
Normal file
@@ -0,0 +1,79 @@
|
||||
package certstore
|
||||
|
||||
const (
|
||||
// NCRYPT Object Property Names
|
||||
NCRYPT_ALGORITHM_GROUP_PROPERTY = "Algorithm Group"
|
||||
NCRYPT_ALGORITHM_PROPERTY = "Algorithm Name"
|
||||
NCRYPT_BLOCK_LENGTH_PROPERTY = "Block Length"
|
||||
NCRYPT_CERTIFICATE_PROPERTY = "SmartCardKeyCertificate"
|
||||
NCRYPT_DH_PARAMETERS_PROPERTY = BCRYPT_DH_PARAMETERS
|
||||
BCRYPT_DH_PARAMETERS = "DHParameters"
|
||||
NCRYPT_EXPORT_POLICY_PROPERTY = "Export Policy"
|
||||
NCRYPT_IMPL_TYPE_PROPERTY = "Impl Type"
|
||||
NCRYPT_KEY_TYPE_PROPERTY = "Key Type"
|
||||
NCRYPT_KEY_USAGE_PROPERTY = "Key Usage"
|
||||
NCRYPT_LAST_MODIFIED_PROPERTY = "Modified"
|
||||
NCRYPT_LENGTH_PROPERTY = "Length"
|
||||
NCRYPT_LENGTHS_PROPERTY = "Lengths"
|
||||
NCRYPT_MAX_NAME_LENGTH_PROPERTY = "Max Name Length"
|
||||
NCRYPT_NAME_PROPERTY = "Name"
|
||||
NCRYPT_PIN_PROMPT_PROPERTY = "SmartCardPinPrompt"
|
||||
NCRYPT_PIN_PROPERTY = "SmartCardPin"
|
||||
NCRYPT_PROVIDER_HANDLE_PROPERTY = "Provider Handle"
|
||||
NCRYPT_READER_PROPERTY = "SmartCardReader"
|
||||
NCRYPT_ROOT_CERTSTORE_PROPERTY = "SmartcardRootCertStore"
|
||||
NCRYPT_SECURE_PIN_PROPERTY = "SmartCardSecurePin"
|
||||
NCRYPT_SECURITY_DESCR_PROPERTY = "Security Descr"
|
||||
NCRYPT_SECURITY_DESCR_SUPPORT_PROPERTY = "Security Descr Support"
|
||||
NCRYPT_SMARTCARD_GUID_PROPERTY = "SmartCardGuid"
|
||||
NCRYPT_UI_POLICY_PROPERTY = "UI Policy"
|
||||
NCRYPT_UNIQUE_NAME_PROPERTY = "Unique Name"
|
||||
NCRYPT_USE_CONTEXT_PROPERTY = "Use Context"
|
||||
NCRYPT_USE_COUNT_ENABLED_PROPERTY = "Enabled Use Count"
|
||||
NCRYPT_USE_COUNT_PROPERTY = "Use Count"
|
||||
NCRYPT_USER_CERTSTORE_PROPERTY = "SmartCardUserCertStore"
|
||||
NCRYPT_VERSION_PROPERTY = "Version"
|
||||
NCRYPT_WINDOW_HANDLE_PROPERTY = "HWND Handle"
|
||||
|
||||
// BCRYPT BLOB Types
|
||||
BCRYPT_DH_PRIVATE_BLOB = "DHPRIVATEBLOB"
|
||||
BCRYPT_DH_PUBLIC_BLOB = "DHPUBLICBLOB"
|
||||
BCRYPT_DSA_PRIVATE_BLOB = "DSAPRIVATEBLOB"
|
||||
BCRYPT_DSA_PUBLIC_BLOB = "DSAPUBLICBLOB"
|
||||
BCRYPT_ECCPRIVATE_BLOB = "ECCPRIVATEBLOB"
|
||||
BCRYPT_ECCPUBLIC_BLOB = "ECCPUBLICBLOB"
|
||||
BCRYPT_PRIVATE_KEY_BLOB = "PRIVATEBLOB"
|
||||
BCRYPT_PUBLIC_KEY_BLOB = "PUBLICBLOB"
|
||||
BCRYPT_RSAFULLPRIVATE_BLOB = "RSAFULLPRIVATEBLOB"
|
||||
BCRYPT_RSAPRIVATE_BLOB = "RSAPRIVATEBLOB"
|
||||
BCRYPT_RSAPUBLIC_BLOB = "RSAPUBLICBLOB"
|
||||
|
||||
// BCRYPT Algorithm Names
|
||||
BCRYPT_3DES_ALGORITHM = "3DES"
|
||||
BCRYPT_AES_ALGORITHM = "AES"
|
||||
BCRYPT_DES_ALGORITHM = "DES"
|
||||
BCRYPT_DSA_ALGORITHM = "DSA"
|
||||
BCRYPT_ECDH_P256_ALGORITHM = "ECDH_P256"
|
||||
BCRYPT_ECDH_P384_ALGORITHM = "ECDH_P384"
|
||||
BCRYPT_ECDSA_P256_ALGORITHM = "ECDSA_P256"
|
||||
BCRYPT_ECDSA_P384_ALGORITHM = "ECDSA_P384"
|
||||
BCRYPT_ECDSA_P521_ALGORITHM = "ECDSA_P521"
|
||||
BCRYPT_MD2_ALGORITHM = "MD2"
|
||||
BCRYPT_MD4_ALGORITHM = "MD4"
|
||||
BCRYPT_MD5_ALGORITHM = "MD5"
|
||||
BCRYPT_RC2_ALGORITHM = "RC2"
|
||||
BCRYPT_RC4_ALGORITHM = "RC4"
|
||||
BCRYPT_RNG_ALGORITHM = "RNG"
|
||||
BCRYPT_RSA_ALGORITHM = "RSA"
|
||||
BCRYPT_RSA_SIGN_ALGORITHM = "RSA_SIGN"
|
||||
BCRYPT_SHA1_ALGORITHM = "SHA1"
|
||||
BCRYPT_SHA256_ALGORITHM = "SHA256"
|
||||
BCRYPT_SHA384_ALGORITHM = "SHA384"
|
||||
BCRYPT_SHA512_ALGORITHM = "SHA512"
|
||||
BCRYPT_SP800108_CTR_HMAC_ALGORITHM = "SP800_108_CTR_HMAC"
|
||||
BCRYPT_SP80056A_CONCAT_ALGORITHM = "SP800_56A_CONCAT"
|
||||
BCRYPT_PBKDF2_ALGORITHM = "PBKDF2"
|
||||
BCRYPT_ECDSA_ALGORITHM = "ECDSA"
|
||||
BCRYPT_ECDH_ALGORITHM = "ECDH"
|
||||
BCRYPT_XTS_AES_ALGORITHM = "XTS-AES"
|
||||
)
|
||||
67
vendor/github.com/tailscale/certstore/syscall_windows.go
generated
vendored
Normal file
67
vendor/github.com/tailscale/certstore/syscall_windows.go
generated
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
package certstore
|
||||
|
||||
import (
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go
|
||||
|
||||
type certStoreHandle windows.Handle
|
||||
|
||||
// CryptoAPI APIs from wincrypt.h
|
||||
|
||||
type cryptProviderHandle windows.Handle // HCRYPTPROV
|
||||
type cryptHashHandle windows.Handle // HCRYPTHASH
|
||||
type cryptKeyHandle windows.Handle // HCRYPTKEY
|
||||
type cryptAlgorithm uint32 // ALG_ID
|
||||
|
||||
const (
|
||||
// Values for ALG_ID
|
||||
CALG_SHA_256 cryptAlgorithm = 0x0000800c
|
||||
CALG_SHA_384 cryptAlgorithm = 0x0000800d
|
||||
CALG_SHA_512 cryptAlgorithm = 0x0000800e
|
||||
|
||||
// Values of the dwParam for CryptGetHashParam and/or CryptSetHashParam
|
||||
HP_ALGID = 0x1
|
||||
HP_HASHVAL = 0x2
|
||||
HP_HASHSIZE = 0x4
|
||||
HP_HMAC_INFO = 0x5
|
||||
|
||||
// Values of the dwParam for CryptGetProvParam and/or CryptSetProvParam
|
||||
PP_NAME = 4
|
||||
PP_CONTAINER = 6
|
||||
PP_PROVTYPE = 16
|
||||
)
|
||||
|
||||
// https://learn.microsoft.com/en-ca/windows/win32/api/wincrypt/nf-wincrypt-cryptcreatehash
|
||||
//sys cryptCreateHash(hProv cryptProviderHandle, Algid cryptAlgorithm, hKey cryptKeyHandle, dwFlags uint32, phHash *cryptHashHandle) (err error) = advapi32.CryptCreateHash
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-cryptdestroyhash
|
||||
//sys cryptDestroyHash(hHash cryptHashHandle) (err error) = advapi32.CryptDestroyHash
|
||||
// https://learn.microsoft.com/en-ca/windows/win32/api/wincrypt/nf-wincrypt-cryptgethashparam
|
||||
//sys cryptGetHashParam(hHash cryptHashHandle, dwParam uint32, pbData unsafe.Pointer, pdwDataLen *uint32, dwFlags uint32) (err error) = advapi32.CryptGetHashParam
|
||||
// https://learn.microsoft.com/en-ca/windows/win32/api/wincrypt/nf-wincrypt-cryptsethashparam
|
||||
//sys cryptSetHashParam(hHash cryptHashHandle, dwParam uint32, pbData unsafe.Pointer, dwFlags uint32) (err error) = advapi32.CryptSetHashParam
|
||||
// https://learn.microsoft.com/en-ca/windows/win32/api/wincrypt/nf-wincrypt-cryptsignhashw
|
||||
//sys cryptSignHash(hHash cryptHashHandle, dwKeySpec uint32, szDescription *uint16, dwFlags uint32, pbSignature *byte, pdwSigLen *uint32) (err error) = advapi32.CryptSignHashW
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-cryptgetprovparam
|
||||
//sys cryptGetProvParam(hProv cryptProviderHandle, dwParam uint32, pbData *byte, pdwDataLen *uint32, dwFlags uint32) (err error) = advapi32.CryptGetProvParam
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-cryptsetprovparam
|
||||
//sys cryptSetProvParam(hProv cryptProviderHandle, dwParam uint32, pbData *byte, dwFlags uint32) (err error) = advapi32.CryptSetProvParam
|
||||
|
||||
// CNG APIs from bcrypt.h and ncrypt.h.
|
||||
const (
|
||||
// Flags for NCryptSignHash and others
|
||||
BCRYPT_PAD_NONE = 0x00000001
|
||||
BCRYPT_PAD_PKCS1 = 0x00000002
|
||||
BCRYPT_PAD_OAEP = 0x00000004
|
||||
BCRYPT_PAD_PSS = 0x00000008
|
||||
)
|
||||
|
||||
type nCryptKeyHandle windows.Handle
|
||||
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/ncrypt/nf-ncrypt-ncryptsignhash
|
||||
//sys nCryptSignHash(hKey nCryptKeyHandle, pPaddingInfo unsafe.Pointer, pbHashValue *byte, cbHashValue uint32, pbSignature *byte, cbSignature uint32, pcbResult *uint32, dwFlags uint32) (ret error) = ncrypt.NCryptSignHash
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/ncrypt/nf-ncrypt-ncryptdeletekey
|
||||
//sys nCryptDeleteKey(hKey nCryptKeyHandle, dwFlags uint32) (ret error) = ncrypt.NCryptDeleteKey
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/ncrypt/nf-ncrypt-ncryptfreeobject
|
||||
//sys nCryptFreeObject(hObject windows.Handle) (ret error) = ncrypt.NCryptFreeObject
|
||||
134
vendor/github.com/tailscale/certstore/zsyscall_windows.go
generated
vendored
Normal file
134
vendor/github.com/tailscale/certstore/zsyscall_windows.go
generated
vendored
Normal file
@@ -0,0 +1,134 @@
|
||||
// Code generated by 'go generate'; DO NOT EDIT.
|
||||
|
||||
package certstore
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
errERROR_EINVAL error = syscall.EINVAL
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return errERROR_EINVAL
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
||||
modncrypt = windows.NewLazySystemDLL("ncrypt.dll")
|
||||
|
||||
procCryptCreateHash = modadvapi32.NewProc("CryptCreateHash")
|
||||
procCryptDestroyHash = modadvapi32.NewProc("CryptDestroyHash")
|
||||
procCryptGetHashParam = modadvapi32.NewProc("CryptGetHashParam")
|
||||
procCryptGetProvParam = modadvapi32.NewProc("CryptGetProvParam")
|
||||
procCryptSetHashParam = modadvapi32.NewProc("CryptSetHashParam")
|
||||
procCryptSetProvParam = modadvapi32.NewProc("CryptSetProvParam")
|
||||
procCryptSignHashW = modadvapi32.NewProc("CryptSignHashW")
|
||||
procNCryptDeleteKey = modncrypt.NewProc("NCryptDeleteKey")
|
||||
procNCryptFreeObject = modncrypt.NewProc("NCryptFreeObject")
|
||||
procNCryptSignHash = modncrypt.NewProc("NCryptSignHash")
|
||||
)
|
||||
|
||||
func cryptCreateHash(hProv cryptProviderHandle, Algid cryptAlgorithm, hKey cryptKeyHandle, dwFlags uint32, phHash *cryptHashHandle) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procCryptCreateHash.Addr(), 5, uintptr(hProv), uintptr(Algid), uintptr(hKey), uintptr(dwFlags), uintptr(unsafe.Pointer(phHash)), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func cryptDestroyHash(hHash cryptHashHandle) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procCryptDestroyHash.Addr(), 1, uintptr(hHash), 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func cryptGetHashParam(hHash cryptHashHandle, dwParam uint32, pbData unsafe.Pointer, pdwDataLen *uint32, dwFlags uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procCryptGetHashParam.Addr(), 5, uintptr(hHash), uintptr(dwParam), uintptr(pbData), uintptr(unsafe.Pointer(pdwDataLen)), uintptr(dwFlags), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func cryptGetProvParam(hProv cryptProviderHandle, dwParam uint32, pbData *byte, pdwDataLen *uint32, dwFlags uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procCryptGetProvParam.Addr(), 5, uintptr(hProv), uintptr(dwParam), uintptr(unsafe.Pointer(pbData)), uintptr(unsafe.Pointer(pdwDataLen)), uintptr(dwFlags), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func cryptSetHashParam(hHash cryptHashHandle, dwParam uint32, pbData unsafe.Pointer, dwFlags uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procCryptSetHashParam.Addr(), 4, uintptr(hHash), uintptr(dwParam), uintptr(pbData), uintptr(dwFlags), 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func cryptSetProvParam(hProv cryptProviderHandle, dwParam uint32, pbData *byte, dwFlags uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procCryptSetProvParam.Addr(), 4, uintptr(hProv), uintptr(dwParam), uintptr(unsafe.Pointer(pbData)), uintptr(dwFlags), 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func cryptSignHash(hHash cryptHashHandle, dwKeySpec uint32, szDescription *uint16, dwFlags uint32, pbSignature *byte, pdwSigLen *uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procCryptSignHashW.Addr(), 6, uintptr(hHash), uintptr(dwKeySpec), uintptr(unsafe.Pointer(szDescription)), uintptr(dwFlags), uintptr(unsafe.Pointer(pbSignature)), uintptr(unsafe.Pointer(pdwSigLen)))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func nCryptDeleteKey(hKey nCryptKeyHandle, dwFlags uint32) (ret error) {
|
||||
r0, _, _ := syscall.Syscall(procNCryptDeleteKey.Addr(), 2, uintptr(hKey), uintptr(dwFlags), 0)
|
||||
if r0 != 0 {
|
||||
ret = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func nCryptFreeObject(hObject windows.Handle) (ret error) {
|
||||
r0, _, _ := syscall.Syscall(procNCryptFreeObject.Addr(), 1, uintptr(hObject), 0, 0)
|
||||
if r0 != 0 {
|
||||
ret = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func nCryptSignHash(hKey nCryptKeyHandle, pPaddingInfo unsafe.Pointer, pbHashValue *byte, cbHashValue uint32, pbSignature *byte, cbSignature uint32, pcbResult *uint32, dwFlags uint32) (ret error) {
|
||||
r0, _, _ := syscall.Syscall9(procNCryptSignHash.Addr(), 8, uintptr(hKey), uintptr(pPaddingInfo), uintptr(unsafe.Pointer(pbHashValue)), uintptr(cbHashValue), uintptr(unsafe.Pointer(pbSignature)), uintptr(cbSignature), uintptr(unsafe.Pointer(pcbResult)), uintptr(dwFlags), 0)
|
||||
if r0 != 0 {
|
||||
ret = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
1
vendor/github.com/tailscale/go-winio/.gitattributes
generated
vendored
Normal file
1
vendor/github.com/tailscale/go-winio/.gitattributes
generated
vendored
Normal file
@@ -0,0 +1 @@
|
||||
* text=auto eol=lf
|
||||
10
vendor/github.com/tailscale/go-winio/.gitignore
generated
vendored
Normal file
10
vendor/github.com/tailscale/go-winio/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
.vscode/
|
||||
|
||||
*.exe
|
||||
|
||||
# testing
|
||||
testdata
|
||||
|
||||
# go workspaces
|
||||
go.work
|
||||
go.work.sum
|
||||
150
vendor/github.com/tailscale/go-winio/.golangci.yml
generated
vendored
Normal file
150
vendor/github.com/tailscale/go-winio/.golangci.yml
generated
vendored
Normal file
@@ -0,0 +1,150 @@
|
||||
run:
|
||||
skip-dirs:
|
||||
- pkg/etw/sample
|
||||
|
||||
linters:
|
||||
enable:
|
||||
# style
|
||||
- containedctx # struct contains a context
|
||||
- dupl # duplicate code
|
||||
- errname # erorrs are named correctly
|
||||
- nolintlint # "//nolint" directives are properly explained
|
||||
- revive # golint replacement
|
||||
- unconvert # unnecessary conversions
|
||||
- wastedassign
|
||||
|
||||
# bugs, performance, unused, etc ...
|
||||
- contextcheck # function uses a non-inherited context
|
||||
- errorlint # errors not wrapped for 1.13
|
||||
- exhaustive # check exhaustiveness of enum switch statements
|
||||
- gofmt # files are gofmt'ed
|
||||
- gosec # security
|
||||
- nilerr # returns nil even with non-nil error
|
||||
- thelper # test helpers without t.Helper()
|
||||
- unparam # unused function params
|
||||
|
||||
issues:
|
||||
exclude-rules:
|
||||
# err is very often shadowed in nested scopes
|
||||
- linters:
|
||||
- govet
|
||||
text: '^shadow: declaration of "err" shadows declaration'
|
||||
|
||||
# ignore long lines for skip autogen directives
|
||||
- linters:
|
||||
- revive
|
||||
text: "^line-length-limit: "
|
||||
source: "^//(go:generate|sys) "
|
||||
|
||||
#TODO: remove after upgrading to go1.18
|
||||
# ignore comment spacing for nolint and sys directives
|
||||
- linters:
|
||||
- revive
|
||||
text: "^comment-spacings: no space between comment delimiter and comment text"
|
||||
source: "//(cspell:|nolint:|sys |todo)"
|
||||
|
||||
# not on go 1.18 yet, so no any
|
||||
- linters:
|
||||
- revive
|
||||
text: "^use-any: since GO 1.18 'interface{}' can be replaced by 'any'"
|
||||
|
||||
# allow unjustified ignores of error checks in defer statements
|
||||
- linters:
|
||||
- nolintlint
|
||||
text: "^directive `//nolint:errcheck` should provide explanation"
|
||||
source: '^\s*defer '
|
||||
|
||||
# allow unjustified ignores of error lints for io.EOF
|
||||
- linters:
|
||||
- nolintlint
|
||||
text: "^directive `//nolint:errorlint` should provide explanation"
|
||||
source: '[=|!]= io.EOF'
|
||||
|
||||
|
||||
linters-settings:
|
||||
exhaustive:
|
||||
default-signifies-exhaustive: true
|
||||
govet:
|
||||
enable-all: true
|
||||
disable:
|
||||
# struct order is often for Win32 compat
|
||||
# also, ignore pointer bytes/GC issues for now until performance becomes an issue
|
||||
- fieldalignment
|
||||
check-shadowing: true
|
||||
nolintlint:
|
||||
allow-leading-space: false
|
||||
require-explanation: true
|
||||
require-specific: true
|
||||
revive:
|
||||
# revive is more configurable than static check, so likely the preferred alternative to static-check
|
||||
# (once the perf issue is solved: https://github.com/golangci/golangci-lint/issues/2997)
|
||||
enable-all-rules:
|
||||
true
|
||||
# https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md
|
||||
rules:
|
||||
# rules with required arguments
|
||||
- name: argument-limit
|
||||
disabled: true
|
||||
- name: banned-characters
|
||||
disabled: true
|
||||
- name: cognitive-complexity
|
||||
disabled: true
|
||||
- name: cyclomatic
|
||||
disabled: true
|
||||
- name: file-header
|
||||
disabled: true
|
||||
- name: function-length
|
||||
disabled: true
|
||||
- name: function-result-limit
|
||||
disabled: true
|
||||
- name: max-public-structs
|
||||
disabled: true
|
||||
# geneally annoying rules
|
||||
- name: add-constant # complains about any and all strings and integers
|
||||
disabled: true
|
||||
- name: confusing-naming # we frequently use "Foo()" and "foo()" together
|
||||
disabled: true
|
||||
- name: flag-parameter # excessive, and a common idiom we use
|
||||
disabled: true
|
||||
- name: unhandled-error # warns over common fmt.Print* and io.Close; rely on errcheck instead
|
||||
disabled: true
|
||||
# general config
|
||||
- name: line-length-limit
|
||||
arguments:
|
||||
- 140
|
||||
- name: var-naming
|
||||
arguments:
|
||||
- []
|
||||
- - CID
|
||||
- CRI
|
||||
- CTRD
|
||||
- DACL
|
||||
- DLL
|
||||
- DOS
|
||||
- ETW
|
||||
- FSCTL
|
||||
- GCS
|
||||
- GMSA
|
||||
- HCS
|
||||
- HV
|
||||
- IO
|
||||
- LCOW
|
||||
- LDAP
|
||||
- LPAC
|
||||
- LTSC
|
||||
- MMIO
|
||||
- NT
|
||||
- OCI
|
||||
- PMEM
|
||||
- PWSH
|
||||
- RX
|
||||
- SACl
|
||||
- SID
|
||||
- SMB
|
||||
- TX
|
||||
- VHD
|
||||
- VHDX
|
||||
- VMID
|
||||
- VPCI
|
||||
- WCOW
|
||||
- WIM
|
||||
1
vendor/github.com/tailscale/go-winio/CODEOWNERS
generated
vendored
Normal file
1
vendor/github.com/tailscale/go-winio/CODEOWNERS
generated
vendored
Normal file
@@ -0,0 +1 @@
|
||||
* @microsoft/containerplat
|
||||
22
vendor/github.com/tailscale/go-winio/LICENSE
generated
vendored
Normal file
22
vendor/github.com/tailscale/go-winio/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015 Microsoft
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
95
vendor/github.com/tailscale/go-winio/README.md
generated
vendored
Normal file
95
vendor/github.com/tailscale/go-winio/README.md
generated
vendored
Normal file
@@ -0,0 +1,95 @@
|
||||
# go-winio
|
||||
|
||||
This repository contains utilities for efficiently performing Win32 IO operations in
|
||||
Go. Currently, this is focused on accessing named pipes and other file handles, and
|
||||
for using named pipes as a net transport.
|
||||
|
||||
This code relies on IO completion ports to avoid blocking IO on system threads, allowing Go
|
||||
to reuse the thread to schedule another goroutine. This limits support to Windows Vista and
|
||||
newer operating systems. This is similar to the implementation of network sockets in Go's net
|
||||
package.
|
||||
|
||||
Please see the LICENSE file for licensing information.
|
||||
|
||||
## Fork from original module
|
||||
|
||||
This is a fork from the [original module](https://github.com/microsoft/go-winio).
|
||||
|
||||
If Tailscale's PRs merge upstream, we will ideally then be able to discontinue this fork.
|
||||
|
||||
## Contributing
|
||||
|
||||
This project welcomes contributions and suggestions.
|
||||
Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that
|
||||
you have the right to, and actually do, grant us the rights to use your contribution.
|
||||
For details, visit [Microsoft CLA](https://cla.microsoft.com).
|
||||
|
||||
When you submit a pull request, a CLA-bot will automatically determine whether you need to
|
||||
provide a CLA and decorate the PR appropriately (e.g., label, comment).
|
||||
Simply follow the instructions provided by the bot.
|
||||
You will only need to do this once across all repos using our CLA.
|
||||
|
||||
Additionally, the pull request pipeline requires the following steps to be performed before
|
||||
mergining.
|
||||
|
||||
### Code Sign-Off
|
||||
|
||||
We require that contributors sign their commits using [`git commit --signoff`][git-commit-s]
|
||||
to certify they either authored the work themselves or otherwise have permission to use it in this project.
|
||||
|
||||
A range of commits can be signed off using [`git rebase --signoff`][git-rebase-s].
|
||||
|
||||
Please see [the developer certificate](https://developercertificate.org) for more info,
|
||||
as well as to make sure that you can attest to the rules listed.
|
||||
Our CI uses the DCO Github app to ensure that all commits in a given PR are signed-off.
|
||||
|
||||
### Linting
|
||||
|
||||
Code must pass a linting stage, which uses [`golangci-lint`][lint].
|
||||
The linting settings are stored in [`.golangci.yaml`](./.golangci.yaml), and can be run
|
||||
automatically with VSCode by adding the following to your workspace or folder settings:
|
||||
|
||||
```json
|
||||
"go.lintTool": "golangci-lint",
|
||||
"go.lintOnSave": "package",
|
||||
```
|
||||
|
||||
Additional editor [integrations options are also available][lint-ide].
|
||||
|
||||
Alternatively, `golangci-lint` can be [installed locally][lint-install] and run from the repo root:
|
||||
|
||||
```shell
|
||||
# use . or specify a path to only lint a package
|
||||
# to show all lint errors, use flags "--max-issues-per-linter=0 --max-same-issues=0"
|
||||
> golangci-lint run ./...
|
||||
```
|
||||
|
||||
### Go Generate
|
||||
|
||||
The pipeline checks that auto-generated code, via `go generate`, are up to date.
|
||||
|
||||
This can be done for the entire repo:
|
||||
|
||||
```shell
|
||||
> go generate ./...
|
||||
```
|
||||
|
||||
## Code of Conduct
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
|
||||
## Special Thanks
|
||||
|
||||
Thanks to [natefinch][natefinch] for the inspiration for this library.
|
||||
See [npipe](https://github.com/natefinch/npipe) for another named pipe implementation.
|
||||
|
||||
[lint]: https://golangci-lint.run/
|
||||
[lint-ide]: https://golangci-lint.run/usage/integrations/#editor-integration
|
||||
[lint-install]: https://golangci-lint.run/usage/install/#local-installation
|
||||
|
||||
[git-commit-s]: https://git-scm.com/docs/git-commit#Documentation/git-commit.txt--s
|
||||
[git-rebase-s]: https://git-scm.com/docs/git-rebase#Documentation/git-rebase.txt---signoff
|
||||
|
||||
[natefinch]: https://github.com/natefinch
|
||||
41
vendor/github.com/tailscale/go-winio/SECURITY.md
generated
vendored
Normal file
41
vendor/github.com/tailscale/go-winio/SECURITY.md
generated
vendored
Normal file
@@ -0,0 +1,41 @@
|
||||
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.7 BLOCK -->
|
||||
|
||||
## Security
|
||||
|
||||
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
|
||||
|
||||
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
|
||||
|
||||
## Reporting Security Issues
|
||||
|
||||
**Please do not report security vulnerabilities through public GitHub issues.**
|
||||
|
||||
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
|
||||
|
||||
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
|
||||
|
||||
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
|
||||
|
||||
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
||||
|
||||
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
||||
* Full paths of source file(s) related to the manifestation of the issue
|
||||
* The location of the affected source code (tag/branch/commit or direct URL)
|
||||
* Any special configuration required to reproduce the issue
|
||||
* Step-by-step instructions to reproduce the issue
|
||||
* Proof-of-concept or exploit code (if possible)
|
||||
* Impact of the issue, including how an attacker might exploit the issue
|
||||
|
||||
This information will help us triage your report more quickly.
|
||||
|
||||
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
|
||||
|
||||
## Preferred Languages
|
||||
|
||||
We prefer all communications to be in English.
|
||||
|
||||
## Policy
|
||||
|
||||
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
|
||||
|
||||
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
||||
287
vendor/github.com/tailscale/go-winio/backup.go
generated
vendored
Normal file
287
vendor/github.com/tailscale/go-winio/backup.go
generated
vendored
Normal file
@@ -0,0 +1,287 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"unicode/utf16"
|
||||
|
||||
"github.com/tailscale/go-winio/internal/fs"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
//sys backupRead(h windows.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupRead
|
||||
//sys backupWrite(h windows.Handle, b []byte, bytesWritten *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupWrite
|
||||
|
||||
const (
|
||||
BackupData = uint32(iota + 1)
|
||||
BackupEaData
|
||||
BackupSecurity
|
||||
BackupAlternateData
|
||||
BackupLink
|
||||
BackupPropertyData
|
||||
BackupObjectId //revive:disable-line:var-naming ID, not Id
|
||||
BackupReparseData
|
||||
BackupSparseBlock
|
||||
BackupTxfsData
|
||||
)
|
||||
|
||||
const (
|
||||
StreamSparseAttributes = uint32(8)
|
||||
)
|
||||
|
||||
//nolint:revive // var-naming: ALL_CAPS
|
||||
const (
|
||||
WRITE_DAC = windows.WRITE_DAC
|
||||
WRITE_OWNER = windows.WRITE_OWNER
|
||||
ACCESS_SYSTEM_SECURITY = windows.ACCESS_SYSTEM_SECURITY
|
||||
)
|
||||
|
||||
// BackupHeader represents a backup stream of a file.
|
||||
type BackupHeader struct {
|
||||
//revive:disable-next-line:var-naming ID, not Id
|
||||
Id uint32 // The backup stream ID
|
||||
Attributes uint32 // Stream attributes
|
||||
Size int64 // The size of the stream in bytes
|
||||
Name string // The name of the stream (for BackupAlternateData only).
|
||||
Offset int64 // The offset of the stream in the file (for BackupSparseBlock only).
|
||||
}
|
||||
|
||||
type win32StreamID struct {
|
||||
StreamID uint32
|
||||
Attributes uint32
|
||||
Size uint64
|
||||
NameSize uint32
|
||||
}
|
||||
|
||||
// BackupStreamReader reads from a stream produced by the BackupRead Win32 API and produces a series
|
||||
// of BackupHeader values.
|
||||
type BackupStreamReader struct {
|
||||
r io.Reader
|
||||
bytesLeft int64
|
||||
}
|
||||
|
||||
// NewBackupStreamReader produces a BackupStreamReader from any io.Reader.
|
||||
func NewBackupStreamReader(r io.Reader) *BackupStreamReader {
|
||||
return &BackupStreamReader{r, 0}
|
||||
}
|
||||
|
||||
// Next returns the next backup stream and prepares for calls to Read(). It skips the remainder of the current stream if
|
||||
// it was not completely read.
|
||||
func (r *BackupStreamReader) Next() (*BackupHeader, error) {
|
||||
if r.bytesLeft > 0 { //nolint:nestif // todo: flatten this
|
||||
if s, ok := r.r.(io.Seeker); ok {
|
||||
// Make sure Seek on io.SeekCurrent sometimes succeeds
|
||||
// before trying the actual seek.
|
||||
if _, err := s.Seek(0, io.SeekCurrent); err == nil {
|
||||
if _, err = s.Seek(r.bytesLeft, io.SeekCurrent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.bytesLeft = 0
|
||||
}
|
||||
}
|
||||
if _, err := io.Copy(io.Discard, r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
var wsi win32StreamID
|
||||
if err := binary.Read(r.r, binary.LittleEndian, &wsi); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hdr := &BackupHeader{
|
||||
Id: wsi.StreamID,
|
||||
Attributes: wsi.Attributes,
|
||||
Size: int64(wsi.Size),
|
||||
}
|
||||
if wsi.NameSize != 0 {
|
||||
name := make([]uint16, int(wsi.NameSize/2))
|
||||
if err := binary.Read(r.r, binary.LittleEndian, name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hdr.Name = windows.UTF16ToString(name)
|
||||
}
|
||||
if wsi.StreamID == BackupSparseBlock {
|
||||
if err := binary.Read(r.r, binary.LittleEndian, &hdr.Offset); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hdr.Size -= 8
|
||||
}
|
||||
r.bytesLeft = hdr.Size
|
||||
return hdr, nil
|
||||
}
|
||||
|
||||
// Read reads from the current backup stream.
|
||||
func (r *BackupStreamReader) Read(b []byte) (int, error) {
|
||||
if r.bytesLeft == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if int64(len(b)) > r.bytesLeft {
|
||||
b = b[:r.bytesLeft]
|
||||
}
|
||||
n, err := r.r.Read(b)
|
||||
r.bytesLeft -= int64(n)
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
} else if r.bytesLeft == 0 && err == nil {
|
||||
err = io.EOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// BackupStreamWriter writes a stream compatible with the BackupWrite Win32 API.
|
||||
type BackupStreamWriter struct {
|
||||
w io.Writer
|
||||
bytesLeft int64
|
||||
}
|
||||
|
||||
// NewBackupStreamWriter produces a BackupStreamWriter on top of an io.Writer.
|
||||
func NewBackupStreamWriter(w io.Writer) *BackupStreamWriter {
|
||||
return &BackupStreamWriter{w, 0}
|
||||
}
|
||||
|
||||
// WriteHeader writes the next backup stream header and prepares for calls to Write().
|
||||
func (w *BackupStreamWriter) WriteHeader(hdr *BackupHeader) error {
|
||||
if w.bytesLeft != 0 {
|
||||
return fmt.Errorf("missing %d bytes", w.bytesLeft)
|
||||
}
|
||||
name := utf16.Encode([]rune(hdr.Name))
|
||||
wsi := win32StreamID{
|
||||
StreamID: hdr.Id,
|
||||
Attributes: hdr.Attributes,
|
||||
Size: uint64(hdr.Size),
|
||||
NameSize: uint32(len(name) * 2),
|
||||
}
|
||||
if hdr.Id == BackupSparseBlock {
|
||||
// Include space for the int64 block offset
|
||||
wsi.Size += 8
|
||||
}
|
||||
if err := binary.Write(w.w, binary.LittleEndian, &wsi); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(name) != 0 {
|
||||
if err := binary.Write(w.w, binary.LittleEndian, name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if hdr.Id == BackupSparseBlock {
|
||||
if err := binary.Write(w.w, binary.LittleEndian, hdr.Offset); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
w.bytesLeft = hdr.Size
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write writes to the current backup stream.
|
||||
func (w *BackupStreamWriter) Write(b []byte) (int, error) {
|
||||
if w.bytesLeft < int64(len(b)) {
|
||||
return 0, fmt.Errorf("too many bytes by %d", int64(len(b))-w.bytesLeft)
|
||||
}
|
||||
n, err := w.w.Write(b)
|
||||
w.bytesLeft -= int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// BackupFileReader provides an io.ReadCloser interface on top of the BackupRead Win32 API.
|
||||
type BackupFileReader struct {
|
||||
f *os.File
|
||||
includeSecurity bool
|
||||
ctx uintptr
|
||||
}
|
||||
|
||||
// NewBackupFileReader returns a new BackupFileReader from a file handle. If includeSecurity is true,
|
||||
// Read will attempt to read the security descriptor of the file.
|
||||
func NewBackupFileReader(f *os.File, includeSecurity bool) *BackupFileReader {
|
||||
r := &BackupFileReader{f, includeSecurity, 0}
|
||||
return r
|
||||
}
|
||||
|
||||
// Read reads a backup stream from the file by calling the Win32 API BackupRead().
|
||||
func (r *BackupFileReader) Read(b []byte) (int, error) {
|
||||
var bytesRead uint32
|
||||
err := backupRead(windows.Handle(r.f.Fd()), b, &bytesRead, false, r.includeSecurity, &r.ctx)
|
||||
if err != nil {
|
||||
return 0, &os.PathError{Op: "BackupRead", Path: r.f.Name(), Err: err}
|
||||
}
|
||||
runtime.KeepAlive(r.f)
|
||||
if bytesRead == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
return int(bytesRead), nil
|
||||
}
|
||||
|
||||
// Close frees Win32 resources associated with the BackupFileReader. It does not close
|
||||
// the underlying file.
|
||||
func (r *BackupFileReader) Close() error {
|
||||
if r.ctx != 0 {
|
||||
_ = backupRead(windows.Handle(r.f.Fd()), nil, nil, true, false, &r.ctx)
|
||||
runtime.KeepAlive(r.f)
|
||||
r.ctx = 0
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BackupFileWriter provides an io.WriteCloser interface on top of the BackupWrite Win32 API.
|
||||
type BackupFileWriter struct {
|
||||
f *os.File
|
||||
includeSecurity bool
|
||||
ctx uintptr
|
||||
}
|
||||
|
||||
// NewBackupFileWriter returns a new BackupFileWriter from a file handle. If includeSecurity is true,
|
||||
// Write() will attempt to restore the security descriptor from the stream.
|
||||
func NewBackupFileWriter(f *os.File, includeSecurity bool) *BackupFileWriter {
|
||||
w := &BackupFileWriter{f, includeSecurity, 0}
|
||||
return w
|
||||
}
|
||||
|
||||
// Write restores a portion of the file using the provided backup stream.
|
||||
func (w *BackupFileWriter) Write(b []byte) (int, error) {
|
||||
var bytesWritten uint32
|
||||
err := backupWrite(windows.Handle(w.f.Fd()), b, &bytesWritten, false, w.includeSecurity, &w.ctx)
|
||||
if err != nil {
|
||||
return 0, &os.PathError{Op: "BackupWrite", Path: w.f.Name(), Err: err}
|
||||
}
|
||||
runtime.KeepAlive(w.f)
|
||||
if int(bytesWritten) != len(b) {
|
||||
return int(bytesWritten), errors.New("not all bytes could be written")
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
// Close frees Win32 resources associated with the BackupFileWriter. It does not
|
||||
// close the underlying file.
|
||||
func (w *BackupFileWriter) Close() error {
|
||||
if w.ctx != 0 {
|
||||
_ = backupWrite(windows.Handle(w.f.Fd()), nil, nil, true, false, &w.ctx)
|
||||
runtime.KeepAlive(w.f)
|
||||
w.ctx = 0
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// OpenForBackup opens a file or directory, potentially skipping access checks if the backup
|
||||
// or restore privileges have been acquired.
|
||||
//
|
||||
// If the file opened was a directory, it cannot be used with Readdir().
|
||||
func OpenForBackup(path string, access uint32, share uint32, createmode uint32) (*os.File, error) {
|
||||
h, err := fs.CreateFile(path,
|
||||
fs.AccessMask(access),
|
||||
fs.FileShareMode(share),
|
||||
nil,
|
||||
fs.FileCreationDisposition(createmode),
|
||||
fs.FILE_FLAG_BACKUP_SEMANTICS|fs.FILE_FLAG_OPEN_REPARSE_POINT,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
err = &os.PathError{Op: "open", Path: path, Err: err}
|
||||
return nil, err
|
||||
}
|
||||
return os.NewFile(uintptr(h), path), nil
|
||||
}
|
||||
22
vendor/github.com/tailscale/go-winio/doc.go
generated
vendored
Normal file
22
vendor/github.com/tailscale/go-winio/doc.go
generated
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
// This package provides utilities for efficiently performing Win32 IO operations in Go.
|
||||
// Currently, this package is provides support for genreal IO and management of
|
||||
// - named pipes
|
||||
// - files
|
||||
// - [Hyper-V sockets]
|
||||
//
|
||||
// This code is similar to Go's [net] package, and uses IO completion ports to avoid
|
||||
// blocking IO on system threads, allowing Go to reuse the thread to schedule other goroutines.
|
||||
//
|
||||
// This limits support to Windows Vista and newer operating systems.
|
||||
//
|
||||
// Additionally, this package provides support for:
|
||||
// - creating and managing GUIDs
|
||||
// - writing to [ETW]
|
||||
// - opening and manageing VHDs
|
||||
// - parsing [Windows Image files]
|
||||
// - auto-generating Win32 API code
|
||||
//
|
||||
// [Hyper-V sockets]: https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service
|
||||
// [ETW]: https://docs.microsoft.com/en-us/windows-hardware/drivers/devtest/event-tracing-for-windows--etw-
|
||||
// [Windows Image files]: https://docs.microsoft.com/en-us/windows-hardware/manufacture/desktop/work-with-windows-images
|
||||
package winio
|
||||
137
vendor/github.com/tailscale/go-winio/ea.go
generated
vendored
Normal file
137
vendor/github.com/tailscale/go-winio/ea.go
generated
vendored
Normal file
@@ -0,0 +1,137 @@
|
||||
package winio
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type fileFullEaInformation struct {
|
||||
NextEntryOffset uint32
|
||||
Flags uint8
|
||||
NameLength uint8
|
||||
ValueLength uint16
|
||||
}
|
||||
|
||||
var (
|
||||
fileFullEaInformationSize = binary.Size(&fileFullEaInformation{})
|
||||
|
||||
errInvalidEaBuffer = errors.New("invalid extended attribute buffer")
|
||||
errEaNameTooLarge = errors.New("extended attribute name too large")
|
||||
errEaValueTooLarge = errors.New("extended attribute value too large")
|
||||
)
|
||||
|
||||
// ExtendedAttribute represents a single Windows EA.
|
||||
type ExtendedAttribute struct {
|
||||
Name string
|
||||
Value []byte
|
||||
Flags uint8
|
||||
}
|
||||
|
||||
func parseEa(b []byte) (ea ExtendedAttribute, nb []byte, err error) {
|
||||
var info fileFullEaInformation
|
||||
err = binary.Read(bytes.NewReader(b), binary.LittleEndian, &info)
|
||||
if err != nil {
|
||||
err = errInvalidEaBuffer
|
||||
return ea, nb, err
|
||||
}
|
||||
|
||||
nameOffset := fileFullEaInformationSize
|
||||
nameLen := int(info.NameLength)
|
||||
valueOffset := nameOffset + int(info.NameLength) + 1
|
||||
valueLen := int(info.ValueLength)
|
||||
nextOffset := int(info.NextEntryOffset)
|
||||
if valueLen+valueOffset > len(b) || nextOffset < 0 || nextOffset > len(b) {
|
||||
err = errInvalidEaBuffer
|
||||
return ea, nb, err
|
||||
}
|
||||
|
||||
ea.Name = string(b[nameOffset : nameOffset+nameLen])
|
||||
ea.Value = b[valueOffset : valueOffset+valueLen]
|
||||
ea.Flags = info.Flags
|
||||
if info.NextEntryOffset != 0 {
|
||||
nb = b[info.NextEntryOffset:]
|
||||
}
|
||||
return ea, nb, err
|
||||
}
|
||||
|
||||
// DecodeExtendedAttributes decodes a list of EAs from a FILE_FULL_EA_INFORMATION
|
||||
// buffer retrieved from BackupRead, ZwQueryEaFile, etc.
|
||||
func DecodeExtendedAttributes(b []byte) (eas []ExtendedAttribute, err error) {
|
||||
for len(b) != 0 {
|
||||
ea, nb, err := parseEa(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
eas = append(eas, ea)
|
||||
b = nb
|
||||
}
|
||||
return eas, err
|
||||
}
|
||||
|
||||
func writeEa(buf *bytes.Buffer, ea *ExtendedAttribute, last bool) error {
|
||||
if int(uint8(len(ea.Name))) != len(ea.Name) {
|
||||
return errEaNameTooLarge
|
||||
}
|
||||
if int(uint16(len(ea.Value))) != len(ea.Value) {
|
||||
return errEaValueTooLarge
|
||||
}
|
||||
entrySize := uint32(fileFullEaInformationSize + len(ea.Name) + 1 + len(ea.Value))
|
||||
withPadding := (entrySize + 3) &^ 3
|
||||
nextOffset := uint32(0)
|
||||
if !last {
|
||||
nextOffset = withPadding
|
||||
}
|
||||
info := fileFullEaInformation{
|
||||
NextEntryOffset: nextOffset,
|
||||
Flags: ea.Flags,
|
||||
NameLength: uint8(len(ea.Name)),
|
||||
ValueLength: uint16(len(ea.Value)),
|
||||
}
|
||||
|
||||
err := binary.Write(buf, binary.LittleEndian, &info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = buf.Write([]byte(ea.Name))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = buf.WriteByte(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = buf.Write(ea.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = buf.Write([]byte{0, 0, 0}[0 : withPadding-entrySize])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncodeExtendedAttributes encodes a list of EAs into a FILE_FULL_EA_INFORMATION
|
||||
// buffer for use with BackupWrite, ZwSetEaFile, etc.
|
||||
func EncodeExtendedAttributes(eas []ExtendedAttribute) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
for i := range eas {
|
||||
last := false
|
||||
if i == len(eas)-1 {
|
||||
last = true
|
||||
}
|
||||
|
||||
err := writeEa(&buf, &eas[i], last)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
338
vendor/github.com/tailscale/go-winio/file.go
generated
vendored
Normal file
338
vendor/github.com/tailscale/go-winio/file.go
generated
vendored
Normal file
@@ -0,0 +1,338 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
//sys cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) = CancelIoEx
|
||||
//sys createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) = CreateIoCompletionPort
|
||||
//sys getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
|
||||
//sys setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
|
||||
//sys wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
|
||||
|
||||
//todo (go1.19): switch to [atomic.Bool]
|
||||
|
||||
type atomicBool int32
|
||||
|
||||
func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
|
||||
func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }
|
||||
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
|
||||
|
||||
//revive:disable-next-line:predeclared Keep "new" to maintain consistency with "atomic" pkg
|
||||
func (b *atomicBool) swap(new bool) bool {
|
||||
var newInt int32
|
||||
if new {
|
||||
newInt = 1
|
||||
}
|
||||
return atomic.SwapInt32((*int32)(b), newInt) == 1
|
||||
}
|
||||
|
||||
var (
|
||||
ErrFileClosed = errors.New("file has already been closed")
|
||||
ErrTimeout = &timeoutError{}
|
||||
)
|
||||
|
||||
type timeoutError struct{}
|
||||
|
||||
func (*timeoutError) Error() string { return "i/o timeout" }
|
||||
func (*timeoutError) Timeout() bool { return true }
|
||||
func (*timeoutError) Temporary() bool { return true }
|
||||
|
||||
type timeoutChan chan struct{}
|
||||
|
||||
var ioInitOnce sync.Once
|
||||
var ioCompletionPort windows.Handle
|
||||
|
||||
// ioResult contains the result of an asynchronous IO operation.
|
||||
type ioResult struct {
|
||||
bytes uint32
|
||||
err error
|
||||
}
|
||||
|
||||
// ioOperation represents an outstanding asynchronous Win32 IO.
|
||||
type ioOperation struct {
|
||||
o windows.Overlapped
|
||||
ch chan ioResult
|
||||
}
|
||||
|
||||
func initIO() {
|
||||
h, err := createIoCompletionPort(windows.InvalidHandle, 0, 0, 0xffffffff)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ioCompletionPort = h
|
||||
go ioCompletionProcessor(h)
|
||||
}
|
||||
|
||||
// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
|
||||
// It takes ownership of this handle and will close it if it is garbage collected.
|
||||
type win32File struct {
|
||||
handle windows.Handle
|
||||
wg sync.WaitGroup
|
||||
wgLock sync.RWMutex
|
||||
closing atomicBool
|
||||
socket bool
|
||||
readDeadline deadlineHandler
|
||||
writeDeadline deadlineHandler
|
||||
}
|
||||
|
||||
type deadlineHandler struct {
|
||||
setLock sync.Mutex
|
||||
channel timeoutChan
|
||||
channelLock sync.RWMutex
|
||||
timer *time.Timer
|
||||
timedout atomicBool
|
||||
}
|
||||
|
||||
// makeWin32File makes a new win32File from an existing file handle.
|
||||
func makeWin32File(h windows.Handle) (*win32File, error) {
|
||||
f := &win32File{handle: h}
|
||||
ioInitOnce.Do(initIO)
|
||||
_, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = setFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f.readDeadline.channel = make(timeoutChan)
|
||||
f.writeDeadline.channel = make(timeoutChan)
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// Deprecated: use NewOpenFile instead.
|
||||
func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) {
|
||||
return NewOpenFile(windows.Handle(h))
|
||||
}
|
||||
|
||||
func NewOpenFile(h windows.Handle) (io.ReadWriteCloser, error) {
|
||||
// If we return the result of makeWin32File directly, it can result in an
|
||||
// interface-wrapped nil, rather than a nil interface value.
|
||||
f, err := makeWin32File(h)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// closeHandle closes the resources associated with a Win32 handle.
|
||||
func (f *win32File) closeHandle() {
|
||||
f.wgLock.Lock()
|
||||
// Atomically set that we are closing, releasing the resources only once.
|
||||
if !f.closing.swap(true) {
|
||||
f.wgLock.Unlock()
|
||||
// cancel all IO and wait for it to complete
|
||||
_ = cancelIoEx(f.handle, nil)
|
||||
f.wg.Wait()
|
||||
// at this point, no new IO can start
|
||||
windows.Close(f.handle)
|
||||
f.handle = 0
|
||||
} else {
|
||||
f.wgLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes a win32File.
|
||||
func (f *win32File) Close() error {
|
||||
f.closeHandle()
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsClosed checks if the file has been closed.
|
||||
func (f *win32File) IsClosed() bool {
|
||||
return f.closing.isSet()
|
||||
}
|
||||
|
||||
// prepareIO prepares for a new IO operation.
|
||||
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
|
||||
func (f *win32File) prepareIO() (*ioOperation, error) {
|
||||
f.wgLock.RLock()
|
||||
if f.closing.isSet() {
|
||||
f.wgLock.RUnlock()
|
||||
return nil, ErrFileClosed
|
||||
}
|
||||
f.wg.Add(1)
|
||||
f.wgLock.RUnlock()
|
||||
c := &ioOperation{}
|
||||
c.ch = make(chan ioResult)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// ioCompletionProcessor processes completed async IOs forever.
|
||||
func ioCompletionProcessor(h windows.Handle) {
|
||||
for {
|
||||
var bytes uint32
|
||||
var key uintptr
|
||||
var op *ioOperation
|
||||
err := getQueuedCompletionStatus(h, &bytes, &key, &op, windows.INFINITE)
|
||||
if op == nil {
|
||||
panic(err)
|
||||
}
|
||||
op.ch <- ioResult{bytes, err}
|
||||
}
|
||||
}
|
||||
|
||||
// todo: helsaawy - create an asyncIO version that takes a context
|
||||
|
||||
// asyncIO processes the return value from ReadFile or WriteFile, blocking until
|
||||
// the operation has actually completed.
|
||||
func (f *win32File) asyncIO(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
|
||||
if err != windows.ERROR_IO_PENDING { //nolint:errorlint // err is Errno
|
||||
return int(bytes), err
|
||||
}
|
||||
|
||||
if f.closing.isSet() {
|
||||
_ = cancelIoEx(f.handle, &c.o)
|
||||
}
|
||||
|
||||
var timeout timeoutChan
|
||||
if d != nil {
|
||||
d.channelLock.Lock()
|
||||
timeout = d.channel
|
||||
d.channelLock.Unlock()
|
||||
}
|
||||
|
||||
var r ioResult
|
||||
select {
|
||||
case r = <-c.ch:
|
||||
err = r.err
|
||||
if err == windows.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno
|
||||
if f.closing.isSet() {
|
||||
err = ErrFileClosed
|
||||
}
|
||||
} else if err != nil && f.socket {
|
||||
// err is from Win32. Query the overlapped structure to get the winsock error.
|
||||
var bytes, flags uint32
|
||||
err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
|
||||
}
|
||||
case <-timeout:
|
||||
_ = cancelIoEx(f.handle, &c.o)
|
||||
r = <-c.ch
|
||||
err = r.err
|
||||
if err == windows.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno
|
||||
err = ErrTimeout
|
||||
}
|
||||
}
|
||||
|
||||
// runtime.KeepAlive is needed, as c is passed via native
|
||||
// code to ioCompletionProcessor, c must remain alive
|
||||
// until the channel read is complete.
|
||||
// todo: (de)allocate *ioOperation via win32 heap functions, instead of needing to KeepAlive?
|
||||
runtime.KeepAlive(c)
|
||||
return int(r.bytes), err
|
||||
}
|
||||
|
||||
// Read reads from a file handle.
|
||||
func (f *win32File) Read(b []byte) (int, error) {
|
||||
c, err := f.prepareIO()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.wg.Done()
|
||||
|
||||
if f.readDeadline.timedout.isSet() {
|
||||
return 0, ErrTimeout
|
||||
}
|
||||
|
||||
var bytes uint32
|
||||
err = windows.ReadFile(f.handle, b, &bytes, &c.o)
|
||||
n, err := f.asyncIO(c, &f.readDeadline, bytes, err)
|
||||
runtime.KeepAlive(b)
|
||||
|
||||
// Handle EOF conditions.
|
||||
if err == nil && n == 0 && len(b) != 0 {
|
||||
return 0, io.EOF
|
||||
} else if err == windows.ERROR_BROKEN_PIPE { //nolint:errorlint // err is Errno
|
||||
return 0, io.EOF
|
||||
} else {
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes to a file handle.
|
||||
func (f *win32File) Write(b []byte) (int, error) {
|
||||
c, err := f.prepareIO()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.wg.Done()
|
||||
|
||||
if f.writeDeadline.timedout.isSet() {
|
||||
return 0, ErrTimeout
|
||||
}
|
||||
|
||||
var bytes uint32
|
||||
err = windows.WriteFile(f.handle, b, &bytes, &c.o)
|
||||
n, err := f.asyncIO(c, &f.writeDeadline, bytes, err)
|
||||
runtime.KeepAlive(b)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (f *win32File) SetReadDeadline(deadline time.Time) error {
|
||||
return f.readDeadline.set(deadline)
|
||||
}
|
||||
|
||||
func (f *win32File) SetWriteDeadline(deadline time.Time) error {
|
||||
return f.writeDeadline.set(deadline)
|
||||
}
|
||||
|
||||
func (f *win32File) Flush() error {
|
||||
return windows.FlushFileBuffers(f.handle)
|
||||
}
|
||||
|
||||
func (f *win32File) Fd() uintptr {
|
||||
return uintptr(f.handle)
|
||||
}
|
||||
|
||||
func (d *deadlineHandler) set(deadline time.Time) error {
|
||||
d.setLock.Lock()
|
||||
defer d.setLock.Unlock()
|
||||
|
||||
if d.timer != nil {
|
||||
if !d.timer.Stop() {
|
||||
<-d.channel
|
||||
}
|
||||
d.timer = nil
|
||||
}
|
||||
d.timedout.setFalse()
|
||||
|
||||
select {
|
||||
case <-d.channel:
|
||||
d.channelLock.Lock()
|
||||
d.channel = make(chan struct{})
|
||||
d.channelLock.Unlock()
|
||||
default:
|
||||
}
|
||||
|
||||
if deadline.IsZero() {
|
||||
return nil
|
||||
}
|
||||
|
||||
timeoutIO := func() {
|
||||
d.timedout.setTrue()
|
||||
close(d.channel)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
duration := deadline.Sub(now)
|
||||
if deadline.After(now) {
|
||||
// Deadline is in the future, set a timer to wait
|
||||
d.timer = time.AfterFunc(duration, timeoutIO)
|
||||
} else {
|
||||
// Deadline is in the past. Cancel all pending IO now.
|
||||
timeoutIO()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
92
vendor/github.com/tailscale/go-winio/fileinfo.go
generated
vendored
Normal file
92
vendor/github.com/tailscale/go-winio/fileinfo.go
generated
vendored
Normal file
@@ -0,0 +1,92 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// FileBasicInfo contains file access time and file attributes information.
|
||||
type FileBasicInfo struct {
|
||||
CreationTime, LastAccessTime, LastWriteTime, ChangeTime windows.Filetime
|
||||
FileAttributes uint32
|
||||
_ uint32 // padding
|
||||
}
|
||||
|
||||
// GetFileBasicInfo retrieves times and attributes for a file.
|
||||
func GetFileBasicInfo(f *os.File) (*FileBasicInfo, error) {
|
||||
bi := &FileBasicInfo{}
|
||||
if err := windows.GetFileInformationByHandleEx(
|
||||
windows.Handle(f.Fd()),
|
||||
windows.FileBasicInfo,
|
||||
(*byte)(unsafe.Pointer(bi)),
|
||||
uint32(unsafe.Sizeof(*bi)),
|
||||
); err != nil {
|
||||
return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err}
|
||||
}
|
||||
runtime.KeepAlive(f)
|
||||
return bi, nil
|
||||
}
|
||||
|
||||
// SetFileBasicInfo sets times and attributes for a file.
|
||||
func SetFileBasicInfo(f *os.File, bi *FileBasicInfo) error {
|
||||
if err := windows.SetFileInformationByHandle(
|
||||
windows.Handle(f.Fd()),
|
||||
windows.FileBasicInfo,
|
||||
(*byte)(unsafe.Pointer(bi)),
|
||||
uint32(unsafe.Sizeof(*bi)),
|
||||
); err != nil {
|
||||
return &os.PathError{Op: "SetFileInformationByHandle", Path: f.Name(), Err: err}
|
||||
}
|
||||
runtime.KeepAlive(f)
|
||||
return nil
|
||||
}
|
||||
|
||||
// FileStandardInfo contains extended information for the file.
|
||||
// FILE_STANDARD_INFO in WinBase.h
|
||||
// https://docs.microsoft.com/en-us/windows/win32/api/winbase/ns-winbase-file_standard_info
|
||||
type FileStandardInfo struct {
|
||||
AllocationSize, EndOfFile int64
|
||||
NumberOfLinks uint32
|
||||
DeletePending, Directory bool
|
||||
}
|
||||
|
||||
// GetFileStandardInfo retrieves ended information for the file.
|
||||
func GetFileStandardInfo(f *os.File) (*FileStandardInfo, error) {
|
||||
si := &FileStandardInfo{}
|
||||
if err := windows.GetFileInformationByHandleEx(windows.Handle(f.Fd()),
|
||||
windows.FileStandardInfo,
|
||||
(*byte)(unsafe.Pointer(si)),
|
||||
uint32(unsafe.Sizeof(*si))); err != nil {
|
||||
return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err}
|
||||
}
|
||||
runtime.KeepAlive(f)
|
||||
return si, nil
|
||||
}
|
||||
|
||||
// FileIDInfo contains the volume serial number and file ID for a file. This pair should be
|
||||
// unique on a system.
|
||||
type FileIDInfo struct {
|
||||
VolumeSerialNumber uint64
|
||||
FileID [16]byte
|
||||
}
|
||||
|
||||
// GetFileID retrieves the unique (volume, file ID) pair for a file.
|
||||
func GetFileID(f *os.File) (*FileIDInfo, error) {
|
||||
fileID := &FileIDInfo{}
|
||||
if err := windows.GetFileInformationByHandleEx(
|
||||
windows.Handle(f.Fd()),
|
||||
windows.FileIdInfo,
|
||||
(*byte)(unsafe.Pointer(fileID)),
|
||||
uint32(unsafe.Sizeof(*fileID)),
|
||||
); err != nil {
|
||||
return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err}
|
||||
}
|
||||
runtime.KeepAlive(f)
|
||||
return fileID, nil
|
||||
}
|
||||
574
vendor/github.com/tailscale/go-winio/hvsock.go
generated
vendored
Normal file
574
vendor/github.com/tailscale/go-winio/hvsock.go
generated
vendored
Normal file
@@ -0,0 +1,574 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/tailscale/go-winio/internal/socket"
|
||||
"github.com/tailscale/go-winio/pkg/guid"
|
||||
)
|
||||
|
||||
const afHVSock = 34 // AF_HYPERV
|
||||
|
||||
// Well known Service and VM IDs
|
||||
// https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service#vmid-wildcards
|
||||
|
||||
// HvsockGUIDWildcard is the wildcard VmId for accepting connections from all partitions.
|
||||
func HvsockGUIDWildcard() guid.GUID { // 00000000-0000-0000-0000-000000000000
|
||||
return guid.GUID{}
|
||||
}
|
||||
|
||||
// HvsockGUIDBroadcast is the wildcard VmId for broadcasting sends to all partitions.
|
||||
func HvsockGUIDBroadcast() guid.GUID { // ffffffff-ffff-ffff-ffff-ffffffffffff
|
||||
return guid.GUID{
|
||||
Data1: 0xffffffff,
|
||||
Data2: 0xffff,
|
||||
Data3: 0xffff,
|
||||
Data4: [8]uint8{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
}
|
||||
}
|
||||
|
||||
// HvsockGUIDLoopback is the Loopback VmId for accepting connections to the same partition as the connector.
|
||||
func HvsockGUIDLoopback() guid.GUID { // e0e16197-dd56-4a10-9195-5ee7a155a838
|
||||
return guid.GUID{
|
||||
Data1: 0xe0e16197,
|
||||
Data2: 0xdd56,
|
||||
Data3: 0x4a10,
|
||||
Data4: [8]uint8{0x91, 0x95, 0x5e, 0xe7, 0xa1, 0x55, 0xa8, 0x38},
|
||||
}
|
||||
}
|
||||
|
||||
// HvsockGUIDSiloHost is the address of a silo's host partition:
|
||||
// - The silo host of a hosted silo is the utility VM.
|
||||
// - The silo host of a silo on a physical host is the physical host.
|
||||
func HvsockGUIDSiloHost() guid.GUID { // 36bd0c5c-7276-4223-88ba-7d03b654c568
|
||||
return guid.GUID{
|
||||
Data1: 0x36bd0c5c,
|
||||
Data2: 0x7276,
|
||||
Data3: 0x4223,
|
||||
Data4: [8]byte{0x88, 0xba, 0x7d, 0x03, 0xb6, 0x54, 0xc5, 0x68},
|
||||
}
|
||||
}
|
||||
|
||||
// HvsockGUIDChildren is the wildcard VmId for accepting connections from the connector's child partitions.
|
||||
func HvsockGUIDChildren() guid.GUID { // 90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd
|
||||
return guid.GUID{
|
||||
Data1: 0x90db8b89,
|
||||
Data2: 0xd35,
|
||||
Data3: 0x4f79,
|
||||
Data4: [8]uint8{0x8c, 0xe9, 0x49, 0xea, 0xa, 0xc8, 0xb7, 0xcd},
|
||||
}
|
||||
}
|
||||
|
||||
// HvsockGUIDParent is the wildcard VmId for accepting connections from the connector's parent partition.
|
||||
// Listening on this VmId accepts connection from:
|
||||
// - Inside silos: silo host partition.
|
||||
// - Inside hosted silo: host of the VM.
|
||||
// - Inside VM: VM host.
|
||||
// - Physical host: Not supported.
|
||||
func HvsockGUIDParent() guid.GUID { // a42e7cda-d03f-480c-9cc2-a4de20abb878
|
||||
return guid.GUID{
|
||||
Data1: 0xa42e7cda,
|
||||
Data2: 0xd03f,
|
||||
Data3: 0x480c,
|
||||
Data4: [8]uint8{0x9c, 0xc2, 0xa4, 0xde, 0x20, 0xab, 0xb8, 0x78},
|
||||
}
|
||||
}
|
||||
|
||||
// hvsockVsockServiceTemplate is the Service GUID used for the VSOCK protocol.
|
||||
func hvsockVsockServiceTemplate() guid.GUID { // 00000000-facb-11e6-bd58-64006a7986d3
|
||||
return guid.GUID{
|
||||
Data2: 0xfacb,
|
||||
Data3: 0x11e6,
|
||||
Data4: [8]uint8{0xbd, 0x58, 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3},
|
||||
}
|
||||
}
|
||||
|
||||
// An HvsockAddr is an address for a AF_HYPERV socket.
|
||||
type HvsockAddr struct {
|
||||
VMID guid.GUID
|
||||
ServiceID guid.GUID
|
||||
}
|
||||
|
||||
type rawHvsockAddr struct {
|
||||
Family uint16
|
||||
_ uint16
|
||||
VMID guid.GUID
|
||||
ServiceID guid.GUID
|
||||
}
|
||||
|
||||
var _ socket.RawSockaddr = &rawHvsockAddr{}
|
||||
|
||||
// Network returns the address's network name, "hvsock".
|
||||
func (*HvsockAddr) Network() string {
|
||||
return "hvsock"
|
||||
}
|
||||
|
||||
func (addr *HvsockAddr) String() string {
|
||||
return fmt.Sprintf("%s:%s", &addr.VMID, &addr.ServiceID)
|
||||
}
|
||||
|
||||
// VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port.
|
||||
func VsockServiceID(port uint32) guid.GUID {
|
||||
g := hvsockVsockServiceTemplate() // make a copy
|
||||
g.Data1 = port
|
||||
return g
|
||||
}
|
||||
|
||||
func (addr *HvsockAddr) raw() rawHvsockAddr {
|
||||
return rawHvsockAddr{
|
||||
Family: afHVSock,
|
||||
VMID: addr.VMID,
|
||||
ServiceID: addr.ServiceID,
|
||||
}
|
||||
}
|
||||
|
||||
func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) {
|
||||
addr.VMID = raw.VMID
|
||||
addr.ServiceID = raw.ServiceID
|
||||
}
|
||||
|
||||
// Sockaddr returns a pointer to and the size of this struct.
|
||||
//
|
||||
// Implements the [socket.RawSockaddr] interface, and allows use in
|
||||
// [socket.Bind] and [socket.ConnectEx].
|
||||
func (r *rawHvsockAddr) Sockaddr() (unsafe.Pointer, int32, error) {
|
||||
return unsafe.Pointer(r), int32(unsafe.Sizeof(rawHvsockAddr{})), nil
|
||||
}
|
||||
|
||||
// Sockaddr interface allows use with `sockets.Bind()` and `.ConnectEx()`.
|
||||
func (r *rawHvsockAddr) FromBytes(b []byte) error {
|
||||
n := int(unsafe.Sizeof(rawHvsockAddr{}))
|
||||
|
||||
if len(b) < n {
|
||||
return fmt.Errorf("got %d, want %d: %w", len(b), n, socket.ErrBufferSize)
|
||||
}
|
||||
|
||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(r)), n), b[:n])
|
||||
if r.Family != afHVSock {
|
||||
return fmt.Errorf("got %d, want %d: %w", r.Family, afHVSock, socket.ErrAddrFamily)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HvsockListener is a socket listener for the AF_HYPERV address family.
|
||||
type HvsockListener struct {
|
||||
sock *win32File
|
||||
addr HvsockAddr
|
||||
}
|
||||
|
||||
var _ net.Listener = &HvsockListener{}
|
||||
|
||||
// HvsockConn is a connected socket of the AF_HYPERV address family.
|
||||
type HvsockConn struct {
|
||||
sock *win32File
|
||||
local, remote HvsockAddr
|
||||
}
|
||||
|
||||
var _ net.Conn = &HvsockConn{}
|
||||
|
||||
func newHVSocket() (*win32File, error) {
|
||||
fd, err := windows.Socket(afHVSock, windows.SOCK_STREAM, 1)
|
||||
if err != nil {
|
||||
return nil, os.NewSyscallError("socket", err)
|
||||
}
|
||||
f, err := makeWin32File(fd)
|
||||
if err != nil {
|
||||
windows.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
f.socket = true
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// ListenHvsock listens for connections on the specified hvsock address.
|
||||
func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) {
|
||||
l := &HvsockListener{addr: *addr}
|
||||
sock, err := newHVSocket()
|
||||
if err != nil {
|
||||
return nil, l.opErr("listen", err)
|
||||
}
|
||||
sa := addr.raw()
|
||||
err = socket.Bind(sock.handle, &sa)
|
||||
if err != nil {
|
||||
return nil, l.opErr("listen", os.NewSyscallError("socket", err))
|
||||
}
|
||||
err = windows.Listen(sock.handle, 16)
|
||||
if err != nil {
|
||||
return nil, l.opErr("listen", os.NewSyscallError("listen", err))
|
||||
}
|
||||
return &HvsockListener{sock: sock, addr: *addr}, nil
|
||||
}
|
||||
|
||||
func (l *HvsockListener) opErr(op string, err error) error {
|
||||
return &net.OpError{Op: op, Net: "hvsock", Addr: &l.addr, Err: err}
|
||||
}
|
||||
|
||||
// Addr returns the listener's network address.
|
||||
func (l *HvsockListener) Addr() net.Addr {
|
||||
return &l.addr
|
||||
}
|
||||
|
||||
// Accept waits for the next connection and returns it.
|
||||
func (l *HvsockListener) Accept() (_ net.Conn, err error) {
|
||||
sock, err := newHVSocket()
|
||||
if err != nil {
|
||||
return nil, l.opErr("accept", err)
|
||||
}
|
||||
defer func() {
|
||||
if sock != nil {
|
||||
sock.Close()
|
||||
}
|
||||
}()
|
||||
c, err := l.sock.prepareIO()
|
||||
if err != nil {
|
||||
return nil, l.opErr("accept", err)
|
||||
}
|
||||
defer l.sock.wg.Done()
|
||||
|
||||
// AcceptEx, per documentation, requires an extra 16 bytes per address.
|
||||
//
|
||||
// https://docs.microsoft.com/en-us/windows/win32/api/mswsock/nf-mswsock-acceptex
|
||||
const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{}))
|
||||
var addrbuf [addrlen * 2]byte
|
||||
|
||||
var bytes uint32
|
||||
err = windows.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0 /* rxdatalen */, addrlen, addrlen, &bytes, &c.o)
|
||||
if _, err = l.sock.asyncIO(c, nil, bytes, err); err != nil {
|
||||
return nil, l.opErr("accept", os.NewSyscallError("acceptex", err))
|
||||
}
|
||||
|
||||
conn := &HvsockConn{
|
||||
sock: sock,
|
||||
}
|
||||
// The local address returned in the AcceptEx buffer is the same as the Listener socket's
|
||||
// address. However, the service GUID reported by GetSockName is different from the Listeners
|
||||
// socket, and is sometimes the same as the local address of the socket that dialed the
|
||||
// address, with the service GUID.Data1 incremented, but othertimes is different.
|
||||
// todo: does the local address matter? is the listener's address or the actual address appropriate?
|
||||
conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0])))
|
||||
conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen])))
|
||||
|
||||
// initialize the accepted socket and update its properties with those of the listening socket
|
||||
if err = windows.Setsockopt(sock.handle,
|
||||
windows.SOL_SOCKET, windows.SO_UPDATE_ACCEPT_CONTEXT,
|
||||
(*byte)(unsafe.Pointer(&l.sock.handle)), int32(unsafe.Sizeof(l.sock.handle))); err != nil {
|
||||
return nil, conn.opErr("accept", os.NewSyscallError("setsockopt", err))
|
||||
}
|
||||
|
||||
sock = nil
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Close closes the listener, causing any pending Accept calls to fail.
|
||||
func (l *HvsockListener) Close() error {
|
||||
return l.sock.Close()
|
||||
}
|
||||
|
||||
// HvsockDialer configures and dials a Hyper-V Socket (ie, [HvsockConn]).
|
||||
type HvsockDialer struct {
|
||||
// Deadline is the time the Dial operation must connect before erroring.
|
||||
Deadline time.Time
|
||||
|
||||
// Retries is the number of additional connects to try if the connection times out, is refused,
|
||||
// or the host is unreachable
|
||||
Retries uint
|
||||
|
||||
// RetryWait is the time to wait after a connection error to retry
|
||||
RetryWait time.Duration
|
||||
|
||||
rt *time.Timer // redial wait timer
|
||||
}
|
||||
|
||||
// Dial the Hyper-V socket at addr.
|
||||
//
|
||||
// See [HvsockDialer.Dial] for more information.
|
||||
func Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) {
|
||||
return (&HvsockDialer{}).Dial(ctx, addr)
|
||||
}
|
||||
|
||||
// Dial attempts to connect to the Hyper-V socket at addr, and returns a connection if successful.
|
||||
// Will attempt (HvsockDialer).Retries if dialing fails, waiting (HvsockDialer).RetryWait between
|
||||
// retries.
|
||||
//
|
||||
// Dialing can be cancelled either by providing (HvsockDialer).Deadline, or cancelling ctx.
|
||||
func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) {
|
||||
op := "dial"
|
||||
// create the conn early to use opErr()
|
||||
conn = &HvsockConn{
|
||||
remote: *addr,
|
||||
}
|
||||
|
||||
if !d.Deadline.IsZero() {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithDeadline(ctx, d.Deadline)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// preemptive timeout/cancellation check
|
||||
if err = ctx.Err(); err != nil {
|
||||
return nil, conn.opErr(op, err)
|
||||
}
|
||||
|
||||
sock, err := newHVSocket()
|
||||
if err != nil {
|
||||
return nil, conn.opErr(op, err)
|
||||
}
|
||||
defer func() {
|
||||
if sock != nil {
|
||||
sock.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
sa := addr.raw()
|
||||
err = socket.Bind(sock.handle, &sa)
|
||||
if err != nil {
|
||||
return nil, conn.opErr(op, os.NewSyscallError("bind", err))
|
||||
}
|
||||
|
||||
c, err := sock.prepareIO()
|
||||
if err != nil {
|
||||
return nil, conn.opErr(op, err)
|
||||
}
|
||||
defer sock.wg.Done()
|
||||
var bytes uint32
|
||||
for i := uint(0); i <= d.Retries; i++ {
|
||||
err = socket.ConnectEx(
|
||||
sock.handle,
|
||||
&sa,
|
||||
nil, // sendBuf
|
||||
0, // sendDataLen
|
||||
&bytes,
|
||||
(*windows.Overlapped)(unsafe.Pointer(&c.o)))
|
||||
_, err = sock.asyncIO(c, nil, bytes, err)
|
||||
if i < d.Retries && canRedial(err) {
|
||||
if err = d.redialWait(ctx); err == nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, conn.opErr(op, os.NewSyscallError("connectex", err))
|
||||
}
|
||||
|
||||
// update the connection properties, so shutdown can be used
|
||||
if err = windows.Setsockopt(
|
||||
sock.handle,
|
||||
windows.SOL_SOCKET,
|
||||
windows.SO_UPDATE_CONNECT_CONTEXT,
|
||||
nil, // optvalue
|
||||
0, // optlen
|
||||
); err != nil {
|
||||
return nil, conn.opErr(op, os.NewSyscallError("setsockopt", err))
|
||||
}
|
||||
|
||||
// get the local name
|
||||
var sal rawHvsockAddr
|
||||
err = socket.GetSockName(sock.handle, &sal)
|
||||
if err != nil {
|
||||
return nil, conn.opErr(op, os.NewSyscallError("getsockname", err))
|
||||
}
|
||||
conn.local.fromRaw(&sal)
|
||||
|
||||
// one last check for timeout, since asyncIO doesn't check the context
|
||||
if err = ctx.Err(); err != nil {
|
||||
return nil, conn.opErr(op, err)
|
||||
}
|
||||
|
||||
conn.sock = sock
|
||||
sock = nil
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// redialWait waits before attempting to redial, resetting the timer as appropriate.
|
||||
func (d *HvsockDialer) redialWait(ctx context.Context) (err error) {
|
||||
if d.RetryWait == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if d.rt == nil {
|
||||
d.rt = time.NewTimer(d.RetryWait)
|
||||
} else {
|
||||
// should already be stopped and drained
|
||||
d.rt.Reset(d.RetryWait)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-d.rt.C:
|
||||
return nil
|
||||
}
|
||||
|
||||
// stop and drain the timer
|
||||
if !d.rt.Stop() {
|
||||
<-d.rt.C
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// assumes error is a plain, unwrapped windows.Errno provided by direct syscall.
|
||||
func canRedial(err error) bool {
|
||||
//nolint:errorlint // guaranteed to be an Errno
|
||||
switch err {
|
||||
case windows.WSAECONNREFUSED, windows.WSAENETUNREACH, windows.WSAETIMEDOUT,
|
||||
windows.ERROR_CONNECTION_REFUSED, windows.ERROR_CONNECTION_UNAVAIL:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *HvsockConn) opErr(op string, err error) error {
|
||||
// translate from "file closed" to "socket closed"
|
||||
if errors.Is(err, ErrFileClosed) {
|
||||
err = socket.ErrSocketClosed
|
||||
}
|
||||
return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err}
|
||||
}
|
||||
|
||||
func (conn *HvsockConn) Read(b []byte) (int, error) {
|
||||
c, err := conn.sock.prepareIO()
|
||||
if err != nil {
|
||||
return 0, conn.opErr("read", err)
|
||||
}
|
||||
defer conn.sock.wg.Done()
|
||||
buf := windows.WSABuf{Buf: &b[0], Len: uint32(len(b))}
|
||||
var flags, bytes uint32
|
||||
err = windows.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil)
|
||||
n, err := conn.sock.asyncIO(c, &conn.sock.readDeadline, bytes, err)
|
||||
if err != nil {
|
||||
var eno windows.Errno
|
||||
if errors.As(err, &eno) {
|
||||
err = os.NewSyscallError("wsarecv", eno)
|
||||
}
|
||||
return 0, conn.opErr("read", err)
|
||||
} else if n == 0 {
|
||||
err = io.EOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (conn *HvsockConn) Write(b []byte) (int, error) {
|
||||
t := 0
|
||||
for len(b) != 0 {
|
||||
n, err := conn.write(b)
|
||||
if err != nil {
|
||||
return t + n, err
|
||||
}
|
||||
t += n
|
||||
b = b[n:]
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (conn *HvsockConn) write(b []byte) (int, error) {
|
||||
c, err := conn.sock.prepareIO()
|
||||
if err != nil {
|
||||
return 0, conn.opErr("write", err)
|
||||
}
|
||||
defer conn.sock.wg.Done()
|
||||
buf := windows.WSABuf{Buf: &b[0], Len: uint32(len(b))}
|
||||
var bytes uint32
|
||||
err = windows.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil)
|
||||
n, err := conn.sock.asyncIO(c, &conn.sock.writeDeadline, bytes, err)
|
||||
if err != nil {
|
||||
var eno windows.Errno
|
||||
if errors.As(err, &eno) {
|
||||
err = os.NewSyscallError("wsasend", eno)
|
||||
}
|
||||
return 0, conn.opErr("write", err)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Close closes the socket connection, failing any pending read or write calls.
|
||||
func (conn *HvsockConn) Close() error {
|
||||
return conn.sock.Close()
|
||||
}
|
||||
|
||||
func (conn *HvsockConn) IsClosed() bool {
|
||||
return conn.sock.IsClosed()
|
||||
}
|
||||
|
||||
// shutdown disables sending or receiving on a socket.
|
||||
func (conn *HvsockConn) shutdown(how int) error {
|
||||
if conn.IsClosed() {
|
||||
return socket.ErrSocketClosed
|
||||
}
|
||||
|
||||
err := windows.Shutdown(conn.sock.handle, how)
|
||||
if err != nil {
|
||||
// If the connection was closed, shutdowns fail with "not connected"
|
||||
if errors.Is(err, windows.WSAENOTCONN) ||
|
||||
errors.Is(err, windows.WSAESHUTDOWN) {
|
||||
err = socket.ErrSocketClosed
|
||||
}
|
||||
return os.NewSyscallError("shutdown", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseRead shuts down the read end of the socket, preventing future read operations.
|
||||
func (conn *HvsockConn) CloseRead() error {
|
||||
err := conn.shutdown(windows.SHUT_RD)
|
||||
if err != nil {
|
||||
return conn.opErr("closeread", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseWrite shuts down the write end of the socket, preventing future write operations and
|
||||
// notifying the other endpoint that no more data will be written.
|
||||
func (conn *HvsockConn) CloseWrite() error {
|
||||
err := conn.shutdown(windows.SHUT_WR)
|
||||
if err != nil {
|
||||
return conn.opErr("closewrite", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr returns the local address of the connection.
|
||||
func (conn *HvsockConn) LocalAddr() net.Addr {
|
||||
return &conn.local
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote address of the connection.
|
||||
func (conn *HvsockConn) RemoteAddr() net.Addr {
|
||||
return &conn.remote
|
||||
}
|
||||
|
||||
// SetDeadline implements the net.Conn SetDeadline method.
|
||||
func (conn *HvsockConn) SetDeadline(t time.Time) error {
|
||||
// todo: implement `SetDeadline` for `win32File`
|
||||
if err := conn.SetReadDeadline(t); err != nil {
|
||||
return fmt.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
if err := conn.SetWriteDeadline(t); err != nil {
|
||||
return fmt.Errorf("set write deadline: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline implements the net.Conn SetReadDeadline method.
|
||||
func (conn *HvsockConn) SetReadDeadline(t time.Time) error {
|
||||
return conn.sock.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
// SetWriteDeadline implements the net.Conn SetWriteDeadline method.
|
||||
func (conn *HvsockConn) SetWriteDeadline(t time.Time) error {
|
||||
return conn.sock.SetWriteDeadline(t)
|
||||
}
|
||||
2
vendor/github.com/tailscale/go-winio/internal/fs/doc.go
generated
vendored
Normal file
2
vendor/github.com/tailscale/go-winio/internal/fs/doc.go
generated
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
// This package contains Win32 filesystem functionality.
|
||||
package fs
|
||||
262
vendor/github.com/tailscale/go-winio/internal/fs/fs.go
generated
vendored
Normal file
262
vendor/github.com/tailscale/go-winio/internal/fs/fs.go
generated
vendored
Normal file
@@ -0,0 +1,262 @@
|
||||
//go:build windows
|
||||
|
||||
package fs
|
||||
|
||||
import (
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/tailscale/go-winio/internal/stringbuffer"
|
||||
)
|
||||
|
||||
//go:generate go run github.com/tailscale/go-winio/tools/mkwinsyscall -output zsyscall_windows.go fs.go
|
||||
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew
|
||||
//sys CreateFile(name string, access AccessMask, mode FileShareMode, sa *windows.SecurityAttributes, createmode FileCreationDisposition, attrs FileFlagOrAttribute, templatefile windows.Handle) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateFileW
|
||||
|
||||
const NullHandle windows.Handle = 0
|
||||
|
||||
// AccessMask defines standard, specific, and generic rights.
|
||||
//
|
||||
// Used with CreateFile and NtCreateFile (and co.).
|
||||
//
|
||||
// Bitmask:
|
||||
// 3 3 2 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1
|
||||
// 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0
|
||||
// +---------------+---------------+-------------------------------+
|
||||
// |G|G|G|G|Resvd|A| StandardRights| SpecificRights |
|
||||
// |R|W|E|A| |S| | |
|
||||
// +-+-------------+---------------+-------------------------------+
|
||||
//
|
||||
// GR Generic Read
|
||||
// GW Generic Write
|
||||
// GE Generic Exectue
|
||||
// GA Generic All
|
||||
// Resvd Reserved
|
||||
// AS Access Security System
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/windows/win32/secauthz/access-mask
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/windows/win32/secauthz/generic-access-rights
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/windows/win32/fileio/file-access-rights-constants
|
||||
type AccessMask = windows.ACCESS_MASK
|
||||
|
||||
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||
const (
|
||||
// Not actually any.
|
||||
//
|
||||
// For CreateFile: "query certain metadata such as file, directory, or device attributes without accessing that file or device"
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew#parameters
|
||||
FILE_ANY_ACCESS AccessMask = 0
|
||||
|
||||
GENERIC_READ AccessMask = 0x8000_0000
|
||||
GENERIC_WRITE AccessMask = 0x4000_0000
|
||||
GENERIC_EXECUTE AccessMask = 0x2000_0000
|
||||
GENERIC_ALL AccessMask = 0x1000_0000
|
||||
ACCESS_SYSTEM_SECURITY AccessMask = 0x0100_0000
|
||||
|
||||
// Specific Object Access
|
||||
// from ntioapi.h
|
||||
|
||||
FILE_READ_DATA AccessMask = (0x0001) // file & pipe
|
||||
FILE_LIST_DIRECTORY AccessMask = (0x0001) // directory
|
||||
|
||||
FILE_WRITE_DATA AccessMask = (0x0002) // file & pipe
|
||||
FILE_ADD_FILE AccessMask = (0x0002) // directory
|
||||
|
||||
FILE_APPEND_DATA AccessMask = (0x0004) // file
|
||||
FILE_ADD_SUBDIRECTORY AccessMask = (0x0004) // directory
|
||||
FILE_CREATE_PIPE_INSTANCE AccessMask = (0x0004) // named pipe
|
||||
|
||||
FILE_READ_EA AccessMask = (0x0008) // file & directory
|
||||
FILE_READ_PROPERTIES AccessMask = FILE_READ_EA
|
||||
|
||||
FILE_WRITE_EA AccessMask = (0x0010) // file & directory
|
||||
FILE_WRITE_PROPERTIES AccessMask = FILE_WRITE_EA
|
||||
|
||||
FILE_EXECUTE AccessMask = (0x0020) // file
|
||||
FILE_TRAVERSE AccessMask = (0x0020) // directory
|
||||
|
||||
FILE_DELETE_CHILD AccessMask = (0x0040) // directory
|
||||
|
||||
FILE_READ_ATTRIBUTES AccessMask = (0x0080) // all
|
||||
|
||||
FILE_WRITE_ATTRIBUTES AccessMask = (0x0100) // all
|
||||
|
||||
FILE_ALL_ACCESS AccessMask = (STANDARD_RIGHTS_REQUIRED | SYNCHRONIZE | 0x1FF)
|
||||
FILE_GENERIC_READ AccessMask = (STANDARD_RIGHTS_READ | FILE_READ_DATA | FILE_READ_ATTRIBUTES | FILE_READ_EA | SYNCHRONIZE)
|
||||
FILE_GENERIC_WRITE AccessMask = (STANDARD_RIGHTS_WRITE | FILE_WRITE_DATA | FILE_WRITE_ATTRIBUTES | FILE_WRITE_EA | FILE_APPEND_DATA | SYNCHRONIZE)
|
||||
FILE_GENERIC_EXECUTE AccessMask = (STANDARD_RIGHTS_EXECUTE | FILE_READ_ATTRIBUTES | FILE_EXECUTE | SYNCHRONIZE)
|
||||
|
||||
SPECIFIC_RIGHTS_ALL AccessMask = 0x0000FFFF
|
||||
|
||||
// Standard Access
|
||||
// from ntseapi.h
|
||||
|
||||
DELETE AccessMask = 0x0001_0000
|
||||
READ_CONTROL AccessMask = 0x0002_0000
|
||||
WRITE_DAC AccessMask = 0x0004_0000
|
||||
WRITE_OWNER AccessMask = 0x0008_0000
|
||||
SYNCHRONIZE AccessMask = 0x0010_0000
|
||||
|
||||
STANDARD_RIGHTS_REQUIRED AccessMask = 0x000F_0000
|
||||
|
||||
STANDARD_RIGHTS_READ AccessMask = READ_CONTROL
|
||||
STANDARD_RIGHTS_WRITE AccessMask = READ_CONTROL
|
||||
STANDARD_RIGHTS_EXECUTE AccessMask = READ_CONTROL
|
||||
|
||||
STANDARD_RIGHTS_ALL AccessMask = 0x001F_0000
|
||||
)
|
||||
|
||||
type FileShareMode uint32
|
||||
|
||||
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||
const (
|
||||
FILE_SHARE_NONE FileShareMode = 0x00
|
||||
FILE_SHARE_READ FileShareMode = 0x01
|
||||
FILE_SHARE_WRITE FileShareMode = 0x02
|
||||
FILE_SHARE_DELETE FileShareMode = 0x04
|
||||
FILE_SHARE_VALID_FLAGS FileShareMode = 0x07
|
||||
)
|
||||
|
||||
type FileCreationDisposition uint32
|
||||
|
||||
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||
const (
|
||||
// from winbase.h
|
||||
|
||||
CREATE_NEW FileCreationDisposition = 0x01
|
||||
CREATE_ALWAYS FileCreationDisposition = 0x02
|
||||
OPEN_EXISTING FileCreationDisposition = 0x03
|
||||
OPEN_ALWAYS FileCreationDisposition = 0x04
|
||||
TRUNCATE_EXISTING FileCreationDisposition = 0x05
|
||||
)
|
||||
|
||||
// Create disposition values for NtCreate*
|
||||
type NTFileCreationDisposition uint32
|
||||
|
||||
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||
const (
|
||||
// From ntioapi.h
|
||||
|
||||
FILE_SUPERSEDE NTFileCreationDisposition = 0x00
|
||||
FILE_OPEN NTFileCreationDisposition = 0x01
|
||||
FILE_CREATE NTFileCreationDisposition = 0x02
|
||||
FILE_OPEN_IF NTFileCreationDisposition = 0x03
|
||||
FILE_OVERWRITE NTFileCreationDisposition = 0x04
|
||||
FILE_OVERWRITE_IF NTFileCreationDisposition = 0x05
|
||||
FILE_MAXIMUM_DISPOSITION NTFileCreationDisposition = 0x05
|
||||
)
|
||||
|
||||
// CreateFile and co. take flags or attributes together as one parameter.
|
||||
// Define alias until we can use generics to allow both
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/windows/win32/fileio/file-attribute-constants
|
||||
type FileFlagOrAttribute uint32
|
||||
|
||||
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||
const (
|
||||
// from winnt.h
|
||||
|
||||
FILE_FLAG_WRITE_THROUGH FileFlagOrAttribute = 0x8000_0000
|
||||
FILE_FLAG_OVERLAPPED FileFlagOrAttribute = 0x4000_0000
|
||||
FILE_FLAG_NO_BUFFERING FileFlagOrAttribute = 0x2000_0000
|
||||
FILE_FLAG_RANDOM_ACCESS FileFlagOrAttribute = 0x1000_0000
|
||||
FILE_FLAG_SEQUENTIAL_SCAN FileFlagOrAttribute = 0x0800_0000
|
||||
FILE_FLAG_DELETE_ON_CLOSE FileFlagOrAttribute = 0x0400_0000
|
||||
FILE_FLAG_BACKUP_SEMANTICS FileFlagOrAttribute = 0x0200_0000
|
||||
FILE_FLAG_POSIX_SEMANTICS FileFlagOrAttribute = 0x0100_0000
|
||||
FILE_FLAG_OPEN_REPARSE_POINT FileFlagOrAttribute = 0x0020_0000
|
||||
FILE_FLAG_OPEN_NO_RECALL FileFlagOrAttribute = 0x0010_0000
|
||||
FILE_FLAG_FIRST_PIPE_INSTANCE FileFlagOrAttribute = 0x0008_0000
|
||||
)
|
||||
|
||||
// NtCreate* functions take a dedicated CreateOptions parameter.
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/Winternl/nf-winternl-ntcreatefile
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/windows/win32/devnotes/nt-create-named-pipe-file
|
||||
type NTCreateOptions uint32
|
||||
|
||||
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||
const (
|
||||
// From ntioapi.h
|
||||
|
||||
FILE_DIRECTORY_FILE NTCreateOptions = 0x0000_0001
|
||||
FILE_WRITE_THROUGH NTCreateOptions = 0x0000_0002
|
||||
FILE_SEQUENTIAL_ONLY NTCreateOptions = 0x0000_0004
|
||||
FILE_NO_INTERMEDIATE_BUFFERING NTCreateOptions = 0x0000_0008
|
||||
|
||||
FILE_SYNCHRONOUS_IO_ALERT NTCreateOptions = 0x0000_0010
|
||||
FILE_SYNCHRONOUS_IO_NONALERT NTCreateOptions = 0x0000_0020
|
||||
FILE_NON_DIRECTORY_FILE NTCreateOptions = 0x0000_0040
|
||||
FILE_CREATE_TREE_CONNECTION NTCreateOptions = 0x0000_0080
|
||||
|
||||
FILE_COMPLETE_IF_OPLOCKED NTCreateOptions = 0x0000_0100
|
||||
FILE_NO_EA_KNOWLEDGE NTCreateOptions = 0x0000_0200
|
||||
FILE_DISABLE_TUNNELING NTCreateOptions = 0x0000_0400
|
||||
FILE_RANDOM_ACCESS NTCreateOptions = 0x0000_0800
|
||||
|
||||
FILE_DELETE_ON_CLOSE NTCreateOptions = 0x0000_1000
|
||||
FILE_OPEN_BY_FILE_ID NTCreateOptions = 0x0000_2000
|
||||
FILE_OPEN_FOR_BACKUP_INTENT NTCreateOptions = 0x0000_4000
|
||||
FILE_NO_COMPRESSION NTCreateOptions = 0x0000_8000
|
||||
)
|
||||
|
||||
type FileSQSFlag = FileFlagOrAttribute
|
||||
|
||||
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||
const (
|
||||
// from winbase.h
|
||||
|
||||
SECURITY_ANONYMOUS FileSQSFlag = FileSQSFlag(SecurityAnonymous << 16)
|
||||
SECURITY_IDENTIFICATION FileSQSFlag = FileSQSFlag(SecurityIdentification << 16)
|
||||
SECURITY_IMPERSONATION FileSQSFlag = FileSQSFlag(SecurityImpersonation << 16)
|
||||
SECURITY_DELEGATION FileSQSFlag = FileSQSFlag(SecurityDelegation << 16)
|
||||
|
||||
SECURITY_SQOS_PRESENT FileSQSFlag = 0x0010_0000
|
||||
SECURITY_VALID_SQOS_FLAGS FileSQSFlag = 0x001F_0000
|
||||
)
|
||||
|
||||
// GetFinalPathNameByHandle flags
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-getfinalpathnamebyhandlew#parameters
|
||||
type GetFinalPathFlag uint32
|
||||
|
||||
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
|
||||
const (
|
||||
GetFinalPathDefaultFlag GetFinalPathFlag = 0x0
|
||||
|
||||
FILE_NAME_NORMALIZED GetFinalPathFlag = 0x0
|
||||
FILE_NAME_OPENED GetFinalPathFlag = 0x8
|
||||
|
||||
VOLUME_NAME_DOS GetFinalPathFlag = 0x0
|
||||
VOLUME_NAME_GUID GetFinalPathFlag = 0x1
|
||||
VOLUME_NAME_NT GetFinalPathFlag = 0x2
|
||||
VOLUME_NAME_NONE GetFinalPathFlag = 0x4
|
||||
)
|
||||
|
||||
// getFinalPathNameByHandle facilitates calling the Windows API GetFinalPathNameByHandle
|
||||
// with the given handle and flags. It transparently takes care of creating a buffer of the
|
||||
// correct size for the call.
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-getfinalpathnamebyhandlew
|
||||
func GetFinalPathNameByHandle(h windows.Handle, flags GetFinalPathFlag) (string, error) {
|
||||
b := stringbuffer.NewWString()
|
||||
//TODO: can loop infinitely if Win32 keeps returning the same (or a larger) n?
|
||||
for {
|
||||
n, err := windows.GetFinalPathNameByHandle(h, b.Pointer(), b.Cap(), uint32(flags))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// If the buffer wasn't large enough, n will be the total size needed (including null terminator).
|
||||
// Resize and try again.
|
||||
if n > b.Cap() {
|
||||
b.ResizeTo(n)
|
||||
continue
|
||||
}
|
||||
// If the buffer is large enough, n will be the size not including the null terminator.
|
||||
// Convert to a Go string and return.
|
||||
return b.String(), nil
|
||||
}
|
||||
}
|
||||
12
vendor/github.com/tailscale/go-winio/internal/fs/security.go
generated
vendored
Normal file
12
vendor/github.com/tailscale/go-winio/internal/fs/security.go
generated
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
package fs
|
||||
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/winnt/ne-winnt-security_impersonation_level
|
||||
type SecurityImpersonationLevel int32 // C default enums underlying type is `int`, which is Go `int32`
|
||||
|
||||
// Impersonation levels
|
||||
const (
|
||||
SecurityAnonymous SecurityImpersonationLevel = 0
|
||||
SecurityIdentification SecurityImpersonationLevel = 1
|
||||
SecurityImpersonation SecurityImpersonationLevel = 2
|
||||
SecurityDelegation SecurityImpersonationLevel = 3
|
||||
)
|
||||
64
vendor/github.com/tailscale/go-winio/internal/fs/zsyscall_windows.go
generated
vendored
Normal file
64
vendor/github.com/tailscale/go-winio/internal/fs/zsyscall_windows.go
generated
vendored
Normal file
@@ -0,0 +1,64 @@
|
||||
//go:build windows
|
||||
|
||||
// Code generated by 'go generate' using "github.com/tailscale/go-winio/tools/mkwinsyscall"; DO NOT EDIT.
|
||||
|
||||
package fs
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
errERROR_EINVAL error = syscall.EINVAL
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return errERROR_EINVAL
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
|
||||
procCreateFileW = modkernel32.NewProc("CreateFileW")
|
||||
)
|
||||
|
||||
func CreateFile(name string, access AccessMask, mode FileShareMode, sa *windows.SecurityAttributes, createmode FileCreationDisposition, attrs FileFlagOrAttribute, templatefile windows.Handle) (handle windows.Handle, err error) {
|
||||
var _p0 *uint16
|
||||
_p0, err = syscall.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return _CreateFile(_p0, access, mode, sa, createmode, attrs, templatefile)
|
||||
}
|
||||
|
||||
func _CreateFile(name *uint16, access AccessMask, mode FileShareMode, sa *windows.SecurityAttributes, createmode FileCreationDisposition, attrs FileFlagOrAttribute, templatefile windows.Handle) (handle windows.Handle, err error) {
|
||||
r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0)
|
||||
handle = windows.Handle(r0)
|
||||
if handle == windows.InvalidHandle {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
20
vendor/github.com/tailscale/go-winio/internal/socket/rawaddr.go
generated
vendored
Normal file
20
vendor/github.com/tailscale/go-winio/internal/socket/rawaddr.go
generated
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
package socket
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// RawSockaddr allows structs to be used with [Bind] and [ConnectEx]. The
|
||||
// struct must meet the Win32 sockaddr requirements specified here:
|
||||
// https://docs.microsoft.com/en-us/windows/win32/winsock/sockaddr-2
|
||||
//
|
||||
// Specifically, the struct size must be least larger than an int16 (unsigned short)
|
||||
// for the address family.
|
||||
type RawSockaddr interface {
|
||||
// Sockaddr returns a pointer to the RawSockaddr and its struct size, allowing
|
||||
// for the RawSockaddr's data to be overwritten by syscalls (if necessary).
|
||||
//
|
||||
// It is the callers responsibility to validate that the values are valid; invalid
|
||||
// pointers or size can cause a panic.
|
||||
Sockaddr() (unsafe.Pointer, int32, error)
|
||||
}
|
||||
179
vendor/github.com/tailscale/go-winio/internal/socket/socket.go
generated
vendored
Normal file
179
vendor/github.com/tailscale/go-winio/internal/socket/socket.go
generated
vendored
Normal file
@@ -0,0 +1,179 @@
|
||||
//go:build windows
|
||||
|
||||
package socket
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/tailscale/go-winio/pkg/guid"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
//go:generate go run github.com/tailscale/go-winio/tools/mkwinsyscall -output zsyscall_windows.go socket.go
|
||||
|
||||
//sys getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getsockname
|
||||
//sys getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getpeername
|
||||
//sys bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind
|
||||
|
||||
const socketError = uintptr(^uint32(0))
|
||||
|
||||
var (
|
||||
// todo(helsaawy): create custom error types to store the desired vs actual size and addr family?
|
||||
|
||||
ErrBufferSize = errors.New("buffer size")
|
||||
ErrAddrFamily = errors.New("address family")
|
||||
ErrInvalidPointer = errors.New("invalid pointer")
|
||||
ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed)
|
||||
)
|
||||
|
||||
// todo(helsaawy): replace these with generics, ie: GetSockName[S RawSockaddr](s windows.Handle) (S, error)
|
||||
|
||||
// GetSockName writes the local address of socket s to the [RawSockaddr] rsa.
|
||||
// If rsa is not large enough, the [windows.WSAEFAULT] is returned.
|
||||
func GetSockName(s windows.Handle, rsa RawSockaddr) error {
|
||||
ptr, l, err := rsa.Sockaddr()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
|
||||
}
|
||||
|
||||
// although getsockname returns WSAEFAULT if the buffer is too small, it does not set
|
||||
// &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy
|
||||
return getsockname(s, ptr, &l)
|
||||
}
|
||||
|
||||
// GetPeerName returns the remote address the socket is connected to.
|
||||
//
|
||||
// See [GetSockName] for more information.
|
||||
func GetPeerName(s windows.Handle, rsa RawSockaddr) error {
|
||||
ptr, l, err := rsa.Sockaddr()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
|
||||
}
|
||||
|
||||
return getpeername(s, ptr, &l)
|
||||
}
|
||||
|
||||
func Bind(s windows.Handle, rsa RawSockaddr) (err error) {
|
||||
ptr, l, err := rsa.Sockaddr()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
|
||||
}
|
||||
|
||||
return bind(s, ptr, l)
|
||||
}
|
||||
|
||||
// "golang.org/x/sys/windows".ConnectEx and .Bind only accept internal implementations of the
|
||||
// their sockaddr interface, so they cannot be used with HvsockAddr
|
||||
// Replicate functionality here from
|
||||
// https://cs.opensource.google/go/x/sys/+/master:windows/syscall_windows.go
|
||||
|
||||
// The function pointers to `AcceptEx`, `ConnectEx` and `GetAcceptExSockaddrs` must be loaded at
|
||||
// runtime via a WSAIoctl call:
|
||||
// https://docs.microsoft.com/en-us/windows/win32/api/Mswsock/nc-mswsock-lpfn_connectex#remarks
|
||||
|
||||
type runtimeFunc struct {
|
||||
id guid.GUID
|
||||
once sync.Once
|
||||
addr uintptr
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *runtimeFunc) Load() error {
|
||||
f.once.Do(func() {
|
||||
var s windows.Handle
|
||||
s, f.err = windows.Socket(windows.AF_INET, windows.SOCK_STREAM, windows.IPPROTO_TCP)
|
||||
if f.err != nil {
|
||||
return
|
||||
}
|
||||
defer windows.CloseHandle(s) //nolint:errcheck
|
||||
|
||||
var n uint32
|
||||
f.err = windows.WSAIoctl(s,
|
||||
windows.SIO_GET_EXTENSION_FUNCTION_POINTER,
|
||||
(*byte)(unsafe.Pointer(&f.id)),
|
||||
uint32(unsafe.Sizeof(f.id)),
|
||||
(*byte)(unsafe.Pointer(&f.addr)),
|
||||
uint32(unsafe.Sizeof(f.addr)),
|
||||
&n,
|
||||
nil, // overlapped
|
||||
0, // completionRoutine
|
||||
)
|
||||
})
|
||||
return f.err
|
||||
}
|
||||
|
||||
var (
|
||||
// todo: add `AcceptEx` and `GetAcceptExSockaddrs`
|
||||
WSAID_CONNECTEX = guid.GUID{ //revive:disable-line:var-naming ALL_CAPS
|
||||
Data1: 0x25a207b9,
|
||||
Data2: 0xddf3,
|
||||
Data3: 0x4660,
|
||||
Data4: [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e},
|
||||
}
|
||||
|
||||
connectExFunc = runtimeFunc{id: WSAID_CONNECTEX}
|
||||
)
|
||||
|
||||
func ConnectEx(
|
||||
fd windows.Handle,
|
||||
rsa RawSockaddr,
|
||||
sendBuf *byte,
|
||||
sendDataLen uint32,
|
||||
bytesSent *uint32,
|
||||
overlapped *windows.Overlapped,
|
||||
) error {
|
||||
if err := connectExFunc.Load(); err != nil {
|
||||
return fmt.Errorf("failed to load ConnectEx function pointer: %w", err)
|
||||
}
|
||||
ptr, n, err := rsa.Sockaddr()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped)
|
||||
}
|
||||
|
||||
// BOOL LpfnConnectex(
|
||||
// [in] SOCKET s,
|
||||
// [in] const sockaddr *name,
|
||||
// [in] int namelen,
|
||||
// [in, optional] PVOID lpSendBuffer,
|
||||
// [in] DWORD dwSendDataLength,
|
||||
// [out] LPDWORD lpdwBytesSent,
|
||||
// [in] LPOVERLAPPED lpOverlapped
|
||||
// )
|
||||
|
||||
func connectEx(
|
||||
s windows.Handle,
|
||||
name unsafe.Pointer,
|
||||
namelen int32,
|
||||
sendBuf *byte,
|
||||
sendDataLen uint32,
|
||||
bytesSent *uint32,
|
||||
overlapped *windows.Overlapped,
|
||||
) (err error) {
|
||||
// todo: after upgrading to 1.18, switch from syscall.Syscall9 to syscall.SyscallN
|
||||
r1, _, e1 := syscall.Syscall9(connectExFunc.addr,
|
||||
7,
|
||||
uintptr(s),
|
||||
uintptr(name),
|
||||
uintptr(namelen),
|
||||
uintptr(unsafe.Pointer(sendBuf)),
|
||||
uintptr(sendDataLen),
|
||||
uintptr(unsafe.Pointer(bytesSent)),
|
||||
uintptr(unsafe.Pointer(overlapped)),
|
||||
0,
|
||||
0)
|
||||
if r1 == 0 {
|
||||
if e1 != 0 {
|
||||
err = error(e1)
|
||||
} else {
|
||||
err = syscall.EINVAL
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
72
vendor/github.com/tailscale/go-winio/internal/socket/zsyscall_windows.go
generated
vendored
Normal file
72
vendor/github.com/tailscale/go-winio/internal/socket/zsyscall_windows.go
generated
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
//go:build windows
|
||||
|
||||
// Code generated by 'go generate' using "github.com/tailscale/go-winio/tools/mkwinsyscall"; DO NOT EDIT.
|
||||
|
||||
package socket
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
errERROR_EINVAL error = syscall.EINVAL
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return errERROR_EINVAL
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
||||
|
||||
procbind = modws2_32.NewProc("bind")
|
||||
procgetpeername = modws2_32.NewProc("getpeername")
|
||||
procgetsockname = modws2_32.NewProc("getsockname")
|
||||
)
|
||||
|
||||
func bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procbind.Addr(), 3, uintptr(s), uintptr(name), uintptr(namelen))
|
||||
if r1 == socketError {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procgetpeername.Addr(), 3, uintptr(s), uintptr(name), uintptr(unsafe.Pointer(namelen)))
|
||||
if r1 == socketError {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procgetsockname.Addr(), 3, uintptr(s), uintptr(name), uintptr(unsafe.Pointer(namelen)))
|
||||
if r1 == socketError {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
132
vendor/github.com/tailscale/go-winio/internal/stringbuffer/wstring.go
generated
vendored
Normal file
132
vendor/github.com/tailscale/go-winio/internal/stringbuffer/wstring.go
generated
vendored
Normal file
@@ -0,0 +1,132 @@
|
||||
package stringbuffer
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"unicode/utf16"
|
||||
)
|
||||
|
||||
// TODO: worth exporting and using in mkwinsyscall?
|
||||
|
||||
// Uint16BufferSize is the buffer size in the pool, chosen somewhat arbitrarily to accommodate
|
||||
// large path strings:
|
||||
// MAX_PATH (260) + size of volume GUID prefix (49) + null terminator = 310.
|
||||
const MinWStringCap = 310
|
||||
|
||||
// use *[]uint16 since []uint16 creates an extra allocation where the slice header
|
||||
// is copied to heap and then referenced via pointer in the interface header that sync.Pool
|
||||
// stores.
|
||||
var pathPool = sync.Pool{ // if go1.18+ adds Pool[T], use that to store []uint16 directly
|
||||
New: func() interface{} {
|
||||
b := make([]uint16, MinWStringCap)
|
||||
return &b
|
||||
},
|
||||
}
|
||||
|
||||
func newBuffer() []uint16 { return *(pathPool.Get().(*[]uint16)) }
|
||||
|
||||
// freeBuffer copies the slice header data, and puts a pointer to that in the pool.
|
||||
// This avoids taking a pointer to the slice header in WString, which can be set to nil.
|
||||
func freeBuffer(b []uint16) { pathPool.Put(&b) }
|
||||
|
||||
// WString is a wide string buffer ([]uint16) meant for storing UTF-16 encoded strings
|
||||
// for interacting with Win32 APIs.
|
||||
// Sizes are specified as uint32 and not int.
|
||||
//
|
||||
// It is not thread safe.
|
||||
type WString struct {
|
||||
// type-def allows casting to []uint16 directly, use struct to prevent that and allow adding fields in the future.
|
||||
|
||||
// raw buffer
|
||||
b []uint16
|
||||
}
|
||||
|
||||
// NewWString returns a [WString] allocated from a shared pool with an
|
||||
// initial capacity of at least [MinWStringCap].
|
||||
// Since the buffer may have been previously used, its contents are not guaranteed to be empty.
|
||||
//
|
||||
// The buffer should be freed via [WString.Free]
|
||||
func NewWString() *WString {
|
||||
return &WString{
|
||||
b: newBuffer(),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *WString) Free() {
|
||||
if b.empty() {
|
||||
return
|
||||
}
|
||||
freeBuffer(b.b)
|
||||
b.b = nil
|
||||
}
|
||||
|
||||
// ResizeTo grows the buffer to at least c and returns the new capacity, freeing the
|
||||
// previous buffer back into pool.
|
||||
func (b *WString) ResizeTo(c uint32) uint32 {
|
||||
// already sufficient (or n is 0)
|
||||
if c <= b.Cap() {
|
||||
return b.Cap()
|
||||
}
|
||||
|
||||
if c <= MinWStringCap {
|
||||
c = MinWStringCap
|
||||
}
|
||||
// allocate at-least double buffer size, as is done in [bytes.Buffer] and other places
|
||||
if c <= 2*b.Cap() {
|
||||
c = 2 * b.Cap()
|
||||
}
|
||||
|
||||
b2 := make([]uint16, c)
|
||||
if !b.empty() {
|
||||
copy(b2, b.b)
|
||||
freeBuffer(b.b)
|
||||
}
|
||||
b.b = b2
|
||||
return c
|
||||
}
|
||||
|
||||
// Buffer returns the underlying []uint16 buffer.
|
||||
func (b *WString) Buffer() []uint16 {
|
||||
if b.empty() {
|
||||
return nil
|
||||
}
|
||||
return b.b
|
||||
}
|
||||
|
||||
// Pointer returns a pointer to the first uint16 in the buffer.
|
||||
// If the [WString.Free] has already been called, the pointer will be nil.
|
||||
func (b *WString) Pointer() *uint16 {
|
||||
if b.empty() {
|
||||
return nil
|
||||
}
|
||||
return &b.b[0]
|
||||
}
|
||||
|
||||
// String returns the returns the UTF-8 encoding of the UTF-16 string in the buffer.
|
||||
//
|
||||
// It assumes that the data is null-terminated.
|
||||
func (b *WString) String() string {
|
||||
// Using [windows.UTF16ToString] would require importing "golang.org/x/sys/windows"
|
||||
// and would make this code Windows-only, which makes no sense.
|
||||
// So copy UTF16ToString code into here.
|
||||
// If other windows-specific code is added, switch to [windows.UTF16ToString]
|
||||
|
||||
s := b.b
|
||||
for i, v := range s {
|
||||
if v == 0 {
|
||||
s = s[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
return string(utf16.Decode(s))
|
||||
}
|
||||
|
||||
// Cap returns the underlying buffer capacity.
|
||||
func (b *WString) Cap() uint32 {
|
||||
if b.empty() {
|
||||
return 0
|
||||
}
|
||||
return b.cap()
|
||||
}
|
||||
|
||||
func (b *WString) cap() uint32 { return uint32(cap(b.b)) }
|
||||
func (b *WString) empty() bool { return b == nil || b.cap() == 0 }
|
||||
586
vendor/github.com/tailscale/go-winio/pipe.go
generated
vendored
Normal file
586
vendor/github.com/tailscale/go-winio/pipe.go
generated
vendored
Normal file
@@ -0,0 +1,586 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/tailscale/go-winio/internal/fs"
|
||||
)
|
||||
|
||||
//sys connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) = ConnectNamedPipe
|
||||
//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateNamedPipeW
|
||||
//sys disconnectNamedPipe(pipe windows.Handle) (err error) = DisconnectNamedPipe
|
||||
//sys getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
|
||||
//sys getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
|
||||
//sys ntCreateNamedPipeFile(pipe *windows.Handle, access ntAccessMask, oa *objectAttributes, iosb *ioStatusBlock, share ntFileShareMode, disposition ntFileCreationDisposition, options ntFileOptions, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntStatus) = ntdll.NtCreateNamedPipeFile
|
||||
//sys rtlNtStatusToDosError(status ntStatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb
|
||||
//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntStatus) = ntdll.RtlDosPathNameToNtPathName_U
|
||||
//sys rtlDefaultNpAcl(dacl *uintptr) (status ntStatus) = ntdll.RtlDefaultNpAcl
|
||||
|
||||
type PipeConn interface {
|
||||
net.Conn
|
||||
Disconnect() error
|
||||
Flush() error
|
||||
}
|
||||
|
||||
// type aliases for mkwinsyscall code
|
||||
type (
|
||||
ntAccessMask = fs.AccessMask
|
||||
ntFileShareMode = fs.FileShareMode
|
||||
ntFileCreationDisposition = fs.NTFileCreationDisposition
|
||||
ntFileOptions = fs.NTCreateOptions
|
||||
)
|
||||
|
||||
type ioStatusBlock struct {
|
||||
Status, Information uintptr
|
||||
}
|
||||
|
||||
// typedef struct _OBJECT_ATTRIBUTES {
|
||||
// ULONG Length;
|
||||
// HANDLE RootDirectory;
|
||||
// PUNICODE_STRING ObjectName;
|
||||
// ULONG Attributes;
|
||||
// PVOID SecurityDescriptor;
|
||||
// PVOID SecurityQualityOfService;
|
||||
// } OBJECT_ATTRIBUTES;
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/ntdef/ns-ntdef-_object_attributes
|
||||
type objectAttributes struct {
|
||||
Length uintptr
|
||||
RootDirectory uintptr
|
||||
ObjectName *unicodeString
|
||||
Attributes uintptr
|
||||
SecurityDescriptor *securityDescriptor
|
||||
SecurityQoS uintptr
|
||||
}
|
||||
|
||||
type unicodeString struct {
|
||||
Length uint16
|
||||
MaximumLength uint16
|
||||
Buffer uintptr
|
||||
}
|
||||
|
||||
// typedef struct _SECURITY_DESCRIPTOR {
|
||||
// BYTE Revision;
|
||||
// BYTE Sbz1;
|
||||
// SECURITY_DESCRIPTOR_CONTROL Control;
|
||||
// PSID Owner;
|
||||
// PSID Group;
|
||||
// PACL Sacl;
|
||||
// PACL Dacl;
|
||||
// } SECURITY_DESCRIPTOR, *PISECURITY_DESCRIPTOR;
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-security_descriptor
|
||||
type securityDescriptor struct {
|
||||
Revision byte
|
||||
Sbz1 byte
|
||||
Control uint16
|
||||
Owner uintptr
|
||||
Group uintptr
|
||||
Sacl uintptr //revive:disable-line:var-naming SACL, not Sacl
|
||||
Dacl uintptr //revive:disable-line:var-naming DACL, not Dacl
|
||||
}
|
||||
|
||||
type ntStatus int32
|
||||
|
||||
func (status ntStatus) Err() error {
|
||||
if status >= 0 {
|
||||
return nil
|
||||
}
|
||||
return rtlNtStatusToDosError(status)
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrPipeListenerClosed is returned for pipe operations on listeners that have been closed.
|
||||
ErrPipeListenerClosed = net.ErrClosed
|
||||
|
||||
errPipeWriteClosed = errors.New("pipe has been closed for write")
|
||||
)
|
||||
|
||||
type win32Pipe struct {
|
||||
*win32File
|
||||
path string
|
||||
}
|
||||
|
||||
var _ PipeConn = (*win32Pipe)(nil)
|
||||
|
||||
type win32MessageBytePipe struct {
|
||||
win32Pipe
|
||||
writeClosed bool
|
||||
readEOF bool
|
||||
}
|
||||
|
||||
type pipeAddress string
|
||||
|
||||
func (f *win32Pipe) LocalAddr() net.Addr {
|
||||
return pipeAddress(f.path)
|
||||
}
|
||||
|
||||
func (f *win32Pipe) RemoteAddr() net.Addr {
|
||||
return pipeAddress(f.path)
|
||||
}
|
||||
|
||||
func (f *win32Pipe) SetDeadline(t time.Time) error {
|
||||
if err := f.SetReadDeadline(t); err != nil {
|
||||
return err
|
||||
}
|
||||
return f.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (f *win32Pipe) Disconnect() error {
|
||||
return disconnectNamedPipe(f.win32File.handle)
|
||||
}
|
||||
|
||||
// CloseWrite closes the write side of a message pipe in byte mode.
|
||||
func (f *win32MessageBytePipe) CloseWrite() error {
|
||||
if f.writeClosed {
|
||||
return errPipeWriteClosed
|
||||
}
|
||||
err := f.win32File.Flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = f.win32File.Write(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.writeClosed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
|
||||
// they are used to implement CloseWrite().
|
||||
func (f *win32MessageBytePipe) Write(b []byte) (int, error) {
|
||||
if f.writeClosed {
|
||||
return 0, errPipeWriteClosed
|
||||
}
|
||||
if len(b) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return f.win32File.Write(b)
|
||||
}
|
||||
|
||||
// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message
|
||||
// mode pipe will return io.EOF, as will all subsequent reads.
|
||||
func (f *win32MessageBytePipe) Read(b []byte) (int, error) {
|
||||
if f.readEOF {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n, err := f.win32File.Read(b)
|
||||
if err == io.EOF { //nolint:errorlint
|
||||
// If this was the result of a zero-byte read, then
|
||||
// it is possible that the read was due to a zero-size
|
||||
// message. Since we are simulating CloseWrite with a
|
||||
// zero-byte message, ensure that all future Read() calls
|
||||
// also return EOF.
|
||||
f.readEOF = true
|
||||
} else if err == windows.ERROR_MORE_DATA { //nolint:errorlint // err is Errno
|
||||
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode
|
||||
// and the message still has more bytes. Treat this as a success, since
|
||||
// this package presents all named pipes as byte streams.
|
||||
err = nil
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (pipeAddress) Network() string {
|
||||
return "pipe"
|
||||
}
|
||||
|
||||
func (s pipeAddress) String() string {
|
||||
return string(s)
|
||||
}
|
||||
|
||||
// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout.
|
||||
func tryDialPipe(ctx context.Context, path *string, access fs.AccessMask, impLevel PipeImpLevel) (windows.Handle, error) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return windows.Handle(0), ctx.Err()
|
||||
default:
|
||||
h, err := fs.CreateFile(*path,
|
||||
access,
|
||||
0, // mode
|
||||
nil, // security attributes
|
||||
fs.OPEN_EXISTING,
|
||||
fs.FILE_FLAG_OVERLAPPED|fs.SECURITY_SQOS_PRESENT|fs.FileSQSFlag(impLevel),
|
||||
0, // template file handle
|
||||
)
|
||||
if err == nil {
|
||||
return h, nil
|
||||
}
|
||||
if err != windows.ERROR_PIPE_BUSY { //nolint:errorlint // err is Errno
|
||||
return h, &os.PathError{Err: err, Op: "open", Path: *path}
|
||||
}
|
||||
// Wait 10 msec and try again. This is a rather simplistic
|
||||
// view, as we always try each 10 milliseconds.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DialPipe connects to a named pipe by path, timing out if the connection
|
||||
// takes longer than the specified duration. If timeout is nil, then we use
|
||||
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
|
||||
func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
|
||||
var absTimeout time.Time
|
||||
if timeout != nil {
|
||||
absTimeout = time.Now().Add(*timeout)
|
||||
} else {
|
||||
absTimeout = time.Now().Add(2 * time.Second)
|
||||
}
|
||||
ctx, cancel := context.WithDeadline(context.Background(), absTimeout)
|
||||
defer cancel()
|
||||
conn, err := DialPipeContext(ctx, path)
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, ErrTimeout
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
|
||||
// cancellation or timeout.
|
||||
func DialPipeContext(ctx context.Context, path string) (net.Conn, error) {
|
||||
return DialPipeAccess(ctx, path, uint32(fs.GENERIC_READ|fs.GENERIC_WRITE))
|
||||
}
|
||||
|
||||
// PipeImpLevel is an enumeration of impersonation levels that may be set
|
||||
// when calling DialPipeAccessImpersonation.
|
||||
type PipeImpLevel uint32
|
||||
|
||||
const (
|
||||
PipeImpLevelAnonymous = PipeImpLevel(fs.SECURITY_ANONYMOUS)
|
||||
PipeImpLevelIdentification = PipeImpLevel(fs.SECURITY_IDENTIFICATION)
|
||||
PipeImpLevelImpersonation = PipeImpLevel(fs.SECURITY_IMPERSONATION)
|
||||
PipeImpLevelDelegation = PipeImpLevel(fs.SECURITY_DELEGATION)
|
||||
)
|
||||
|
||||
// DialPipeAccess attempts to connect to a named pipe by `path` with `access` until `ctx`
|
||||
// cancellation or timeout.
|
||||
func DialPipeAccess(ctx context.Context, path string, access uint32) (net.Conn, error) {
|
||||
return DialPipeAccessImpLevel(ctx, path, access, PipeImpLevelAnonymous)
|
||||
}
|
||||
|
||||
// DialPipeAccessImpLevel attempts to connect to a named pipe by `path` with
|
||||
// `access` at `impLevel` until `ctx` cancellation or timeout. The other
|
||||
// DialPipe* implementations use PipeImpLevelAnonymous.
|
||||
func DialPipeAccessImpLevel(ctx context.Context, path string, access uint32, impLevel PipeImpLevel) (net.Conn, error) {
|
||||
var err error
|
||||
var h windows.Handle
|
||||
h, err = tryDialPipe(ctx, &path, fs.AccessMask(access), impLevel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var flags uint32
|
||||
err = getNamedPipeInfo(h, &flags, nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := makeWin32File(h)
|
||||
if err != nil {
|
||||
windows.Close(h)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If the pipe is in message mode, return a message byte pipe, which
|
||||
// supports CloseWrite().
|
||||
if flags&windows.PIPE_TYPE_MESSAGE != 0 {
|
||||
return &win32MessageBytePipe{
|
||||
win32Pipe: win32Pipe{win32File: f, path: path},
|
||||
}, nil
|
||||
}
|
||||
return &win32Pipe{win32File: f, path: path}, nil
|
||||
}
|
||||
|
||||
type acceptResponse struct {
|
||||
f *win32File
|
||||
err error
|
||||
}
|
||||
|
||||
type win32PipeListener struct {
|
||||
firstHandle windows.Handle
|
||||
path string
|
||||
config PipeConfig
|
||||
acceptCh chan (chan acceptResponse)
|
||||
closeCh chan int
|
||||
doneCh chan int
|
||||
}
|
||||
|
||||
func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (windows.Handle, error) {
|
||||
path16, err := windows.UTF16FromString(path)
|
||||
if err != nil {
|
||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||
}
|
||||
|
||||
var oa objectAttributes
|
||||
oa.Length = unsafe.Sizeof(oa)
|
||||
|
||||
var ntPath unicodeString
|
||||
if err := rtlDosPathNameToNtPathName(&path16[0],
|
||||
&ntPath,
|
||||
0,
|
||||
0,
|
||||
).Err(); err != nil {
|
||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||
}
|
||||
defer windows.LocalFree(windows.Handle(ntPath.Buffer)) //nolint:errcheck
|
||||
oa.ObjectName = &ntPath
|
||||
oa.Attributes = windows.OBJ_CASE_INSENSITIVE
|
||||
|
||||
// The security descriptor is only needed for the first pipe.
|
||||
if first {
|
||||
if sd != nil {
|
||||
//todo: does `sdb` need to be allocated on the heap, or can go allocate it?
|
||||
l := uint32(len(sd))
|
||||
sdb, err := windows.LocalAlloc(0, l)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("LocalAlloc for security descriptor with of length %d: %w", l, err)
|
||||
}
|
||||
defer windows.LocalFree(windows.Handle(sdb)) //nolint:errcheck
|
||||
copy((*[0xffff]byte)(unsafe.Pointer(sdb))[:], sd)
|
||||
oa.SecurityDescriptor = (*securityDescriptor)(unsafe.Pointer(sdb))
|
||||
} else {
|
||||
// Construct the default named pipe security descriptor.
|
||||
var dacl uintptr
|
||||
if err := rtlDefaultNpAcl(&dacl).Err(); err != nil {
|
||||
return 0, fmt.Errorf("getting default named pipe ACL: %w", err)
|
||||
}
|
||||
defer windows.LocalFree(windows.Handle(dacl)) //nolint:errcheck
|
||||
|
||||
sdb := &securityDescriptor{
|
||||
Revision: 1,
|
||||
Control: windows.SE_DACL_PRESENT,
|
||||
Dacl: dacl,
|
||||
}
|
||||
oa.SecurityDescriptor = sdb
|
||||
}
|
||||
}
|
||||
|
||||
typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS)
|
||||
if c.MessageMode {
|
||||
typ |= windows.FILE_PIPE_MESSAGE_TYPE
|
||||
}
|
||||
|
||||
disposition := fs.FILE_OPEN
|
||||
access := fs.GENERIC_READ | fs.GENERIC_WRITE | fs.SYNCHRONIZE
|
||||
if first {
|
||||
disposition = fs.FILE_CREATE
|
||||
// By not asking for read or write access, the named pipe file system
|
||||
// will put this pipe into an initially disconnected state, blocking
|
||||
// client connections until the next call with first == false.
|
||||
access = fs.SYNCHRONIZE
|
||||
}
|
||||
|
||||
timeout := int64(-50 * 10000) // 50ms
|
||||
|
||||
var (
|
||||
h windows.Handle
|
||||
iosb ioStatusBlock
|
||||
)
|
||||
err = ntCreateNamedPipeFile(&h,
|
||||
access,
|
||||
&oa,
|
||||
&iosb,
|
||||
fs.FILE_SHARE_READ|fs.FILE_SHARE_WRITE,
|
||||
disposition,
|
||||
0,
|
||||
typ,
|
||||
0,
|
||||
0,
|
||||
0xffffffff,
|
||||
uint32(c.InputBufferSize),
|
||||
uint32(c.OutputBufferSize),
|
||||
&timeout).Err()
|
||||
if err != nil {
|
||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||
}
|
||||
|
||||
runtime.KeepAlive(ntPath)
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func (l *win32PipeListener) makeServerPipe() (*win32File, error) {
|
||||
h, err := makeServerPipeHandle(l.path, nil, &l.config, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f, err := makeWin32File(h)
|
||||
if err != nil {
|
||||
windows.Close(h)
|
||||
return nil, err
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (l *win32PipeListener) makeConnectedServerPipe() (*win32File, error) {
|
||||
p, err := l.makeServerPipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wait for the client to connect.
|
||||
ch := make(chan error)
|
||||
go func(p *win32File) {
|
||||
ch <- connectPipe(p)
|
||||
}(p)
|
||||
|
||||
select {
|
||||
case err = <-ch:
|
||||
if err != nil {
|
||||
p.Close()
|
||||
p = nil
|
||||
}
|
||||
case <-l.closeCh:
|
||||
// Abort the connect request by closing the handle.
|
||||
p.Close()
|
||||
p = nil
|
||||
err = <-ch
|
||||
if err == nil || err == ErrFileClosed { //nolint:errorlint // err is Errno
|
||||
err = ErrPipeListenerClosed
|
||||
}
|
||||
}
|
||||
return p, err
|
||||
}
|
||||
|
||||
func (l *win32PipeListener) listenerRoutine() {
|
||||
closed := false
|
||||
for !closed {
|
||||
select {
|
||||
case <-l.closeCh:
|
||||
closed = true
|
||||
case responseCh := <-l.acceptCh:
|
||||
var (
|
||||
p *win32File
|
||||
err error
|
||||
)
|
||||
for {
|
||||
p, err = l.makeConnectedServerPipe()
|
||||
// If the connection was immediately closed by the client, try
|
||||
// again.
|
||||
if err != windows.ERROR_NO_DATA { //nolint:errorlint // err is Errno
|
||||
break
|
||||
}
|
||||
}
|
||||
responseCh <- acceptResponse{p, err}
|
||||
closed = err == ErrPipeListenerClosed //nolint:errorlint // err is Errno
|
||||
}
|
||||
}
|
||||
windows.Close(l.firstHandle)
|
||||
l.firstHandle = 0
|
||||
// Notify Close() and Accept() callers that the handle has been closed.
|
||||
close(l.doneCh)
|
||||
}
|
||||
|
||||
// PipeConfig contain configuration for the pipe listener.
|
||||
type PipeConfig struct {
|
||||
// SecurityDescriptor contains a Windows security descriptor in SDDL format.
|
||||
SecurityDescriptor string
|
||||
|
||||
// MessageMode determines whether the pipe is in byte or message mode. In either
|
||||
// case the pipe is read in byte mode by default. The only practical difference in
|
||||
// this implementation is that CloseWrite() is only supported for message mode pipes;
|
||||
// CloseWrite() is implemented as a zero-byte write, but zero-byte writes are only
|
||||
// transferred to the reader (and returned as io.EOF in this implementation)
|
||||
// when the pipe is in message mode.
|
||||
MessageMode bool
|
||||
|
||||
// InputBufferSize specifies the size of the input buffer, in bytes.
|
||||
InputBufferSize int32
|
||||
|
||||
// OutputBufferSize specifies the size of the output buffer, in bytes.
|
||||
OutputBufferSize int32
|
||||
}
|
||||
|
||||
// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe.
|
||||
// The pipe must not already exist.
|
||||
func ListenPipe(path string, c *PipeConfig) (net.Listener, error) {
|
||||
var (
|
||||
sd []byte
|
||||
err error
|
||||
)
|
||||
if c == nil {
|
||||
c = &PipeConfig{}
|
||||
}
|
||||
if c.SecurityDescriptor != "" {
|
||||
sd, err = SddlToSecurityDescriptor(c.SecurityDescriptor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
h, err := makeServerPipeHandle(path, sd, c, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l := &win32PipeListener{
|
||||
firstHandle: h,
|
||||
path: path,
|
||||
config: *c,
|
||||
acceptCh: make(chan (chan acceptResponse)),
|
||||
closeCh: make(chan int),
|
||||
doneCh: make(chan int),
|
||||
}
|
||||
go l.listenerRoutine()
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func connectPipe(p *win32File) error {
|
||||
c, err := p.prepareIO()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer p.wg.Done()
|
||||
|
||||
err = connectNamedPipe(p.handle, &c.o)
|
||||
_, err = p.asyncIO(c, nil, 0, err)
|
||||
if err != nil && err != windows.ERROR_PIPE_CONNECTED { //nolint:errorlint // err is Errno
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *win32PipeListener) Accept() (net.Conn, error) {
|
||||
ch := make(chan acceptResponse)
|
||||
select {
|
||||
case l.acceptCh <- ch:
|
||||
response := <-ch
|
||||
err := response.err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if l.config.MessageMode {
|
||||
return &win32MessageBytePipe{
|
||||
win32Pipe: win32Pipe{win32File: response.f, path: l.path},
|
||||
}, nil
|
||||
}
|
||||
return &win32Pipe{win32File: response.f, path: l.path}, nil
|
||||
case <-l.doneCh:
|
||||
return nil, ErrPipeListenerClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (l *win32PipeListener) Close() error {
|
||||
select {
|
||||
case l.closeCh <- 1:
|
||||
<-l.doneCh
|
||||
case <-l.doneCh:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *win32PipeListener) Addr() net.Addr {
|
||||
return pipeAddress(l.path)
|
||||
}
|
||||
232
vendor/github.com/tailscale/go-winio/pkg/guid/guid.go
generated
vendored
Normal file
232
vendor/github.com/tailscale/go-winio/pkg/guid/guid.go
generated
vendored
Normal file
@@ -0,0 +1,232 @@
|
||||
// Package guid provides a GUID type. The backing structure for a GUID is
|
||||
// identical to that used by the golang.org/x/sys/windows GUID type.
|
||||
// There are two main binary encodings used for a GUID, the big-endian encoding,
|
||||
// and the Windows (mixed-endian) encoding. See here for details:
|
||||
// https://en.wikipedia.org/wiki/Universally_unique_identifier#Encoding
|
||||
package guid
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha1" //nolint:gosec // not used for secure application
|
||||
"encoding"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
//go:generate go run golang.org/x/tools/cmd/stringer -type=Variant -trimprefix=Variant -linecomment
|
||||
|
||||
// Variant specifies which GUID variant (or "type") of the GUID. It determines
|
||||
// how the entirety of the rest of the GUID is interpreted.
|
||||
type Variant uint8
|
||||
|
||||
// The variants specified by RFC 4122 section 4.1.1.
|
||||
const (
|
||||
// VariantUnknown specifies a GUID variant which does not conform to one of
|
||||
// the variant encodings specified in RFC 4122.
|
||||
VariantUnknown Variant = iota
|
||||
VariantNCS
|
||||
VariantRFC4122 // RFC 4122
|
||||
VariantMicrosoft
|
||||
VariantFuture
|
||||
)
|
||||
|
||||
// Version specifies how the bits in the GUID were generated. For instance, a
|
||||
// version 4 GUID is randomly generated, and a version 5 is generated from the
|
||||
// hash of an input string.
|
||||
type Version uint8
|
||||
|
||||
func (v Version) String() string {
|
||||
return strconv.FormatUint(uint64(v), 10)
|
||||
}
|
||||
|
||||
var _ = (encoding.TextMarshaler)(GUID{})
|
||||
var _ = (encoding.TextUnmarshaler)(&GUID{})
|
||||
|
||||
// NewV4 returns a new version 4 (pseudorandom) GUID, as defined by RFC 4122.
|
||||
func NewV4() (GUID, error) {
|
||||
var b [16]byte
|
||||
if _, err := rand.Read(b[:]); err != nil {
|
||||
return GUID{}, err
|
||||
}
|
||||
|
||||
g := FromArray(b)
|
||||
g.setVersion(4) // Version 4 means randomly generated.
|
||||
g.setVariant(VariantRFC4122)
|
||||
|
||||
return g, nil
|
||||
}
|
||||
|
||||
// NewV5 returns a new version 5 (generated from a string via SHA-1 hashing)
|
||||
// GUID, as defined by RFC 4122. The RFC is unclear on the encoding of the name,
|
||||
// and the sample code treats it as a series of bytes, so we do the same here.
|
||||
//
|
||||
// Some implementations, such as those found on Windows, treat the name as a
|
||||
// big-endian UTF16 stream of bytes. If that is desired, the string can be
|
||||
// encoded as such before being passed to this function.
|
||||
func NewV5(namespace GUID, name []byte) (GUID, error) {
|
||||
b := sha1.New() //nolint:gosec // not used for secure application
|
||||
namespaceBytes := namespace.ToArray()
|
||||
b.Write(namespaceBytes[:])
|
||||
b.Write(name)
|
||||
|
||||
a := [16]byte{}
|
||||
copy(a[:], b.Sum(nil))
|
||||
|
||||
g := FromArray(a)
|
||||
g.setVersion(5) // Version 5 means generated from a string.
|
||||
g.setVariant(VariantRFC4122)
|
||||
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func fromArray(b [16]byte, order binary.ByteOrder) GUID {
|
||||
var g GUID
|
||||
g.Data1 = order.Uint32(b[0:4])
|
||||
g.Data2 = order.Uint16(b[4:6])
|
||||
g.Data3 = order.Uint16(b[6:8])
|
||||
copy(g.Data4[:], b[8:16])
|
||||
return g
|
||||
}
|
||||
|
||||
func (g GUID) toArray(order binary.ByteOrder) [16]byte {
|
||||
b := [16]byte{}
|
||||
order.PutUint32(b[0:4], g.Data1)
|
||||
order.PutUint16(b[4:6], g.Data2)
|
||||
order.PutUint16(b[6:8], g.Data3)
|
||||
copy(b[8:16], g.Data4[:])
|
||||
return b
|
||||
}
|
||||
|
||||
// FromArray constructs a GUID from a big-endian encoding array of 16 bytes.
|
||||
func FromArray(b [16]byte) GUID {
|
||||
return fromArray(b, binary.BigEndian)
|
||||
}
|
||||
|
||||
// ToArray returns an array of 16 bytes representing the GUID in big-endian
|
||||
// encoding.
|
||||
func (g GUID) ToArray() [16]byte {
|
||||
return g.toArray(binary.BigEndian)
|
||||
}
|
||||
|
||||
// FromWindowsArray constructs a GUID from a Windows encoding array of bytes.
|
||||
func FromWindowsArray(b [16]byte) GUID {
|
||||
return fromArray(b, binary.LittleEndian)
|
||||
}
|
||||
|
||||
// ToWindowsArray returns an array of 16 bytes representing the GUID in Windows
|
||||
// encoding.
|
||||
func (g GUID) ToWindowsArray() [16]byte {
|
||||
return g.toArray(binary.LittleEndian)
|
||||
}
|
||||
|
||||
func (g GUID) String() string {
|
||||
return fmt.Sprintf(
|
||||
"%08x-%04x-%04x-%04x-%012x",
|
||||
g.Data1,
|
||||
g.Data2,
|
||||
g.Data3,
|
||||
g.Data4[:2],
|
||||
g.Data4[2:])
|
||||
}
|
||||
|
||||
// FromString parses a string containing a GUID and returns the GUID. The only
|
||||
// format currently supported is the `xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx`
|
||||
// format.
|
||||
func FromString(s string) (GUID, error) {
|
||||
if len(s) != 36 {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
|
||||
var g GUID
|
||||
|
||||
data1, err := strconv.ParseUint(s[0:8], 16, 32)
|
||||
if err != nil {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
g.Data1 = uint32(data1)
|
||||
|
||||
data2, err := strconv.ParseUint(s[9:13], 16, 16)
|
||||
if err != nil {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
g.Data2 = uint16(data2)
|
||||
|
||||
data3, err := strconv.ParseUint(s[14:18], 16, 16)
|
||||
if err != nil {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
g.Data3 = uint16(data3)
|
||||
|
||||
for i, x := range []int{19, 21, 24, 26, 28, 30, 32, 34} {
|
||||
v, err := strconv.ParseUint(s[x:x+2], 16, 8)
|
||||
if err != nil {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
g.Data4[i] = uint8(v)
|
||||
}
|
||||
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func (g *GUID) setVariant(v Variant) {
|
||||
d := g.Data4[0]
|
||||
switch v {
|
||||
case VariantNCS:
|
||||
d = (d & 0x7f)
|
||||
case VariantRFC4122:
|
||||
d = (d & 0x3f) | 0x80
|
||||
case VariantMicrosoft:
|
||||
d = (d & 0x1f) | 0xc0
|
||||
case VariantFuture:
|
||||
d = (d & 0x0f) | 0xe0
|
||||
case VariantUnknown:
|
||||
fallthrough
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid variant: %d", v))
|
||||
}
|
||||
g.Data4[0] = d
|
||||
}
|
||||
|
||||
// Variant returns the GUID variant, as defined in RFC 4122.
|
||||
func (g GUID) Variant() Variant {
|
||||
b := g.Data4[0]
|
||||
if b&0x80 == 0 {
|
||||
return VariantNCS
|
||||
} else if b&0xc0 == 0x80 {
|
||||
return VariantRFC4122
|
||||
} else if b&0xe0 == 0xc0 {
|
||||
return VariantMicrosoft
|
||||
} else if b&0xe0 == 0xe0 {
|
||||
return VariantFuture
|
||||
}
|
||||
return VariantUnknown
|
||||
}
|
||||
|
||||
func (g *GUID) setVersion(v Version) {
|
||||
g.Data3 = (g.Data3 & 0x0fff) | (uint16(v) << 12)
|
||||
}
|
||||
|
||||
// Version returns the GUID version, as defined in RFC 4122.
|
||||
func (g GUID) Version() Version {
|
||||
return Version((g.Data3 & 0xF000) >> 12)
|
||||
}
|
||||
|
||||
// MarshalText returns the textual representation of the GUID.
|
||||
func (g GUID) MarshalText() ([]byte, error) {
|
||||
return []byte(g.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText takes the textual representation of a GUID, and unmarhals it
|
||||
// into this GUID.
|
||||
func (g *GUID) UnmarshalText(text []byte) error {
|
||||
g2, err := FromString(string(text))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*g = g2
|
||||
return nil
|
||||
}
|
||||
16
vendor/github.com/tailscale/go-winio/pkg/guid/guid_nonwindows.go
generated
vendored
Normal file
16
vendor/github.com/tailscale/go-winio/pkg/guid/guid_nonwindows.go
generated
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package guid
|
||||
|
||||
// GUID represents a GUID/UUID. It has the same structure as
|
||||
// golang.org/x/sys/windows.GUID so that it can be used with functions expecting
|
||||
// that type. It is defined as its own type as that is only available to builds
|
||||
// targeted at `windows`. The representation matches that used by native Windows
|
||||
// code.
|
||||
type GUID struct {
|
||||
Data1 uint32
|
||||
Data2 uint16
|
||||
Data3 uint16
|
||||
Data4 [8]byte
|
||||
}
|
||||
13
vendor/github.com/tailscale/go-winio/pkg/guid/guid_windows.go
generated
vendored
Normal file
13
vendor/github.com/tailscale/go-winio/pkg/guid/guid_windows.go
generated
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package guid
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
// GUID represents a GUID/UUID. It has the same structure as
|
||||
// golang.org/x/sys/windows.GUID so that it can be used with functions expecting
|
||||
// that type. It is defined as its own type so that stringification and
|
||||
// marshaling can be supported. The representation matches that used by native
|
||||
// Windows code.
|
||||
type GUID windows.GUID
|
||||
27
vendor/github.com/tailscale/go-winio/pkg/guid/variant_string.go
generated
vendored
Normal file
27
vendor/github.com/tailscale/go-winio/pkg/guid/variant_string.go
generated
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
// Code generated by "stringer -type=Variant -trimprefix=Variant -linecomment"; DO NOT EDIT.
|
||||
|
||||
package guid
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[VariantUnknown-0]
|
||||
_ = x[VariantNCS-1]
|
||||
_ = x[VariantRFC4122-2]
|
||||
_ = x[VariantMicrosoft-3]
|
||||
_ = x[VariantFuture-4]
|
||||
}
|
||||
|
||||
const _Variant_name = "UnknownNCSRFC 4122MicrosoftFuture"
|
||||
|
||||
var _Variant_index = [...]uint8{0, 7, 10, 18, 27, 33}
|
||||
|
||||
func (i Variant) String() string {
|
||||
if i >= Variant(len(_Variant_index)-1) {
|
||||
return "Variant(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _Variant_name[_Variant_index[i]:_Variant_index[i+1]]
|
||||
}
|
||||
196
vendor/github.com/tailscale/go-winio/privilege.go
generated
vendored
Normal file
196
vendor/github.com/tailscale/go-winio/privilege.go
generated
vendored
Normal file
@@ -0,0 +1,196 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"unicode/utf16"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
//sys adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) [true] = advapi32.AdjustTokenPrivileges
|
||||
//sys impersonateSelf(level uint32) (err error) = advapi32.ImpersonateSelf
|
||||
//sys revertToSelf() (err error) = advapi32.RevertToSelf
|
||||
//sys openThreadToken(thread windows.Handle, accessMask uint32, openAsSelf bool, token *windows.Token) (err error) = advapi32.OpenThreadToken
|
||||
//sys getCurrentThread() (h windows.Handle) = GetCurrentThread
|
||||
//sys lookupPrivilegeValue(systemName string, name string, luid *uint64) (err error) = advapi32.LookupPrivilegeValueW
|
||||
//sys lookupPrivilegeName(systemName string, luid *uint64, buffer *uint16, size *uint32) (err error) = advapi32.LookupPrivilegeNameW
|
||||
//sys lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) = advapi32.LookupPrivilegeDisplayNameW
|
||||
|
||||
const (
|
||||
//revive:disable-next-line:var-naming ALL_CAPS
|
||||
SE_PRIVILEGE_ENABLED = windows.SE_PRIVILEGE_ENABLED
|
||||
|
||||
//revive:disable-next-line:var-naming ALL_CAPS
|
||||
ERROR_NOT_ALL_ASSIGNED windows.Errno = windows.ERROR_NOT_ALL_ASSIGNED
|
||||
|
||||
SeBackupPrivilege = "SeBackupPrivilege"
|
||||
SeRestorePrivilege = "SeRestorePrivilege"
|
||||
SeSecurityPrivilege = "SeSecurityPrivilege"
|
||||
)
|
||||
|
||||
var (
|
||||
privNames = make(map[string]uint64)
|
||||
privNameMutex sync.Mutex
|
||||
)
|
||||
|
||||
// PrivilegeError represents an error enabling privileges.
|
||||
type PrivilegeError struct {
|
||||
privileges []uint64
|
||||
}
|
||||
|
||||
func (e *PrivilegeError) Error() string {
|
||||
s := "Could not enable privilege "
|
||||
if len(e.privileges) > 1 {
|
||||
s = "Could not enable privileges "
|
||||
}
|
||||
for i, p := range e.privileges {
|
||||
if i != 0 {
|
||||
s += ", "
|
||||
}
|
||||
s += `"`
|
||||
s += getPrivilegeName(p)
|
||||
s += `"`
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// RunWithPrivilege enables a single privilege for a function call.
|
||||
func RunWithPrivilege(name string, fn func() error) error {
|
||||
return RunWithPrivileges([]string{name}, fn)
|
||||
}
|
||||
|
||||
// RunWithPrivileges enables privileges for a function call.
|
||||
func RunWithPrivileges(names []string, fn func() error) error {
|
||||
privileges, err := mapPrivileges(names)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
token, err := newThreadToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer releaseThreadToken(token)
|
||||
err = adjustPrivileges(token, privileges, SE_PRIVILEGE_ENABLED)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fn()
|
||||
}
|
||||
|
||||
func mapPrivileges(names []string) ([]uint64, error) {
|
||||
privileges := make([]uint64, 0, len(names))
|
||||
privNameMutex.Lock()
|
||||
defer privNameMutex.Unlock()
|
||||
for _, name := range names {
|
||||
p, ok := privNames[name]
|
||||
if !ok {
|
||||
err := lookupPrivilegeValue("", name, &p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
privNames[name] = p
|
||||
}
|
||||
privileges = append(privileges, p)
|
||||
}
|
||||
return privileges, nil
|
||||
}
|
||||
|
||||
// EnableProcessPrivileges enables privileges globally for the process.
|
||||
func EnableProcessPrivileges(names []string) error {
|
||||
return enableDisableProcessPrivilege(names, SE_PRIVILEGE_ENABLED)
|
||||
}
|
||||
|
||||
// DisableProcessPrivileges disables privileges globally for the process.
|
||||
func DisableProcessPrivileges(names []string) error {
|
||||
return enableDisableProcessPrivilege(names, 0)
|
||||
}
|
||||
|
||||
func enableDisableProcessPrivilege(names []string, action uint32) error {
|
||||
privileges, err := mapPrivileges(names)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p := windows.CurrentProcess()
|
||||
var token windows.Token
|
||||
err = windows.OpenProcessToken(p, windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, &token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer token.Close()
|
||||
return adjustPrivileges(token, privileges, action)
|
||||
}
|
||||
|
||||
func adjustPrivileges(token windows.Token, privileges []uint64, action uint32) error {
|
||||
var b bytes.Buffer
|
||||
_ = binary.Write(&b, binary.LittleEndian, uint32(len(privileges)))
|
||||
for _, p := range privileges {
|
||||
_ = binary.Write(&b, binary.LittleEndian, p)
|
||||
_ = binary.Write(&b, binary.LittleEndian, action)
|
||||
}
|
||||
prevState := make([]byte, b.Len())
|
||||
reqSize := uint32(0)
|
||||
success, err := adjustTokenPrivileges(token, false, &b.Bytes()[0], uint32(len(prevState)), &prevState[0], &reqSize)
|
||||
if !success {
|
||||
return err
|
||||
}
|
||||
if err == ERROR_NOT_ALL_ASSIGNED { //nolint:errorlint // err is Errno
|
||||
return &PrivilegeError{privileges}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getPrivilegeName(luid uint64) string {
|
||||
var nameBuffer [256]uint16
|
||||
bufSize := uint32(len(nameBuffer))
|
||||
err := lookupPrivilegeName("", &luid, &nameBuffer[0], &bufSize)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("<unknown privilege %d>", luid)
|
||||
}
|
||||
|
||||
var displayNameBuffer [256]uint16
|
||||
displayBufSize := uint32(len(displayNameBuffer))
|
||||
var langID uint32
|
||||
err = lookupPrivilegeDisplayName("", &nameBuffer[0], &displayNameBuffer[0], &displayBufSize, &langID)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("<unknown privilege %s>", string(utf16.Decode(nameBuffer[:bufSize])))
|
||||
}
|
||||
|
||||
return string(utf16.Decode(displayNameBuffer[:displayBufSize]))
|
||||
}
|
||||
|
||||
func newThreadToken() (windows.Token, error) {
|
||||
err := impersonateSelf(windows.SecurityImpersonation)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var token windows.Token
|
||||
err = openThreadToken(getCurrentThread(), windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, false, &token)
|
||||
if err != nil {
|
||||
rerr := revertToSelf()
|
||||
if rerr != nil {
|
||||
panic(rerr)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func releaseThreadToken(h windows.Token) {
|
||||
err := revertToSelf()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
h.Close()
|
||||
}
|
||||
131
vendor/github.com/tailscale/go-winio/reparse.go
generated
vendored
Normal file
131
vendor/github.com/tailscale/go-winio/reparse.go
generated
vendored
Normal file
@@ -0,0 +1,131 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode/utf16"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
reparseTagMountPoint = 0xA0000003
|
||||
reparseTagSymlink = 0xA000000C
|
||||
)
|
||||
|
||||
type reparseDataBuffer struct {
|
||||
ReparseTag uint32
|
||||
ReparseDataLength uint16
|
||||
Reserved uint16
|
||||
SubstituteNameOffset uint16
|
||||
SubstituteNameLength uint16
|
||||
PrintNameOffset uint16
|
||||
PrintNameLength uint16
|
||||
}
|
||||
|
||||
// ReparsePoint describes a Win32 symlink or mount point.
|
||||
type ReparsePoint struct {
|
||||
Target string
|
||||
IsMountPoint bool
|
||||
}
|
||||
|
||||
// UnsupportedReparsePointError is returned when trying to decode a non-symlink or
|
||||
// mount point reparse point.
|
||||
type UnsupportedReparsePointError struct {
|
||||
Tag uint32
|
||||
}
|
||||
|
||||
func (e *UnsupportedReparsePointError) Error() string {
|
||||
return fmt.Sprintf("unsupported reparse point %x", e.Tag)
|
||||
}
|
||||
|
||||
// DecodeReparsePoint decodes a Win32 REPARSE_DATA_BUFFER structure containing either a symlink
|
||||
// or a mount point.
|
||||
func DecodeReparsePoint(b []byte) (*ReparsePoint, error) {
|
||||
tag := binary.LittleEndian.Uint32(b[0:4])
|
||||
return DecodeReparsePointData(tag, b[8:])
|
||||
}
|
||||
|
||||
func DecodeReparsePointData(tag uint32, b []byte) (*ReparsePoint, error) {
|
||||
isMountPoint := false
|
||||
switch tag {
|
||||
case reparseTagMountPoint:
|
||||
isMountPoint = true
|
||||
case reparseTagSymlink:
|
||||
default:
|
||||
return nil, &UnsupportedReparsePointError{tag}
|
||||
}
|
||||
nameOffset := 8 + binary.LittleEndian.Uint16(b[4:6])
|
||||
if !isMountPoint {
|
||||
nameOffset += 4
|
||||
}
|
||||
nameLength := binary.LittleEndian.Uint16(b[6:8])
|
||||
name := make([]uint16, nameLength/2)
|
||||
err := binary.Read(bytes.NewReader(b[nameOffset:nameOffset+nameLength]), binary.LittleEndian, &name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ReparsePoint{string(utf16.Decode(name)), isMountPoint}, nil
|
||||
}
|
||||
|
||||
func isDriveLetter(c byte) bool {
|
||||
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')
|
||||
}
|
||||
|
||||
// EncodeReparsePoint encodes a Win32 REPARSE_DATA_BUFFER structure describing a symlink or
|
||||
// mount point.
|
||||
func EncodeReparsePoint(rp *ReparsePoint) []byte {
|
||||
// Generate an NT path and determine if this is a relative path.
|
||||
var ntTarget string
|
||||
relative := false
|
||||
if strings.HasPrefix(rp.Target, `\\?\`) {
|
||||
ntTarget = `\??\` + rp.Target[4:]
|
||||
} else if strings.HasPrefix(rp.Target, `\\`) {
|
||||
ntTarget = `\??\UNC\` + rp.Target[2:]
|
||||
} else if len(rp.Target) >= 2 && isDriveLetter(rp.Target[0]) && rp.Target[1] == ':' {
|
||||
ntTarget = `\??\` + rp.Target
|
||||
} else {
|
||||
ntTarget = rp.Target
|
||||
relative = true
|
||||
}
|
||||
|
||||
// The paths must be NUL-terminated even though they are counted strings.
|
||||
target16 := utf16.Encode([]rune(rp.Target + "\x00"))
|
||||
ntTarget16 := utf16.Encode([]rune(ntTarget + "\x00"))
|
||||
|
||||
size := int(unsafe.Sizeof(reparseDataBuffer{})) - 8
|
||||
size += len(ntTarget16)*2 + len(target16)*2
|
||||
|
||||
tag := uint32(reparseTagMountPoint)
|
||||
if !rp.IsMountPoint {
|
||||
tag = reparseTagSymlink
|
||||
size += 4 // Add room for symlink flags
|
||||
}
|
||||
|
||||
data := reparseDataBuffer{
|
||||
ReparseTag: tag,
|
||||
ReparseDataLength: uint16(size),
|
||||
SubstituteNameOffset: 0,
|
||||
SubstituteNameLength: uint16((len(ntTarget16) - 1) * 2),
|
||||
PrintNameOffset: uint16(len(ntTarget16) * 2),
|
||||
PrintNameLength: uint16((len(target16) - 1) * 2),
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
_ = binary.Write(&b, binary.LittleEndian, &data)
|
||||
if !rp.IsMountPoint {
|
||||
flags := uint32(0)
|
||||
if relative {
|
||||
flags |= 1
|
||||
}
|
||||
_ = binary.Write(&b, binary.LittleEndian, flags)
|
||||
}
|
||||
|
||||
_ = binary.Write(&b, binary.LittleEndian, ntTarget16)
|
||||
_ = binary.Write(&b, binary.LittleEndian, target16)
|
||||
return b.Bytes()
|
||||
}
|
||||
133
vendor/github.com/tailscale/go-winio/sd.go
generated
vendored
Normal file
133
vendor/github.com/tailscale/go-winio/sd.go
generated
vendored
Normal file
@@ -0,0 +1,133 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
//sys lookupAccountName(systemName *uint16, accountName string, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) = advapi32.LookupAccountNameW
|
||||
//sys lookupAccountSid(systemName *uint16, sid *byte, name *uint16, nameSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) = advapi32.LookupAccountSidW
|
||||
//sys convertSidToStringSid(sid *byte, str **uint16) (err error) = advapi32.ConvertSidToStringSidW
|
||||
//sys convertStringSidToSid(str *uint16, sid **byte) (err error) = advapi32.ConvertStringSidToSidW
|
||||
|
||||
type AccountLookupError struct {
|
||||
Name string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *AccountLookupError) Error() string {
|
||||
if e.Name == "" {
|
||||
return "lookup account: empty account name specified"
|
||||
}
|
||||
var s string
|
||||
switch {
|
||||
case errors.Is(e.Err, windows.ERROR_INVALID_SID):
|
||||
s = "the security ID structure is invalid"
|
||||
case errors.Is(e.Err, windows.ERROR_NONE_MAPPED):
|
||||
s = "not found"
|
||||
default:
|
||||
s = e.Err.Error()
|
||||
}
|
||||
return "lookup account " + e.Name + ": " + s
|
||||
}
|
||||
|
||||
func (e *AccountLookupError) Unwrap() error { return e.Err }
|
||||
|
||||
type SddlConversionError struct {
|
||||
Sddl string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *SddlConversionError) Error() string {
|
||||
return "convert " + e.Sddl + ": " + e.Err.Error()
|
||||
}
|
||||
|
||||
func (e *SddlConversionError) Unwrap() error { return e.Err }
|
||||
|
||||
// LookupSidByName looks up the SID of an account by name
|
||||
//
|
||||
//revive:disable-next-line:var-naming SID, not Sid
|
||||
func LookupSidByName(name string) (sid string, err error) {
|
||||
if name == "" {
|
||||
return "", &AccountLookupError{name, windows.ERROR_NONE_MAPPED}
|
||||
}
|
||||
|
||||
var sidSize, sidNameUse, refDomainSize uint32
|
||||
err = lookupAccountName(nil, name, nil, &sidSize, nil, &refDomainSize, &sidNameUse)
|
||||
if err != nil && err != windows.ERROR_INSUFFICIENT_BUFFER { //nolint:errorlint // err is Errno
|
||||
return "", &AccountLookupError{name, err}
|
||||
}
|
||||
sidBuffer := make([]byte, sidSize)
|
||||
refDomainBuffer := make([]uint16, refDomainSize)
|
||||
err = lookupAccountName(nil, name, &sidBuffer[0], &sidSize, &refDomainBuffer[0], &refDomainSize, &sidNameUse)
|
||||
if err != nil {
|
||||
return "", &AccountLookupError{name, err}
|
||||
}
|
||||
var strBuffer *uint16
|
||||
err = convertSidToStringSid(&sidBuffer[0], &strBuffer)
|
||||
if err != nil {
|
||||
return "", &AccountLookupError{name, err}
|
||||
}
|
||||
sid = windows.UTF16ToString((*[0xffff]uint16)(unsafe.Pointer(strBuffer))[:])
|
||||
_, _ = windows.LocalFree(windows.Handle(unsafe.Pointer(strBuffer)))
|
||||
return sid, nil
|
||||
}
|
||||
|
||||
// LookupNameBySid looks up the name of an account by SID
|
||||
//
|
||||
//revive:disable-next-line:var-naming SID, not Sid
|
||||
func LookupNameBySid(sid string) (name string, err error) {
|
||||
if sid == "" {
|
||||
return "", &AccountLookupError{sid, windows.ERROR_NONE_MAPPED}
|
||||
}
|
||||
|
||||
sidBuffer, err := windows.UTF16PtrFromString(sid)
|
||||
if err != nil {
|
||||
return "", &AccountLookupError{sid, err}
|
||||
}
|
||||
|
||||
var sidPtr *byte
|
||||
if err = convertStringSidToSid(sidBuffer, &sidPtr); err != nil {
|
||||
return "", &AccountLookupError{sid, err}
|
||||
}
|
||||
defer windows.LocalFree(windows.Handle(unsafe.Pointer(sidPtr))) //nolint:errcheck
|
||||
|
||||
var nameSize, refDomainSize, sidNameUse uint32
|
||||
err = lookupAccountSid(nil, sidPtr, nil, &nameSize, nil, &refDomainSize, &sidNameUse)
|
||||
if err != nil && err != windows.ERROR_INSUFFICIENT_BUFFER { //nolint:errorlint // err is Errno
|
||||
return "", &AccountLookupError{sid, err}
|
||||
}
|
||||
|
||||
nameBuffer := make([]uint16, nameSize)
|
||||
refDomainBuffer := make([]uint16, refDomainSize)
|
||||
err = lookupAccountSid(nil, sidPtr, &nameBuffer[0], &nameSize, &refDomainBuffer[0], &refDomainSize, &sidNameUse)
|
||||
if err != nil {
|
||||
return "", &AccountLookupError{sid, err}
|
||||
}
|
||||
|
||||
name = windows.UTF16ToString(nameBuffer)
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
|
||||
sd, err := windows.SecurityDescriptorFromString(sddl)
|
||||
if err != nil {
|
||||
return nil, &SddlConversionError{Sddl: sddl, Err: err}
|
||||
}
|
||||
b := unsafe.Slice((*byte)(unsafe.Pointer(sd)), sd.Length())
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func SecurityDescriptorToSddl(sd []byte) (string, error) {
|
||||
if l := int(unsafe.Sizeof(windows.SECURITY_DESCRIPTOR{})); len(sd) < l {
|
||||
return "", fmt.Errorf("SecurityDescriptor (%d) smaller than expected (%d): %w", len(sd), l, windows.ERROR_INCORRECT_SIZE)
|
||||
}
|
||||
s := (*windows.SECURITY_DESCRIPTOR)(unsafe.Pointer(&sd[0]))
|
||||
return s.String(), nil
|
||||
}
|
||||
5
vendor/github.com/tailscale/go-winio/syscall.go
generated
vendored
Normal file
5
vendor/github.com/tailscale/go-winio/syscall.go
generated
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build windows
|
||||
|
||||
package winio
|
||||
|
||||
//go:generate go run github.com/tailscale/go-winio/tools/mkwinsyscall -output zsyscall_windows.go ./*.go
|
||||
381
vendor/github.com/tailscale/go-winio/zsyscall_windows.go
generated
vendored
Normal file
381
vendor/github.com/tailscale/go-winio/zsyscall_windows.go
generated
vendored
Normal file
@@ -0,0 +1,381 @@
|
||||
//go:build windows
|
||||
|
||||
// Code generated by 'go generate' using "github.com/tailscale/go-winio/tools/mkwinsyscall"; DO NOT EDIT.
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
errERROR_EINVAL error = syscall.EINVAL
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return errERROR_EINVAL
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
||||
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
modntdll = windows.NewLazySystemDLL("ntdll.dll")
|
||||
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
||||
|
||||
procAdjustTokenPrivileges = modadvapi32.NewProc("AdjustTokenPrivileges")
|
||||
procConvertSidToStringSidW = modadvapi32.NewProc("ConvertSidToStringSidW")
|
||||
procConvertStringSidToSidW = modadvapi32.NewProc("ConvertStringSidToSidW")
|
||||
procImpersonateSelf = modadvapi32.NewProc("ImpersonateSelf")
|
||||
procLookupAccountNameW = modadvapi32.NewProc("LookupAccountNameW")
|
||||
procLookupAccountSidW = modadvapi32.NewProc("LookupAccountSidW")
|
||||
procLookupPrivilegeDisplayNameW = modadvapi32.NewProc("LookupPrivilegeDisplayNameW")
|
||||
procLookupPrivilegeNameW = modadvapi32.NewProc("LookupPrivilegeNameW")
|
||||
procLookupPrivilegeValueW = modadvapi32.NewProc("LookupPrivilegeValueW")
|
||||
procOpenThreadToken = modadvapi32.NewProc("OpenThreadToken")
|
||||
procRevertToSelf = modadvapi32.NewProc("RevertToSelf")
|
||||
procBackupRead = modkernel32.NewProc("BackupRead")
|
||||
procBackupWrite = modkernel32.NewProc("BackupWrite")
|
||||
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
|
||||
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
|
||||
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
|
||||
procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW")
|
||||
procDisconnectNamedPipe = modkernel32.NewProc("DisconnectNamedPipe")
|
||||
procGetCurrentThread = modkernel32.NewProc("GetCurrentThread")
|
||||
procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW")
|
||||
procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo")
|
||||
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
|
||||
procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes")
|
||||
procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile")
|
||||
procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl")
|
||||
procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
|
||||
procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
|
||||
procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
|
||||
)
|
||||
|
||||
func adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) {
|
||||
var _p0 uint32
|
||||
if releaseAll {
|
||||
_p0 = 1
|
||||
}
|
||||
r0, _, e1 := syscall.Syscall6(procAdjustTokenPrivileges.Addr(), 6, uintptr(token), uintptr(_p0), uintptr(unsafe.Pointer(input)), uintptr(outputSize), uintptr(unsafe.Pointer(output)), uintptr(unsafe.Pointer(requiredSize)))
|
||||
success = r0 != 0
|
||||
if true {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func convertSidToStringSid(sid *byte, str **uint16) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procConvertSidToStringSidW.Addr(), 2, uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(str)), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func convertStringSidToSid(str *uint16, sid **byte) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procConvertStringSidToSidW.Addr(), 2, uintptr(unsafe.Pointer(str)), uintptr(unsafe.Pointer(sid)), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func impersonateSelf(level uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procImpersonateSelf.Addr(), 1, uintptr(level), 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func lookupAccountName(systemName *uint16, accountName string, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) {
|
||||
var _p0 *uint16
|
||||
_p0, err = syscall.UTF16PtrFromString(accountName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return _lookupAccountName(systemName, _p0, sid, sidSize, refDomain, refDomainSize, sidNameUse)
|
||||
}
|
||||
|
||||
func _lookupAccountName(systemName *uint16, accountName *uint16, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall9(procLookupAccountNameW.Addr(), 7, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(accountName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(sidSize)), uintptr(unsafe.Pointer(refDomain)), uintptr(unsafe.Pointer(refDomainSize)), uintptr(unsafe.Pointer(sidNameUse)), 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func lookupAccountSid(systemName *uint16, sid *byte, name *uint16, nameSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall9(procLookupAccountSidW.Addr(), 7, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(nameSize)), uintptr(unsafe.Pointer(refDomain)), uintptr(unsafe.Pointer(refDomainSize)), uintptr(unsafe.Pointer(sidNameUse)), 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) {
|
||||
var _p0 *uint16
|
||||
_p0, err = syscall.UTF16PtrFromString(systemName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return _lookupPrivilegeDisplayName(_p0, name, buffer, size, languageId)
|
||||
}
|
||||
|
||||
func _lookupPrivilegeDisplayName(systemName *uint16, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procLookupPrivilegeDisplayNameW.Addr(), 5, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size)), uintptr(unsafe.Pointer(languageId)), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func lookupPrivilegeName(systemName string, luid *uint64, buffer *uint16, size *uint32) (err error) {
|
||||
var _p0 *uint16
|
||||
_p0, err = syscall.UTF16PtrFromString(systemName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return _lookupPrivilegeName(_p0, luid, buffer, size)
|
||||
}
|
||||
|
||||
func _lookupPrivilegeName(systemName *uint16, luid *uint64, buffer *uint16, size *uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procLookupPrivilegeNameW.Addr(), 4, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(luid)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size)), 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func lookupPrivilegeValue(systemName string, name string, luid *uint64) (err error) {
|
||||
var _p0 *uint16
|
||||
_p0, err = syscall.UTF16PtrFromString(systemName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var _p1 *uint16
|
||||
_p1, err = syscall.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return _lookupPrivilegeValue(_p0, _p1, luid)
|
||||
}
|
||||
|
||||
func _lookupPrivilegeValue(systemName *uint16, name *uint16, luid *uint64) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procLookupPrivilegeValueW.Addr(), 3, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(luid)))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func openThreadToken(thread windows.Handle, accessMask uint32, openAsSelf bool, token *windows.Token) (err error) {
|
||||
var _p0 uint32
|
||||
if openAsSelf {
|
||||
_p0 = 1
|
||||
}
|
||||
r1, _, e1 := syscall.Syscall6(procOpenThreadToken.Addr(), 4, uintptr(thread), uintptr(accessMask), uintptr(_p0), uintptr(unsafe.Pointer(token)), 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func revertToSelf() (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procRevertToSelf.Addr(), 0, 0, 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func backupRead(h windows.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) {
|
||||
var _p0 *byte
|
||||
if len(b) > 0 {
|
||||
_p0 = &b[0]
|
||||
}
|
||||
var _p1 uint32
|
||||
if abort {
|
||||
_p1 = 1
|
||||
}
|
||||
var _p2 uint32
|
||||
if processSecurity {
|
||||
_p2 = 1
|
||||
}
|
||||
r1, _, e1 := syscall.Syscall9(procBackupRead.Addr(), 7, uintptr(h), uintptr(unsafe.Pointer(_p0)), uintptr(len(b)), uintptr(unsafe.Pointer(bytesRead)), uintptr(_p1), uintptr(_p2), uintptr(unsafe.Pointer(context)), 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func backupWrite(h windows.Handle, b []byte, bytesWritten *uint32, abort bool, processSecurity bool, context *uintptr) (err error) {
|
||||
var _p0 *byte
|
||||
if len(b) > 0 {
|
||||
_p0 = &b[0]
|
||||
}
|
||||
var _p1 uint32
|
||||
if abort {
|
||||
_p1 = 1
|
||||
}
|
||||
var _p2 uint32
|
||||
if processSecurity {
|
||||
_p2 = 1
|
||||
}
|
||||
r1, _, e1 := syscall.Syscall9(procBackupWrite.Addr(), 7, uintptr(h), uintptr(unsafe.Pointer(_p0)), uintptr(len(b)), uintptr(unsafe.Pointer(bytesWritten)), uintptr(_p1), uintptr(_p2), uintptr(unsafe.Pointer(context)), 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) {
|
||||
r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0)
|
||||
newport = windows.Handle(r0)
|
||||
if newport == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
|
||||
var _p0 *uint16
|
||||
_p0, err = syscall.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa)
|
||||
}
|
||||
|
||||
func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
|
||||
r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0)
|
||||
handle = windows.Handle(r0)
|
||||
if handle == windows.InvalidHandle {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func disconnectNamedPipe(pipe windows.Handle) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procDisconnectNamedPipe.Addr(), 1, uintptr(pipe), 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getCurrentThread() (h windows.Handle) {
|
||||
r0, _, _ := syscall.Syscall(procGetCurrentThread.Addr(), 0, 0, 0, 0)
|
||||
h = windows.Handle(r0)
|
||||
return
|
||||
}
|
||||
|
||||
func getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ntCreateNamedPipeFile(pipe *windows.Handle, access ntAccessMask, oa *objectAttributes, iosb *ioStatusBlock, share ntFileShareMode, disposition ntFileCreationDisposition, options ntFileOptions, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntStatus) {
|
||||
r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0)
|
||||
status = ntStatus(r0)
|
||||
return
|
||||
}
|
||||
|
||||
func rtlDefaultNpAcl(dacl *uintptr) (status ntStatus) {
|
||||
r0, _, _ := syscall.Syscall(procRtlDefaultNpAcl.Addr(), 1, uintptr(unsafe.Pointer(dacl)), 0, 0)
|
||||
status = ntStatus(r0)
|
||||
return
|
||||
}
|
||||
|
||||
func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntStatus) {
|
||||
r0, _, _ := syscall.Syscall6(procRtlDosPathNameToNtPathName_U.Addr(), 4, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved), 0, 0)
|
||||
status = ntStatus(r0)
|
||||
return
|
||||
}
|
||||
|
||||
func rtlNtStatusToDosError(status ntStatus) (winerr error) {
|
||||
r0, _, _ := syscall.Syscall(procRtlNtStatusToDosErrorNoTeb.Addr(), 1, uintptr(status), 0, 0)
|
||||
if r0 != 0 {
|
||||
winerr = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
|
||||
var _p0 uint32
|
||||
if wait {
|
||||
_p0 = 1
|
||||
}
|
||||
r1, _, e1 := syscall.Syscall6(procWSAGetOverlappedResult.Addr(), 5, uintptr(h), uintptr(unsafe.Pointer(o)), uintptr(unsafe.Pointer(bytes)), uintptr(_p0), uintptr(unsafe.Pointer(flags)), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
27
vendor/github.com/tailscale/golang-x-crypto/LICENSE
generated
vendored
Normal file
27
vendor/github.com/tailscale/golang-x-crypto/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
Copyright (c) 2009 The Go Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
22
vendor/github.com/tailscale/golang-x-crypto/PATENTS
generated
vendored
Normal file
22
vendor/github.com/tailscale/golang-x-crypto/PATENTS
generated
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
Additional IP Rights Grant (Patents)
|
||||
|
||||
"This implementation" means the copyrightable works distributed by
|
||||
Google as part of the Go project.
|
||||
|
||||
Google hereby grants to You a perpetual, worldwide, non-exclusive,
|
||||
no-charge, royalty-free, irrevocable (except as stated in this section)
|
||||
patent license to make, have made, use, offer to sell, sell, import,
|
||||
transfer and otherwise run, modify and propagate the contents of this
|
||||
implementation of Go, where such license applies only to those patent
|
||||
claims, both currently owned or controlled by Google and acquired in
|
||||
the future, licensable by Google that are necessarily infringed by this
|
||||
implementation of Go. This grant does not include claims that would be
|
||||
infringed only as a consequence of further modification of this
|
||||
implementation. If you or your agent or exclusive licensee institute or
|
||||
order or agree to the institution of patent litigation against any
|
||||
entity (including a cross-claim or counterclaim in a lawsuit) alleging
|
||||
that this implementation of Go or any code incorporated within this
|
||||
implementation of Go constitutes direct or contributory patent
|
||||
infringement, or inducement of patent infringement, then any patent
|
||||
rights granted to you under this License for this implementation of Go
|
||||
shall terminate as of the date such litigation is filed.
|
||||
861
vendor/github.com/tailscale/golang-x-crypto/acme/acme.go
generated
vendored
Normal file
861
vendor/github.com/tailscale/golang-x-crypto/acme/acme.go
generated
vendored
Normal file
@@ -0,0 +1,861 @@
|
||||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package acme provides an implementation of the
|
||||
// Automatic Certificate Management Environment (ACME) spec,
|
||||
// most famously used by Let's Encrypt.
|
||||
//
|
||||
// The initial implementation of this package was based on an early version
|
||||
// of the spec. The current implementation supports only the modern
|
||||
// RFC 8555 but some of the old API surface remains for compatibility.
|
||||
// While code using the old API will still compile, it will return an error.
|
||||
// Note the deprecation comments to update your code.
|
||||
//
|
||||
// See https://tools.ietf.org/html/rfc8555 for the spec.
|
||||
//
|
||||
// Most common scenarios will want to use autocert subdirectory instead,
|
||||
// which provides automatic access to certificates from Let's Encrypt
|
||||
// and any other ACME-based CA.
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// LetsEncryptURL is the Directory endpoint of Let's Encrypt CA.
|
||||
LetsEncryptURL = "https://acme-v02.api.letsencrypt.org/directory"
|
||||
|
||||
// ALPNProto is the ALPN protocol name used by a CA server when validating
|
||||
// tls-alpn-01 challenges.
|
||||
//
|
||||
// Package users must ensure their servers can negotiate the ACME ALPN in
|
||||
// order for tls-alpn-01 challenge verifications to succeed.
|
||||
// See the crypto/tls package's Config.NextProtos field.
|
||||
ALPNProto = "acme-tls/1"
|
||||
)
|
||||
|
||||
// idPeACMEIdentifier is the OID for the ACME extension for the TLS-ALPN challenge.
|
||||
// https://tools.ietf.org/html/draft-ietf-acme-tls-alpn-05#section-5.1
|
||||
var idPeACMEIdentifier = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31}
|
||||
|
||||
const (
|
||||
maxChainLen = 5 // max depth and breadth of a certificate chain
|
||||
maxCertSize = 1 << 20 // max size of a certificate, in DER bytes
|
||||
// Used for decoding certs from application/pem-certificate-chain response,
|
||||
// the default when in RFC mode.
|
||||
maxCertChainSize = maxCertSize * maxChainLen
|
||||
|
||||
// Max number of collected nonces kept in memory.
|
||||
// Expect usual peak of 1 or 2.
|
||||
maxNonces = 100
|
||||
)
|
||||
|
||||
// Client is an ACME client.
|
||||
//
|
||||
// The only required field is Key. An example of creating a client with a new key
|
||||
// is as follows:
|
||||
//
|
||||
// key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
// client := &Client{Key: key}
|
||||
type Client struct {
|
||||
// Key is the account key used to register with a CA and sign requests.
|
||||
// Key.Public() must return a *rsa.PublicKey or *ecdsa.PublicKey.
|
||||
//
|
||||
// The following algorithms are supported:
|
||||
// RS256, ES256, ES384 and ES512.
|
||||
// See RFC 7518 for more details about the algorithms.
|
||||
Key crypto.Signer
|
||||
|
||||
// HTTPClient optionally specifies an HTTP client to use
|
||||
// instead of http.DefaultClient.
|
||||
HTTPClient *http.Client
|
||||
|
||||
// DirectoryURL points to the CA directory endpoint.
|
||||
// If empty, LetsEncryptURL is used.
|
||||
// Mutating this value after a successful call of Client's Discover method
|
||||
// will have no effect.
|
||||
DirectoryURL string
|
||||
|
||||
// RetryBackoff computes the duration after which the nth retry of a failed request
|
||||
// should occur. The value of n for the first call on failure is 1.
|
||||
// The values of r and resp are the request and response of the last failed attempt.
|
||||
// If the returned value is negative or zero, no more retries are done and an error
|
||||
// is returned to the caller of the original method.
|
||||
//
|
||||
// Requests which result in a 4xx client error are not retried,
|
||||
// except for 400 Bad Request due to "bad nonce" errors and 429 Too Many Requests.
|
||||
//
|
||||
// If RetryBackoff is nil, a truncated exponential backoff algorithm
|
||||
// with the ceiling of 10 seconds is used, where each subsequent retry n
|
||||
// is done after either ("Retry-After" + jitter) or (2^n seconds + jitter),
|
||||
// preferring the former if "Retry-After" header is found in the resp.
|
||||
// The jitter is a random value up to 1 second.
|
||||
RetryBackoff func(n int, r *http.Request, resp *http.Response) time.Duration
|
||||
|
||||
// UserAgent is prepended to the User-Agent header sent to the ACME server,
|
||||
// which by default is this package's name and version.
|
||||
//
|
||||
// Reusable libraries and tools in particular should set this value to be
|
||||
// identifiable by the server, in case they are causing issues.
|
||||
UserAgent string
|
||||
|
||||
cacheMu sync.Mutex
|
||||
dir *Directory // cached result of Client's Discover method
|
||||
// KID is the key identifier provided by the CA. If not provided it will be
|
||||
// retrieved from the CA by making a call to the registration endpoint.
|
||||
KID KeyID
|
||||
|
||||
noncesMu sync.Mutex
|
||||
nonces map[string]struct{} // nonces collected from previous responses
|
||||
}
|
||||
|
||||
// accountKID returns a key ID associated with c.Key, the account identity
|
||||
// provided by the CA during RFC based registration.
|
||||
// It assumes c.Discover has already been called.
|
||||
//
|
||||
// accountKID requires at most one network roundtrip.
|
||||
// It caches only successful result.
|
||||
//
|
||||
// When in pre-RFC mode or when c.getRegRFC responds with an error, accountKID
|
||||
// returns noKeyID.
|
||||
func (c *Client) accountKID(ctx context.Context) KeyID {
|
||||
c.cacheMu.Lock()
|
||||
defer c.cacheMu.Unlock()
|
||||
if c.KID != noKeyID {
|
||||
return c.KID
|
||||
}
|
||||
a, err := c.getRegRFC(ctx)
|
||||
if err != nil {
|
||||
return noKeyID
|
||||
}
|
||||
c.KID = KeyID(a.URI)
|
||||
return c.KID
|
||||
}
|
||||
|
||||
var errPreRFC = errors.New("acme: server does not support the RFC 8555 version of ACME")
|
||||
|
||||
// Discover performs ACME server discovery using c.DirectoryURL.
|
||||
//
|
||||
// It caches successful result. So, subsequent calls will not result in
|
||||
// a network round-trip. This also means mutating c.DirectoryURL after successful call
|
||||
// of this method will have no effect.
|
||||
func (c *Client) Discover(ctx context.Context) (Directory, error) {
|
||||
c.cacheMu.Lock()
|
||||
defer c.cacheMu.Unlock()
|
||||
if c.dir != nil {
|
||||
return *c.dir, nil
|
||||
}
|
||||
|
||||
res, err := c.get(ctx, c.directoryURL(), wantStatus(http.StatusOK))
|
||||
if err != nil {
|
||||
return Directory{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
c.addNonce(res.Header)
|
||||
|
||||
var v struct {
|
||||
Reg string `json:"newAccount"`
|
||||
Authz string `json:"newAuthz"`
|
||||
Order string `json:"newOrder"`
|
||||
Revoke string `json:"revokeCert"`
|
||||
Nonce string `json:"newNonce"`
|
||||
KeyChange string `json:"keyChange"`
|
||||
RenewalInfo string `json:"renewalInfo"`
|
||||
Meta struct {
|
||||
Terms string `json:"termsOfService"`
|
||||
Website string `json:"website"`
|
||||
CAA []string `json:"caaIdentities"`
|
||||
ExternalAcct bool `json:"externalAccountRequired"`
|
||||
}
|
||||
}
|
||||
if err := json.NewDecoder(res.Body).Decode(&v); err != nil {
|
||||
return Directory{}, err
|
||||
}
|
||||
if v.Order == "" {
|
||||
return Directory{}, errPreRFC
|
||||
}
|
||||
c.dir = &Directory{
|
||||
RegURL: v.Reg,
|
||||
AuthzURL: v.Authz,
|
||||
OrderURL: v.Order,
|
||||
RevokeURL: v.Revoke,
|
||||
NonceURL: v.Nonce,
|
||||
KeyChangeURL: v.KeyChange,
|
||||
RenewalInfoURL: v.RenewalInfo,
|
||||
Terms: v.Meta.Terms,
|
||||
Website: v.Meta.Website,
|
||||
CAA: v.Meta.CAA,
|
||||
ExternalAccountRequired: v.Meta.ExternalAcct,
|
||||
}
|
||||
return *c.dir, nil
|
||||
}
|
||||
|
||||
func (c *Client) directoryURL() string {
|
||||
if c.DirectoryURL != "" {
|
||||
return c.DirectoryURL
|
||||
}
|
||||
return LetsEncryptURL
|
||||
}
|
||||
|
||||
// CreateCert was part of the old version of ACME. It is incompatible with RFC 8555.
|
||||
//
|
||||
// Deprecated: this was for the pre-RFC 8555 version of ACME. Callers should use CreateOrderCert.
|
||||
func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration, bundle bool) (der [][]byte, certURL string, err error) {
|
||||
return nil, "", errPreRFC
|
||||
}
|
||||
|
||||
// FetchCert retrieves already issued certificate from the given url, in DER format.
|
||||
// It retries the request until the certificate is successfully retrieved,
|
||||
// context is cancelled by the caller or an error response is received.
|
||||
//
|
||||
// If the bundle argument is true, the returned value also contains the CA (issuer)
|
||||
// certificate chain.
|
||||
//
|
||||
// FetchCert returns an error if the CA's response or chain was unreasonably large.
|
||||
// Callers are encouraged to parse the returned value to ensure the certificate is valid
|
||||
// and has expected features.
|
||||
func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]byte, error) {
|
||||
if _, err := c.Discover(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.fetchCertRFC(ctx, url, bundle)
|
||||
}
|
||||
|
||||
// RevokeCert revokes a previously issued certificate cert, provided in DER format.
|
||||
//
|
||||
// The key argument, used to sign the request, must be authorized
|
||||
// to revoke the certificate. It's up to the CA to decide which keys are authorized.
|
||||
// For instance, the key pair of the certificate may be authorized.
|
||||
// If the key is nil, c.Key is used instead.
|
||||
func (c *Client) RevokeCert(ctx context.Context, key crypto.Signer, cert []byte, reason CRLReasonCode) error {
|
||||
if _, err := c.Discover(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.revokeCertRFC(ctx, key, cert, reason)
|
||||
}
|
||||
|
||||
// FetchRenewalInfo retrieves the RenewalInfo from Directory.RenewalInfoURL.
|
||||
func (c *Client) FetchRenewalInfo(ctx context.Context, leaf []byte) (*RenewalInfo, error) {
|
||||
if _, err := c.Discover(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsedLeaf, err := x509.ParseCertificate(leaf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing leaf certificate: %w", err)
|
||||
}
|
||||
|
||||
renewalURL, err := c.getRenewalURL(parsedLeaf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating renewal info URL: %w", err)
|
||||
}
|
||||
|
||||
res, err := c.get(ctx, renewalURL, wantStatus(http.StatusOK))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetching renewal info: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
var info RenewalInfo
|
||||
if err := json.NewDecoder(res.Body).Decode(&info); err != nil {
|
||||
return nil, fmt.Errorf("parsing renewal info response: %w", err)
|
||||
}
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
func (c *Client) getRenewalURL(cert *x509.Certificate) (string, error) {
|
||||
// See https://www.ietf.org/archive/id/draft-ietf-acme-ari-04.html#name-the-renewalinfo-resource
|
||||
// for how the request URL is built.
|
||||
url := c.dir.RenewalInfoURL
|
||||
if !strings.HasSuffix(url, "/") {
|
||||
url += "/"
|
||||
}
|
||||
aki := base64.RawURLEncoding.EncodeToString(cert.AuthorityKeyId)
|
||||
serial := base64.RawURLEncoding.EncodeToString(cert.SerialNumber.Bytes())
|
||||
return fmt.Sprintf("%s%s.%s", url, aki, serial), nil
|
||||
}
|
||||
|
||||
// AcceptTOS always returns true to indicate the acceptance of a CA's Terms of Service
|
||||
// during account registration. See Register method of Client for more details.
|
||||
func AcceptTOS(tosURL string) bool { return true }
|
||||
|
||||
// Register creates a new account with the CA using c.Key.
|
||||
// It returns the registered account. The account acct is not modified.
|
||||
//
|
||||
// The registration may require the caller to agree to the CA's Terms of Service (TOS).
|
||||
// If so, and the account has not indicated the acceptance of the terms (see Account for details),
|
||||
// Register calls prompt with a TOS URL provided by the CA. Prompt should report
|
||||
// whether the caller agrees to the terms. To always accept the terms, the caller can use AcceptTOS.
|
||||
//
|
||||
// When interfacing with an RFC-compliant CA, non-RFC 8555 fields of acct are ignored
|
||||
// and prompt is called if Directory's Terms field is non-zero.
|
||||
// Also see Error's Instance field for when a CA requires already registered accounts to agree
|
||||
// to an updated Terms of Service.
|
||||
func (c *Client) Register(ctx context.Context, acct *Account, prompt func(tosURL string) bool) (*Account, error) {
|
||||
if c.Key == nil {
|
||||
return nil, errors.New("acme: client.Key must be set to Register")
|
||||
}
|
||||
if _, err := c.Discover(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.registerRFC(ctx, acct, prompt)
|
||||
}
|
||||
|
||||
// GetReg retrieves an existing account associated with c.Key.
|
||||
//
|
||||
// The url argument is a legacy artifact of the pre-RFC 8555 API
|
||||
// and is ignored.
|
||||
func (c *Client) GetReg(ctx context.Context, url string) (*Account, error) {
|
||||
if _, err := c.Discover(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.getRegRFC(ctx)
|
||||
}
|
||||
|
||||
// UpdateReg updates an existing registration.
|
||||
// It returns an updated account copy. The provided account is not modified.
|
||||
//
|
||||
// The account's URI is ignored and the account URL associated with
|
||||
// c.Key is used instead.
|
||||
func (c *Client) UpdateReg(ctx context.Context, acct *Account) (*Account, error) {
|
||||
if _, err := c.Discover(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.updateRegRFC(ctx, acct)
|
||||
}
|
||||
|
||||
// AccountKeyRollover attempts to transition a client's account key to a new key.
|
||||
// On success client's Key is updated which is not concurrency safe.
|
||||
// On failure an error will be returned.
|
||||
// The new key is already registered with the ACME provider if the following is true:
|
||||
// - error is of type acme.Error
|
||||
// - StatusCode should be 409 (Conflict)
|
||||
// - Location header will have the KID of the associated account
|
||||
//
|
||||
// More about account key rollover can be found at
|
||||
// https://tools.ietf.org/html/rfc8555#section-7.3.5.
|
||||
func (c *Client) AccountKeyRollover(ctx context.Context, newKey crypto.Signer) error {
|
||||
return c.accountKeyRollover(ctx, newKey)
|
||||
}
|
||||
|
||||
// Authorize performs the initial step in the pre-authorization flow,
|
||||
// as opposed to order-based flow.
|
||||
// The caller will then need to choose from and perform a set of returned
|
||||
// challenges using c.Accept in order to successfully complete authorization.
|
||||
//
|
||||
// Once complete, the caller can use AuthorizeOrder which the CA
|
||||
// should provision with the already satisfied authorization.
|
||||
// For pre-RFC CAs, the caller can proceed directly to requesting a certificate
|
||||
// using CreateCert method.
|
||||
//
|
||||
// If an authorization has been previously granted, the CA may return
|
||||
// a valid authorization which has its Status field set to StatusValid.
|
||||
//
|
||||
// More about pre-authorization can be found at
|
||||
// https://tools.ietf.org/html/rfc8555#section-7.4.1.
|
||||
func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization, error) {
|
||||
return c.authorize(ctx, "dns", domain)
|
||||
}
|
||||
|
||||
// AuthorizeIP is the same as Authorize but requests IP address authorization.
|
||||
// Clients which successfully obtain such authorization may request to issue
|
||||
// a certificate for IP addresses.
|
||||
//
|
||||
// See the ACME spec extension for more details about IP address identifiers:
|
||||
// https://tools.ietf.org/html/draft-ietf-acme-ip.
|
||||
func (c *Client) AuthorizeIP(ctx context.Context, ipaddr string) (*Authorization, error) {
|
||||
return c.authorize(ctx, "ip", ipaddr)
|
||||
}
|
||||
|
||||
func (c *Client) authorize(ctx context.Context, typ, val string) (*Authorization, error) {
|
||||
if _, err := c.Discover(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
type authzID struct {
|
||||
Type string `json:"type"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
req := struct {
|
||||
Resource string `json:"resource"`
|
||||
Identifier authzID `json:"identifier"`
|
||||
}{
|
||||
Resource: "new-authz",
|
||||
Identifier: authzID{Type: typ, Value: val},
|
||||
}
|
||||
res, err := c.post(ctx, nil, c.dir.AuthzURL, req, wantStatus(http.StatusCreated))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
var v wireAuthz
|
||||
if err := json.NewDecoder(res.Body).Decode(&v); err != nil {
|
||||
return nil, fmt.Errorf("acme: invalid response: %v", err)
|
||||
}
|
||||
if v.Status != StatusPending && v.Status != StatusValid {
|
||||
return nil, fmt.Errorf("acme: unexpected status: %s", v.Status)
|
||||
}
|
||||
return v.authorization(res.Header.Get("Location")), nil
|
||||
}
|
||||
|
||||
// GetAuthorization retrieves an authorization identified by the given URL.
|
||||
//
|
||||
// If a caller needs to poll an authorization until its status is final,
|
||||
// see the WaitAuthorization method.
|
||||
func (c *Client) GetAuthorization(ctx context.Context, url string) (*Authorization, error) {
|
||||
if _, err := c.Discover(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var v wireAuthz
|
||||
if err := json.NewDecoder(res.Body).Decode(&v); err != nil {
|
||||
return nil, fmt.Errorf("acme: invalid response: %v", err)
|
||||
}
|
||||
return v.authorization(url), nil
|
||||
}
|
||||
|
||||
// RevokeAuthorization relinquishes an existing authorization identified
|
||||
// by the given URL.
|
||||
// The url argument is an Authorization.URI value.
|
||||
//
|
||||
// If successful, the caller will be required to obtain a new authorization
|
||||
// using the Authorize or AuthorizeOrder methods before being able to request
|
||||
// a new certificate for the domain associated with the authorization.
|
||||
//
|
||||
// It does not revoke existing certificates.
|
||||
func (c *Client) RevokeAuthorization(ctx context.Context, url string) error {
|
||||
if _, err := c.Discover(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := struct {
|
||||
Resource string `json:"resource"`
|
||||
Status string `json:"status"`
|
||||
Delete bool `json:"delete"`
|
||||
}{
|
||||
Resource: "authz",
|
||||
Status: "deactivated",
|
||||
Delete: true,
|
||||
}
|
||||
res, err := c.post(ctx, nil, url, req, wantStatus(http.StatusOK))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// WaitAuthorization polls an authorization at the given URL
|
||||
// until it is in one of the final states, StatusValid or StatusInvalid,
|
||||
// the ACME CA responded with a 4xx error code, or the context is done.
|
||||
//
|
||||
// It returns a non-nil Authorization only if its Status is StatusValid.
|
||||
// In all other cases WaitAuthorization returns an error.
|
||||
// If the Status is StatusInvalid, the returned error is of type *AuthorizationError.
|
||||
func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorization, error) {
|
||||
if _, err := c.Discover(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for {
|
||||
res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK, http.StatusAccepted))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var raw wireAuthz
|
||||
err = json.NewDecoder(res.Body).Decode(&raw)
|
||||
res.Body.Close()
|
||||
switch {
|
||||
case err != nil:
|
||||
// Skip and retry.
|
||||
case raw.Status == StatusValid:
|
||||
return raw.authorization(url), nil
|
||||
case raw.Status == StatusInvalid:
|
||||
return nil, raw.error(url)
|
||||
}
|
||||
|
||||
// Exponential backoff is implemented in c.get above.
|
||||
// This is just to prevent continuously hitting the CA
|
||||
// while waiting for a final authorization status.
|
||||
d := retryAfter(res.Header.Get("Retry-After"))
|
||||
if d == 0 {
|
||||
// Given that the fastest challenges TLS-SNI and HTTP-01
|
||||
// require a CA to make at least 1 network round trip
|
||||
// and most likely persist a challenge state,
|
||||
// this default delay seems reasonable.
|
||||
d = time.Second
|
||||
}
|
||||
t := time.NewTimer(d)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Stop()
|
||||
return nil, ctx.Err()
|
||||
case <-t.C:
|
||||
// Retry.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetChallenge retrieves the current status of an challenge.
|
||||
//
|
||||
// A client typically polls a challenge status using this method.
|
||||
func (c *Client) GetChallenge(ctx context.Context, url string) (*Challenge, error) {
|
||||
if _, err := c.Discover(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK, http.StatusAccepted))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
v := wireChallenge{URI: url}
|
||||
if err := json.NewDecoder(res.Body).Decode(&v); err != nil {
|
||||
return nil, fmt.Errorf("acme: invalid response: %v", err)
|
||||
}
|
||||
return v.challenge(), nil
|
||||
}
|
||||
|
||||
// Accept informs the server that the client accepts one of its challenges
|
||||
// previously obtained with c.Authorize.
|
||||
//
|
||||
// The server will then perform the validation asynchronously.
|
||||
func (c *Client) Accept(ctx context.Context, chal *Challenge) (*Challenge, error) {
|
||||
if _, err := c.Discover(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := c.post(ctx, nil, chal.URI, json.RawMessage("{}"), wantStatus(
|
||||
http.StatusOK, // according to the spec
|
||||
http.StatusAccepted, // Let's Encrypt: see https://goo.gl/WsJ7VT (acme-divergences.md)
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
var v wireChallenge
|
||||
if err := json.NewDecoder(res.Body).Decode(&v); err != nil {
|
||||
return nil, fmt.Errorf("acme: invalid response: %v", err)
|
||||
}
|
||||
return v.challenge(), nil
|
||||
}
|
||||
|
||||
// DNS01ChallengeRecord returns a DNS record value for a dns-01 challenge response.
|
||||
// A TXT record containing the returned value must be provisioned under
|
||||
// "_acme-challenge" name of the domain being validated.
|
||||
//
|
||||
// The token argument is a Challenge.Token value.
|
||||
func (c *Client) DNS01ChallengeRecord(token string) (string, error) {
|
||||
ka, err := keyAuth(c.Key.Public(), token)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
b := sha256.Sum256([]byte(ka))
|
||||
return base64.RawURLEncoding.EncodeToString(b[:]), nil
|
||||
}
|
||||
|
||||
// HTTP01ChallengeResponse returns the response for an http-01 challenge.
|
||||
// Servers should respond with the value to HTTP requests at the URL path
|
||||
// provided by HTTP01ChallengePath to validate the challenge and prove control
|
||||
// over a domain name.
|
||||
//
|
||||
// The token argument is a Challenge.Token value.
|
||||
func (c *Client) HTTP01ChallengeResponse(token string) (string, error) {
|
||||
return keyAuth(c.Key.Public(), token)
|
||||
}
|
||||
|
||||
// HTTP01ChallengePath returns the URL path at which the response for an http-01 challenge
|
||||
// should be provided by the servers.
|
||||
// The response value can be obtained with HTTP01ChallengeResponse.
|
||||
//
|
||||
// The token argument is a Challenge.Token value.
|
||||
func (c *Client) HTTP01ChallengePath(token string) string {
|
||||
return "/.well-known/acme-challenge/" + token
|
||||
}
|
||||
|
||||
// TLSSNI01ChallengeCert creates a certificate for TLS-SNI-01 challenge response.
|
||||
//
|
||||
// Deprecated: This challenge type is unused in both draft-02 and RFC versions of the ACME spec.
|
||||
func (c *Client) TLSSNI01ChallengeCert(token string, opt ...CertOption) (cert tls.Certificate, name string, err error) {
|
||||
ka, err := keyAuth(c.Key.Public(), token)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, "", err
|
||||
}
|
||||
b := sha256.Sum256([]byte(ka))
|
||||
h := hex.EncodeToString(b[:])
|
||||
name = fmt.Sprintf("%s.%s.acme.invalid", h[:32], h[32:])
|
||||
cert, err = tlsChallengeCert([]string{name}, opt)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, "", err
|
||||
}
|
||||
return cert, name, nil
|
||||
}
|
||||
|
||||
// TLSSNI02ChallengeCert creates a certificate for TLS-SNI-02 challenge response.
|
||||
//
|
||||
// Deprecated: This challenge type is unused in both draft-02 and RFC versions of the ACME spec.
|
||||
func (c *Client) TLSSNI02ChallengeCert(token string, opt ...CertOption) (cert tls.Certificate, name string, err error) {
|
||||
b := sha256.Sum256([]byte(token))
|
||||
h := hex.EncodeToString(b[:])
|
||||
sanA := fmt.Sprintf("%s.%s.token.acme.invalid", h[:32], h[32:])
|
||||
|
||||
ka, err := keyAuth(c.Key.Public(), token)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, "", err
|
||||
}
|
||||
b = sha256.Sum256([]byte(ka))
|
||||
h = hex.EncodeToString(b[:])
|
||||
sanB := fmt.Sprintf("%s.%s.ka.acme.invalid", h[:32], h[32:])
|
||||
|
||||
cert, err = tlsChallengeCert([]string{sanA, sanB}, opt)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, "", err
|
||||
}
|
||||
return cert, sanA, nil
|
||||
}
|
||||
|
||||
// TLSALPN01ChallengeCert creates a certificate for TLS-ALPN-01 challenge response.
|
||||
// Servers can present the certificate to validate the challenge and prove control
|
||||
// over a domain name. For more details on TLS-ALPN-01 see
|
||||
// https://tools.ietf.org/html/draft-shoemaker-acme-tls-alpn-00#section-3
|
||||
//
|
||||
// The token argument is a Challenge.Token value.
|
||||
// If a WithKey option is provided, its private part signs the returned cert,
|
||||
// and the public part is used to specify the signee.
|
||||
// If no WithKey option is provided, a new ECDSA key is generated using P-256 curve.
|
||||
//
|
||||
// The returned certificate is valid for the next 24 hours and must be presented only when
|
||||
// the server name in the TLS ClientHello matches the domain, and the special acme-tls/1 ALPN protocol
|
||||
// has been specified.
|
||||
func (c *Client) TLSALPN01ChallengeCert(token, domain string, opt ...CertOption) (cert tls.Certificate, err error) {
|
||||
ka, err := keyAuth(c.Key.Public(), token)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
shasum := sha256.Sum256([]byte(ka))
|
||||
extValue, err := asn1.Marshal(shasum[:])
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
acmeExtension := pkix.Extension{
|
||||
Id: idPeACMEIdentifier,
|
||||
Critical: true,
|
||||
Value: extValue,
|
||||
}
|
||||
|
||||
tmpl := defaultTLSChallengeCertTemplate()
|
||||
|
||||
var newOpt []CertOption
|
||||
for _, o := range opt {
|
||||
switch o := o.(type) {
|
||||
case *certOptTemplate:
|
||||
t := *(*x509.Certificate)(o) // shallow copy is ok
|
||||
tmpl = &t
|
||||
default:
|
||||
newOpt = append(newOpt, o)
|
||||
}
|
||||
}
|
||||
tmpl.ExtraExtensions = append(tmpl.ExtraExtensions, acmeExtension)
|
||||
newOpt = append(newOpt, WithTemplate(tmpl))
|
||||
return tlsChallengeCert([]string{domain}, newOpt)
|
||||
}
|
||||
|
||||
// popNonce returns a nonce value previously stored with c.addNonce
|
||||
// or fetches a fresh one from c.dir.NonceURL.
|
||||
// If NonceURL is empty, it first tries c.directoryURL() and, failing that,
|
||||
// the provided url.
|
||||
func (c *Client) popNonce(ctx context.Context, url string) (string, error) {
|
||||
c.noncesMu.Lock()
|
||||
defer c.noncesMu.Unlock()
|
||||
if len(c.nonces) == 0 {
|
||||
if c.dir != nil && c.dir.NonceURL != "" {
|
||||
return c.fetchNonce(ctx, c.dir.NonceURL)
|
||||
}
|
||||
dirURL := c.directoryURL()
|
||||
v, err := c.fetchNonce(ctx, dirURL)
|
||||
if err != nil && url != dirURL {
|
||||
v, err = c.fetchNonce(ctx, url)
|
||||
}
|
||||
return v, err
|
||||
}
|
||||
var nonce string
|
||||
for nonce = range c.nonces {
|
||||
delete(c.nonces, nonce)
|
||||
break
|
||||
}
|
||||
return nonce, nil
|
||||
}
|
||||
|
||||
// clearNonces clears any stored nonces
|
||||
func (c *Client) clearNonces() {
|
||||
c.noncesMu.Lock()
|
||||
defer c.noncesMu.Unlock()
|
||||
c.nonces = make(map[string]struct{})
|
||||
}
|
||||
|
||||
// addNonce stores a nonce value found in h (if any) for future use.
|
||||
func (c *Client) addNonce(h http.Header) {
|
||||
v := nonceFromHeader(h)
|
||||
if v == "" {
|
||||
return
|
||||
}
|
||||
c.noncesMu.Lock()
|
||||
defer c.noncesMu.Unlock()
|
||||
if len(c.nonces) >= maxNonces {
|
||||
return
|
||||
}
|
||||
if c.nonces == nil {
|
||||
c.nonces = make(map[string]struct{})
|
||||
}
|
||||
c.nonces[v] = struct{}{}
|
||||
}
|
||||
|
||||
func (c *Client) fetchNonce(ctx context.Context, url string) (string, error) {
|
||||
r, err := http.NewRequest("HEAD", url, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := c.doNoRetry(ctx, r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
nonce := nonceFromHeader(resp.Header)
|
||||
if nonce == "" {
|
||||
if resp.StatusCode > 299 {
|
||||
return "", responseError(resp)
|
||||
}
|
||||
return "", errors.New("acme: nonce not found")
|
||||
}
|
||||
return nonce, nil
|
||||
}
|
||||
|
||||
func nonceFromHeader(h http.Header) string {
|
||||
return h.Get("Replay-Nonce")
|
||||
}
|
||||
|
||||
// linkHeader returns URI-Reference values of all Link headers
|
||||
// with relation-type rel.
|
||||
// See https://tools.ietf.org/html/rfc5988#section-5 for details.
|
||||
func linkHeader(h http.Header, rel string) []string {
|
||||
var links []string
|
||||
for _, v := range h["Link"] {
|
||||
parts := strings.Split(v, ";")
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if !strings.HasPrefix(p, "rel=") {
|
||||
continue
|
||||
}
|
||||
if v := strings.Trim(p[4:], `"`); v == rel {
|
||||
links = append(links, strings.Trim(parts[0], "<>"))
|
||||
}
|
||||
}
|
||||
}
|
||||
return links
|
||||
}
|
||||
|
||||
// keyAuth generates a key authorization string for a given token.
|
||||
func keyAuth(pub crypto.PublicKey, token string) (string, error) {
|
||||
th, err := JWKThumbprint(pub)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", token, th), nil
|
||||
}
|
||||
|
||||
// defaultTLSChallengeCertTemplate is a template used to create challenge certs for TLS challenges.
|
||||
func defaultTLSChallengeCertTemplate() *x509.Certificate {
|
||||
return &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
BasicConstraintsValid: true,
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
}
|
||||
}
|
||||
|
||||
// tlsChallengeCert creates a temporary certificate for TLS-SNI challenges
|
||||
// with the given SANs and auto-generated public/private key pair.
|
||||
// The Subject Common Name is set to the first SAN to aid debugging.
|
||||
// To create a cert with a custom key pair, specify WithKey option.
|
||||
func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) {
|
||||
var key crypto.Signer
|
||||
tmpl := defaultTLSChallengeCertTemplate()
|
||||
for _, o := range opt {
|
||||
switch o := o.(type) {
|
||||
case *certOptKey:
|
||||
if key != nil {
|
||||
return tls.Certificate{}, errors.New("acme: duplicate key option")
|
||||
}
|
||||
key = o.key
|
||||
case *certOptTemplate:
|
||||
t := *(*x509.Certificate)(o) // shallow copy is ok
|
||||
tmpl = &t
|
||||
default:
|
||||
// package's fault, if we let this happen:
|
||||
panic(fmt.Sprintf("unsupported option type %T", o))
|
||||
}
|
||||
}
|
||||
if key == nil {
|
||||
var err error
|
||||
if key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader); err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
}
|
||||
tmpl.DNSNames = san
|
||||
if len(san) > 0 {
|
||||
tmpl.Subject.CommonName = san[0]
|
||||
}
|
||||
|
||||
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
return tls.Certificate{
|
||||
Certificate: [][]byte{der},
|
||||
PrivateKey: key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// encodePEM returns b encoded as PEM with block of type typ.
|
||||
func encodePEM(typ string, b []byte) []byte {
|
||||
pb := &pem.Block{Type: typ, Bytes: b}
|
||||
return pem.EncodeToMemory(pb)
|
||||
}
|
||||
|
||||
// timeNow is time.Now, except in tests which can mess with it.
|
||||
var timeNow = time.Now
|
||||
325
vendor/github.com/tailscale/golang-x-crypto/acme/http.go
generated
vendored
Normal file
325
vendor/github.com/tailscale/golang-x-crypto/acme/http.go
generated
vendored
Normal file
@@ -0,0 +1,325 @@
|
||||
// Copyright 2018 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package acme
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// retryTimer encapsulates common logic for retrying unsuccessful requests.
|
||||
// It is not safe for concurrent use.
|
||||
type retryTimer struct {
|
||||
// backoffFn provides backoff delay sequence for retries.
|
||||
// See Client.RetryBackoff doc comment.
|
||||
backoffFn func(n int, r *http.Request, res *http.Response) time.Duration
|
||||
// n is the current retry attempt.
|
||||
n int
|
||||
}
|
||||
|
||||
func (t *retryTimer) inc() {
|
||||
t.n++
|
||||
}
|
||||
|
||||
// backoff pauses the current goroutine as described in Client.RetryBackoff.
|
||||
func (t *retryTimer) backoff(ctx context.Context, r *http.Request, res *http.Response) error {
|
||||
d := t.backoffFn(t.n, r, res)
|
||||
if d <= 0 {
|
||||
return fmt.Errorf("acme: no more retries for %s; tried %d time(s)", r.URL, t.n)
|
||||
}
|
||||
wakeup := time.NewTimer(d)
|
||||
defer wakeup.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-wakeup.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) retryTimer() *retryTimer {
|
||||
f := c.RetryBackoff
|
||||
if f == nil {
|
||||
f = defaultBackoff
|
||||
}
|
||||
return &retryTimer{backoffFn: f}
|
||||
}
|
||||
|
||||
// defaultBackoff provides default Client.RetryBackoff implementation
|
||||
// using a truncated exponential backoff algorithm,
|
||||
// as described in Client.RetryBackoff.
|
||||
//
|
||||
// The n argument is always bounded between 1 and 30.
|
||||
// The returned value is always greater than 0.
|
||||
func defaultBackoff(n int, r *http.Request, res *http.Response) time.Duration {
|
||||
const max = 10 * time.Second
|
||||
var jitter time.Duration
|
||||
if x, err := rand.Int(rand.Reader, big.NewInt(1000)); err == nil {
|
||||
// Set the minimum to 1ms to avoid a case where
|
||||
// an invalid Retry-After value is parsed into 0 below,
|
||||
// resulting in the 0 returned value which would unintentionally
|
||||
// stop the retries.
|
||||
jitter = (1 + time.Duration(x.Int64())) * time.Millisecond
|
||||
}
|
||||
if v, ok := res.Header["Retry-After"]; ok {
|
||||
return retryAfter(v[0]) + jitter
|
||||
}
|
||||
|
||||
if n < 1 {
|
||||
n = 1
|
||||
}
|
||||
if n > 30 {
|
||||
n = 30
|
||||
}
|
||||
d := time.Duration(1<<uint(n-1))*time.Second + jitter
|
||||
if d > max {
|
||||
return max
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// retryAfter parses a Retry-After HTTP header value,
|
||||
// trying to convert v into an int (seconds) or use http.ParseTime otherwise.
|
||||
// It returns zero value if v cannot be parsed.
|
||||
func retryAfter(v string) time.Duration {
|
||||
if i, err := strconv.Atoi(v); err == nil {
|
||||
return time.Duration(i) * time.Second
|
||||
}
|
||||
t, err := http.ParseTime(v)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return t.Sub(timeNow())
|
||||
}
|
||||
|
||||
// resOkay is a function that reports whether the provided response is okay.
|
||||
// It is expected to keep the response body unread.
|
||||
type resOkay func(*http.Response) bool
|
||||
|
||||
// wantStatus returns a function which reports whether the code
|
||||
// matches the status code of a response.
|
||||
func wantStatus(codes ...int) resOkay {
|
||||
return func(res *http.Response) bool {
|
||||
for _, code := range codes {
|
||||
if code == res.StatusCode {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// get issues an unsigned GET request to the specified URL.
|
||||
// It returns a non-error value only when ok reports true.
|
||||
//
|
||||
// get retries unsuccessful attempts according to c.RetryBackoff
|
||||
// until the context is done or a non-retriable error is received.
|
||||
func (c *Client) get(ctx context.Context, url string, ok resOkay) (*http.Response, error) {
|
||||
retry := c.retryTimer()
|
||||
for {
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res, err := c.doNoRetry(ctx, req)
|
||||
switch {
|
||||
case err != nil:
|
||||
return nil, err
|
||||
case ok(res):
|
||||
return res, nil
|
||||
case isRetriable(res.StatusCode):
|
||||
retry.inc()
|
||||
resErr := responseError(res)
|
||||
res.Body.Close()
|
||||
// Ignore the error value from retry.backoff
|
||||
// and return the one from last retry, as received from the CA.
|
||||
if retry.backoff(ctx, req, res) != nil {
|
||||
return nil, resErr
|
||||
}
|
||||
default:
|
||||
defer res.Body.Close()
|
||||
return nil, responseError(res)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// postAsGet is POST-as-GET, a replacement for GET in RFC 8555
|
||||
// as described in https://tools.ietf.org/html/rfc8555#section-6.3.
|
||||
// It makes a POST request in KID form with zero JWS payload.
|
||||
// See nopayload doc comments in jws.go.
|
||||
func (c *Client) postAsGet(ctx context.Context, url string, ok resOkay) (*http.Response, error) {
|
||||
return c.post(ctx, nil, url, noPayload, ok)
|
||||
}
|
||||
|
||||
// post issues a signed POST request in JWS format using the provided key
|
||||
// to the specified URL. If key is nil, c.Key is used instead.
|
||||
// It returns a non-error value only when ok reports true.
|
||||
//
|
||||
// post retries unsuccessful attempts according to c.RetryBackoff
|
||||
// until the context is done or a non-retriable error is received.
|
||||
// It uses postNoRetry to make individual requests.
|
||||
func (c *Client) post(ctx context.Context, key crypto.Signer, url string, body interface{}, ok resOkay) (*http.Response, error) {
|
||||
retry := c.retryTimer()
|
||||
for {
|
||||
res, req, err := c.postNoRetry(ctx, key, url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok(res) {
|
||||
return res, nil
|
||||
}
|
||||
resErr := responseError(res)
|
||||
res.Body.Close()
|
||||
switch {
|
||||
// Check for bad nonce before isRetriable because it may have been returned
|
||||
// with an unretriable response code such as 400 Bad Request.
|
||||
case isBadNonce(resErr):
|
||||
// Consider any previously stored nonce values to be invalid.
|
||||
c.clearNonces()
|
||||
case !isRetriable(res.StatusCode):
|
||||
return nil, resErr
|
||||
}
|
||||
retry.inc()
|
||||
// Ignore the error value from retry.backoff
|
||||
// and return the one from last retry, as received from the CA.
|
||||
if err := retry.backoff(ctx, req, res); err != nil {
|
||||
return nil, resErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// postNoRetry signs the body with the given key and POSTs it to the provided url.
|
||||
// It is used by c.post to retry unsuccessful attempts.
|
||||
// The body argument must be JSON-serializable.
|
||||
//
|
||||
// If key argument is nil, c.Key is used to sign the request.
|
||||
// If key argument is nil and c.accountKID returns a non-zero keyID,
|
||||
// the request is sent in KID form. Otherwise, JWK form is used.
|
||||
//
|
||||
// In practice, when interfacing with RFC-compliant CAs most requests are sent in KID form
|
||||
// and JWK is used only when KID is unavailable: new account endpoint and certificate
|
||||
// revocation requests authenticated by a cert key.
|
||||
// See jwsEncodeJSON for other details.
|
||||
func (c *Client) postNoRetry(ctx context.Context, key crypto.Signer, url string, body interface{}) (*http.Response, *http.Request, error) {
|
||||
kid := noKeyID
|
||||
if key == nil {
|
||||
if c.Key == nil {
|
||||
return nil, nil, errors.New("acme: Client.Key must be populated to make POST requests")
|
||||
}
|
||||
key = c.Key
|
||||
kid = c.accountKID(ctx)
|
||||
}
|
||||
nonce, err := c.popNonce(ctx, url)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
b, err := jwsEncodeJSON(body, key, kid, nonce, url)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
req, err := http.NewRequest("POST", url, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/jose+json")
|
||||
res, err := c.doNoRetry(ctx, req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
c.addNonce(res.Header)
|
||||
return res, req, nil
|
||||
}
|
||||
|
||||
// doNoRetry issues a request req, replacing its context (if any) with ctx.
|
||||
func (c *Client) doNoRetry(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||
req.Header.Set("User-Agent", c.userAgent())
|
||||
res, err := c.httpClient().Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Prefer the unadorned context error.
|
||||
// (The acme package had tests assuming this, previously from ctxhttp's
|
||||
// behavior, predating net/http supporting contexts natively)
|
||||
// TODO(bradfitz): reconsider this in the future. But for now this
|
||||
// requires no test updates.
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (c *Client) httpClient() *http.Client {
|
||||
if c.HTTPClient != nil {
|
||||
return c.HTTPClient
|
||||
}
|
||||
return http.DefaultClient
|
||||
}
|
||||
|
||||
// packageVersion is the version of the module that contains this package, for
|
||||
// sending as part of the User-Agent header. It's set in version_go112.go.
|
||||
var packageVersion string
|
||||
|
||||
// userAgent returns the User-Agent header value. It includes the package name,
|
||||
// the module version (if available), and the c.UserAgent value (if set).
|
||||
func (c *Client) userAgent() string {
|
||||
ua := "golang.org/x/crypto/acme"
|
||||
if packageVersion != "" {
|
||||
ua += "@" + packageVersion
|
||||
}
|
||||
if c.UserAgent != "" {
|
||||
ua = c.UserAgent + " " + ua
|
||||
}
|
||||
return ua
|
||||
}
|
||||
|
||||
// isBadNonce reports whether err is an ACME "badnonce" error.
|
||||
func isBadNonce(err error) bool {
|
||||
// According to the spec badNonce is urn:ietf:params:acme:error:badNonce.
|
||||
// However, ACME servers in the wild return their versions of the error.
|
||||
// See https://tools.ietf.org/html/draft-ietf-acme-acme-02#section-5.4
|
||||
// and https://github.com/letsencrypt/boulder/blob/0e07eacb/docs/acme-divergences.md#section-66.
|
||||
ae, ok := err.(*Error)
|
||||
return ok && strings.HasSuffix(strings.ToLower(ae.ProblemType), ":badnonce")
|
||||
}
|
||||
|
||||
// isRetriable reports whether a request can be retried
|
||||
// based on the response status code.
|
||||
//
|
||||
// Note that a "bad nonce" error is returned with a non-retriable 400 Bad Request code.
|
||||
// Callers should parse the response and check with isBadNonce.
|
||||
func isRetriable(code int) bool {
|
||||
return code <= 399 || code >= 500 || code == http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
// responseError creates an error of Error type from resp.
|
||||
func responseError(resp *http.Response) error {
|
||||
// don't care if ReadAll returns an error:
|
||||
// json.Unmarshal will fail in that case anyway
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
e := &wireError{Status: resp.StatusCode}
|
||||
if err := json.Unmarshal(b, e); err != nil {
|
||||
// this is not a regular error response:
|
||||
// populate detail with anything we received,
|
||||
// e.Status will already contain HTTP response code value
|
||||
e.Detail = string(b)
|
||||
if e.Detail == "" {
|
||||
e.Detail = resp.Status
|
||||
}
|
||||
}
|
||||
return e.error(resp.Header)
|
||||
}
|
||||
257
vendor/github.com/tailscale/golang-x-crypto/acme/jws.go
generated
vendored
Normal file
257
vendor/github.com/tailscale/golang-x-crypto/acme/jws.go
generated
vendored
Normal file
@@ -0,0 +1,257 @@
|
||||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package acme
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
_ "crypto/sha512" // need for EC keys
|
||||
"encoding/asn1"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
// KeyID is the account key identity provided by a CA during registration.
|
||||
type KeyID string
|
||||
|
||||
// noKeyID indicates that jwsEncodeJSON should compute and use JWK instead of a KID.
|
||||
// See jwsEncodeJSON for details.
|
||||
const noKeyID = KeyID("")
|
||||
|
||||
// noPayload indicates jwsEncodeJSON will encode zero-length octet string
|
||||
// in a JWS request. This is called POST-as-GET in RFC 8555 and is used to make
|
||||
// authenticated GET requests via POSTing with an empty payload.
|
||||
// See https://tools.ietf.org/html/rfc8555#section-6.3 for more details.
|
||||
const noPayload = ""
|
||||
|
||||
// noNonce indicates that the nonce should be omitted from the protected header.
|
||||
// See jwsEncodeJSON for details.
|
||||
const noNonce = ""
|
||||
|
||||
// jsonWebSignature can be easily serialized into a JWS following
|
||||
// https://tools.ietf.org/html/rfc7515#section-3.2.
|
||||
type jsonWebSignature struct {
|
||||
Protected string `json:"protected"`
|
||||
Payload string `json:"payload"`
|
||||
Sig string `json:"signature"`
|
||||
}
|
||||
|
||||
// jwsEncodeJSON signs claimset using provided key and a nonce.
|
||||
// The result is serialized in JSON format containing either kid or jwk
|
||||
// fields based on the provided KeyID value.
|
||||
//
|
||||
// The claimset is marshalled using json.Marshal unless it is a string.
|
||||
// In which case it is inserted directly into the message.
|
||||
//
|
||||
// If kid is non-empty, its quoted value is inserted in the protected header
|
||||
// as "kid" field value. Otherwise, JWK is computed using jwkEncode and inserted
|
||||
// as "jwk" field value. The "jwk" and "kid" fields are mutually exclusive.
|
||||
//
|
||||
// If nonce is non-empty, its quoted value is inserted in the protected header.
|
||||
//
|
||||
// See https://tools.ietf.org/html/rfc7515#section-7.
|
||||
func jwsEncodeJSON(claimset interface{}, key crypto.Signer, kid KeyID, nonce, url string) ([]byte, error) {
|
||||
if key == nil {
|
||||
return nil, errors.New("nil key")
|
||||
}
|
||||
alg, sha := jwsHasher(key.Public())
|
||||
if alg == "" || !sha.Available() {
|
||||
return nil, ErrUnsupportedKey
|
||||
}
|
||||
headers := struct {
|
||||
Alg string `json:"alg"`
|
||||
KID string `json:"kid,omitempty"`
|
||||
JWK json.RawMessage `json:"jwk,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
URL string `json:"url"`
|
||||
}{
|
||||
Alg: alg,
|
||||
Nonce: nonce,
|
||||
URL: url,
|
||||
}
|
||||
switch kid {
|
||||
case noKeyID:
|
||||
jwk, err := jwkEncode(key.Public())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
headers.JWK = json.RawMessage(jwk)
|
||||
default:
|
||||
headers.KID = string(kid)
|
||||
}
|
||||
phJSON, err := json.Marshal(headers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
phead := base64.RawURLEncoding.EncodeToString([]byte(phJSON))
|
||||
var payload string
|
||||
if val, ok := claimset.(string); ok {
|
||||
payload = val
|
||||
} else {
|
||||
cs, err := json.Marshal(claimset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload = base64.RawURLEncoding.EncodeToString(cs)
|
||||
}
|
||||
hash := sha.New()
|
||||
hash.Write([]byte(phead + "." + payload))
|
||||
sig, err := jwsSign(key, sha, hash.Sum(nil))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
enc := jsonWebSignature{
|
||||
Protected: phead,
|
||||
Payload: payload,
|
||||
Sig: base64.RawURLEncoding.EncodeToString(sig),
|
||||
}
|
||||
return json.Marshal(&enc)
|
||||
}
|
||||
|
||||
// jwsWithMAC creates and signs a JWS using the given key and the HS256
|
||||
// algorithm. kid and url are included in the protected header. rawPayload
|
||||
// should not be base64-URL-encoded.
|
||||
func jwsWithMAC(key []byte, kid, url string, rawPayload []byte) (*jsonWebSignature, error) {
|
||||
if len(key) == 0 {
|
||||
return nil, errors.New("acme: cannot sign JWS with an empty MAC key")
|
||||
}
|
||||
header := struct {
|
||||
Algorithm string `json:"alg"`
|
||||
KID string `json:"kid"`
|
||||
URL string `json:"url,omitempty"`
|
||||
}{
|
||||
// Only HMAC-SHA256 is supported.
|
||||
Algorithm: "HS256",
|
||||
KID: kid,
|
||||
URL: url,
|
||||
}
|
||||
rawProtected, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
protected := base64.RawURLEncoding.EncodeToString(rawProtected)
|
||||
payload := base64.RawURLEncoding.EncodeToString(rawPayload)
|
||||
|
||||
h := hmac.New(sha256.New, key)
|
||||
if _, err := h.Write([]byte(protected + "." + payload)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mac := h.Sum(nil)
|
||||
|
||||
return &jsonWebSignature{
|
||||
Protected: protected,
|
||||
Payload: payload,
|
||||
Sig: base64.RawURLEncoding.EncodeToString(mac),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// jwkEncode encodes public part of an RSA or ECDSA key into a JWK.
|
||||
// The result is also suitable for creating a JWK thumbprint.
|
||||
// https://tools.ietf.org/html/rfc7517
|
||||
func jwkEncode(pub crypto.PublicKey) (string, error) {
|
||||
switch pub := pub.(type) {
|
||||
case *rsa.PublicKey:
|
||||
// https://tools.ietf.org/html/rfc7518#section-6.3.1
|
||||
n := pub.N
|
||||
e := big.NewInt(int64(pub.E))
|
||||
// Field order is important.
|
||||
// See https://tools.ietf.org/html/rfc7638#section-3.3 for details.
|
||||
return fmt.Sprintf(`{"e":"%s","kty":"RSA","n":"%s"}`,
|
||||
base64.RawURLEncoding.EncodeToString(e.Bytes()),
|
||||
base64.RawURLEncoding.EncodeToString(n.Bytes()),
|
||||
), nil
|
||||
case *ecdsa.PublicKey:
|
||||
// https://tools.ietf.org/html/rfc7518#section-6.2.1
|
||||
p := pub.Curve.Params()
|
||||
n := p.BitSize / 8
|
||||
if p.BitSize%8 != 0 {
|
||||
n++
|
||||
}
|
||||
x := pub.X.Bytes()
|
||||
if n > len(x) {
|
||||
x = append(make([]byte, n-len(x)), x...)
|
||||
}
|
||||
y := pub.Y.Bytes()
|
||||
if n > len(y) {
|
||||
y = append(make([]byte, n-len(y)), y...)
|
||||
}
|
||||
// Field order is important.
|
||||
// See https://tools.ietf.org/html/rfc7638#section-3.3 for details.
|
||||
return fmt.Sprintf(`{"crv":"%s","kty":"EC","x":"%s","y":"%s"}`,
|
||||
p.Name,
|
||||
base64.RawURLEncoding.EncodeToString(x),
|
||||
base64.RawURLEncoding.EncodeToString(y),
|
||||
), nil
|
||||
}
|
||||
return "", ErrUnsupportedKey
|
||||
}
|
||||
|
||||
// jwsSign signs the digest using the given key.
|
||||
// The hash is unused for ECDSA keys.
|
||||
func jwsSign(key crypto.Signer, hash crypto.Hash, digest []byte) ([]byte, error) {
|
||||
switch pub := key.Public().(type) {
|
||||
case *rsa.PublicKey:
|
||||
return key.Sign(rand.Reader, digest, hash)
|
||||
case *ecdsa.PublicKey:
|
||||
sigASN1, err := key.Sign(rand.Reader, digest, hash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var rs struct{ R, S *big.Int }
|
||||
if _, err := asn1.Unmarshal(sigASN1, &rs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rb, sb := rs.R.Bytes(), rs.S.Bytes()
|
||||
size := pub.Params().BitSize / 8
|
||||
if size%8 > 0 {
|
||||
size++
|
||||
}
|
||||
sig := make([]byte, size*2)
|
||||
copy(sig[size-len(rb):], rb)
|
||||
copy(sig[size*2-len(sb):], sb)
|
||||
return sig, nil
|
||||
}
|
||||
return nil, ErrUnsupportedKey
|
||||
}
|
||||
|
||||
// jwsHasher indicates suitable JWS algorithm name and a hash function
|
||||
// to use for signing a digest with the provided key.
|
||||
// It returns ("", 0) if the key is not supported.
|
||||
func jwsHasher(pub crypto.PublicKey) (string, crypto.Hash) {
|
||||
switch pub := pub.(type) {
|
||||
case *rsa.PublicKey:
|
||||
return "RS256", crypto.SHA256
|
||||
case *ecdsa.PublicKey:
|
||||
switch pub.Params().Name {
|
||||
case "P-256":
|
||||
return "ES256", crypto.SHA256
|
||||
case "P-384":
|
||||
return "ES384", crypto.SHA384
|
||||
case "P-521":
|
||||
return "ES512", crypto.SHA512
|
||||
}
|
||||
}
|
||||
return "", 0
|
||||
}
|
||||
|
||||
// JWKThumbprint creates a JWK thumbprint out of pub
|
||||
// as specified in https://tools.ietf.org/html/rfc7638.
|
||||
func JWKThumbprint(pub crypto.PublicKey) (string, error) {
|
||||
jwk, err := jwkEncode(pub)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
b := sha256.Sum256([]byte(jwk))
|
||||
return base64.RawURLEncoding.EncodeToString(b[:]), nil
|
||||
}
|
||||
476
vendor/github.com/tailscale/golang-x-crypto/acme/rfc8555.go
generated
vendored
Normal file
476
vendor/github.com/tailscale/golang-x-crypto/acme/rfc8555.go
generated
vendored
Normal file
@@ -0,0 +1,476 @@
|
||||
// Copyright 2019 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DeactivateReg permanently disables an existing account associated with c.Key.
|
||||
// A deactivated account can no longer request certificate issuance or access
|
||||
// resources related to the account, such as orders or authorizations.
|
||||
//
|
||||
// It only works with CAs implementing RFC 8555.
|
||||
func (c *Client) DeactivateReg(ctx context.Context) error {
|
||||
if _, err := c.Discover(ctx); err != nil { // required by c.accountKID
|
||||
return err
|
||||
}
|
||||
url := string(c.accountKID(ctx))
|
||||
if url == "" {
|
||||
return ErrNoAccount
|
||||
}
|
||||
req := json.RawMessage(`{"status": "deactivated"}`)
|
||||
res, err := c.post(ctx, nil, url, req, wantStatus(http.StatusOK))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.Body.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// registerRFC is equivalent to c.Register but for CAs implementing RFC 8555.
|
||||
// It expects c.Discover to have already been called.
|
||||
func (c *Client) registerRFC(ctx context.Context, acct *Account, prompt func(tosURL string) bool) (*Account, error) {
|
||||
c.cacheMu.Lock() // guard c.kid access
|
||||
defer c.cacheMu.Unlock()
|
||||
|
||||
req := struct {
|
||||
TermsAgreed bool `json:"termsOfServiceAgreed,omitempty"`
|
||||
Contact []string `json:"contact,omitempty"`
|
||||
ExternalAccountBinding *jsonWebSignature `json:"externalAccountBinding,omitempty"`
|
||||
}{
|
||||
Contact: acct.Contact,
|
||||
}
|
||||
if c.dir.Terms != "" {
|
||||
req.TermsAgreed = prompt(c.dir.Terms)
|
||||
}
|
||||
|
||||
// set 'externalAccountBinding' field if requested
|
||||
if acct.ExternalAccountBinding != nil {
|
||||
eabJWS, err := c.encodeExternalAccountBinding(acct.ExternalAccountBinding)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("acme: failed to encode external account binding: %v", err)
|
||||
}
|
||||
req.ExternalAccountBinding = eabJWS
|
||||
}
|
||||
|
||||
res, err := c.post(ctx, c.Key, c.dir.RegURL, req, wantStatus(
|
||||
http.StatusOK, // account with this key already registered
|
||||
http.StatusCreated, // new account created
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
a, err := responseAccount(res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Cache Account URL even if we return an error to the caller.
|
||||
// It is by all means a valid and usable "kid" value for future requests.
|
||||
c.KID = KeyID(a.URI)
|
||||
if res.StatusCode == http.StatusOK {
|
||||
return nil, ErrAccountAlreadyExists
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// encodeExternalAccountBinding will encode an external account binding stanza
|
||||
// as described in https://tools.ietf.org/html/rfc8555#section-7.3.4.
|
||||
func (c *Client) encodeExternalAccountBinding(eab *ExternalAccountBinding) (*jsonWebSignature, error) {
|
||||
jwk, err := jwkEncode(c.Key.Public())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jwsWithMAC(eab.Key, eab.KID, c.dir.RegURL, []byte(jwk))
|
||||
}
|
||||
|
||||
// updateRegRFC is equivalent to c.UpdateReg but for CAs implementing RFC 8555.
|
||||
// It expects c.Discover to have already been called.
|
||||
func (c *Client) updateRegRFC(ctx context.Context, a *Account) (*Account, error) {
|
||||
url := string(c.accountKID(ctx))
|
||||
if url == "" {
|
||||
return nil, ErrNoAccount
|
||||
}
|
||||
req := struct {
|
||||
Contact []string `json:"contact,omitempty"`
|
||||
}{
|
||||
Contact: a.Contact,
|
||||
}
|
||||
res, err := c.post(ctx, nil, url, req, wantStatus(http.StatusOK))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
return responseAccount(res)
|
||||
}
|
||||
|
||||
// getRegRFC is equivalent to c.GetReg but for CAs implementing RFC 8555.
|
||||
// It expects c.Discover to have already been called.
|
||||
func (c *Client) getRegRFC(ctx context.Context) (*Account, error) {
|
||||
req := json.RawMessage(`{"onlyReturnExisting": true}`)
|
||||
res, err := c.post(ctx, c.Key, c.dir.RegURL, req, wantStatus(http.StatusOK))
|
||||
if e, ok := err.(*Error); ok && e.ProblemType == "urn:ietf:params:acme:error:accountDoesNotExist" {
|
||||
return nil, ErrNoAccount
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
return responseAccount(res)
|
||||
}
|
||||
|
||||
func responseAccount(res *http.Response) (*Account, error) {
|
||||
var v struct {
|
||||
Status string
|
||||
Contact []string
|
||||
Orders string
|
||||
}
|
||||
if err := json.NewDecoder(res.Body).Decode(&v); err != nil {
|
||||
return nil, fmt.Errorf("acme: invalid account response: %v", err)
|
||||
}
|
||||
return &Account{
|
||||
URI: res.Header.Get("Location"),
|
||||
Status: v.Status,
|
||||
Contact: v.Contact,
|
||||
OrdersURL: v.Orders,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// accountKeyRollover attempts to perform account key rollover.
|
||||
// On success it will change client.Key to the new key.
|
||||
func (c *Client) accountKeyRollover(ctx context.Context, newKey crypto.Signer) error {
|
||||
dir, err := c.Discover(ctx) // Also required by c.accountKID
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
kid := c.accountKID(ctx)
|
||||
if kid == noKeyID {
|
||||
return ErrNoAccount
|
||||
}
|
||||
oldKey, err := jwkEncode(c.Key.Public())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload := struct {
|
||||
Account string `json:"account"`
|
||||
OldKey json.RawMessage `json:"oldKey"`
|
||||
}{
|
||||
Account: string(kid),
|
||||
OldKey: json.RawMessage(oldKey),
|
||||
}
|
||||
inner, err := jwsEncodeJSON(payload, newKey, noKeyID, noNonce, dir.KeyChangeURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := c.post(ctx, nil, dir.KeyChangeURL, base64.RawURLEncoding.EncodeToString(inner), wantStatus(http.StatusOK))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
c.Key = newKey
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthorizeOrder initiates the order-based application for certificate issuance,
|
||||
// as opposed to pre-authorization in Authorize.
|
||||
// It is only supported by CAs implementing RFC 8555.
|
||||
//
|
||||
// The caller then needs to fetch each authorization with GetAuthorization,
|
||||
// identify those with StatusPending status and fulfill a challenge using Accept.
|
||||
// Once all authorizations are satisfied, the caller will typically want to poll
|
||||
// order status using WaitOrder until it's in StatusReady state.
|
||||
// To finalize the order and obtain a certificate, the caller submits a CSR with CreateOrderCert.
|
||||
func (c *Client) AuthorizeOrder(ctx context.Context, id []AuthzID, opt ...OrderOption) (*Order, error) {
|
||||
dir, err := c.Discover(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := struct {
|
||||
Identifiers []wireAuthzID `json:"identifiers"`
|
||||
NotBefore string `json:"notBefore,omitempty"`
|
||||
NotAfter string `json:"notAfter,omitempty"`
|
||||
}{}
|
||||
for _, v := range id {
|
||||
req.Identifiers = append(req.Identifiers, wireAuthzID{
|
||||
Type: v.Type,
|
||||
Value: v.Value,
|
||||
})
|
||||
}
|
||||
for _, o := range opt {
|
||||
switch o := o.(type) {
|
||||
case orderNotBeforeOpt:
|
||||
req.NotBefore = time.Time(o).Format(time.RFC3339)
|
||||
case orderNotAfterOpt:
|
||||
req.NotAfter = time.Time(o).Format(time.RFC3339)
|
||||
default:
|
||||
// Package's fault if we let this happen.
|
||||
panic(fmt.Sprintf("unsupported order option type %T", o))
|
||||
}
|
||||
}
|
||||
|
||||
res, err := c.post(ctx, nil, dir.OrderURL, req, wantStatus(http.StatusCreated))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
return responseOrder(res)
|
||||
}
|
||||
|
||||
// GetOrder retrives an order identified by the given URL.
|
||||
// For orders created with AuthorizeOrder, the url value is Order.URI.
|
||||
//
|
||||
// If a caller needs to poll an order until its status is final,
|
||||
// see the WaitOrder method.
|
||||
func (c *Client) GetOrder(ctx context.Context, url string) (*Order, error) {
|
||||
if _, err := c.Discover(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
return responseOrder(res)
|
||||
}
|
||||
|
||||
// WaitOrder polls an order from the given URL until it is in one of the final states,
|
||||
// StatusReady, StatusValid or StatusInvalid, the CA responded with a non-retryable error
|
||||
// or the context is done.
|
||||
//
|
||||
// It returns a non-nil Order only if its Status is StatusReady or StatusValid.
|
||||
// In all other cases WaitOrder returns an error.
|
||||
// If the Status is StatusInvalid, the returned error is of type *OrderError.
|
||||
func (c *Client) WaitOrder(ctx context.Context, url string) (*Order, error) {
|
||||
if _, err := c.Discover(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for {
|
||||
res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o, err := responseOrder(res)
|
||||
res.Body.Close()
|
||||
switch {
|
||||
case err != nil:
|
||||
// Skip and retry.
|
||||
case o.Status == StatusInvalid:
|
||||
return nil, &OrderError{OrderURL: o.URI, Status: o.Status}
|
||||
case o.Status == StatusReady || o.Status == StatusValid:
|
||||
return o, nil
|
||||
}
|
||||
|
||||
d := retryAfter(res.Header.Get("Retry-After"))
|
||||
if d == 0 {
|
||||
// Default retry-after.
|
||||
// Same reasoning as in WaitAuthorization.
|
||||
d = time.Second
|
||||
}
|
||||
t := time.NewTimer(d)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Stop()
|
||||
return nil, ctx.Err()
|
||||
case <-t.C:
|
||||
// Retry.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func responseOrder(res *http.Response) (*Order, error) {
|
||||
var v struct {
|
||||
Status string
|
||||
Expires time.Time
|
||||
Identifiers []wireAuthzID
|
||||
NotBefore time.Time
|
||||
NotAfter time.Time
|
||||
Error *wireError
|
||||
Authorizations []string
|
||||
Finalize string
|
||||
Certificate string
|
||||
}
|
||||
if err := json.NewDecoder(res.Body).Decode(&v); err != nil {
|
||||
return nil, fmt.Errorf("acme: error reading order: %v", err)
|
||||
}
|
||||
o := &Order{
|
||||
URI: res.Header.Get("Location"),
|
||||
Status: v.Status,
|
||||
Expires: v.Expires,
|
||||
NotBefore: v.NotBefore,
|
||||
NotAfter: v.NotAfter,
|
||||
AuthzURLs: v.Authorizations,
|
||||
FinalizeURL: v.Finalize,
|
||||
CertURL: v.Certificate,
|
||||
}
|
||||
for _, id := range v.Identifiers {
|
||||
o.Identifiers = append(o.Identifiers, AuthzID{Type: id.Type, Value: id.Value})
|
||||
}
|
||||
if v.Error != nil {
|
||||
o.Error = v.Error.error(nil /* headers */)
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// CreateOrderCert submits the CSR (Certificate Signing Request) to a CA at the specified URL.
|
||||
// The URL is the FinalizeURL field of an Order created with AuthorizeOrder.
|
||||
//
|
||||
// If the bundle argument is true, the returned value also contain the CA (issuer)
|
||||
// certificate chain. Otherwise, only a leaf certificate is returned.
|
||||
// The returned URL can be used to re-fetch the certificate using FetchCert.
|
||||
//
|
||||
// This method is only supported by CAs implementing RFC 8555. See CreateCert for pre-RFC CAs.
|
||||
//
|
||||
// CreateOrderCert returns an error if the CA's response is unreasonably large.
|
||||
// Callers are encouraged to parse the returned value to ensure the certificate is valid and has the expected features.
|
||||
func (c *Client) CreateOrderCert(ctx context.Context, url string, csr []byte, bundle bool) (der [][]byte, certURL string, err error) {
|
||||
if _, err := c.Discover(ctx); err != nil { // required by c.accountKID
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// RFC describes this as "finalize order" request.
|
||||
req := struct {
|
||||
CSR string `json:"csr"`
|
||||
}{
|
||||
CSR: base64.RawURLEncoding.EncodeToString(csr),
|
||||
}
|
||||
res, err := c.post(ctx, nil, url, req, wantStatus(http.StatusOK))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
o, err := responseOrder(res)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// Wait for CA to issue the cert if they haven't.
|
||||
if o.Status != StatusValid {
|
||||
o, err = c.WaitOrder(ctx, o.URI)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
// The only acceptable status post finalize and WaitOrder is "valid".
|
||||
if o.Status != StatusValid {
|
||||
return nil, "", &OrderError{OrderURL: o.URI, Status: o.Status}
|
||||
}
|
||||
crt, err := c.fetchCertRFC(ctx, o.CertURL, bundle)
|
||||
return crt, o.CertURL, err
|
||||
}
|
||||
|
||||
// fetchCertRFC downloads issued certificate from the given URL.
|
||||
// It expects the CA to respond with PEM-encoded certificate chain.
|
||||
//
|
||||
// The URL argument is the CertURL field of Order.
|
||||
func (c *Client) fetchCertRFC(ctx context.Context, url string, bundle bool) ([][]byte, error) {
|
||||
res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
// Get all the bytes up to a sane maximum.
|
||||
// Account very roughly for base64 overhead.
|
||||
const max = maxCertChainSize + maxCertChainSize/33
|
||||
b, err := io.ReadAll(io.LimitReader(res.Body, max+1))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("acme: fetch cert response stream: %v", err)
|
||||
}
|
||||
if len(b) > max {
|
||||
return nil, errors.New("acme: certificate chain is too big")
|
||||
}
|
||||
|
||||
// Decode PEM chain.
|
||||
var chain [][]byte
|
||||
for {
|
||||
var p *pem.Block
|
||||
p, b = pem.Decode(b)
|
||||
if p == nil {
|
||||
break
|
||||
}
|
||||
if p.Type != "CERTIFICATE" {
|
||||
return nil, fmt.Errorf("acme: invalid PEM cert type %q", p.Type)
|
||||
}
|
||||
|
||||
chain = append(chain, p.Bytes)
|
||||
if !bundle {
|
||||
return chain, nil
|
||||
}
|
||||
if len(chain) > maxChainLen {
|
||||
return nil, errors.New("acme: certificate chain is too long")
|
||||
}
|
||||
}
|
||||
if len(chain) == 0 {
|
||||
return nil, errors.New("acme: certificate chain is empty")
|
||||
}
|
||||
return chain, nil
|
||||
}
|
||||
|
||||
// sends a cert revocation request in either JWK form when key is non-nil or KID form otherwise.
|
||||
func (c *Client) revokeCertRFC(ctx context.Context, key crypto.Signer, cert []byte, reason CRLReasonCode) error {
|
||||
req := &struct {
|
||||
Cert string `json:"certificate"`
|
||||
Reason int `json:"reason"`
|
||||
}{
|
||||
Cert: base64.RawURLEncoding.EncodeToString(cert),
|
||||
Reason: int(reason),
|
||||
}
|
||||
res, err := c.post(ctx, key, c.dir.RevokeURL, req, wantStatus(http.StatusOK))
|
||||
if err != nil {
|
||||
if isAlreadyRevoked(err) {
|
||||
// Assume it is not an error to revoke an already revoked cert.
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func isAlreadyRevoked(err error) bool {
|
||||
e, ok := err.(*Error)
|
||||
return ok && e.ProblemType == "urn:ietf:params:acme:error:alreadyRevoked"
|
||||
}
|
||||
|
||||
// ListCertAlternates retrieves any alternate certificate chain URLs for the
|
||||
// given certificate chain URL. These alternate URLs can be passed to FetchCert
|
||||
// in order to retrieve the alternate certificate chains.
|
||||
//
|
||||
// If there are no alternate issuer certificate chains, a nil slice will be
|
||||
// returned.
|
||||
func (c *Client) ListCertAlternates(ctx context.Context, url string) ([]string, error) {
|
||||
if _, err := c.Discover(ctx); err != nil { // required by c.accountKID
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
// We don't need the body but we need to discard it so we don't end up
|
||||
// preventing keep-alive
|
||||
if _, err := io.Copy(io.Discard, res.Body); err != nil {
|
||||
return nil, fmt.Errorf("acme: cert alternates response stream: %v", err)
|
||||
}
|
||||
alts := linkHeader(res.Header, "alternate")
|
||||
return alts, nil
|
||||
}
|
||||
632
vendor/github.com/tailscale/golang-x-crypto/acme/types.go
generated
vendored
Normal file
632
vendor/github.com/tailscale/golang-x-crypto/acme/types.go
generated
vendored
Normal file
@@ -0,0 +1,632 @@
|
||||
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package acme
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ACME status values of Account, Order, Authorization and Challenge objects.
|
||||
// See https://tools.ietf.org/html/rfc8555#section-7.1.6 for details.
|
||||
const (
|
||||
StatusDeactivated = "deactivated"
|
||||
StatusExpired = "expired"
|
||||
StatusInvalid = "invalid"
|
||||
StatusPending = "pending"
|
||||
StatusProcessing = "processing"
|
||||
StatusReady = "ready"
|
||||
StatusRevoked = "revoked"
|
||||
StatusUnknown = "unknown"
|
||||
StatusValid = "valid"
|
||||
)
|
||||
|
||||
// CRLReasonCode identifies the reason for a certificate revocation.
|
||||
type CRLReasonCode int
|
||||
|
||||
// CRL reason codes as defined in RFC 5280.
|
||||
const (
|
||||
CRLReasonUnspecified CRLReasonCode = 0
|
||||
CRLReasonKeyCompromise CRLReasonCode = 1
|
||||
CRLReasonCACompromise CRLReasonCode = 2
|
||||
CRLReasonAffiliationChanged CRLReasonCode = 3
|
||||
CRLReasonSuperseded CRLReasonCode = 4
|
||||
CRLReasonCessationOfOperation CRLReasonCode = 5
|
||||
CRLReasonCertificateHold CRLReasonCode = 6
|
||||
CRLReasonRemoveFromCRL CRLReasonCode = 8
|
||||
CRLReasonPrivilegeWithdrawn CRLReasonCode = 9
|
||||
CRLReasonAACompromise CRLReasonCode = 10
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrUnsupportedKey is returned when an unsupported key type is encountered.
|
||||
ErrUnsupportedKey = errors.New("acme: unknown key type; only RSA and ECDSA are supported")
|
||||
|
||||
// ErrAccountAlreadyExists indicates that the Client's key has already been registered
|
||||
// with the CA. It is returned by Register method.
|
||||
ErrAccountAlreadyExists = errors.New("acme: account already exists")
|
||||
|
||||
// ErrNoAccount indicates that the Client's key has not been registered with the CA.
|
||||
ErrNoAccount = errors.New("acme: account does not exist")
|
||||
)
|
||||
|
||||
// A Subproblem describes an ACME subproblem as reported in an Error.
|
||||
type Subproblem struct {
|
||||
// Type is a URI reference that identifies the problem type,
|
||||
// typically in a "urn:acme:error:xxx" form.
|
||||
Type string
|
||||
// Detail is a human-readable explanation specific to this occurrence of the problem.
|
||||
Detail string
|
||||
// Instance indicates a URL that the client should direct a human user to visit
|
||||
// in order for instructions on how to agree to the updated Terms of Service.
|
||||
// In such an event CA sets StatusCode to 403, Type to
|
||||
// "urn:ietf:params:acme:error:userActionRequired", and adds a Link header with relation
|
||||
// "terms-of-service" containing the latest TOS URL.
|
||||
Instance string
|
||||
// Identifier may contain the ACME identifier that the error is for.
|
||||
Identifier *AuthzID
|
||||
}
|
||||
|
||||
func (sp Subproblem) String() string {
|
||||
str := fmt.Sprintf("%s: ", sp.Type)
|
||||
if sp.Identifier != nil {
|
||||
str += fmt.Sprintf("[%s: %s] ", sp.Identifier.Type, sp.Identifier.Value)
|
||||
}
|
||||
str += sp.Detail
|
||||
return str
|
||||
}
|
||||
|
||||
// Error is an ACME error, defined in Problem Details for HTTP APIs doc
|
||||
// http://tools.ietf.org/html/draft-ietf-appsawg-http-problem.
|
||||
type Error struct {
|
||||
// StatusCode is The HTTP status code generated by the origin server.
|
||||
StatusCode int
|
||||
// ProblemType is a URI reference that identifies the problem type,
|
||||
// typically in a "urn:acme:error:xxx" form.
|
||||
ProblemType string
|
||||
// Detail is a human-readable explanation specific to this occurrence of the problem.
|
||||
Detail string
|
||||
// Instance indicates a URL that the client should direct a human user to visit
|
||||
// in order for instructions on how to agree to the updated Terms of Service.
|
||||
// In such an event CA sets StatusCode to 403, ProblemType to
|
||||
// "urn:ietf:params:acme:error:userActionRequired" and a Link header with relation
|
||||
// "terms-of-service" containing the latest TOS URL.
|
||||
Instance string
|
||||
// Header is the original server error response headers.
|
||||
// It may be nil.
|
||||
Header http.Header
|
||||
// Subproblems may contain more detailed information about the individual problems
|
||||
// that caused the error. This field is only sent by RFC 8555 compatible ACME
|
||||
// servers. Defined in RFC 8555 Section 6.7.1.
|
||||
Subproblems []Subproblem
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
str := fmt.Sprintf("%d %s: %s", e.StatusCode, e.ProblemType, e.Detail)
|
||||
if len(e.Subproblems) > 0 {
|
||||
str += fmt.Sprintf("; subproblems:")
|
||||
for _, sp := range e.Subproblems {
|
||||
str += fmt.Sprintf("\n\t%s", sp)
|
||||
}
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
// AuthorizationError indicates that an authorization for an identifier
|
||||
// did not succeed.
|
||||
// It contains all errors from Challenge items of the failed Authorization.
|
||||
type AuthorizationError struct {
|
||||
// URI uniquely identifies the failed Authorization.
|
||||
URI string
|
||||
|
||||
// Identifier is an AuthzID.Value of the failed Authorization.
|
||||
Identifier string
|
||||
|
||||
// Errors is a collection of non-nil error values of Challenge items
|
||||
// of the failed Authorization.
|
||||
Errors []error
|
||||
}
|
||||
|
||||
func (a *AuthorizationError) Error() string {
|
||||
e := make([]string, len(a.Errors))
|
||||
for i, err := range a.Errors {
|
||||
e[i] = err.Error()
|
||||
}
|
||||
|
||||
if a.Identifier != "" {
|
||||
return fmt.Sprintf("acme: authorization error for %s: %s", a.Identifier, strings.Join(e, "; "))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("acme: authorization error: %s", strings.Join(e, "; "))
|
||||
}
|
||||
|
||||
// OrderError is returned from Client's order related methods.
|
||||
// It indicates the order is unusable and the clients should start over with
|
||||
// AuthorizeOrder.
|
||||
//
|
||||
// The clients can still fetch the order object from CA using GetOrder
|
||||
// to inspect its state.
|
||||
type OrderError struct {
|
||||
OrderURL string
|
||||
Status string
|
||||
}
|
||||
|
||||
func (oe *OrderError) Error() string {
|
||||
return fmt.Sprintf("acme: order %s status: %s", oe.OrderURL, oe.Status)
|
||||
}
|
||||
|
||||
// RateLimit reports whether err represents a rate limit error and
|
||||
// any Retry-After duration returned by the server.
|
||||
//
|
||||
// See the following for more details on rate limiting:
|
||||
// https://tools.ietf.org/html/draft-ietf-acme-acme-05#section-5.6
|
||||
func RateLimit(err error) (time.Duration, bool) {
|
||||
e, ok := err.(*Error)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
// Some CA implementations may return incorrect values.
|
||||
// Use case-insensitive comparison.
|
||||
if !strings.HasSuffix(strings.ToLower(e.ProblemType), ":ratelimited") {
|
||||
return 0, false
|
||||
}
|
||||
if e.Header == nil {
|
||||
return 0, true
|
||||
}
|
||||
return retryAfter(e.Header.Get("Retry-After")), true
|
||||
}
|
||||
|
||||
// Account is a user account. It is associated with a private key.
|
||||
// Non-RFC 8555 fields are empty when interfacing with a compliant CA.
|
||||
type Account struct {
|
||||
// URI is the account unique ID, which is also a URL used to retrieve
|
||||
// account data from the CA.
|
||||
// When interfacing with RFC 8555-compliant CAs, URI is the "kid" field
|
||||
// value in JWS signed requests.
|
||||
URI string
|
||||
|
||||
// Contact is a slice of contact info used during registration.
|
||||
// See https://tools.ietf.org/html/rfc8555#section-7.3 for supported
|
||||
// formats.
|
||||
Contact []string
|
||||
|
||||
// Status indicates current account status as returned by the CA.
|
||||
// Possible values are StatusValid, StatusDeactivated, and StatusRevoked.
|
||||
Status string
|
||||
|
||||
// OrdersURL is a URL from which a list of orders submitted by this account
|
||||
// can be fetched.
|
||||
OrdersURL string
|
||||
|
||||
// The terms user has agreed to.
|
||||
// A value not matching CurrentTerms indicates that the user hasn't agreed
|
||||
// to the actual Terms of Service of the CA.
|
||||
//
|
||||
// It is non-RFC 8555 compliant. Package users can store the ToS they agree to
|
||||
// during Client's Register call in the prompt callback function.
|
||||
AgreedTerms string
|
||||
|
||||
// Actual terms of a CA.
|
||||
//
|
||||
// It is non-RFC 8555 compliant. Use Directory's Terms field.
|
||||
// When a CA updates their terms and requires an account agreement,
|
||||
// a URL at which instructions to do so is available in Error's Instance field.
|
||||
CurrentTerms string
|
||||
|
||||
// Authz is the authorization URL used to initiate a new authz flow.
|
||||
//
|
||||
// It is non-RFC 8555 compliant. Use Directory's AuthzURL or OrderURL.
|
||||
Authz string
|
||||
|
||||
// Authorizations is a URI from which a list of authorizations
|
||||
// granted to this account can be fetched via a GET request.
|
||||
//
|
||||
// It is non-RFC 8555 compliant and is obsoleted by OrdersURL.
|
||||
Authorizations string
|
||||
|
||||
// Certificates is a URI from which a list of certificates
|
||||
// issued for this account can be fetched via a GET request.
|
||||
//
|
||||
// It is non-RFC 8555 compliant and is obsoleted by OrdersURL.
|
||||
Certificates string
|
||||
|
||||
// ExternalAccountBinding represents an arbitrary binding to an account of
|
||||
// the CA which the ACME server is tied to.
|
||||
// See https://tools.ietf.org/html/rfc8555#section-7.3.4 for more details.
|
||||
ExternalAccountBinding *ExternalAccountBinding
|
||||
}
|
||||
|
||||
// ExternalAccountBinding contains the data needed to form a request with
|
||||
// an external account binding.
|
||||
// See https://tools.ietf.org/html/rfc8555#section-7.3.4 for more details.
|
||||
type ExternalAccountBinding struct {
|
||||
// KID is the Key ID of the symmetric MAC key that the CA provides to
|
||||
// identify an external account from ACME.
|
||||
KID string
|
||||
|
||||
// Key is the bytes of the symmetric key that the CA provides to identify
|
||||
// the account. Key must correspond to the KID.
|
||||
Key []byte
|
||||
}
|
||||
|
||||
func (e *ExternalAccountBinding) String() string {
|
||||
return fmt.Sprintf("&{KID: %q, Key: redacted}", e.KID)
|
||||
}
|
||||
|
||||
// Directory is ACME server discovery data.
|
||||
// See https://tools.ietf.org/html/rfc8555#section-7.1.1 for more details.
|
||||
type Directory struct {
|
||||
// NonceURL indicates an endpoint where to fetch fresh nonce values from.
|
||||
NonceURL string
|
||||
|
||||
// RegURL is an account endpoint URL, allowing for creating new accounts.
|
||||
// Pre-RFC 8555 CAs also allow modifying existing accounts at this URL.
|
||||
RegURL string
|
||||
|
||||
// OrderURL is used to initiate the certificate issuance flow
|
||||
// as described in RFC 8555.
|
||||
OrderURL string
|
||||
|
||||
// AuthzURL is used to initiate identifier pre-authorization flow.
|
||||
// Empty string indicates the flow is unsupported by the CA.
|
||||
AuthzURL string
|
||||
|
||||
// CertURL is a new certificate issuance endpoint URL.
|
||||
// It is non-RFC 8555 compliant and is obsoleted by OrderURL.
|
||||
CertURL string
|
||||
|
||||
// RevokeURL is used to initiate a certificate revocation flow.
|
||||
RevokeURL string
|
||||
|
||||
// KeyChangeURL allows to perform account key rollover flow.
|
||||
KeyChangeURL string
|
||||
|
||||
// RenewalInfoURL allows to perform certificate renewal using the ACME
|
||||
// Renewal Information (ARI) Extension.
|
||||
RenewalInfoURL string
|
||||
|
||||
// Term is a URI identifying the current terms of service.
|
||||
Terms string
|
||||
|
||||
// Website is an HTTP or HTTPS URL locating a website
|
||||
// providing more information about the ACME server.
|
||||
Website string
|
||||
|
||||
// CAA consists of lowercase hostname elements, which the ACME server
|
||||
// recognises as referring to itself for the purposes of CAA record validation
|
||||
// as defined in RFC 6844.
|
||||
CAA []string
|
||||
|
||||
// ExternalAccountRequired indicates that the CA requires for all account-related
|
||||
// requests to include external account binding information.
|
||||
ExternalAccountRequired bool
|
||||
}
|
||||
|
||||
// Order represents a client's request for a certificate.
|
||||
// It tracks the request flow progress through to issuance.
|
||||
type Order struct {
|
||||
// URI uniquely identifies an order.
|
||||
URI string
|
||||
|
||||
// Status represents the current status of the order.
|
||||
// It indicates which action the client should take.
|
||||
//
|
||||
// Possible values are StatusPending, StatusReady, StatusProcessing, StatusValid and StatusInvalid.
|
||||
// Pending means the CA does not believe that the client has fulfilled the requirements.
|
||||
// Ready indicates that the client has fulfilled all the requirements and can submit a CSR
|
||||
// to obtain a certificate. This is done with Client's CreateOrderCert.
|
||||
// Processing means the certificate is being issued.
|
||||
// Valid indicates the CA has issued the certificate. It can be downloaded
|
||||
// from the Order's CertURL. This is done with Client's FetchCert.
|
||||
// Invalid means the certificate will not be issued. Users should consider this order
|
||||
// abandoned.
|
||||
Status string
|
||||
|
||||
// Expires is the timestamp after which CA considers this order invalid.
|
||||
Expires time.Time
|
||||
|
||||
// Identifiers contains all identifier objects which the order pertains to.
|
||||
Identifiers []AuthzID
|
||||
|
||||
// NotBefore is the requested value of the notBefore field in the certificate.
|
||||
NotBefore time.Time
|
||||
|
||||
// NotAfter is the requested value of the notAfter field in the certificate.
|
||||
NotAfter time.Time
|
||||
|
||||
// AuthzURLs represents authorizations to complete before a certificate
|
||||
// for identifiers specified in the order can be issued.
|
||||
// It also contains unexpired authorizations that the client has completed
|
||||
// in the past.
|
||||
//
|
||||
// Authorization objects can be fetched using Client's GetAuthorization method.
|
||||
//
|
||||
// The required authorizations are dictated by CA policies.
|
||||
// There may not be a 1:1 relationship between the identifiers and required authorizations.
|
||||
// Required authorizations can be identified by their StatusPending status.
|
||||
//
|
||||
// For orders in the StatusValid or StatusInvalid state these are the authorizations
|
||||
// which were completed.
|
||||
AuthzURLs []string
|
||||
|
||||
// FinalizeURL is the endpoint at which a CSR is submitted to obtain a certificate
|
||||
// once all the authorizations are satisfied.
|
||||
FinalizeURL string
|
||||
|
||||
// CertURL points to the certificate that has been issued in response to this order.
|
||||
CertURL string
|
||||
|
||||
// The error that occurred while processing the order as received from a CA, if any.
|
||||
Error *Error
|
||||
}
|
||||
|
||||
// OrderOption allows customizing Client.AuthorizeOrder call.
|
||||
type OrderOption interface {
|
||||
privateOrderOpt()
|
||||
}
|
||||
|
||||
// WithOrderNotBefore sets order's NotBefore field.
|
||||
func WithOrderNotBefore(t time.Time) OrderOption {
|
||||
return orderNotBeforeOpt(t)
|
||||
}
|
||||
|
||||
// WithOrderNotAfter sets order's NotAfter field.
|
||||
func WithOrderNotAfter(t time.Time) OrderOption {
|
||||
return orderNotAfterOpt(t)
|
||||
}
|
||||
|
||||
type orderNotBeforeOpt time.Time
|
||||
|
||||
func (orderNotBeforeOpt) privateOrderOpt() {}
|
||||
|
||||
type orderNotAfterOpt time.Time
|
||||
|
||||
func (orderNotAfterOpt) privateOrderOpt() {}
|
||||
|
||||
// Authorization encodes an authorization response.
|
||||
type Authorization struct {
|
||||
// URI uniquely identifies a authorization.
|
||||
URI string
|
||||
|
||||
// Status is the current status of an authorization.
|
||||
// Possible values are StatusPending, StatusValid, StatusInvalid, StatusDeactivated,
|
||||
// StatusExpired and StatusRevoked.
|
||||
Status string
|
||||
|
||||
// Identifier is what the account is authorized to represent.
|
||||
Identifier AuthzID
|
||||
|
||||
// The timestamp after which the CA considers the authorization invalid.
|
||||
Expires time.Time
|
||||
|
||||
// Wildcard is true for authorizations of a wildcard domain name.
|
||||
Wildcard bool
|
||||
|
||||
// Challenges that the client needs to fulfill in order to prove possession
|
||||
// of the identifier (for pending authorizations).
|
||||
// For valid authorizations, the challenge that was validated.
|
||||
// For invalid authorizations, the challenge that was attempted and failed.
|
||||
//
|
||||
// RFC 8555 compatible CAs require users to fuflfill only one of the challenges.
|
||||
Challenges []*Challenge
|
||||
|
||||
// A collection of sets of challenges, each of which would be sufficient
|
||||
// to prove possession of the identifier.
|
||||
// Clients must complete a set of challenges that covers at least one set.
|
||||
// Challenges are identified by their indices in the challenges array.
|
||||
// If this field is empty, the client needs to complete all challenges.
|
||||
//
|
||||
// This field is unused in RFC 8555.
|
||||
Combinations [][]int
|
||||
}
|
||||
|
||||
// AuthzID is an identifier that an account is authorized to represent.
|
||||
type AuthzID struct {
|
||||
Type string // The type of identifier, "dns" or "ip".
|
||||
Value string // The identifier itself, e.g. "example.org".
|
||||
}
|
||||
|
||||
// DomainIDs creates a slice of AuthzID with "dns" identifier type.
|
||||
func DomainIDs(names ...string) []AuthzID {
|
||||
a := make([]AuthzID, len(names))
|
||||
for i, v := range names {
|
||||
a[i] = AuthzID{Type: "dns", Value: v}
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
// IPIDs creates a slice of AuthzID with "ip" identifier type.
|
||||
// Each element of addr is textual form of an address as defined
|
||||
// in RFC 1123 Section 2.1 for IPv4 and in RFC 5952 Section 4 for IPv6.
|
||||
func IPIDs(addr ...string) []AuthzID {
|
||||
a := make([]AuthzID, len(addr))
|
||||
for i, v := range addr {
|
||||
a[i] = AuthzID{Type: "ip", Value: v}
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
// wireAuthzID is ACME JSON representation of authorization identifier objects.
|
||||
type wireAuthzID struct {
|
||||
Type string `json:"type"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// wireAuthz is ACME JSON representation of Authorization objects.
|
||||
type wireAuthz struct {
|
||||
Identifier wireAuthzID
|
||||
Status string
|
||||
Expires time.Time
|
||||
Wildcard bool
|
||||
Challenges []wireChallenge
|
||||
Combinations [][]int
|
||||
Error *wireError
|
||||
}
|
||||
|
||||
func (z *wireAuthz) authorization(uri string) *Authorization {
|
||||
a := &Authorization{
|
||||
URI: uri,
|
||||
Status: z.Status,
|
||||
Identifier: AuthzID{Type: z.Identifier.Type, Value: z.Identifier.Value},
|
||||
Expires: z.Expires,
|
||||
Wildcard: z.Wildcard,
|
||||
Challenges: make([]*Challenge, len(z.Challenges)),
|
||||
Combinations: z.Combinations, // shallow copy
|
||||
}
|
||||
for i, v := range z.Challenges {
|
||||
a.Challenges[i] = v.challenge()
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func (z *wireAuthz) error(uri string) *AuthorizationError {
|
||||
err := &AuthorizationError{
|
||||
URI: uri,
|
||||
Identifier: z.Identifier.Value,
|
||||
}
|
||||
|
||||
if z.Error != nil {
|
||||
err.Errors = append(err.Errors, z.Error.error(nil))
|
||||
}
|
||||
|
||||
for _, raw := range z.Challenges {
|
||||
if raw.Error != nil {
|
||||
err.Errors = append(err.Errors, raw.Error.error(nil))
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Challenge encodes a returned CA challenge.
|
||||
// Its Error field may be non-nil if the challenge is part of an Authorization
|
||||
// with StatusInvalid.
|
||||
type Challenge struct {
|
||||
// Type is the challenge type, e.g. "http-01", "tls-alpn-01", "dns-01".
|
||||
Type string
|
||||
|
||||
// URI is where a challenge response can be posted to.
|
||||
URI string
|
||||
|
||||
// Token is a random value that uniquely identifies the challenge.
|
||||
Token string
|
||||
|
||||
// Status identifies the status of this challenge.
|
||||
// In RFC 8555, possible values are StatusPending, StatusProcessing, StatusValid,
|
||||
// and StatusInvalid.
|
||||
Status string
|
||||
|
||||
// Validated is the time at which the CA validated this challenge.
|
||||
// Always zero value in pre-RFC 8555.
|
||||
Validated time.Time
|
||||
|
||||
// Error indicates the reason for an authorization failure
|
||||
// when this challenge was used.
|
||||
// The type of a non-nil value is *Error.
|
||||
Error error
|
||||
}
|
||||
|
||||
// wireChallenge is ACME JSON challenge representation.
|
||||
type wireChallenge struct {
|
||||
URL string `json:"url"` // RFC
|
||||
URI string `json:"uri"` // pre-RFC
|
||||
Type string
|
||||
Token string
|
||||
Status string
|
||||
Validated time.Time
|
||||
Error *wireError
|
||||
}
|
||||
|
||||
func (c *wireChallenge) challenge() *Challenge {
|
||||
v := &Challenge{
|
||||
URI: c.URL,
|
||||
Type: c.Type,
|
||||
Token: c.Token,
|
||||
Status: c.Status,
|
||||
}
|
||||
if v.URI == "" {
|
||||
v.URI = c.URI // c.URL was empty; use legacy
|
||||
}
|
||||
if v.Status == "" {
|
||||
v.Status = StatusPending
|
||||
}
|
||||
if c.Error != nil {
|
||||
v.Error = c.Error.error(nil)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// wireError is a subset of fields of the Problem Details object
|
||||
// as described in https://tools.ietf.org/html/rfc7807#section-3.1.
|
||||
type wireError struct {
|
||||
Status int
|
||||
Type string
|
||||
Detail string
|
||||
Instance string
|
||||
Subproblems []Subproblem
|
||||
}
|
||||
|
||||
func (e *wireError) error(h http.Header) *Error {
|
||||
err := &Error{
|
||||
StatusCode: e.Status,
|
||||
ProblemType: e.Type,
|
||||
Detail: e.Detail,
|
||||
Instance: e.Instance,
|
||||
Header: h,
|
||||
Subproblems: e.Subproblems,
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// CertOption is an optional argument type for the TLS ChallengeCert methods for
|
||||
// customizing a temporary certificate for TLS-based challenges.
|
||||
type CertOption interface {
|
||||
privateCertOpt()
|
||||
}
|
||||
|
||||
// WithKey creates an option holding a private/public key pair.
|
||||
// The private part signs a certificate, and the public part represents the signee.
|
||||
func WithKey(key crypto.Signer) CertOption {
|
||||
return &certOptKey{key}
|
||||
}
|
||||
|
||||
type certOptKey struct {
|
||||
key crypto.Signer
|
||||
}
|
||||
|
||||
func (*certOptKey) privateCertOpt() {}
|
||||
|
||||
// WithTemplate creates an option for specifying a certificate template.
|
||||
// See x509.CreateCertificate for template usage details.
|
||||
//
|
||||
// In TLS ChallengeCert methods, the template is also used as parent,
|
||||
// resulting in a self-signed certificate.
|
||||
// The DNSNames field of t is always overwritten for tls-sni challenge certs.
|
||||
func WithTemplate(t *x509.Certificate) CertOption {
|
||||
return (*certOptTemplate)(t)
|
||||
}
|
||||
|
||||
type certOptTemplate x509.Certificate
|
||||
|
||||
func (*certOptTemplate) privateCertOpt() {}
|
||||
|
||||
// RenewalInfoWindow describes the time frame during which the ACME client
|
||||
// should attempt to renew, using the ACME Renewal Info Extension.
|
||||
type RenewalInfoWindow struct {
|
||||
Start time.Time `json:"start"`
|
||||
End time.Time `json:"end"`
|
||||
}
|
||||
|
||||
// RenewalInfo describes the suggested renewal window for a given certificate,
|
||||
// returned from an ACME server, using the ACME Renewal Info Extension.
|
||||
type RenewalInfo struct {
|
||||
SuggestedWindow RenewalInfoWindow `json:"suggestedWindow"`
|
||||
ExplanationURL string `json:"explanationURL"`
|
||||
}
|
||||
27
vendor/github.com/tailscale/golang-x-crypto/acme/version_go112.go
generated
vendored
Normal file
27
vendor/github.com/tailscale/golang-x-crypto/acme/version_go112.go
generated
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
// Copyright 2019 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build go1.12
|
||||
|
||||
package acme
|
||||
|
||||
import "runtime/debug"
|
||||
|
||||
func init() {
|
||||
// Set packageVersion if the binary was built in modules mode and x/crypto
|
||||
// was not replaced with a different module.
|
||||
info, ok := debug.ReadBuildInfo()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for _, m := range info.Deps {
|
||||
if m.Path != "golang.org/x/crypto" {
|
||||
continue
|
||||
}
|
||||
if m.Replace == nil {
|
||||
packageVersion = m.Version
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
9
vendor/github.com/tailscale/golang-x-crypto/internal/poly1305/mac_noasm.go
generated
vendored
Normal file
9
vendor/github.com/tailscale/golang-x-crypto/internal/poly1305/mac_noasm.go
generated
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright 2018 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build (!amd64 && !ppc64le && !s390x) || !gc || purego
|
||||
|
||||
package poly1305
|
||||
|
||||
type mac struct{ macGeneric }
|
||||
99
vendor/github.com/tailscale/golang-x-crypto/internal/poly1305/poly1305.go
generated
vendored
Normal file
99
vendor/github.com/tailscale/golang-x-crypto/internal/poly1305/poly1305.go
generated
vendored
Normal file
@@ -0,0 +1,99 @@
|
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package poly1305 implements Poly1305 one-time message authentication code as
|
||||
// specified in https://cr.yp.to/mac/poly1305-20050329.pdf.
|
||||
//
|
||||
// Poly1305 is a fast, one-time authentication function. It is infeasible for an
|
||||
// attacker to generate an authenticator for a message without the key. However, a
|
||||
// key must only be used for a single message. Authenticating two different
|
||||
// messages with the same key allows an attacker to forge authenticators for other
|
||||
// messages with the same key.
|
||||
//
|
||||
// Poly1305 was originally coupled with AES in order to make Poly1305-AES. AES was
|
||||
// used with a fixed key in order to generate one-time keys from an nonce.
|
||||
// However, in this package AES isn't used and the one-time key is specified
|
||||
// directly.
|
||||
package poly1305
|
||||
|
||||
import "crypto/subtle"
|
||||
|
||||
// TagSize is the size, in bytes, of a poly1305 authenticator.
|
||||
const TagSize = 16
|
||||
|
||||
// Sum generates an authenticator for msg using a one-time key and puts the
|
||||
// 16-byte result into out. Authenticating two different messages with the same
|
||||
// key allows an attacker to forge messages at will.
|
||||
func Sum(out *[16]byte, m []byte, key *[32]byte) {
|
||||
h := New(key)
|
||||
h.Write(m)
|
||||
h.Sum(out[:0])
|
||||
}
|
||||
|
||||
// Verify returns true if mac is a valid authenticator for m with the given key.
|
||||
func Verify(mac *[16]byte, m []byte, key *[32]byte) bool {
|
||||
var tmp [16]byte
|
||||
Sum(&tmp, m, key)
|
||||
return subtle.ConstantTimeCompare(tmp[:], mac[:]) == 1
|
||||
}
|
||||
|
||||
// New returns a new MAC computing an authentication
|
||||
// tag of all data written to it with the given key.
|
||||
// This allows writing the message progressively instead
|
||||
// of passing it as a single slice. Common users should use
|
||||
// the Sum function instead.
|
||||
//
|
||||
// The key must be unique for each message, as authenticating
|
||||
// two different messages with the same key allows an attacker
|
||||
// to forge messages at will.
|
||||
func New(key *[32]byte) *MAC {
|
||||
m := &MAC{}
|
||||
initialize(key, &m.macState)
|
||||
return m
|
||||
}
|
||||
|
||||
// MAC is an io.Writer computing an authentication tag
|
||||
// of the data written to it.
|
||||
//
|
||||
// MAC cannot be used like common hash.Hash implementations,
|
||||
// because using a poly1305 key twice breaks its security.
|
||||
// Therefore writing data to a running MAC after calling
|
||||
// Sum or Verify causes it to panic.
|
||||
type MAC struct {
|
||||
mac // platform-dependent implementation
|
||||
|
||||
finalized bool
|
||||
}
|
||||
|
||||
// Size returns the number of bytes Sum will return.
|
||||
func (h *MAC) Size() int { return TagSize }
|
||||
|
||||
// Write adds more data to the running message authentication code.
|
||||
// It never returns an error.
|
||||
//
|
||||
// It must not be called after the first call of Sum or Verify.
|
||||
func (h *MAC) Write(p []byte) (n int, err error) {
|
||||
if h.finalized {
|
||||
panic("poly1305: write to MAC after Sum or Verify")
|
||||
}
|
||||
return h.mac.Write(p)
|
||||
}
|
||||
|
||||
// Sum computes the authenticator of all data written to the
|
||||
// message authentication code.
|
||||
func (h *MAC) Sum(b []byte) []byte {
|
||||
var mac [TagSize]byte
|
||||
h.mac.Sum(&mac)
|
||||
h.finalized = true
|
||||
return append(b, mac[:]...)
|
||||
}
|
||||
|
||||
// Verify returns whether the authenticator of all data written to
|
||||
// the message authentication code matches the expected value.
|
||||
func (h *MAC) Verify(expected []byte) bool {
|
||||
var mac [TagSize]byte
|
||||
h.mac.Sum(&mac)
|
||||
h.finalized = true
|
||||
return subtle.ConstantTimeCompare(expected, mac[:]) == 1
|
||||
}
|
||||
47
vendor/github.com/tailscale/golang-x-crypto/internal/poly1305/sum_amd64.go
generated
vendored
Normal file
47
vendor/github.com/tailscale/golang-x-crypto/internal/poly1305/sum_amd64.go
generated
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build gc && !purego
|
||||
|
||||
package poly1305
|
||||
|
||||
//go:noescape
|
||||
func update(state *macState, msg []byte)
|
||||
|
||||
// mac is a wrapper for macGeneric that redirects calls that would have gone to
|
||||
// updateGeneric to update.
|
||||
//
|
||||
// Its Write and Sum methods are otherwise identical to the macGeneric ones, but
|
||||
// using function pointers would carry a major performance cost.
|
||||
type mac struct{ macGeneric }
|
||||
|
||||
func (h *mac) Write(p []byte) (int, error) {
|
||||
nn := len(p)
|
||||
if h.offset > 0 {
|
||||
n := copy(h.buffer[h.offset:], p)
|
||||
if h.offset+n < TagSize {
|
||||
h.offset += n
|
||||
return nn, nil
|
||||
}
|
||||
p = p[n:]
|
||||
h.offset = 0
|
||||
update(&h.macState, h.buffer[:])
|
||||
}
|
||||
if n := len(p) - (len(p) % TagSize); n > 0 {
|
||||
update(&h.macState, p[:n])
|
||||
p = p[n:]
|
||||
}
|
||||
if len(p) > 0 {
|
||||
h.offset += copy(h.buffer[h.offset:], p)
|
||||
}
|
||||
return nn, nil
|
||||
}
|
||||
|
||||
func (h *mac) Sum(out *[16]byte) {
|
||||
state := h.macState
|
||||
if h.offset > 0 {
|
||||
update(&state, h.buffer[:h.offset])
|
||||
}
|
||||
finalize(out, &state.h, &state.s)
|
||||
}
|
||||
108
vendor/github.com/tailscale/golang-x-crypto/internal/poly1305/sum_amd64.s
generated
vendored
Normal file
108
vendor/github.com/tailscale/golang-x-crypto/internal/poly1305/sum_amd64.s
generated
vendored
Normal file
@@ -0,0 +1,108 @@
|
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build gc && !purego
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
#define POLY1305_ADD(msg, h0, h1, h2) \
|
||||
ADDQ 0(msg), h0; \
|
||||
ADCQ 8(msg), h1; \
|
||||
ADCQ $1, h2; \
|
||||
LEAQ 16(msg), msg
|
||||
|
||||
#define POLY1305_MUL(h0, h1, h2, r0, r1, t0, t1, t2, t3) \
|
||||
MOVQ r0, AX; \
|
||||
MULQ h0; \
|
||||
MOVQ AX, t0; \
|
||||
MOVQ DX, t1; \
|
||||
MOVQ r0, AX; \
|
||||
MULQ h1; \
|
||||
ADDQ AX, t1; \
|
||||
ADCQ $0, DX; \
|
||||
MOVQ r0, t2; \
|
||||
IMULQ h2, t2; \
|
||||
ADDQ DX, t2; \
|
||||
\
|
||||
MOVQ r1, AX; \
|
||||
MULQ h0; \
|
||||
ADDQ AX, t1; \
|
||||
ADCQ $0, DX; \
|
||||
MOVQ DX, h0; \
|
||||
MOVQ r1, t3; \
|
||||
IMULQ h2, t3; \
|
||||
MOVQ r1, AX; \
|
||||
MULQ h1; \
|
||||
ADDQ AX, t2; \
|
||||
ADCQ DX, t3; \
|
||||
ADDQ h0, t2; \
|
||||
ADCQ $0, t3; \
|
||||
\
|
||||
MOVQ t0, h0; \
|
||||
MOVQ t1, h1; \
|
||||
MOVQ t2, h2; \
|
||||
ANDQ $3, h2; \
|
||||
MOVQ t2, t0; \
|
||||
ANDQ $0xFFFFFFFFFFFFFFFC, t0; \
|
||||
ADDQ t0, h0; \
|
||||
ADCQ t3, h1; \
|
||||
ADCQ $0, h2; \
|
||||
SHRQ $2, t3, t2; \
|
||||
SHRQ $2, t3; \
|
||||
ADDQ t2, h0; \
|
||||
ADCQ t3, h1; \
|
||||
ADCQ $0, h2
|
||||
|
||||
// func update(state *[7]uint64, msg []byte)
|
||||
TEXT ·update(SB), $0-32
|
||||
MOVQ state+0(FP), DI
|
||||
MOVQ msg_base+8(FP), SI
|
||||
MOVQ msg_len+16(FP), R15
|
||||
|
||||
MOVQ 0(DI), R8 // h0
|
||||
MOVQ 8(DI), R9 // h1
|
||||
MOVQ 16(DI), R10 // h2
|
||||
MOVQ 24(DI), R11 // r0
|
||||
MOVQ 32(DI), R12 // r1
|
||||
|
||||
CMPQ R15, $16
|
||||
JB bytes_between_0_and_15
|
||||
|
||||
loop:
|
||||
POLY1305_ADD(SI, R8, R9, R10)
|
||||
|
||||
multiply:
|
||||
POLY1305_MUL(R8, R9, R10, R11, R12, BX, CX, R13, R14)
|
||||
SUBQ $16, R15
|
||||
CMPQ R15, $16
|
||||
JAE loop
|
||||
|
||||
bytes_between_0_and_15:
|
||||
TESTQ R15, R15
|
||||
JZ done
|
||||
MOVQ $1, BX
|
||||
XORQ CX, CX
|
||||
XORQ R13, R13
|
||||
ADDQ R15, SI
|
||||
|
||||
flush_buffer:
|
||||
SHLQ $8, BX, CX
|
||||
SHLQ $8, BX
|
||||
MOVB -1(SI), R13
|
||||
XORQ R13, BX
|
||||
DECQ SI
|
||||
DECQ R15
|
||||
JNZ flush_buffer
|
||||
|
||||
ADDQ BX, R8
|
||||
ADCQ CX, R9
|
||||
ADCQ $0, R10
|
||||
MOVQ $16, R15
|
||||
JMP multiply
|
||||
|
||||
done:
|
||||
MOVQ R8, 0(DI)
|
||||
MOVQ R9, 8(DI)
|
||||
MOVQ R10, 16(DI)
|
||||
RET
|
||||
312
vendor/github.com/tailscale/golang-x-crypto/internal/poly1305/sum_generic.go
generated
vendored
Normal file
312
vendor/github.com/tailscale/golang-x-crypto/internal/poly1305/sum_generic.go
generated
vendored
Normal file
@@ -0,0 +1,312 @@
|
||||
// Copyright 2018 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// This file provides the generic implementation of Sum and MAC. Other files
|
||||
// might provide optimized assembly implementations of some of this code.
|
||||
|
||||
package poly1305
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
// Poly1305 [RFC 7539] is a relatively simple algorithm: the authentication tag
|
||||
// for a 64 bytes message is approximately
|
||||
//
|
||||
// s + m[0:16] * r⁴ + m[16:32] * r³ + m[32:48] * r² + m[48:64] * r mod 2¹³⁰ - 5
|
||||
//
|
||||
// for some secret r and s. It can be computed sequentially like
|
||||
//
|
||||
// for len(msg) > 0:
|
||||
// h += read(msg, 16)
|
||||
// h *= r
|
||||
// h %= 2¹³⁰ - 5
|
||||
// return h + s
|
||||
//
|
||||
// All the complexity is about doing performant constant-time math on numbers
|
||||
// larger than any available numeric type.
|
||||
|
||||
func sumGeneric(out *[TagSize]byte, msg []byte, key *[32]byte) {
|
||||
h := newMACGeneric(key)
|
||||
h.Write(msg)
|
||||
h.Sum(out)
|
||||
}
|
||||
|
||||
func newMACGeneric(key *[32]byte) macGeneric {
|
||||
m := macGeneric{}
|
||||
initialize(key, &m.macState)
|
||||
return m
|
||||
}
|
||||
|
||||
// macState holds numbers in saturated 64-bit little-endian limbs. That is,
|
||||
// the value of [x0, x1, x2] is x[0] + x[1] * 2⁶⁴ + x[2] * 2¹²⁸.
|
||||
type macState struct {
|
||||
// h is the main accumulator. It is to be interpreted modulo 2¹³⁰ - 5, but
|
||||
// can grow larger during and after rounds. It must, however, remain below
|
||||
// 2 * (2¹³⁰ - 5).
|
||||
h [3]uint64
|
||||
// r and s are the private key components.
|
||||
r [2]uint64
|
||||
s [2]uint64
|
||||
}
|
||||
|
||||
type macGeneric struct {
|
||||
macState
|
||||
|
||||
buffer [TagSize]byte
|
||||
offset int
|
||||
}
|
||||
|
||||
// Write splits the incoming message into TagSize chunks, and passes them to
|
||||
// update. It buffers incomplete chunks.
|
||||
func (h *macGeneric) Write(p []byte) (int, error) {
|
||||
nn := len(p)
|
||||
if h.offset > 0 {
|
||||
n := copy(h.buffer[h.offset:], p)
|
||||
if h.offset+n < TagSize {
|
||||
h.offset += n
|
||||
return nn, nil
|
||||
}
|
||||
p = p[n:]
|
||||
h.offset = 0
|
||||
updateGeneric(&h.macState, h.buffer[:])
|
||||
}
|
||||
if n := len(p) - (len(p) % TagSize); n > 0 {
|
||||
updateGeneric(&h.macState, p[:n])
|
||||
p = p[n:]
|
||||
}
|
||||
if len(p) > 0 {
|
||||
h.offset += copy(h.buffer[h.offset:], p)
|
||||
}
|
||||
return nn, nil
|
||||
}
|
||||
|
||||
// Sum flushes the last incomplete chunk from the buffer, if any, and generates
|
||||
// the MAC output. It does not modify its state, in order to allow for multiple
|
||||
// calls to Sum, even if no Write is allowed after Sum.
|
||||
func (h *macGeneric) Sum(out *[TagSize]byte) {
|
||||
state := h.macState
|
||||
if h.offset > 0 {
|
||||
updateGeneric(&state, h.buffer[:h.offset])
|
||||
}
|
||||
finalize(out, &state.h, &state.s)
|
||||
}
|
||||
|
||||
// [rMask0, rMask1] is the specified Poly1305 clamping mask in little-endian. It
|
||||
// clears some bits of the secret coefficient to make it possible to implement
|
||||
// multiplication more efficiently.
|
||||
const (
|
||||
rMask0 = 0x0FFFFFFC0FFFFFFF
|
||||
rMask1 = 0x0FFFFFFC0FFFFFFC
|
||||
)
|
||||
|
||||
// initialize loads the 256-bit key into the two 128-bit secret values r and s.
|
||||
func initialize(key *[32]byte, m *macState) {
|
||||
m.r[0] = binary.LittleEndian.Uint64(key[0:8]) & rMask0
|
||||
m.r[1] = binary.LittleEndian.Uint64(key[8:16]) & rMask1
|
||||
m.s[0] = binary.LittleEndian.Uint64(key[16:24])
|
||||
m.s[1] = binary.LittleEndian.Uint64(key[24:32])
|
||||
}
|
||||
|
||||
// uint128 holds a 128-bit number as two 64-bit limbs, for use with the
|
||||
// bits.Mul64 and bits.Add64 intrinsics.
|
||||
type uint128 struct {
|
||||
lo, hi uint64
|
||||
}
|
||||
|
||||
func mul64(a, b uint64) uint128 {
|
||||
hi, lo := bits.Mul64(a, b)
|
||||
return uint128{lo, hi}
|
||||
}
|
||||
|
||||
func add128(a, b uint128) uint128 {
|
||||
lo, c := bits.Add64(a.lo, b.lo, 0)
|
||||
hi, c := bits.Add64(a.hi, b.hi, c)
|
||||
if c != 0 {
|
||||
panic("poly1305: unexpected overflow")
|
||||
}
|
||||
return uint128{lo, hi}
|
||||
}
|
||||
|
||||
func shiftRightBy2(a uint128) uint128 {
|
||||
a.lo = a.lo>>2 | (a.hi&3)<<62
|
||||
a.hi = a.hi >> 2
|
||||
return a
|
||||
}
|
||||
|
||||
// updateGeneric absorbs msg into the state.h accumulator. For each chunk m of
|
||||
// 128 bits of message, it computes
|
||||
//
|
||||
// h₊ = (h + m) * r mod 2¹³⁰ - 5
|
||||
//
|
||||
// If the msg length is not a multiple of TagSize, it assumes the last
|
||||
// incomplete chunk is the final one.
|
||||
func updateGeneric(state *macState, msg []byte) {
|
||||
h0, h1, h2 := state.h[0], state.h[1], state.h[2]
|
||||
r0, r1 := state.r[0], state.r[1]
|
||||
|
||||
for len(msg) > 0 {
|
||||
var c uint64
|
||||
|
||||
// For the first step, h + m, we use a chain of bits.Add64 intrinsics.
|
||||
// The resulting value of h might exceed 2¹³⁰ - 5, but will be partially
|
||||
// reduced at the end of the multiplication below.
|
||||
//
|
||||
// The spec requires us to set a bit just above the message size, not to
|
||||
// hide leading zeroes. For full chunks, that's 1 << 128, so we can just
|
||||
// add 1 to the most significant (2¹²⁸) limb, h2.
|
||||
if len(msg) >= TagSize {
|
||||
h0, c = bits.Add64(h0, binary.LittleEndian.Uint64(msg[0:8]), 0)
|
||||
h1, c = bits.Add64(h1, binary.LittleEndian.Uint64(msg[8:16]), c)
|
||||
h2 += c + 1
|
||||
|
||||
msg = msg[TagSize:]
|
||||
} else {
|
||||
var buf [TagSize]byte
|
||||
copy(buf[:], msg)
|
||||
buf[len(msg)] = 1
|
||||
|
||||
h0, c = bits.Add64(h0, binary.LittleEndian.Uint64(buf[0:8]), 0)
|
||||
h1, c = bits.Add64(h1, binary.LittleEndian.Uint64(buf[8:16]), c)
|
||||
h2 += c
|
||||
|
||||
msg = nil
|
||||
}
|
||||
|
||||
// Multiplication of big number limbs is similar to elementary school
|
||||
// columnar multiplication. Instead of digits, there are 64-bit limbs.
|
||||
//
|
||||
// We are multiplying a 3 limbs number, h, by a 2 limbs number, r.
|
||||
//
|
||||
// h2 h1 h0 x
|
||||
// r1 r0 =
|
||||
// ----------------
|
||||
// h2r0 h1r0 h0r0 <-- individual 128-bit products
|
||||
// + h2r1 h1r1 h0r1
|
||||
// ------------------------
|
||||
// m3 m2 m1 m0 <-- result in 128-bit overlapping limbs
|
||||
// ------------------------
|
||||
// m3.hi m2.hi m1.hi m0.hi <-- carry propagation
|
||||
// + m3.lo m2.lo m1.lo m0.lo
|
||||
// -------------------------------
|
||||
// t4 t3 t2 t1 t0 <-- final result in 64-bit limbs
|
||||
//
|
||||
// The main difference from pen-and-paper multiplication is that we do
|
||||
// carry propagation in a separate step, as if we wrote two digit sums
|
||||
// at first (the 128-bit limbs), and then carried the tens all at once.
|
||||
|
||||
h0r0 := mul64(h0, r0)
|
||||
h1r0 := mul64(h1, r0)
|
||||
h2r0 := mul64(h2, r0)
|
||||
h0r1 := mul64(h0, r1)
|
||||
h1r1 := mul64(h1, r1)
|
||||
h2r1 := mul64(h2, r1)
|
||||
|
||||
// Since h2 is known to be at most 7 (5 + 1 + 1), and r0 and r1 have their
|
||||
// top 4 bits cleared by rMask{0,1}, we know that their product is not going
|
||||
// to overflow 64 bits, so we can ignore the high part of the products.
|
||||
//
|
||||
// This also means that the product doesn't have a fifth limb (t4).
|
||||
if h2r0.hi != 0 {
|
||||
panic("poly1305: unexpected overflow")
|
||||
}
|
||||
if h2r1.hi != 0 {
|
||||
panic("poly1305: unexpected overflow")
|
||||
}
|
||||
|
||||
m0 := h0r0
|
||||
m1 := add128(h1r0, h0r1) // These two additions don't overflow thanks again
|
||||
m2 := add128(h2r0, h1r1) // to the 4 masked bits at the top of r0 and r1.
|
||||
m3 := h2r1
|
||||
|
||||
t0 := m0.lo
|
||||
t1, c := bits.Add64(m1.lo, m0.hi, 0)
|
||||
t2, c := bits.Add64(m2.lo, m1.hi, c)
|
||||
t3, _ := bits.Add64(m3.lo, m2.hi, c)
|
||||
|
||||
// Now we have the result as 4 64-bit limbs, and we need to reduce it
|
||||
// modulo 2¹³⁰ - 5. The special shape of this Crandall prime lets us do
|
||||
// a cheap partial reduction according to the reduction identity
|
||||
//
|
||||
// c * 2¹³⁰ + n = c * 5 + n mod 2¹³⁰ - 5
|
||||
//
|
||||
// because 2¹³⁰ = 5 mod 2¹³⁰ - 5. Partial reduction since the result is
|
||||
// likely to be larger than 2¹³⁰ - 5, but still small enough to fit the
|
||||
// assumptions we make about h in the rest of the code.
|
||||
//
|
||||
// See also https://speakerdeck.com/gtank/engineering-prime-numbers?slide=23
|
||||
|
||||
// We split the final result at the 2¹³⁰ mark into h and cc, the carry.
|
||||
// Note that the carry bits are effectively shifted left by 2, in other
|
||||
// words, cc = c * 4 for the c in the reduction identity.
|
||||
h0, h1, h2 = t0, t1, t2&maskLow2Bits
|
||||
cc := uint128{t2 & maskNotLow2Bits, t3}
|
||||
|
||||
// To add c * 5 to h, we first add cc = c * 4, and then add (cc >> 2) = c.
|
||||
|
||||
h0, c = bits.Add64(h0, cc.lo, 0)
|
||||
h1, c = bits.Add64(h1, cc.hi, c)
|
||||
h2 += c
|
||||
|
||||
cc = shiftRightBy2(cc)
|
||||
|
||||
h0, c = bits.Add64(h0, cc.lo, 0)
|
||||
h1, c = bits.Add64(h1, cc.hi, c)
|
||||
h2 += c
|
||||
|
||||
// h2 is at most 3 + 1 + 1 = 5, making the whole of h at most
|
||||
//
|
||||
// 5 * 2¹²⁸ + (2¹²⁸ - 1) = 6 * 2¹²⁸ - 1
|
||||
}
|
||||
|
||||
state.h[0], state.h[1], state.h[2] = h0, h1, h2
|
||||
}
|
||||
|
||||
const (
|
||||
maskLow2Bits uint64 = 0x0000000000000003
|
||||
maskNotLow2Bits uint64 = ^maskLow2Bits
|
||||
)
|
||||
|
||||
// select64 returns x if v == 1 and y if v == 0, in constant time.
|
||||
func select64(v, x, y uint64) uint64 { return ^(v-1)&x | (v-1)&y }
|
||||
|
||||
// [p0, p1, p2] is 2¹³⁰ - 5 in little endian order.
|
||||
const (
|
||||
p0 = 0xFFFFFFFFFFFFFFFB
|
||||
p1 = 0xFFFFFFFFFFFFFFFF
|
||||
p2 = 0x0000000000000003
|
||||
)
|
||||
|
||||
// finalize completes the modular reduction of h and computes
|
||||
//
|
||||
// out = h + s mod 2¹²⁸
|
||||
func finalize(out *[TagSize]byte, h *[3]uint64, s *[2]uint64) {
|
||||
h0, h1, h2 := h[0], h[1], h[2]
|
||||
|
||||
// After the partial reduction in updateGeneric, h might be more than
|
||||
// 2¹³⁰ - 5, but will be less than 2 * (2¹³⁰ - 5). To complete the reduction
|
||||
// in constant time, we compute t = h - (2¹³⁰ - 5), and select h as the
|
||||
// result if the subtraction underflows, and t otherwise.
|
||||
|
||||
hMinusP0, b := bits.Sub64(h0, p0, 0)
|
||||
hMinusP1, b := bits.Sub64(h1, p1, b)
|
||||
_, b = bits.Sub64(h2, p2, b)
|
||||
|
||||
// h = h if h < p else h - p
|
||||
h0 = select64(b, h0, hMinusP0)
|
||||
h1 = select64(b, h1, hMinusP1)
|
||||
|
||||
// Finally, we compute the last Poly1305 step
|
||||
//
|
||||
// tag = h + s mod 2¹²⁸
|
||||
//
|
||||
// by just doing a wide addition with the 128 low bits of h and discarding
|
||||
// the overflow.
|
||||
h0, c := bits.Add64(h0, s[0], 0)
|
||||
h1, _ = bits.Add64(h1, s[1], c)
|
||||
|
||||
binary.LittleEndian.PutUint64(out[0:8], h0)
|
||||
binary.LittleEndian.PutUint64(out[8:16], h1)
|
||||
}
|
||||
47
vendor/github.com/tailscale/golang-x-crypto/internal/poly1305/sum_ppc64le.go
generated
vendored
Normal file
47
vendor/github.com/tailscale/golang-x-crypto/internal/poly1305/sum_ppc64le.go
generated
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright 2019 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build gc && !purego
|
||||
|
||||
package poly1305
|
||||
|
||||
//go:noescape
|
||||
func update(state *macState, msg []byte)
|
||||
|
||||
// mac is a wrapper for macGeneric that redirects calls that would have gone to
|
||||
// updateGeneric to update.
|
||||
//
|
||||
// Its Write and Sum methods are otherwise identical to the macGeneric ones, but
|
||||
// using function pointers would carry a major performance cost.
|
||||
type mac struct{ macGeneric }
|
||||
|
||||
func (h *mac) Write(p []byte) (int, error) {
|
||||
nn := len(p)
|
||||
if h.offset > 0 {
|
||||
n := copy(h.buffer[h.offset:], p)
|
||||
if h.offset+n < TagSize {
|
||||
h.offset += n
|
||||
return nn, nil
|
||||
}
|
||||
p = p[n:]
|
||||
h.offset = 0
|
||||
update(&h.macState, h.buffer[:])
|
||||
}
|
||||
if n := len(p) - (len(p) % TagSize); n > 0 {
|
||||
update(&h.macState, p[:n])
|
||||
p = p[n:]
|
||||
}
|
||||
if len(p) > 0 {
|
||||
h.offset += copy(h.buffer[h.offset:], p)
|
||||
}
|
||||
return nn, nil
|
||||
}
|
||||
|
||||
func (h *mac) Sum(out *[16]byte) {
|
||||
state := h.macState
|
||||
if h.offset > 0 {
|
||||
update(&state, h.buffer[:h.offset])
|
||||
}
|
||||
finalize(out, &state.h, &state.s)
|
||||
}
|
||||
181
vendor/github.com/tailscale/golang-x-crypto/internal/poly1305/sum_ppc64le.s
generated
vendored
Normal file
181
vendor/github.com/tailscale/golang-x-crypto/internal/poly1305/sum_ppc64le.s
generated
vendored
Normal file
@@ -0,0 +1,181 @@
|
||||
// Copyright 2019 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build gc && !purego
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
// This was ported from the amd64 implementation.
|
||||
|
||||
#define POLY1305_ADD(msg, h0, h1, h2, t0, t1, t2) \
|
||||
MOVD (msg), t0; \
|
||||
MOVD 8(msg), t1; \
|
||||
MOVD $1, t2; \
|
||||
ADDC t0, h0, h0; \
|
||||
ADDE t1, h1, h1; \
|
||||
ADDE t2, h2; \
|
||||
ADD $16, msg
|
||||
|
||||
#define POLY1305_MUL(h0, h1, h2, r0, r1, t0, t1, t2, t3, t4, t5) \
|
||||
MULLD r0, h0, t0; \
|
||||
MULLD r0, h1, t4; \
|
||||
MULHDU r0, h0, t1; \
|
||||
MULHDU r0, h1, t5; \
|
||||
ADDC t4, t1, t1; \
|
||||
MULLD r0, h2, t2; \
|
||||
ADDZE t5; \
|
||||
MULHDU r1, h0, t4; \
|
||||
MULLD r1, h0, h0; \
|
||||
ADD t5, t2, t2; \
|
||||
ADDC h0, t1, t1; \
|
||||
MULLD h2, r1, t3; \
|
||||
ADDZE t4, h0; \
|
||||
MULHDU r1, h1, t5; \
|
||||
MULLD r1, h1, t4; \
|
||||
ADDC t4, t2, t2; \
|
||||
ADDE t5, t3, t3; \
|
||||
ADDC h0, t2, t2; \
|
||||
MOVD $-4, t4; \
|
||||
MOVD t0, h0; \
|
||||
MOVD t1, h1; \
|
||||
ADDZE t3; \
|
||||
ANDCC $3, t2, h2; \
|
||||
AND t2, t4, t0; \
|
||||
ADDC t0, h0, h0; \
|
||||
ADDE t3, h1, h1; \
|
||||
SLD $62, t3, t4; \
|
||||
SRD $2, t2; \
|
||||
ADDZE h2; \
|
||||
OR t4, t2, t2; \
|
||||
SRD $2, t3; \
|
||||
ADDC t2, h0, h0; \
|
||||
ADDE t3, h1, h1; \
|
||||
ADDZE h2
|
||||
|
||||
DATA ·poly1305Mask<>+0x00(SB)/8, $0x0FFFFFFC0FFFFFFF
|
||||
DATA ·poly1305Mask<>+0x08(SB)/8, $0x0FFFFFFC0FFFFFFC
|
||||
GLOBL ·poly1305Mask<>(SB), RODATA, $16
|
||||
|
||||
// func update(state *[7]uint64, msg []byte)
|
||||
TEXT ·update(SB), $0-32
|
||||
MOVD state+0(FP), R3
|
||||
MOVD msg_base+8(FP), R4
|
||||
MOVD msg_len+16(FP), R5
|
||||
|
||||
MOVD 0(R3), R8 // h0
|
||||
MOVD 8(R3), R9 // h1
|
||||
MOVD 16(R3), R10 // h2
|
||||
MOVD 24(R3), R11 // r0
|
||||
MOVD 32(R3), R12 // r1
|
||||
|
||||
CMP R5, $16
|
||||
BLT bytes_between_0_and_15
|
||||
|
||||
loop:
|
||||
POLY1305_ADD(R4, R8, R9, R10, R20, R21, R22)
|
||||
|
||||
multiply:
|
||||
POLY1305_MUL(R8, R9, R10, R11, R12, R16, R17, R18, R14, R20, R21)
|
||||
ADD $-16, R5
|
||||
CMP R5, $16
|
||||
BGE loop
|
||||
|
||||
bytes_between_0_and_15:
|
||||
CMP R5, $0
|
||||
BEQ done
|
||||
MOVD $0, R16 // h0
|
||||
MOVD $0, R17 // h1
|
||||
|
||||
flush_buffer:
|
||||
CMP R5, $8
|
||||
BLE just1
|
||||
|
||||
MOVD $8, R21
|
||||
SUB R21, R5, R21
|
||||
|
||||
// Greater than 8 -- load the rightmost remaining bytes in msg
|
||||
// and put into R17 (h1)
|
||||
MOVD (R4)(R21), R17
|
||||
MOVD $16, R22
|
||||
|
||||
// Find the offset to those bytes
|
||||
SUB R5, R22, R22
|
||||
SLD $3, R22
|
||||
|
||||
// Shift to get only the bytes in msg
|
||||
SRD R22, R17, R17
|
||||
|
||||
// Put 1 at high end
|
||||
MOVD $1, R23
|
||||
SLD $3, R21
|
||||
SLD R21, R23, R23
|
||||
OR R23, R17, R17
|
||||
|
||||
// Remainder is 8
|
||||
MOVD $8, R5
|
||||
|
||||
just1:
|
||||
CMP R5, $8
|
||||
BLT less8
|
||||
|
||||
// Exactly 8
|
||||
MOVD (R4), R16
|
||||
|
||||
CMP R17, $0
|
||||
|
||||
// Check if we've already set R17; if not
|
||||
// set 1 to indicate end of msg.
|
||||
BNE carry
|
||||
MOVD $1, R17
|
||||
BR carry
|
||||
|
||||
less8:
|
||||
MOVD $0, R16 // h0
|
||||
MOVD $0, R22 // shift count
|
||||
CMP R5, $4
|
||||
BLT less4
|
||||
MOVWZ (R4), R16
|
||||
ADD $4, R4
|
||||
ADD $-4, R5
|
||||
MOVD $32, R22
|
||||
|
||||
less4:
|
||||
CMP R5, $2
|
||||
BLT less2
|
||||
MOVHZ (R4), R21
|
||||
SLD R22, R21, R21
|
||||
OR R16, R21, R16
|
||||
ADD $16, R22
|
||||
ADD $-2, R5
|
||||
ADD $2, R4
|
||||
|
||||
less2:
|
||||
CMP R5, $0
|
||||
BEQ insert1
|
||||
MOVBZ (R4), R21
|
||||
SLD R22, R21, R21
|
||||
OR R16, R21, R16
|
||||
ADD $8, R22
|
||||
|
||||
insert1:
|
||||
// Insert 1 at end of msg
|
||||
MOVD $1, R21
|
||||
SLD R22, R21, R21
|
||||
OR R16, R21, R16
|
||||
|
||||
carry:
|
||||
// Add new values to h0, h1, h2
|
||||
ADDC R16, R8
|
||||
ADDE R17, R9
|
||||
ADDZE R10, R10
|
||||
MOVD $16, R5
|
||||
ADD R5, R4
|
||||
BR multiply
|
||||
|
||||
done:
|
||||
// Save h0, h1, h2 in state
|
||||
MOVD R8, 0(R3)
|
||||
MOVD R9, 8(R3)
|
||||
MOVD R10, 16(R3)
|
||||
RET
|
||||
76
vendor/github.com/tailscale/golang-x-crypto/internal/poly1305/sum_s390x.go
generated
vendored
Normal file
76
vendor/github.com/tailscale/golang-x-crypto/internal/poly1305/sum_s390x.go
generated
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
// Copyright 2018 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build gc && !purego
|
||||
|
||||
package poly1305
|
||||
|
||||
import (
|
||||
"golang.org/x/sys/cpu"
|
||||
)
|
||||
|
||||
// updateVX is an assembly implementation of Poly1305 that uses vector
|
||||
// instructions. It must only be called if the vector facility (vx) is
|
||||
// available.
|
||||
//
|
||||
//go:noescape
|
||||
func updateVX(state *macState, msg []byte)
|
||||
|
||||
// mac is a replacement for macGeneric that uses a larger buffer and redirects
|
||||
// calls that would have gone to updateGeneric to updateVX if the vector
|
||||
// facility is installed.
|
||||
//
|
||||
// A larger buffer is required for good performance because the vector
|
||||
// implementation has a higher fixed cost per call than the generic
|
||||
// implementation.
|
||||
type mac struct {
|
||||
macState
|
||||
|
||||
buffer [16 * TagSize]byte // size must be a multiple of block size (16)
|
||||
offset int
|
||||
}
|
||||
|
||||
func (h *mac) Write(p []byte) (int, error) {
|
||||
nn := len(p)
|
||||
if h.offset > 0 {
|
||||
n := copy(h.buffer[h.offset:], p)
|
||||
if h.offset+n < len(h.buffer) {
|
||||
h.offset += n
|
||||
return nn, nil
|
||||
}
|
||||
p = p[n:]
|
||||
h.offset = 0
|
||||
if cpu.S390X.HasVX {
|
||||
updateVX(&h.macState, h.buffer[:])
|
||||
} else {
|
||||
updateGeneric(&h.macState, h.buffer[:])
|
||||
}
|
||||
}
|
||||
|
||||
tail := len(p) % len(h.buffer) // number of bytes to copy into buffer
|
||||
body := len(p) - tail // number of bytes to process now
|
||||
if body > 0 {
|
||||
if cpu.S390X.HasVX {
|
||||
updateVX(&h.macState, p[:body])
|
||||
} else {
|
||||
updateGeneric(&h.macState, p[:body])
|
||||
}
|
||||
}
|
||||
h.offset = copy(h.buffer[:], p[body:]) // copy tail bytes - can be 0
|
||||
return nn, nil
|
||||
}
|
||||
|
||||
func (h *mac) Sum(out *[TagSize]byte) {
|
||||
state := h.macState
|
||||
remainder := h.buffer[:h.offset]
|
||||
|
||||
// Use the generic implementation if we have 2 or fewer blocks left
|
||||
// to sum. The vector implementation has a higher startup time.
|
||||
if cpu.S390X.HasVX && len(remainder) > 2*TagSize {
|
||||
updateVX(&state, remainder)
|
||||
} else if len(remainder) > 0 {
|
||||
updateGeneric(&state, remainder)
|
||||
}
|
||||
finalize(out, &state.h, &state.s)
|
||||
}
|
||||
503
vendor/github.com/tailscale/golang-x-crypto/internal/poly1305/sum_s390x.s
generated
vendored
Normal file
503
vendor/github.com/tailscale/golang-x-crypto/internal/poly1305/sum_s390x.s
generated
vendored
Normal file
@@ -0,0 +1,503 @@
|
||||
// Copyright 2018 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build gc && !purego
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
// This implementation of Poly1305 uses the vector facility (vx)
|
||||
// to process up to 2 blocks (32 bytes) per iteration using an
|
||||
// algorithm based on the one described in:
|
||||
//
|
||||
// NEON crypto, Daniel J. Bernstein & Peter Schwabe
|
||||
// https://cryptojedi.org/papers/neoncrypto-20120320.pdf
|
||||
//
|
||||
// This algorithm uses 5 26-bit limbs to represent a 130-bit
|
||||
// value. These limbs are, for the most part, zero extended and
|
||||
// placed into 64-bit vector register elements. Each vector
|
||||
// register is 128-bits wide and so holds 2 of these elements.
|
||||
// Using 26-bit limbs allows us plenty of headroom to accommodate
|
||||
// accumulations before and after multiplication without
|
||||
// overflowing either 32-bits (before multiplication) or 64-bits
|
||||
// (after multiplication).
|
||||
//
|
||||
// In order to parallelise the operations required to calculate
|
||||
// the sum we use two separate accumulators and then sum those
|
||||
// in an extra final step. For compatibility with the generic
|
||||
// implementation we perform this summation at the end of every
|
||||
// updateVX call.
|
||||
//
|
||||
// To use two accumulators we must multiply the message blocks
|
||||
// by r² rather than r. Only the final message block should be
|
||||
// multiplied by r.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// We want to calculate the sum (h) for a 64 byte message (m):
|
||||
//
|
||||
// h = m[0:16]r⁴ + m[16:32]r³ + m[32:48]r² + m[48:64]r
|
||||
//
|
||||
// To do this we split the calculation into the even indices
|
||||
// and odd indices of the message. These form our SIMD 'lanes':
|
||||
//
|
||||
// h = m[ 0:16]r⁴ + m[32:48]r² + <- lane 0
|
||||
// m[16:32]r³ + m[48:64]r <- lane 1
|
||||
//
|
||||
// To calculate this iteratively we refactor so that both lanes
|
||||
// are written in terms of r² and r:
|
||||
//
|
||||
// h = (m[ 0:16]r² + m[32:48])r² + <- lane 0
|
||||
// (m[16:32]r² + m[48:64])r <- lane 1
|
||||
// ^ ^
|
||||
// | coefficients for second iteration
|
||||
// coefficients for first iteration
|
||||
//
|
||||
// So in this case we would have two iterations. In the first
|
||||
// both lanes are multiplied by r². In the second only the
|
||||
// first lane is multiplied by r² and the second lane is
|
||||
// instead multiplied by r. This gives use the odd and even
|
||||
// powers of r that we need from the original equation.
|
||||
//
|
||||
// Notation:
|
||||
//
|
||||
// h - accumulator
|
||||
// r - key
|
||||
// m - message
|
||||
//
|
||||
// [a, b] - SIMD register holding two 64-bit values
|
||||
// [a, b, c, d] - SIMD register holding four 32-bit values
|
||||
// xᵢ[n] - limb n of variable x with bit width i
|
||||
//
|
||||
// Limbs are expressed in little endian order, so for 26-bit
|
||||
// limbs x₂₆[4] will be the most significant limb and x₂₆[0]
|
||||
// will be the least significant limb.
|
||||
|
||||
// masking constants
|
||||
#define MOD24 V0 // [0x0000000000ffffff, 0x0000000000ffffff] - mask low 24-bits
|
||||
#define MOD26 V1 // [0x0000000003ffffff, 0x0000000003ffffff] - mask low 26-bits
|
||||
|
||||
// expansion constants (see EXPAND macro)
|
||||
#define EX0 V2
|
||||
#define EX1 V3
|
||||
#define EX2 V4
|
||||
|
||||
// key (r², r or 1 depending on context)
|
||||
#define R_0 V5
|
||||
#define R_1 V6
|
||||
#define R_2 V7
|
||||
#define R_3 V8
|
||||
#define R_4 V9
|
||||
|
||||
// precalculated coefficients (5r², 5r or 0 depending on context)
|
||||
#define R5_1 V10
|
||||
#define R5_2 V11
|
||||
#define R5_3 V12
|
||||
#define R5_4 V13
|
||||
|
||||
// message block (m)
|
||||
#define M_0 V14
|
||||
#define M_1 V15
|
||||
#define M_2 V16
|
||||
#define M_3 V17
|
||||
#define M_4 V18
|
||||
|
||||
// accumulator (h)
|
||||
#define H_0 V19
|
||||
#define H_1 V20
|
||||
#define H_2 V21
|
||||
#define H_3 V22
|
||||
#define H_4 V23
|
||||
|
||||
// temporary registers (for short-lived values)
|
||||
#define T_0 V24
|
||||
#define T_1 V25
|
||||
#define T_2 V26
|
||||
#define T_3 V27
|
||||
#define T_4 V28
|
||||
|
||||
GLOBL ·constants<>(SB), RODATA, $0x30
|
||||
// EX0
|
||||
DATA ·constants<>+0x00(SB)/8, $0x0006050403020100
|
||||
DATA ·constants<>+0x08(SB)/8, $0x1016151413121110
|
||||
// EX1
|
||||
DATA ·constants<>+0x10(SB)/8, $0x060c0b0a09080706
|
||||
DATA ·constants<>+0x18(SB)/8, $0x161c1b1a19181716
|
||||
// EX2
|
||||
DATA ·constants<>+0x20(SB)/8, $0x0d0d0d0d0d0f0e0d
|
||||
DATA ·constants<>+0x28(SB)/8, $0x1d1d1d1d1d1f1e1d
|
||||
|
||||
// MULTIPLY multiplies each lane of f and g, partially reduced
|
||||
// modulo 2¹³⁰ - 5. The result, h, consists of partial products
|
||||
// in each lane that need to be reduced further to produce the
|
||||
// final result.
|
||||
//
|
||||
// h₁₃₀ = (f₁₃₀g₁₃₀) % 2¹³⁰ + (5f₁₃₀g₁₃₀) / 2¹³⁰
|
||||
//
|
||||
// Note that the multiplication by 5 of the high bits is
|
||||
// achieved by precalculating the multiplication of four of the
|
||||
// g coefficients by 5. These are g51-g54.
|
||||
#define MULTIPLY(f0, f1, f2, f3, f4, g0, g1, g2, g3, g4, g51, g52, g53, g54, h0, h1, h2, h3, h4) \
|
||||
VMLOF f0, g0, h0 \
|
||||
VMLOF f0, g3, h3 \
|
||||
VMLOF f0, g1, h1 \
|
||||
VMLOF f0, g4, h4 \
|
||||
VMLOF f0, g2, h2 \
|
||||
VMLOF f1, g54, T_0 \
|
||||
VMLOF f1, g2, T_3 \
|
||||
VMLOF f1, g0, T_1 \
|
||||
VMLOF f1, g3, T_4 \
|
||||
VMLOF f1, g1, T_2 \
|
||||
VMALOF f2, g53, h0, h0 \
|
||||
VMALOF f2, g1, h3, h3 \
|
||||
VMALOF f2, g54, h1, h1 \
|
||||
VMALOF f2, g2, h4, h4 \
|
||||
VMALOF f2, g0, h2, h2 \
|
||||
VMALOF f3, g52, T_0, T_0 \
|
||||
VMALOF f3, g0, T_3, T_3 \
|
||||
VMALOF f3, g53, T_1, T_1 \
|
||||
VMALOF f3, g1, T_4, T_4 \
|
||||
VMALOF f3, g54, T_2, T_2 \
|
||||
VMALOF f4, g51, h0, h0 \
|
||||
VMALOF f4, g54, h3, h3 \
|
||||
VMALOF f4, g52, h1, h1 \
|
||||
VMALOF f4, g0, h4, h4 \
|
||||
VMALOF f4, g53, h2, h2 \
|
||||
VAG T_0, h0, h0 \
|
||||
VAG T_3, h3, h3 \
|
||||
VAG T_1, h1, h1 \
|
||||
VAG T_4, h4, h4 \
|
||||
VAG T_2, h2, h2
|
||||
|
||||
// REDUCE performs the following carry operations in four
|
||||
// stages, as specified in Bernstein & Schwabe:
|
||||
//
|
||||
// 1: h₂₆[0]->h₂₆[1] h₂₆[3]->h₂₆[4]
|
||||
// 2: h₂₆[1]->h₂₆[2] h₂₆[4]->h₂₆[0]
|
||||
// 3: h₂₆[0]->h₂₆[1] h₂₆[2]->h₂₆[3]
|
||||
// 4: h₂₆[3]->h₂₆[4]
|
||||
//
|
||||
// The result is that all of the limbs are limited to 26-bits
|
||||
// except for h₂₆[1] and h₂₆[4] which are limited to 27-bits.
|
||||
//
|
||||
// Note that although each limb is aligned at 26-bit intervals
|
||||
// they may contain values that exceed 2²⁶ - 1, hence the need
|
||||
// to carry the excess bits in each limb.
|
||||
#define REDUCE(h0, h1, h2, h3, h4) \
|
||||
VESRLG $26, h0, T_0 \
|
||||
VESRLG $26, h3, T_1 \
|
||||
VN MOD26, h0, h0 \
|
||||
VN MOD26, h3, h3 \
|
||||
VAG T_0, h1, h1 \
|
||||
VAG T_1, h4, h4 \
|
||||
VESRLG $26, h1, T_2 \
|
||||
VESRLG $26, h4, T_3 \
|
||||
VN MOD26, h1, h1 \
|
||||
VN MOD26, h4, h4 \
|
||||
VESLG $2, T_3, T_4 \
|
||||
VAG T_3, T_4, T_4 \
|
||||
VAG T_2, h2, h2 \
|
||||
VAG T_4, h0, h0 \
|
||||
VESRLG $26, h2, T_0 \
|
||||
VESRLG $26, h0, T_1 \
|
||||
VN MOD26, h2, h2 \
|
||||
VN MOD26, h0, h0 \
|
||||
VAG T_0, h3, h3 \
|
||||
VAG T_1, h1, h1 \
|
||||
VESRLG $26, h3, T_2 \
|
||||
VN MOD26, h3, h3 \
|
||||
VAG T_2, h4, h4
|
||||
|
||||
// EXPAND splits the 128-bit little-endian values in0 and in1
|
||||
// into 26-bit big-endian limbs and places the results into
|
||||
// the first and second lane of d₂₆[0:4] respectively.
|
||||
//
|
||||
// The EX0, EX1 and EX2 constants are arrays of byte indices
|
||||
// for permutation. The permutation both reverses the bytes
|
||||
// in the input and ensures the bytes are copied into the
|
||||
// destination limb ready to be shifted into their final
|
||||
// position.
|
||||
#define EXPAND(in0, in1, d0, d1, d2, d3, d4) \
|
||||
VPERM in0, in1, EX0, d0 \
|
||||
VPERM in0, in1, EX1, d2 \
|
||||
VPERM in0, in1, EX2, d4 \
|
||||
VESRLG $26, d0, d1 \
|
||||
VESRLG $30, d2, d3 \
|
||||
VESRLG $4, d2, d2 \
|
||||
VN MOD26, d0, d0 \ // [in0₂₆[0], in1₂₆[0]]
|
||||
VN MOD26, d3, d3 \ // [in0₂₆[3], in1₂₆[3]]
|
||||
VN MOD26, d1, d1 \ // [in0₂₆[1], in1₂₆[1]]
|
||||
VN MOD24, d4, d4 \ // [in0₂₆[4], in1₂₆[4]]
|
||||
VN MOD26, d2, d2 // [in0₂₆[2], in1₂₆[2]]
|
||||
|
||||
// func updateVX(state *macState, msg []byte)
|
||||
TEXT ·updateVX(SB), NOSPLIT, $0
|
||||
MOVD state+0(FP), R1
|
||||
LMG msg+8(FP), R2, R3 // R2=msg_base, R3=msg_len
|
||||
|
||||
// load EX0, EX1 and EX2
|
||||
MOVD $·constants<>(SB), R5
|
||||
VLM (R5), EX0, EX2
|
||||
|
||||
// generate masks
|
||||
VGMG $(64-24), $63, MOD24 // [0x00ffffff, 0x00ffffff]
|
||||
VGMG $(64-26), $63, MOD26 // [0x03ffffff, 0x03ffffff]
|
||||
|
||||
// load h (accumulator) and r (key) from state
|
||||
VZERO T_1 // [0, 0]
|
||||
VL 0(R1), T_0 // [h₆₄[0], h₆₄[1]]
|
||||
VLEG $0, 16(R1), T_1 // [h₆₄[2], 0]
|
||||
VL 24(R1), T_2 // [r₆₄[0], r₆₄[1]]
|
||||
VPDI $0, T_0, T_2, T_3 // [h₆₄[0], r₆₄[0]]
|
||||
VPDI $5, T_0, T_2, T_4 // [h₆₄[1], r₆₄[1]]
|
||||
|
||||
// unpack h and r into 26-bit limbs
|
||||
// note: h₆₄[2] may have the low 3 bits set, so h₂₆[4] is a 27-bit value
|
||||
VN MOD26, T_3, H_0 // [h₂₆[0], r₂₆[0]]
|
||||
VZERO H_1 // [0, 0]
|
||||
VZERO H_3 // [0, 0]
|
||||
VGMG $(64-12-14), $(63-12), T_0 // [0x03fff000, 0x03fff000] - 26-bit mask with low 12 bits masked out
|
||||
VESLG $24, T_1, T_1 // [h₆₄[2]<<24, 0]
|
||||
VERIMG $-26&63, T_3, MOD26, H_1 // [h₂₆[1], r₂₆[1]]
|
||||
VESRLG $+52&63, T_3, H_2 // [h₂₆[2], r₂₆[2]] - low 12 bits only
|
||||
VERIMG $-14&63, T_4, MOD26, H_3 // [h₂₆[1], r₂₆[1]]
|
||||
VESRLG $40, T_4, H_4 // [h₂₆[4], r₂₆[4]] - low 24 bits only
|
||||
VERIMG $+12&63, T_4, T_0, H_2 // [h₂₆[2], r₂₆[2]] - complete
|
||||
VO T_1, H_4, H_4 // [h₂₆[4], r₂₆[4]] - complete
|
||||
|
||||
// replicate r across all 4 vector elements
|
||||
VREPF $3, H_0, R_0 // [r₂₆[0], r₂₆[0], r₂₆[0], r₂₆[0]]
|
||||
VREPF $3, H_1, R_1 // [r₂₆[1], r₂₆[1], r₂₆[1], r₂₆[1]]
|
||||
VREPF $3, H_2, R_2 // [r₂₆[2], r₂₆[2], r₂₆[2], r₂₆[2]]
|
||||
VREPF $3, H_3, R_3 // [r₂₆[3], r₂₆[3], r₂₆[3], r₂₆[3]]
|
||||
VREPF $3, H_4, R_4 // [r₂₆[4], r₂₆[4], r₂₆[4], r₂₆[4]]
|
||||
|
||||
// zero out lane 1 of h
|
||||
VLEIG $1, $0, H_0 // [h₂₆[0], 0]
|
||||
VLEIG $1, $0, H_1 // [h₂₆[1], 0]
|
||||
VLEIG $1, $0, H_2 // [h₂₆[2], 0]
|
||||
VLEIG $1, $0, H_3 // [h₂₆[3], 0]
|
||||
VLEIG $1, $0, H_4 // [h₂₆[4], 0]
|
||||
|
||||
// calculate 5r (ignore least significant limb)
|
||||
VREPIF $5, T_0
|
||||
VMLF T_0, R_1, R5_1 // [5r₂₆[1], 5r₂₆[1], 5r₂₆[1], 5r₂₆[1]]
|
||||
VMLF T_0, R_2, R5_2 // [5r₂₆[2], 5r₂₆[2], 5r₂₆[2], 5r₂₆[2]]
|
||||
VMLF T_0, R_3, R5_3 // [5r₂₆[3], 5r₂₆[3], 5r₂₆[3], 5r₂₆[3]]
|
||||
VMLF T_0, R_4, R5_4 // [5r₂₆[4], 5r₂₆[4], 5r₂₆[4], 5r₂₆[4]]
|
||||
|
||||
// skip r² calculation if we are only calculating one block
|
||||
CMPBLE R3, $16, skip
|
||||
|
||||
// calculate r²
|
||||
MULTIPLY(R_0, R_1, R_2, R_3, R_4, R_0, R_1, R_2, R_3, R_4, R5_1, R5_2, R5_3, R5_4, M_0, M_1, M_2, M_3, M_4)
|
||||
REDUCE(M_0, M_1, M_2, M_3, M_4)
|
||||
VGBM $0x0f0f, T_0
|
||||
VERIMG $0, M_0, T_0, R_0 // [r₂₆[0], r²₂₆[0], r₂₆[0], r²₂₆[0]]
|
||||
VERIMG $0, M_1, T_0, R_1 // [r₂₆[1], r²₂₆[1], r₂₆[1], r²₂₆[1]]
|
||||
VERIMG $0, M_2, T_0, R_2 // [r₂₆[2], r²₂₆[2], r₂₆[2], r²₂₆[2]]
|
||||
VERIMG $0, M_3, T_0, R_3 // [r₂₆[3], r²₂₆[3], r₂₆[3], r²₂₆[3]]
|
||||
VERIMG $0, M_4, T_0, R_4 // [r₂₆[4], r²₂₆[4], r₂₆[4], r²₂₆[4]]
|
||||
|
||||
// calculate 5r² (ignore least significant limb)
|
||||
VREPIF $5, T_0
|
||||
VMLF T_0, R_1, R5_1 // [5r₂₆[1], 5r²₂₆[1], 5r₂₆[1], 5r²₂₆[1]]
|
||||
VMLF T_0, R_2, R5_2 // [5r₂₆[2], 5r²₂₆[2], 5r₂₆[2], 5r²₂₆[2]]
|
||||
VMLF T_0, R_3, R5_3 // [5r₂₆[3], 5r²₂₆[3], 5r₂₆[3], 5r²₂₆[3]]
|
||||
VMLF T_0, R_4, R5_4 // [5r₂₆[4], 5r²₂₆[4], 5r₂₆[4], 5r²₂₆[4]]
|
||||
|
||||
loop:
|
||||
CMPBLE R3, $32, b2 // 2 or fewer blocks remaining, need to change key coefficients
|
||||
|
||||
// load next 2 blocks from message
|
||||
VLM (R2), T_0, T_1
|
||||
|
||||
// update message slice
|
||||
SUB $32, R3
|
||||
MOVD $32(R2), R2
|
||||
|
||||
// unpack message blocks into 26-bit big-endian limbs
|
||||
EXPAND(T_0, T_1, M_0, M_1, M_2, M_3, M_4)
|
||||
|
||||
// add 2¹²⁸ to each message block value
|
||||
VLEIB $4, $1, M_4
|
||||
VLEIB $12, $1, M_4
|
||||
|
||||
multiply:
|
||||
// accumulate the incoming message
|
||||
VAG H_0, M_0, M_0
|
||||
VAG H_3, M_3, M_3
|
||||
VAG H_1, M_1, M_1
|
||||
VAG H_4, M_4, M_4
|
||||
VAG H_2, M_2, M_2
|
||||
|
||||
// multiply the accumulator by the key coefficient
|
||||
MULTIPLY(M_0, M_1, M_2, M_3, M_4, R_0, R_1, R_2, R_3, R_4, R5_1, R5_2, R5_3, R5_4, H_0, H_1, H_2, H_3, H_4)
|
||||
|
||||
// carry and partially reduce the partial products
|
||||
REDUCE(H_0, H_1, H_2, H_3, H_4)
|
||||
|
||||
CMPBNE R3, $0, loop
|
||||
|
||||
finish:
|
||||
// sum lane 0 and lane 1 and put the result in lane 1
|
||||
VZERO T_0
|
||||
VSUMQG H_0, T_0, H_0
|
||||
VSUMQG H_3, T_0, H_3
|
||||
VSUMQG H_1, T_0, H_1
|
||||
VSUMQG H_4, T_0, H_4
|
||||
VSUMQG H_2, T_0, H_2
|
||||
|
||||
// reduce again after summation
|
||||
// TODO(mundaym): there might be a more efficient way to do this
|
||||
// now that we only have 1 active lane. For example, we could
|
||||
// simultaneously pack the values as we reduce them.
|
||||
REDUCE(H_0, H_1, H_2, H_3, H_4)
|
||||
|
||||
// carry h[1] through to h[4] so that only h[4] can exceed 2²⁶ - 1
|
||||
// TODO(mundaym): in testing this final carry was unnecessary.
|
||||
// Needs a proof before it can be removed though.
|
||||
VESRLG $26, H_1, T_1
|
||||
VN MOD26, H_1, H_1
|
||||
VAQ T_1, H_2, H_2
|
||||
VESRLG $26, H_2, T_2
|
||||
VN MOD26, H_2, H_2
|
||||
VAQ T_2, H_3, H_3
|
||||
VESRLG $26, H_3, T_3
|
||||
VN MOD26, H_3, H_3
|
||||
VAQ T_3, H_4, H_4
|
||||
|
||||
// h is now < 2(2¹³⁰ - 5)
|
||||
// Pack each lane in h₂₆[0:4] into h₁₂₈[0:1].
|
||||
VESLG $26, H_1, H_1
|
||||
VESLG $26, H_3, H_3
|
||||
VO H_0, H_1, H_0
|
||||
VO H_2, H_3, H_2
|
||||
VESLG $4, H_2, H_2
|
||||
VLEIB $7, $48, H_1
|
||||
VSLB H_1, H_2, H_2
|
||||
VO H_0, H_2, H_0
|
||||
VLEIB $7, $104, H_1
|
||||
VSLB H_1, H_4, H_3
|
||||
VO H_3, H_0, H_0
|
||||
VLEIB $7, $24, H_1
|
||||
VSRLB H_1, H_4, H_1
|
||||
|
||||
// update state
|
||||
VSTEG $1, H_0, 0(R1)
|
||||
VSTEG $0, H_0, 8(R1)
|
||||
VSTEG $1, H_1, 16(R1)
|
||||
RET
|
||||
|
||||
b2: // 2 or fewer blocks remaining
|
||||
CMPBLE R3, $16, b1
|
||||
|
||||
// Load the 2 remaining blocks (17-32 bytes remaining).
|
||||
MOVD $-17(R3), R0 // index of final byte to load modulo 16
|
||||
VL (R2), T_0 // load full 16 byte block
|
||||
VLL R0, 16(R2), T_1 // load final (possibly partial) block and pad with zeros to 16 bytes
|
||||
|
||||
// The Poly1305 algorithm requires that a 1 bit be appended to
|
||||
// each message block. If the final block is less than 16 bytes
|
||||
// long then it is easiest to insert the 1 before the message
|
||||
// block is split into 26-bit limbs. If, on the other hand, the
|
||||
// final message block is 16 bytes long then we append the 1 bit
|
||||
// after expansion as normal.
|
||||
MOVBZ $1, R0
|
||||
MOVD $-16(R3), R3 // index of byte in last block to insert 1 at (could be 16)
|
||||
CMPBEQ R3, $16, 2(PC) // skip the insertion if the final block is 16 bytes long
|
||||
VLVGB R3, R0, T_1 // insert 1 into the byte at index R3
|
||||
|
||||
// Split both blocks into 26-bit limbs in the appropriate lanes.
|
||||
EXPAND(T_0, T_1, M_0, M_1, M_2, M_3, M_4)
|
||||
|
||||
// Append a 1 byte to the end of the second to last block.
|
||||
VLEIB $4, $1, M_4
|
||||
|
||||
// Append a 1 byte to the end of the last block only if it is a
|
||||
// full 16 byte block.
|
||||
CMPBNE R3, $16, 2(PC)
|
||||
VLEIB $12, $1, M_4
|
||||
|
||||
// Finally, set up the coefficients for the final multiplication.
|
||||
// We have previously saved r and 5r in the 32-bit even indexes
|
||||
// of the R_[0-4] and R5_[1-4] coefficient registers.
|
||||
//
|
||||
// We want lane 0 to be multiplied by r² so that can be kept the
|
||||
// same. We want lane 1 to be multiplied by r so we need to move
|
||||
// the saved r value into the 32-bit odd index in lane 1 by
|
||||
// rotating the 64-bit lane by 32.
|
||||
VGBM $0x00ff, T_0 // [0, 0xffffffffffffffff] - mask lane 1 only
|
||||
VERIMG $32, R_0, T_0, R_0 // [_, r²₂₆[0], _, r₂₆[0]]
|
||||
VERIMG $32, R_1, T_0, R_1 // [_, r²₂₆[1], _, r₂₆[1]]
|
||||
VERIMG $32, R_2, T_0, R_2 // [_, r²₂₆[2], _, r₂₆[2]]
|
||||
VERIMG $32, R_3, T_0, R_3 // [_, r²₂₆[3], _, r₂₆[3]]
|
||||
VERIMG $32, R_4, T_0, R_4 // [_, r²₂₆[4], _, r₂₆[4]]
|
||||
VERIMG $32, R5_1, T_0, R5_1 // [_, 5r²₂₆[1], _, 5r₂₆[1]]
|
||||
VERIMG $32, R5_2, T_0, R5_2 // [_, 5r²₂₆[2], _, 5r₂₆[2]]
|
||||
VERIMG $32, R5_3, T_0, R5_3 // [_, 5r²₂₆[3], _, 5r₂₆[3]]
|
||||
VERIMG $32, R5_4, T_0, R5_4 // [_, 5r²₂₆[4], _, 5r₂₆[4]]
|
||||
|
||||
MOVD $0, R3
|
||||
BR multiply
|
||||
|
||||
skip:
|
||||
CMPBEQ R3, $0, finish
|
||||
|
||||
b1: // 1 block remaining
|
||||
|
||||
// Load the final block (1-16 bytes). This will be placed into
|
||||
// lane 0.
|
||||
MOVD $-1(R3), R0
|
||||
VLL R0, (R2), T_0 // pad to 16 bytes with zeros
|
||||
|
||||
// The Poly1305 algorithm requires that a 1 bit be appended to
|
||||
// each message block. If the final block is less than 16 bytes
|
||||
// long then it is easiest to insert the 1 before the message
|
||||
// block is split into 26-bit limbs. If, on the other hand, the
|
||||
// final message block is 16 bytes long then we append the 1 bit
|
||||
// after expansion as normal.
|
||||
MOVBZ $1, R0
|
||||
CMPBEQ R3, $16, 2(PC)
|
||||
VLVGB R3, R0, T_0
|
||||
|
||||
// Set the message block in lane 1 to the value 0 so that it
|
||||
// can be accumulated without affecting the final result.
|
||||
VZERO T_1
|
||||
|
||||
// Split the final message block into 26-bit limbs in lane 0.
|
||||
// Lane 1 will be contain 0.
|
||||
EXPAND(T_0, T_1, M_0, M_1, M_2, M_3, M_4)
|
||||
|
||||
// Append a 1 byte to the end of the last block only if it is a
|
||||
// full 16 byte block.
|
||||
CMPBNE R3, $16, 2(PC)
|
||||
VLEIB $4, $1, M_4
|
||||
|
||||
// We have previously saved r and 5r in the 32-bit even indexes
|
||||
// of the R_[0-4] and R5_[1-4] coefficient registers.
|
||||
//
|
||||
// We want lane 0 to be multiplied by r so we need to move the
|
||||
// saved r value into the 32-bit odd index in lane 0. We want
|
||||
// lane 1 to be set to the value 1. This makes multiplication
|
||||
// a no-op. We do this by setting lane 1 in every register to 0
|
||||
// and then just setting the 32-bit index 3 in R_0 to 1.
|
||||
VZERO T_0
|
||||
MOVD $0, R0
|
||||
MOVD $0x10111213, R12
|
||||
VLVGP R12, R0, T_1 // [_, 0x10111213, _, 0x00000000]
|
||||
VPERM T_0, R_0, T_1, R_0 // [_, r₂₆[0], _, 0]
|
||||
VPERM T_0, R_1, T_1, R_1 // [_, r₂₆[1], _, 0]
|
||||
VPERM T_0, R_2, T_1, R_2 // [_, r₂₆[2], _, 0]
|
||||
VPERM T_0, R_3, T_1, R_3 // [_, r₂₆[3], _, 0]
|
||||
VPERM T_0, R_4, T_1, R_4 // [_, r₂₆[4], _, 0]
|
||||
VPERM T_0, R5_1, T_1, R5_1 // [_, 5r₂₆[1], _, 0]
|
||||
VPERM T_0, R5_2, T_1, R5_2 // [_, 5r₂₆[2], _, 0]
|
||||
VPERM T_0, R5_3, T_1, R5_3 // [_, 5r₂₆[3], _, 0]
|
||||
VPERM T_0, R5_4, T_1, R5_4 // [_, 5r₂₆[4], _, 0]
|
||||
|
||||
// Set the value of lane 1 to be 1.
|
||||
VLEIF $3, $1, R_0 // [_, r₂₆[0], _, 1]
|
||||
|
||||
MOVD $0, R3
|
||||
BR multiply
|
||||
97
vendor/github.com/tailscale/golang-x-crypto/ssh/buffer.go
generated
vendored
Normal file
97
vendor/github.com/tailscale/golang-x-crypto/ssh/buffer.go
generated
vendored
Normal file
@@ -0,0 +1,97 @@
|
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// buffer provides a linked list buffer for data exchange
|
||||
// between producer and consumer. Theoretically the buffer is
|
||||
// of unlimited capacity as it does no allocation of its own.
|
||||
type buffer struct {
|
||||
// protects concurrent access to head, tail and closed
|
||||
*sync.Cond
|
||||
|
||||
head *element // the buffer that will be read first
|
||||
tail *element // the buffer that will be read last
|
||||
|
||||
closed bool
|
||||
}
|
||||
|
||||
// An element represents a single link in a linked list.
|
||||
type element struct {
|
||||
buf []byte
|
||||
next *element
|
||||
}
|
||||
|
||||
// newBuffer returns an empty buffer that is not closed.
|
||||
func newBuffer() *buffer {
|
||||
e := new(element)
|
||||
b := &buffer{
|
||||
Cond: newCond(),
|
||||
head: e,
|
||||
tail: e,
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// write makes buf available for Read to receive.
|
||||
// buf must not be modified after the call to write.
|
||||
func (b *buffer) write(buf []byte) {
|
||||
b.Cond.L.Lock()
|
||||
e := &element{buf: buf}
|
||||
b.tail.next = e
|
||||
b.tail = e
|
||||
b.Cond.Signal()
|
||||
b.Cond.L.Unlock()
|
||||
}
|
||||
|
||||
// eof closes the buffer. Reads from the buffer once all
|
||||
// the data has been consumed will receive io.EOF.
|
||||
func (b *buffer) eof() {
|
||||
b.Cond.L.Lock()
|
||||
b.closed = true
|
||||
b.Cond.Signal()
|
||||
b.Cond.L.Unlock()
|
||||
}
|
||||
|
||||
// Read reads data from the internal buffer in buf. Reads will block
|
||||
// if no data is available, or until the buffer is closed.
|
||||
func (b *buffer) Read(buf []byte) (n int, err error) {
|
||||
b.Cond.L.Lock()
|
||||
defer b.Cond.L.Unlock()
|
||||
|
||||
for len(buf) > 0 {
|
||||
// if there is data in b.head, copy it
|
||||
if len(b.head.buf) > 0 {
|
||||
r := copy(buf, b.head.buf)
|
||||
buf, b.head.buf = buf[r:], b.head.buf[r:]
|
||||
n += r
|
||||
continue
|
||||
}
|
||||
// if there is a next buffer, make it the head
|
||||
if len(b.head.buf) == 0 && b.head != b.tail {
|
||||
b.head = b.head.next
|
||||
continue
|
||||
}
|
||||
|
||||
// if at least one byte has been copied, return
|
||||
if n > 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// if nothing was read, and there is nothing outstanding
|
||||
// check to see if the buffer is closed.
|
||||
if b.closed {
|
||||
err = io.EOF
|
||||
break
|
||||
}
|
||||
// out of buffers, wait for producer
|
||||
b.Cond.Wait()
|
||||
}
|
||||
return
|
||||
}
|
||||
611
vendor/github.com/tailscale/golang-x-crypto/ssh/certs.go
generated
vendored
Normal file
611
vendor/github.com/tailscale/golang-x-crypto/ssh/certs.go
generated
vendored
Normal file
@@ -0,0 +1,611 @@
|
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sort"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Certificate algorithm names from [PROTOCOL.certkeys]. These values can appear
|
||||
// in Certificate.Type, PublicKey.Type, and ClientConfig.HostKeyAlgorithms.
|
||||
// Unlike key algorithm names, these are not passed to AlgorithmSigner nor
|
||||
// returned by MultiAlgorithmSigner and don't appear in the Signature.Format
|
||||
// field.
|
||||
const (
|
||||
CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com"
|
||||
CertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com"
|
||||
CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com"
|
||||
CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com"
|
||||
CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com"
|
||||
CertAlgoSKECDSA256v01 = "sk-ecdsa-sha2-nistp256-cert-v01@openssh.com"
|
||||
CertAlgoED25519v01 = "ssh-ed25519-cert-v01@openssh.com"
|
||||
CertAlgoSKED25519v01 = "sk-ssh-ed25519-cert-v01@openssh.com"
|
||||
|
||||
// CertAlgoRSASHA256v01 and CertAlgoRSASHA512v01 can't appear as a
|
||||
// Certificate.Type (or PublicKey.Type), but only in
|
||||
// ClientConfig.HostKeyAlgorithms.
|
||||
CertAlgoRSASHA256v01 = "rsa-sha2-256-cert-v01@openssh.com"
|
||||
CertAlgoRSASHA512v01 = "rsa-sha2-512-cert-v01@openssh.com"
|
||||
)
|
||||
|
||||
const (
|
||||
// Deprecated: use CertAlgoRSAv01.
|
||||
CertSigAlgoRSAv01 = CertAlgoRSAv01
|
||||
// Deprecated: use CertAlgoRSASHA256v01.
|
||||
CertSigAlgoRSASHA2256v01 = CertAlgoRSASHA256v01
|
||||
// Deprecated: use CertAlgoRSASHA512v01.
|
||||
CertSigAlgoRSASHA2512v01 = CertAlgoRSASHA512v01
|
||||
)
|
||||
|
||||
// Certificate types distinguish between host and user
|
||||
// certificates. The values can be set in the CertType field of
|
||||
// Certificate.
|
||||
const (
|
||||
UserCert = 1
|
||||
HostCert = 2
|
||||
)
|
||||
|
||||
// Signature represents a cryptographic signature.
|
||||
type Signature struct {
|
||||
Format string
|
||||
Blob []byte
|
||||
Rest []byte `ssh:"rest"`
|
||||
}
|
||||
|
||||
// CertTimeInfinity can be used for OpenSSHCertV01.ValidBefore to indicate that
|
||||
// a certificate does not expire.
|
||||
const CertTimeInfinity = 1<<64 - 1
|
||||
|
||||
// An Certificate represents an OpenSSH certificate as defined in
|
||||
// [PROTOCOL.certkeys]?rev=1.8. The Certificate type implements the
|
||||
// PublicKey interface, so it can be unmarshaled using
|
||||
// ParsePublicKey.
|
||||
type Certificate struct {
|
||||
Nonce []byte
|
||||
Key PublicKey
|
||||
Serial uint64
|
||||
CertType uint32
|
||||
KeyId string
|
||||
ValidPrincipals []string
|
||||
ValidAfter uint64
|
||||
ValidBefore uint64
|
||||
Permissions
|
||||
Reserved []byte
|
||||
SignatureKey PublicKey
|
||||
Signature *Signature
|
||||
}
|
||||
|
||||
// genericCertData holds the key-independent part of the certificate data.
|
||||
// Overall, certificates contain an nonce, public key fields and
|
||||
// key-independent fields.
|
||||
type genericCertData struct {
|
||||
Serial uint64
|
||||
CertType uint32
|
||||
KeyId string
|
||||
ValidPrincipals []byte
|
||||
ValidAfter uint64
|
||||
ValidBefore uint64
|
||||
CriticalOptions []byte
|
||||
Extensions []byte
|
||||
Reserved []byte
|
||||
SignatureKey []byte
|
||||
Signature []byte
|
||||
}
|
||||
|
||||
func marshalStringList(namelist []string) []byte {
|
||||
var to []byte
|
||||
for _, name := range namelist {
|
||||
s := struct{ N string }{name}
|
||||
to = append(to, Marshal(&s)...)
|
||||
}
|
||||
return to
|
||||
}
|
||||
|
||||
type optionsTuple struct {
|
||||
Key string
|
||||
Value []byte
|
||||
}
|
||||
|
||||
type optionsTupleValue struct {
|
||||
Value string
|
||||
}
|
||||
|
||||
// serialize a map of critical options or extensions
|
||||
// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation,
|
||||
// we need two length prefixes for a non-empty string value
|
||||
func marshalTuples(tups map[string]string) []byte {
|
||||
keys := make([]string, 0, len(tups))
|
||||
for key := range tups {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
var ret []byte
|
||||
for _, key := range keys {
|
||||
s := optionsTuple{Key: key}
|
||||
if value := tups[key]; len(value) > 0 {
|
||||
s.Value = Marshal(&optionsTupleValue{value})
|
||||
}
|
||||
ret = append(ret, Marshal(&s)...)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation,
|
||||
// we need two length prefixes for a non-empty option value
|
||||
func parseTuples(in []byte) (map[string]string, error) {
|
||||
tups := map[string]string{}
|
||||
var lastKey string
|
||||
var haveLastKey bool
|
||||
|
||||
for len(in) > 0 {
|
||||
var key, val, extra []byte
|
||||
var ok bool
|
||||
|
||||
if key, in, ok = parseString(in); !ok {
|
||||
return nil, errShortRead
|
||||
}
|
||||
keyStr := string(key)
|
||||
// according to [PROTOCOL.certkeys], the names must be in
|
||||
// lexical order.
|
||||
if haveLastKey && keyStr <= lastKey {
|
||||
return nil, fmt.Errorf("ssh: certificate options are not in lexical order")
|
||||
}
|
||||
lastKey, haveLastKey = keyStr, true
|
||||
// the next field is a data field, which if non-empty has a string embedded
|
||||
if val, in, ok = parseString(in); !ok {
|
||||
return nil, errShortRead
|
||||
}
|
||||
if len(val) > 0 {
|
||||
val, extra, ok = parseString(val)
|
||||
if !ok {
|
||||
return nil, errShortRead
|
||||
}
|
||||
if len(extra) > 0 {
|
||||
return nil, fmt.Errorf("ssh: unexpected trailing data after certificate option value")
|
||||
}
|
||||
tups[keyStr] = string(val)
|
||||
} else {
|
||||
tups[keyStr] = ""
|
||||
}
|
||||
}
|
||||
return tups, nil
|
||||
}
|
||||
|
||||
func parseCert(in []byte, privAlgo string) (*Certificate, error) {
|
||||
nonce, rest, ok := parseString(in)
|
||||
if !ok {
|
||||
return nil, errShortRead
|
||||
}
|
||||
|
||||
key, rest, err := parsePubKey(rest, privAlgo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var g genericCertData
|
||||
if err := Unmarshal(rest, &g); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := &Certificate{
|
||||
Nonce: nonce,
|
||||
Key: key,
|
||||
Serial: g.Serial,
|
||||
CertType: g.CertType,
|
||||
KeyId: g.KeyId,
|
||||
ValidAfter: g.ValidAfter,
|
||||
ValidBefore: g.ValidBefore,
|
||||
}
|
||||
|
||||
for principals := g.ValidPrincipals; len(principals) > 0; {
|
||||
principal, rest, ok := parseString(principals)
|
||||
if !ok {
|
||||
return nil, errShortRead
|
||||
}
|
||||
c.ValidPrincipals = append(c.ValidPrincipals, string(principal))
|
||||
principals = rest
|
||||
}
|
||||
|
||||
c.CriticalOptions, err = parseTuples(g.CriticalOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Extensions, err = parseTuples(g.Extensions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Reserved = g.Reserved
|
||||
k, err := ParsePublicKey(g.SignatureKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.SignatureKey = k
|
||||
c.Signature, rest, ok = parseSignatureBody(g.Signature)
|
||||
if !ok || len(rest) > 0 {
|
||||
return nil, errors.New("ssh: signature parse error")
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
type openSSHCertSigner struct {
|
||||
pub *Certificate
|
||||
signer Signer
|
||||
}
|
||||
|
||||
type algorithmOpenSSHCertSigner struct {
|
||||
*openSSHCertSigner
|
||||
algorithmSigner AlgorithmSigner
|
||||
}
|
||||
|
||||
// NewCertSigner returns a Signer that signs with the given Certificate, whose
|
||||
// private key is held by signer. It returns an error if the public key in cert
|
||||
// doesn't match the key used by signer.
|
||||
func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) {
|
||||
if !bytes.Equal(cert.Key.Marshal(), signer.PublicKey().Marshal()) {
|
||||
return nil, errors.New("ssh: signer and cert have different public key")
|
||||
}
|
||||
|
||||
switch s := signer.(type) {
|
||||
case MultiAlgorithmSigner:
|
||||
return &multiAlgorithmSigner{
|
||||
AlgorithmSigner: &algorithmOpenSSHCertSigner{
|
||||
&openSSHCertSigner{cert, signer}, s},
|
||||
supportedAlgorithms: s.Algorithms(),
|
||||
}, nil
|
||||
case AlgorithmSigner:
|
||||
return &algorithmOpenSSHCertSigner{
|
||||
&openSSHCertSigner{cert, signer}, s}, nil
|
||||
default:
|
||||
return &openSSHCertSigner{cert, signer}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
|
||||
return s.signer.Sign(rand, data)
|
||||
}
|
||||
|
||||
func (s *openSSHCertSigner) PublicKey() PublicKey {
|
||||
return s.pub
|
||||
}
|
||||
|
||||
func (s *algorithmOpenSSHCertSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
|
||||
return s.algorithmSigner.SignWithAlgorithm(rand, data, algorithm)
|
||||
}
|
||||
|
||||
const sourceAddressCriticalOption = "source-address"
|
||||
|
||||
// CertChecker does the work of verifying a certificate. Its methods
|
||||
// can be plugged into ClientConfig.HostKeyCallback and
|
||||
// ServerConfig.PublicKeyCallback. For the CertChecker to work,
|
||||
// minimally, the IsAuthority callback should be set.
|
||||
type CertChecker struct {
|
||||
// SupportedCriticalOptions lists the CriticalOptions that the
|
||||
// server application layer understands. These are only used
|
||||
// for user certificates.
|
||||
SupportedCriticalOptions []string
|
||||
|
||||
// IsUserAuthority should return true if the key is recognized as an
|
||||
// authority for the given user certificate. This allows for
|
||||
// certificates to be signed by other certificates. This must be set
|
||||
// if this CertChecker will be checking user certificates.
|
||||
IsUserAuthority func(auth PublicKey) bool
|
||||
|
||||
// IsHostAuthority should report whether the key is recognized as
|
||||
// an authority for this host. This allows for certificates to be
|
||||
// signed by other keys, and for those other keys to only be valid
|
||||
// signers for particular hostnames. This must be set if this
|
||||
// CertChecker will be checking host certificates.
|
||||
IsHostAuthority func(auth PublicKey, address string) bool
|
||||
|
||||
// Clock is used for verifying time stamps. If nil, time.Now
|
||||
// is used.
|
||||
Clock func() time.Time
|
||||
|
||||
// UserKeyFallback is called when CertChecker.Authenticate encounters a
|
||||
// public key that is not a certificate. It must implement validation
|
||||
// of user keys or else, if nil, all such keys are rejected.
|
||||
UserKeyFallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
|
||||
|
||||
// HostKeyFallback is called when CertChecker.CheckHostKey encounters a
|
||||
// public key that is not a certificate. It must implement host key
|
||||
// validation or else, if nil, all such keys are rejected.
|
||||
HostKeyFallback HostKeyCallback
|
||||
|
||||
// IsRevoked is called for each certificate so that revocation checking
|
||||
// can be implemented. It should return true if the given certificate
|
||||
// is revoked and false otherwise. If nil, no certificates are
|
||||
// considered to have been revoked.
|
||||
IsRevoked func(cert *Certificate) bool
|
||||
}
|
||||
|
||||
// CheckHostKey checks a host key certificate. This method can be
|
||||
// plugged into ClientConfig.HostKeyCallback.
|
||||
func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) error {
|
||||
cert, ok := key.(*Certificate)
|
||||
if !ok {
|
||||
if c.HostKeyFallback != nil {
|
||||
return c.HostKeyFallback(addr, remote, key)
|
||||
}
|
||||
return errors.New("ssh: non-certificate host key")
|
||||
}
|
||||
if cert.CertType != HostCert {
|
||||
return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType)
|
||||
}
|
||||
if !c.IsHostAuthority(cert.SignatureKey, addr) {
|
||||
return fmt.Errorf("ssh: no authorities for hostname: %v", addr)
|
||||
}
|
||||
|
||||
hostname, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Pass hostname only as principal for host certificates (consistent with OpenSSH)
|
||||
return c.CheckCert(hostname, cert)
|
||||
}
|
||||
|
||||
// Authenticate checks a user certificate. Authenticate can be used as
|
||||
// a value for ServerConfig.PublicKeyCallback.
|
||||
func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permissions, error) {
|
||||
cert, ok := pubKey.(*Certificate)
|
||||
if !ok {
|
||||
if c.UserKeyFallback != nil {
|
||||
return c.UserKeyFallback(conn, pubKey)
|
||||
}
|
||||
return nil, errors.New("ssh: normal key pairs not accepted")
|
||||
}
|
||||
|
||||
if cert.CertType != UserCert {
|
||||
return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType)
|
||||
}
|
||||
if !c.IsUserAuthority(cert.SignatureKey) {
|
||||
return nil, fmt.Errorf("ssh: certificate signed by unrecognized authority")
|
||||
}
|
||||
|
||||
if err := c.CheckCert(conn.User(), cert); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &cert.Permissions, nil
|
||||
}
|
||||
|
||||
// CheckCert checks CriticalOptions, ValidPrincipals, revocation, timestamp and
|
||||
// the signature of the certificate.
|
||||
func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
|
||||
if c.IsRevoked != nil && c.IsRevoked(cert) {
|
||||
return fmt.Errorf("ssh: certificate serial %d revoked", cert.Serial)
|
||||
}
|
||||
|
||||
for opt := range cert.CriticalOptions {
|
||||
// sourceAddressCriticalOption will be enforced by
|
||||
// serverAuthenticate
|
||||
if opt == sourceAddressCriticalOption {
|
||||
continue
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, supp := range c.SupportedCriticalOptions {
|
||||
if supp == opt {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("ssh: unsupported critical option %q in certificate", opt)
|
||||
}
|
||||
}
|
||||
|
||||
if len(cert.ValidPrincipals) > 0 {
|
||||
// By default, certs are valid for all users/hosts.
|
||||
found := false
|
||||
for _, p := range cert.ValidPrincipals {
|
||||
if p == principal {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("ssh: principal %q not in the set of valid principals for given certificate: %q", principal, cert.ValidPrincipals)
|
||||
}
|
||||
}
|
||||
|
||||
clock := c.Clock
|
||||
if clock == nil {
|
||||
clock = time.Now
|
||||
}
|
||||
|
||||
unixNow := clock().Unix()
|
||||
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
|
||||
return fmt.Errorf("ssh: cert is not yet valid")
|
||||
}
|
||||
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(CertTimeInfinity) && (unixNow >= before || before < 0) {
|
||||
return fmt.Errorf("ssh: cert has expired")
|
||||
}
|
||||
if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil {
|
||||
return fmt.Errorf("ssh: certificate signature does not verify")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SignCert signs the certificate with an authority, setting the Nonce,
|
||||
// SignatureKey, and Signature fields. If the authority implements the
|
||||
// MultiAlgorithmSigner interface the first algorithm in the list is used. This
|
||||
// is useful if you want to sign with a specific algorithm.
|
||||
func (c *Certificate) SignCert(rand io.Reader, authority Signer) error {
|
||||
c.Nonce = make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand, c.Nonce); err != nil {
|
||||
return err
|
||||
}
|
||||
c.SignatureKey = authority.PublicKey()
|
||||
|
||||
if v, ok := authority.(MultiAlgorithmSigner); ok {
|
||||
if len(v.Algorithms()) == 0 {
|
||||
return errors.New("the provided authority has no signature algorithm")
|
||||
}
|
||||
// Use the first algorithm in the list.
|
||||
sig, err := v.SignWithAlgorithm(rand, c.bytesForSigning(), v.Algorithms()[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Signature = sig
|
||||
return nil
|
||||
} else if v, ok := authority.(AlgorithmSigner); ok && v.PublicKey().Type() == KeyAlgoRSA {
|
||||
// Default to KeyAlgoRSASHA512 for ssh-rsa signers.
|
||||
// TODO: consider using KeyAlgoRSASHA256 as default.
|
||||
sig, err := v.SignWithAlgorithm(rand, c.bytesForSigning(), KeyAlgoRSASHA512)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Signature = sig
|
||||
return nil
|
||||
}
|
||||
|
||||
sig, err := authority.Sign(rand, c.bytesForSigning())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Signature = sig
|
||||
return nil
|
||||
}
|
||||
|
||||
// certKeyAlgoNames is a mapping from known certificate algorithm names to the
|
||||
// corresponding public key signature algorithm.
|
||||
//
|
||||
// This map must be kept in sync with the one in agent/client.go.
|
||||
var certKeyAlgoNames = map[string]string{
|
||||
CertAlgoRSAv01: KeyAlgoRSA,
|
||||
CertAlgoRSASHA256v01: KeyAlgoRSASHA256,
|
||||
CertAlgoRSASHA512v01: KeyAlgoRSASHA512,
|
||||
CertAlgoDSAv01: KeyAlgoDSA,
|
||||
CertAlgoECDSA256v01: KeyAlgoECDSA256,
|
||||
CertAlgoECDSA384v01: KeyAlgoECDSA384,
|
||||
CertAlgoECDSA521v01: KeyAlgoECDSA521,
|
||||
CertAlgoSKECDSA256v01: KeyAlgoSKECDSA256,
|
||||
CertAlgoED25519v01: KeyAlgoED25519,
|
||||
CertAlgoSKED25519v01: KeyAlgoSKED25519,
|
||||
}
|
||||
|
||||
// underlyingAlgo returns the signature algorithm associated with algo (which is
|
||||
// an advertised or negotiated public key or host key algorithm). These are
|
||||
// usually the same, except for certificate algorithms.
|
||||
func underlyingAlgo(algo string) string {
|
||||
if a, ok := certKeyAlgoNames[algo]; ok {
|
||||
return a
|
||||
}
|
||||
return algo
|
||||
}
|
||||
|
||||
// certificateAlgo returns the certificate algorithms that uses the provided
|
||||
// underlying signature algorithm.
|
||||
func certificateAlgo(algo string) (certAlgo string, ok bool) {
|
||||
for certName, algoName := range certKeyAlgoNames {
|
||||
if algoName == algo {
|
||||
return certName, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (cert *Certificate) bytesForSigning() []byte {
|
||||
c2 := *cert
|
||||
c2.Signature = nil
|
||||
out := c2.Marshal()
|
||||
// Drop trailing signature length.
|
||||
return out[:len(out)-4]
|
||||
}
|
||||
|
||||
// Marshal serializes c into OpenSSH's wire format. It is part of the
|
||||
// PublicKey interface.
|
||||
func (c *Certificate) Marshal() []byte {
|
||||
generic := genericCertData{
|
||||
Serial: c.Serial,
|
||||
CertType: c.CertType,
|
||||
KeyId: c.KeyId,
|
||||
ValidPrincipals: marshalStringList(c.ValidPrincipals),
|
||||
ValidAfter: uint64(c.ValidAfter),
|
||||
ValidBefore: uint64(c.ValidBefore),
|
||||
CriticalOptions: marshalTuples(c.CriticalOptions),
|
||||
Extensions: marshalTuples(c.Extensions),
|
||||
Reserved: c.Reserved,
|
||||
SignatureKey: c.SignatureKey.Marshal(),
|
||||
}
|
||||
if c.Signature != nil {
|
||||
generic.Signature = Marshal(c.Signature)
|
||||
}
|
||||
genericBytes := Marshal(&generic)
|
||||
keyBytes := c.Key.Marshal()
|
||||
_, keyBytes, _ = parseString(keyBytes)
|
||||
prefix := Marshal(&struct {
|
||||
Name string
|
||||
Nonce []byte
|
||||
Key []byte `ssh:"rest"`
|
||||
}{c.Type(), c.Nonce, keyBytes})
|
||||
|
||||
result := make([]byte, 0, len(prefix)+len(genericBytes))
|
||||
result = append(result, prefix...)
|
||||
result = append(result, genericBytes...)
|
||||
return result
|
||||
}
|
||||
|
||||
// Type returns the certificate algorithm name. It is part of the PublicKey interface.
|
||||
func (c *Certificate) Type() string {
|
||||
certName, ok := certificateAlgo(c.Key.Type())
|
||||
if !ok {
|
||||
panic("unknown certificate type for key type " + c.Key.Type())
|
||||
}
|
||||
return certName
|
||||
}
|
||||
|
||||
// Verify verifies a signature against the certificate's public
|
||||
// key. It is part of the PublicKey interface.
|
||||
func (c *Certificate) Verify(data []byte, sig *Signature) error {
|
||||
return c.Key.Verify(data, sig)
|
||||
}
|
||||
|
||||
func parseSignatureBody(in []byte) (out *Signature, rest []byte, ok bool) {
|
||||
format, in, ok := parseString(in)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
out = &Signature{
|
||||
Format: string(format),
|
||||
}
|
||||
|
||||
if out.Blob, in, ok = parseString(in); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
switch out.Format {
|
||||
case KeyAlgoSKECDSA256, CertAlgoSKECDSA256v01, KeyAlgoSKED25519, CertAlgoSKED25519v01:
|
||||
out.Rest = in
|
||||
return out, nil, ok
|
||||
}
|
||||
|
||||
return out, in, ok
|
||||
}
|
||||
|
||||
func parseSignature(in []byte) (out *Signature, rest []byte, ok bool) {
|
||||
sigBytes, rest, ok := parseString(in)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
out, trailing, ok := parseSignatureBody(sigBytes)
|
||||
if !ok || len(trailing) > 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
return
|
||||
}
|
||||
645
vendor/github.com/tailscale/golang-x-crypto/ssh/channel.go
generated
vendored
Normal file
645
vendor/github.com/tailscale/golang-x-crypto/ssh/channel.go
generated
vendored
Normal file
@@ -0,0 +1,645 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
minPacketLength = 9
|
||||
// channelMaxPacket contains the maximum number of bytes that will be
|
||||
// sent in a single packet. As per RFC 4253, section 6.1, 32k is also
|
||||
// the minimum.
|
||||
channelMaxPacket = 1 << 15
|
||||
// We follow OpenSSH here.
|
||||
channelWindowSize = 64 * channelMaxPacket
|
||||
)
|
||||
|
||||
// NewChannel represents an incoming request to a channel. It must either be
|
||||
// accepted for use by calling Accept, or rejected by calling Reject.
|
||||
type NewChannel interface {
|
||||
// Accept accepts the channel creation request. It returns the Channel
|
||||
// and a Go channel containing SSH requests. The Go channel must be
|
||||
// serviced otherwise the Channel will hang.
|
||||
Accept() (Channel, <-chan *Request, error)
|
||||
|
||||
// Reject rejects the channel creation request. After calling
|
||||
// this, no other methods on the Channel may be called.
|
||||
Reject(reason RejectionReason, message string) error
|
||||
|
||||
// ChannelType returns the type of the channel, as supplied by the
|
||||
// client.
|
||||
ChannelType() string
|
||||
|
||||
// ExtraData returns the arbitrary payload for this channel, as supplied
|
||||
// by the client. This data is specific to the channel type.
|
||||
ExtraData() []byte
|
||||
}
|
||||
|
||||
// A Channel is an ordered, reliable, flow-controlled, duplex stream
|
||||
// that is multiplexed over an SSH connection.
|
||||
type Channel interface {
|
||||
// Read reads up to len(data) bytes from the channel.
|
||||
Read(data []byte) (int, error)
|
||||
|
||||
// Write writes len(data) bytes to the channel.
|
||||
Write(data []byte) (int, error)
|
||||
|
||||
// Close signals end of channel use. No data may be sent after this
|
||||
// call.
|
||||
Close() error
|
||||
|
||||
// CloseWrite signals the end of sending in-band
|
||||
// data. Requests may still be sent, and the other side may
|
||||
// still send data
|
||||
CloseWrite() error
|
||||
|
||||
// SendRequest sends a channel request. If wantReply is true,
|
||||
// it will wait for a reply and return the result as a
|
||||
// boolean, otherwise the return value will be false. Channel
|
||||
// requests are out-of-band messages so they may be sent even
|
||||
// if the data stream is closed or blocked by flow control.
|
||||
// If the channel is closed before a reply is returned, io.EOF
|
||||
// is returned.
|
||||
SendRequest(name string, wantReply bool, payload []byte) (bool, error)
|
||||
|
||||
// Stderr returns an io.ReadWriter that writes to this channel
|
||||
// with the extended data type set to stderr. Stderr may
|
||||
// safely be read and written from a different goroutine than
|
||||
// Read and Write respectively.
|
||||
Stderr() io.ReadWriter
|
||||
}
|
||||
|
||||
// Request is a request sent outside of the normal stream of
|
||||
// data. Requests can either be specific to an SSH channel, or they
|
||||
// can be global.
|
||||
type Request struct {
|
||||
Type string
|
||||
WantReply bool
|
||||
Payload []byte
|
||||
|
||||
ch *channel
|
||||
mux *mux
|
||||
}
|
||||
|
||||
// Reply sends a response to a request. It must be called for all requests
|
||||
// where WantReply is true and is a no-op otherwise. The payload argument is
|
||||
// ignored for replies to channel-specific requests.
|
||||
func (r *Request) Reply(ok bool, payload []byte) error {
|
||||
if !r.WantReply {
|
||||
return nil
|
||||
}
|
||||
|
||||
if r.ch == nil {
|
||||
return r.mux.ackRequest(ok, payload)
|
||||
}
|
||||
|
||||
return r.ch.ackRequest(ok)
|
||||
}
|
||||
|
||||
// RejectionReason is an enumeration used when rejecting channel creation
|
||||
// requests. See RFC 4254, section 5.1.
|
||||
type RejectionReason uint32
|
||||
|
||||
const (
|
||||
Prohibited RejectionReason = iota + 1
|
||||
ConnectionFailed
|
||||
UnknownChannelType
|
||||
ResourceShortage
|
||||
)
|
||||
|
||||
// String converts the rejection reason to human readable form.
|
||||
func (r RejectionReason) String() string {
|
||||
switch r {
|
||||
case Prohibited:
|
||||
return "administratively prohibited"
|
||||
case ConnectionFailed:
|
||||
return "connect failed"
|
||||
case UnknownChannelType:
|
||||
return "unknown channel type"
|
||||
case ResourceShortage:
|
||||
return "resource shortage"
|
||||
}
|
||||
return fmt.Sprintf("unknown reason %d", int(r))
|
||||
}
|
||||
|
||||
func min(a uint32, b int) uint32 {
|
||||
if a < uint32(b) {
|
||||
return a
|
||||
}
|
||||
return uint32(b)
|
||||
}
|
||||
|
||||
type channelDirection uint8
|
||||
|
||||
const (
|
||||
channelInbound channelDirection = iota
|
||||
channelOutbound
|
||||
)
|
||||
|
||||
// channel is an implementation of the Channel interface that works
|
||||
// with the mux class.
|
||||
type channel struct {
|
||||
// R/O after creation
|
||||
chanType string
|
||||
extraData []byte
|
||||
localId, remoteId uint32
|
||||
|
||||
// maxIncomingPayload and maxRemotePayload are the maximum
|
||||
// payload sizes of normal and extended data packets for
|
||||
// receiving and sending, respectively. The wire packet will
|
||||
// be 9 or 13 bytes larger (excluding encryption overhead).
|
||||
maxIncomingPayload uint32
|
||||
maxRemotePayload uint32
|
||||
|
||||
mux *mux
|
||||
|
||||
// decided is set to true if an accept or reject message has been sent
|
||||
// (for outbound channels) or received (for inbound channels).
|
||||
decided bool
|
||||
|
||||
// direction contains either channelOutbound, for channels created
|
||||
// locally, or channelInbound, for channels created by the peer.
|
||||
direction channelDirection
|
||||
|
||||
// Pending internal channel messages.
|
||||
msg chan interface{}
|
||||
|
||||
// Since requests have no ID, there can be only one request
|
||||
// with WantReply=true outstanding. This lock is held by a
|
||||
// goroutine that has such an outgoing request pending.
|
||||
sentRequestMu sync.Mutex
|
||||
|
||||
incomingRequests chan *Request
|
||||
|
||||
sentEOF bool
|
||||
|
||||
// thread-safe data
|
||||
remoteWin window
|
||||
pending *buffer
|
||||
extPending *buffer
|
||||
|
||||
// windowMu protects myWindow, the flow-control window, and myConsumed,
|
||||
// the number of bytes consumed since we last increased myWindow
|
||||
windowMu sync.Mutex
|
||||
myWindow uint32
|
||||
myConsumed uint32
|
||||
|
||||
// writeMu serializes calls to mux.conn.writePacket() and
|
||||
// protects sentClose and packetPool. This mutex must be
|
||||
// different from windowMu, as writePacket can block if there
|
||||
// is a key exchange pending.
|
||||
writeMu sync.Mutex
|
||||
sentClose bool
|
||||
|
||||
// packetPool has a buffer for each extended channel ID to
|
||||
// save allocations during writes.
|
||||
packetPool map[uint32][]byte
|
||||
}
|
||||
|
||||
// writePacket sends a packet. If the packet is a channel close, it updates
|
||||
// sentClose. This method takes the lock c.writeMu.
|
||||
func (ch *channel) writePacket(packet []byte) error {
|
||||
ch.writeMu.Lock()
|
||||
if ch.sentClose {
|
||||
ch.writeMu.Unlock()
|
||||
return io.EOF
|
||||
}
|
||||
ch.sentClose = (packet[0] == msgChannelClose)
|
||||
err := ch.mux.conn.writePacket(packet)
|
||||
ch.writeMu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (ch *channel) sendMessage(msg interface{}) error {
|
||||
if debugMux {
|
||||
log.Printf("send(%d): %#v", ch.mux.chanList.offset, msg)
|
||||
}
|
||||
|
||||
p := Marshal(msg)
|
||||
binary.BigEndian.PutUint32(p[1:], ch.remoteId)
|
||||
return ch.writePacket(p)
|
||||
}
|
||||
|
||||
// WriteExtended writes data to a specific extended stream. These streams are
|
||||
// used, for example, for stderr.
|
||||
func (ch *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
|
||||
if ch.sentEOF {
|
||||
return 0, io.EOF
|
||||
}
|
||||
// 1 byte message type, 4 bytes remoteId, 4 bytes data length
|
||||
opCode := byte(msgChannelData)
|
||||
headerLength := uint32(9)
|
||||
if extendedCode > 0 {
|
||||
headerLength += 4
|
||||
opCode = msgChannelExtendedData
|
||||
}
|
||||
|
||||
ch.writeMu.Lock()
|
||||
packet := ch.packetPool[extendedCode]
|
||||
// We don't remove the buffer from packetPool, so
|
||||
// WriteExtended calls from different goroutines will be
|
||||
// flagged as errors by the race detector.
|
||||
ch.writeMu.Unlock()
|
||||
|
||||
for len(data) > 0 {
|
||||
space := min(ch.maxRemotePayload, len(data))
|
||||
if space, err = ch.remoteWin.reserve(space); err != nil {
|
||||
return n, err
|
||||
}
|
||||
if want := headerLength + space; uint32(cap(packet)) < want {
|
||||
packet = make([]byte, want)
|
||||
} else {
|
||||
packet = packet[:want]
|
||||
}
|
||||
|
||||
todo := data[:space]
|
||||
|
||||
packet[0] = opCode
|
||||
binary.BigEndian.PutUint32(packet[1:], ch.remoteId)
|
||||
if extendedCode > 0 {
|
||||
binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode))
|
||||
}
|
||||
binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo)))
|
||||
copy(packet[headerLength:], todo)
|
||||
if err = ch.writePacket(packet); err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
n += len(todo)
|
||||
data = data[len(todo):]
|
||||
}
|
||||
|
||||
ch.writeMu.Lock()
|
||||
ch.packetPool[extendedCode] = packet
|
||||
ch.writeMu.Unlock()
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (ch *channel) handleData(packet []byte) error {
|
||||
headerLen := 9
|
||||
isExtendedData := packet[0] == msgChannelExtendedData
|
||||
if isExtendedData {
|
||||
headerLen = 13
|
||||
}
|
||||
if len(packet) < headerLen {
|
||||
// malformed data packet
|
||||
return parseError(packet[0])
|
||||
}
|
||||
|
||||
var extended uint32
|
||||
if isExtendedData {
|
||||
extended = binary.BigEndian.Uint32(packet[5:])
|
||||
}
|
||||
|
||||
length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen])
|
||||
if length == 0 {
|
||||
return nil
|
||||
}
|
||||
if length > ch.maxIncomingPayload {
|
||||
// TODO(hanwen): should send Disconnect?
|
||||
return errors.New("ssh: incoming packet exceeds maximum payload size")
|
||||
}
|
||||
|
||||
data := packet[headerLen:]
|
||||
if length != uint32(len(data)) {
|
||||
return errors.New("ssh: wrong packet length")
|
||||
}
|
||||
|
||||
ch.windowMu.Lock()
|
||||
if ch.myWindow < length {
|
||||
ch.windowMu.Unlock()
|
||||
// TODO(hanwen): should send Disconnect with reason?
|
||||
return errors.New("ssh: remote side wrote too much")
|
||||
}
|
||||
ch.myWindow -= length
|
||||
ch.windowMu.Unlock()
|
||||
|
||||
if extended == 1 {
|
||||
ch.extPending.write(data)
|
||||
} else if extended > 0 {
|
||||
// discard other extended data.
|
||||
} else {
|
||||
ch.pending.write(data)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *channel) adjustWindow(adj uint32) error {
|
||||
c.windowMu.Lock()
|
||||
// Since myConsumed and myWindow are managed on our side, and can never
|
||||
// exceed the initial window setting, we don't worry about overflow.
|
||||
c.myConsumed += adj
|
||||
var sendAdj uint32
|
||||
if (channelWindowSize-c.myWindow > 3*c.maxIncomingPayload) ||
|
||||
(c.myWindow < channelWindowSize/2) {
|
||||
sendAdj = c.myConsumed
|
||||
c.myConsumed = 0
|
||||
c.myWindow += sendAdj
|
||||
}
|
||||
c.windowMu.Unlock()
|
||||
if sendAdj == 0 {
|
||||
return nil
|
||||
}
|
||||
return c.sendMessage(windowAdjustMsg{
|
||||
AdditionalBytes: sendAdj,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) {
|
||||
switch extended {
|
||||
case 1:
|
||||
n, err = c.extPending.Read(data)
|
||||
case 0:
|
||||
n, err = c.pending.Read(data)
|
||||
default:
|
||||
return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended)
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
err = c.adjustWindow(uint32(n))
|
||||
// sendWindowAdjust can return io.EOF if the remote
|
||||
// peer has closed the connection, however we want to
|
||||
// defer forwarding io.EOF to the caller of Read until
|
||||
// the buffer has been drained.
|
||||
if n > 0 && err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *channel) close() {
|
||||
c.pending.eof()
|
||||
c.extPending.eof()
|
||||
close(c.msg)
|
||||
close(c.incomingRequests)
|
||||
c.writeMu.Lock()
|
||||
// This is not necessary for a normal channel teardown, but if
|
||||
// there was another error, it is.
|
||||
c.sentClose = true
|
||||
c.writeMu.Unlock()
|
||||
// Unblock writers.
|
||||
c.remoteWin.close()
|
||||
}
|
||||
|
||||
// responseMessageReceived is called when a success or failure message is
|
||||
// received on a channel to check that such a message is reasonable for the
|
||||
// given channel.
|
||||
func (ch *channel) responseMessageReceived() error {
|
||||
if ch.direction == channelInbound {
|
||||
return errors.New("ssh: channel response message received on inbound channel")
|
||||
}
|
||||
if ch.decided {
|
||||
return errors.New("ssh: duplicate response received for channel")
|
||||
}
|
||||
ch.decided = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ch *channel) handlePacket(packet []byte) error {
|
||||
switch packet[0] {
|
||||
case msgChannelData, msgChannelExtendedData:
|
||||
return ch.handleData(packet)
|
||||
case msgChannelClose:
|
||||
ch.sendMessage(channelCloseMsg{PeersID: ch.remoteId})
|
||||
ch.mux.chanList.remove(ch.localId)
|
||||
ch.close()
|
||||
return nil
|
||||
case msgChannelEOF:
|
||||
// RFC 4254 is mute on how EOF affects dataExt messages but
|
||||
// it is logical to signal EOF at the same time.
|
||||
ch.extPending.eof()
|
||||
ch.pending.eof()
|
||||
return nil
|
||||
}
|
||||
|
||||
decoded, err := decode(packet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch msg := decoded.(type) {
|
||||
case *channelOpenFailureMsg:
|
||||
if err := ch.responseMessageReceived(); err != nil {
|
||||
return err
|
||||
}
|
||||
ch.mux.chanList.remove(msg.PeersID)
|
||||
ch.msg <- msg
|
||||
case *channelOpenConfirmMsg:
|
||||
if err := ch.responseMessageReceived(); err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
|
||||
return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize)
|
||||
}
|
||||
ch.remoteId = msg.MyID
|
||||
ch.maxRemotePayload = msg.MaxPacketSize
|
||||
ch.remoteWin.add(msg.MyWindow)
|
||||
ch.msg <- msg
|
||||
case *windowAdjustMsg:
|
||||
if !ch.remoteWin.add(msg.AdditionalBytes) {
|
||||
return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes)
|
||||
}
|
||||
case *channelRequestMsg:
|
||||
req := Request{
|
||||
Type: msg.Request,
|
||||
WantReply: msg.WantReply,
|
||||
Payload: msg.RequestSpecificData,
|
||||
ch: ch,
|
||||
}
|
||||
|
||||
ch.incomingRequests <- &req
|
||||
default:
|
||||
ch.msg <- msg
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel {
|
||||
ch := &channel{
|
||||
remoteWin: window{Cond: newCond()},
|
||||
myWindow: channelWindowSize,
|
||||
pending: newBuffer(),
|
||||
extPending: newBuffer(),
|
||||
direction: direction,
|
||||
incomingRequests: make(chan *Request, chanSize),
|
||||
msg: make(chan interface{}, chanSize),
|
||||
chanType: chanType,
|
||||
extraData: extraData,
|
||||
mux: m,
|
||||
packetPool: make(map[uint32][]byte),
|
||||
}
|
||||
ch.localId = m.chanList.add(ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
var errUndecided = errors.New("ssh: must Accept or Reject channel")
|
||||
var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once")
|
||||
|
||||
type extChannel struct {
|
||||
code uint32
|
||||
ch *channel
|
||||
}
|
||||
|
||||
func (e *extChannel) Write(data []byte) (n int, err error) {
|
||||
return e.ch.WriteExtended(data, e.code)
|
||||
}
|
||||
|
||||
func (e *extChannel) Read(data []byte) (n int, err error) {
|
||||
return e.ch.ReadExtended(data, e.code)
|
||||
}
|
||||
|
||||
func (ch *channel) Accept() (Channel, <-chan *Request, error) {
|
||||
if ch.decided {
|
||||
return nil, nil, errDecidedAlready
|
||||
}
|
||||
ch.maxIncomingPayload = channelMaxPacket
|
||||
confirm := channelOpenConfirmMsg{
|
||||
PeersID: ch.remoteId,
|
||||
MyID: ch.localId,
|
||||
MyWindow: ch.myWindow,
|
||||
MaxPacketSize: ch.maxIncomingPayload,
|
||||
}
|
||||
ch.decided = true
|
||||
if err := ch.sendMessage(confirm); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return ch, ch.incomingRequests, nil
|
||||
}
|
||||
|
||||
func (ch *channel) Reject(reason RejectionReason, message string) error {
|
||||
if ch.decided {
|
||||
return errDecidedAlready
|
||||
}
|
||||
reject := channelOpenFailureMsg{
|
||||
PeersID: ch.remoteId,
|
||||
Reason: reason,
|
||||
Message: message,
|
||||
Language: "en",
|
||||
}
|
||||
ch.decided = true
|
||||
return ch.sendMessage(reject)
|
||||
}
|
||||
|
||||
func (ch *channel) Read(data []byte) (int, error) {
|
||||
if !ch.decided {
|
||||
return 0, errUndecided
|
||||
}
|
||||
return ch.ReadExtended(data, 0)
|
||||
}
|
||||
|
||||
func (ch *channel) Write(data []byte) (int, error) {
|
||||
if !ch.decided {
|
||||
return 0, errUndecided
|
||||
}
|
||||
return ch.WriteExtended(data, 0)
|
||||
}
|
||||
|
||||
func (ch *channel) CloseWrite() error {
|
||||
if !ch.decided {
|
||||
return errUndecided
|
||||
}
|
||||
ch.sentEOF = true
|
||||
return ch.sendMessage(channelEOFMsg{
|
||||
PeersID: ch.remoteId})
|
||||
}
|
||||
|
||||
func (ch *channel) Close() error {
|
||||
if !ch.decided {
|
||||
return errUndecided
|
||||
}
|
||||
|
||||
return ch.sendMessage(channelCloseMsg{
|
||||
PeersID: ch.remoteId})
|
||||
}
|
||||
|
||||
// Extended returns an io.ReadWriter that sends and receives data on the given,
|
||||
// SSH extended stream. Such streams are used, for example, for stderr.
|
||||
func (ch *channel) Extended(code uint32) io.ReadWriter {
|
||||
if !ch.decided {
|
||||
return nil
|
||||
}
|
||||
return &extChannel{code, ch}
|
||||
}
|
||||
|
||||
func (ch *channel) Stderr() io.ReadWriter {
|
||||
return ch.Extended(1)
|
||||
}
|
||||
|
||||
func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
|
||||
if !ch.decided {
|
||||
return false, errUndecided
|
||||
}
|
||||
|
||||
if wantReply {
|
||||
ch.sentRequestMu.Lock()
|
||||
defer ch.sentRequestMu.Unlock()
|
||||
}
|
||||
|
||||
msg := channelRequestMsg{
|
||||
PeersID: ch.remoteId,
|
||||
Request: name,
|
||||
WantReply: wantReply,
|
||||
RequestSpecificData: payload,
|
||||
}
|
||||
|
||||
if err := ch.sendMessage(msg); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if wantReply {
|
||||
m, ok := (<-ch.msg)
|
||||
if !ok {
|
||||
return false, io.EOF
|
||||
}
|
||||
switch m.(type) {
|
||||
case *channelRequestFailureMsg:
|
||||
return false, nil
|
||||
case *channelRequestSuccessMsg:
|
||||
return true, nil
|
||||
default:
|
||||
return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m)
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// ackRequest either sends an ack or nack to the channel request.
|
||||
func (ch *channel) ackRequest(ok bool) error {
|
||||
if !ch.decided {
|
||||
return errUndecided
|
||||
}
|
||||
|
||||
var msg interface{}
|
||||
if !ok {
|
||||
msg = channelRequestFailureMsg{
|
||||
PeersID: ch.remoteId,
|
||||
}
|
||||
} else {
|
||||
msg = channelRequestSuccessMsg{
|
||||
PeersID: ch.remoteId,
|
||||
}
|
||||
}
|
||||
return ch.sendMessage(msg)
|
||||
}
|
||||
|
||||
func (ch *channel) ChannelType() string {
|
||||
return ch.chanType
|
||||
}
|
||||
|
||||
func (ch *channel) ExtraData() []byte {
|
||||
return ch.extraData
|
||||
}
|
||||
789
vendor/github.com/tailscale/golang-x-crypto/ssh/cipher.go
generated
vendored
Normal file
789
vendor/github.com/tailscale/golang-x-crypto/ssh/cipher.go
generated
vendored
Normal file
@@ -0,0 +1,789 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/des"
|
||||
"crypto/rc4"
|
||||
"crypto/subtle"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
|
||||
"github.com/tailscale/golang-x-crypto/internal/poly1305"
|
||||
"golang.org/x/crypto/chacha20"
|
||||
)
|
||||
|
||||
const (
|
||||
packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
|
||||
|
||||
// RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations
|
||||
// MUST be able to process (plus a few more kilobytes for padding and mac). The RFC
|
||||
// indicates implementations SHOULD be able to handle larger packet sizes, but then
|
||||
// waffles on about reasonable limits.
|
||||
//
|
||||
// OpenSSH caps their maxPacket at 256kB so we choose to do
|
||||
// the same. maxPacket is also used to ensure that uint32
|
||||
// length fields do not overflow, so it should remain well
|
||||
// below 4G.
|
||||
maxPacket = 256 * 1024
|
||||
)
|
||||
|
||||
// noneCipher implements cipher.Stream and provides no encryption. It is used
|
||||
// by the transport before the first key-exchange.
|
||||
type noneCipher struct{}
|
||||
|
||||
func (c noneCipher) XORKeyStream(dst, src []byte) {
|
||||
copy(dst, src)
|
||||
}
|
||||
|
||||
func newAESCTR(key, iv []byte) (cipher.Stream, error) {
|
||||
c, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cipher.NewCTR(c, iv), nil
|
||||
}
|
||||
|
||||
func newRC4(key, iv []byte) (cipher.Stream, error) {
|
||||
return rc4.NewCipher(key)
|
||||
}
|
||||
|
||||
type cipherMode struct {
|
||||
keySize int
|
||||
ivSize int
|
||||
create func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error)
|
||||
}
|
||||
|
||||
func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream, error)) func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
|
||||
return func(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
|
||||
stream, err := createFunc(key, iv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var streamDump []byte
|
||||
if skip > 0 {
|
||||
streamDump = make([]byte, 512)
|
||||
}
|
||||
|
||||
for remainingToDump := skip; remainingToDump > 0; {
|
||||
dumpThisTime := remainingToDump
|
||||
if dumpThisTime > len(streamDump) {
|
||||
dumpThisTime = len(streamDump)
|
||||
}
|
||||
stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime])
|
||||
remainingToDump -= dumpThisTime
|
||||
}
|
||||
|
||||
mac := macModes[algs.MAC].new(macKey)
|
||||
return &streamPacketCipher{
|
||||
mac: mac,
|
||||
etm: macModes[algs.MAC].etm,
|
||||
macResult: make([]byte, mac.Size()),
|
||||
cipher: stream,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// cipherModes documents properties of supported ciphers. Ciphers not included
|
||||
// are not supported and will not be negotiated, even if explicitly requested in
|
||||
// ClientConfig.Crypto.Ciphers.
|
||||
var cipherModes = map[string]*cipherMode{
|
||||
// Ciphers from RFC 4344, which introduced many CTR-based ciphers. Algorithms
|
||||
// are defined in the order specified in the RFC.
|
||||
"aes128-ctr": {16, aes.BlockSize, streamCipherMode(0, newAESCTR)},
|
||||
"aes192-ctr": {24, aes.BlockSize, streamCipherMode(0, newAESCTR)},
|
||||
"aes256-ctr": {32, aes.BlockSize, streamCipherMode(0, newAESCTR)},
|
||||
|
||||
// Ciphers from RFC 4345, which introduces security-improved arcfour ciphers.
|
||||
// They are defined in the order specified in the RFC.
|
||||
"arcfour128": {16, 0, streamCipherMode(1536, newRC4)},
|
||||
"arcfour256": {32, 0, streamCipherMode(1536, newRC4)},
|
||||
|
||||
// Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol.
|
||||
// Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and
|
||||
// RC4) has problems with weak keys, and should be used with caution."
|
||||
// RFC 4345 introduces improved versions of Arcfour.
|
||||
"arcfour": {16, 0, streamCipherMode(0, newRC4)},
|
||||
|
||||
// AEAD ciphers
|
||||
gcm128CipherID: {16, 12, newGCMCipher},
|
||||
gcm256CipherID: {32, 12, newGCMCipher},
|
||||
chacha20Poly1305ID: {64, 0, newChaCha20Cipher},
|
||||
|
||||
// CBC mode is insecure and so is not included in the default config.
|
||||
// (See https://www.ieee-security.org/TC/SP2013/papers/4977a526.pdf). If absolutely
|
||||
// needed, it's possible to specify a custom Config to enable it.
|
||||
// You should expect that an active attacker can recover plaintext if
|
||||
// you do.
|
||||
aes128cbcID: {16, aes.BlockSize, newAESCBCCipher},
|
||||
|
||||
// 3des-cbc is insecure and is not included in the default
|
||||
// config.
|
||||
tripledescbcID: {24, des.BlockSize, newTripleDESCBCCipher},
|
||||
}
|
||||
|
||||
// prefixLen is the length of the packet prefix that contains the packet length
|
||||
// and number of padding bytes.
|
||||
const prefixLen = 5
|
||||
|
||||
// streamPacketCipher is a packetCipher using a stream cipher.
|
||||
type streamPacketCipher struct {
|
||||
mac hash.Hash
|
||||
cipher cipher.Stream
|
||||
etm bool
|
||||
|
||||
// The following members are to avoid per-packet allocations.
|
||||
prefix [prefixLen]byte
|
||||
seqNumBytes [4]byte
|
||||
padding [2 * packetSizeMultiple]byte
|
||||
packetData []byte
|
||||
macResult []byte
|
||||
}
|
||||
|
||||
// readCipherPacket reads and decrypt a single packet from the reader argument.
|
||||
func (s *streamPacketCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) {
|
||||
if _, err := io.ReadFull(r, s.prefix[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var encryptedPaddingLength [1]byte
|
||||
if s.mac != nil && s.etm {
|
||||
copy(encryptedPaddingLength[:], s.prefix[4:5])
|
||||
s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5])
|
||||
} else {
|
||||
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
|
||||
}
|
||||
|
||||
length := binary.BigEndian.Uint32(s.prefix[0:4])
|
||||
paddingLength := uint32(s.prefix[4])
|
||||
|
||||
var macSize uint32
|
||||
if s.mac != nil {
|
||||
s.mac.Reset()
|
||||
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum)
|
||||
s.mac.Write(s.seqNumBytes[:])
|
||||
if s.etm {
|
||||
s.mac.Write(s.prefix[:4])
|
||||
s.mac.Write(encryptedPaddingLength[:])
|
||||
} else {
|
||||
s.mac.Write(s.prefix[:])
|
||||
}
|
||||
macSize = uint32(s.mac.Size())
|
||||
}
|
||||
|
||||
if length <= paddingLength+1 {
|
||||
return nil, errors.New("ssh: invalid packet length, packet too small")
|
||||
}
|
||||
|
||||
if length > maxPacket {
|
||||
return nil, errors.New("ssh: invalid packet length, packet too large")
|
||||
}
|
||||
|
||||
// the maxPacket check above ensures that length-1+macSize
|
||||
// does not overflow.
|
||||
if uint32(cap(s.packetData)) < length-1+macSize {
|
||||
s.packetData = make([]byte, length-1+macSize)
|
||||
} else {
|
||||
s.packetData = s.packetData[:length-1+macSize]
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(r, s.packetData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mac := s.packetData[length-1:]
|
||||
data := s.packetData[:length-1]
|
||||
|
||||
if s.mac != nil && s.etm {
|
||||
s.mac.Write(data)
|
||||
}
|
||||
|
||||
s.cipher.XORKeyStream(data, data)
|
||||
|
||||
if s.mac != nil {
|
||||
if !s.etm {
|
||||
s.mac.Write(data)
|
||||
}
|
||||
s.macResult = s.mac.Sum(s.macResult[:0])
|
||||
if subtle.ConstantTimeCompare(s.macResult, mac) != 1 {
|
||||
return nil, errors.New("ssh: MAC failure")
|
||||
}
|
||||
}
|
||||
|
||||
return s.packetData[:length-paddingLength-1], nil
|
||||
}
|
||||
|
||||
// writeCipherPacket encrypts and sends a packet of data to the writer argument
|
||||
func (s *streamPacketCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error {
|
||||
if len(packet) > maxPacket {
|
||||
return errors.New("ssh: packet too large")
|
||||
}
|
||||
|
||||
aadlen := 0
|
||||
if s.mac != nil && s.etm {
|
||||
// packet length is not encrypted for EtM modes
|
||||
aadlen = 4
|
||||
}
|
||||
|
||||
paddingLength := packetSizeMultiple - (prefixLen+len(packet)-aadlen)%packetSizeMultiple
|
||||
if paddingLength < 4 {
|
||||
paddingLength += packetSizeMultiple
|
||||
}
|
||||
|
||||
length := len(packet) + 1 + paddingLength
|
||||
binary.BigEndian.PutUint32(s.prefix[:], uint32(length))
|
||||
s.prefix[4] = byte(paddingLength)
|
||||
padding := s.padding[:paddingLength]
|
||||
if _, err := io.ReadFull(rand, padding); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.mac != nil {
|
||||
s.mac.Reset()
|
||||
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum)
|
||||
s.mac.Write(s.seqNumBytes[:])
|
||||
|
||||
if s.etm {
|
||||
// For EtM algorithms, the packet length must stay unencrypted,
|
||||
// but the following data (padding length) must be encrypted
|
||||
s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5])
|
||||
}
|
||||
|
||||
s.mac.Write(s.prefix[:])
|
||||
|
||||
if !s.etm {
|
||||
// For non-EtM algorithms, the algorithm is applied on unencrypted data
|
||||
s.mac.Write(packet)
|
||||
s.mac.Write(padding)
|
||||
}
|
||||
}
|
||||
|
||||
if !(s.mac != nil && s.etm) {
|
||||
// For EtM algorithms, the padding length has already been encrypted
|
||||
// and the packet length must remain unencrypted
|
||||
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
|
||||
}
|
||||
|
||||
s.cipher.XORKeyStream(packet, packet)
|
||||
s.cipher.XORKeyStream(padding, padding)
|
||||
|
||||
if s.mac != nil && s.etm {
|
||||
// For EtM algorithms, packet and padding must be encrypted
|
||||
s.mac.Write(packet)
|
||||
s.mac.Write(padding)
|
||||
}
|
||||
|
||||
if _, err := w.Write(s.prefix[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.Write(packet); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.Write(padding); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.mac != nil {
|
||||
s.macResult = s.mac.Sum(s.macResult[:0])
|
||||
if _, err := w.Write(s.macResult); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type gcmCipher struct {
|
||||
aead cipher.AEAD
|
||||
prefix [4]byte
|
||||
iv []byte
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func newGCMCipher(key, iv, unusedMacKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) {
|
||||
c, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
aead, err := cipher.NewGCM(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &gcmCipher{
|
||||
aead: aead,
|
||||
iv: iv,
|
||||
}, nil
|
||||
}
|
||||
|
||||
const gcmTagSize = 16
|
||||
|
||||
func (c *gcmCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error {
|
||||
// Pad out to multiple of 16 bytes. This is different from the
|
||||
// stream cipher because that encrypts the length too.
|
||||
padding := byte(packetSizeMultiple - (1+len(packet))%packetSizeMultiple)
|
||||
if padding < 4 {
|
||||
padding += packetSizeMultiple
|
||||
}
|
||||
|
||||
length := uint32(len(packet) + int(padding) + 1)
|
||||
binary.BigEndian.PutUint32(c.prefix[:], length)
|
||||
if _, err := w.Write(c.prefix[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cap(c.buf) < int(length) {
|
||||
c.buf = make([]byte, length)
|
||||
} else {
|
||||
c.buf = c.buf[:length]
|
||||
}
|
||||
|
||||
c.buf[0] = padding
|
||||
copy(c.buf[1:], packet)
|
||||
if _, err := io.ReadFull(rand, c.buf[1+len(packet):]); err != nil {
|
||||
return err
|
||||
}
|
||||
c.buf = c.aead.Seal(c.buf[:0], c.iv, c.buf, c.prefix[:])
|
||||
if _, err := w.Write(c.buf); err != nil {
|
||||
return err
|
||||
}
|
||||
c.incIV()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *gcmCipher) incIV() {
|
||||
for i := 4 + 7; i >= 4; i-- {
|
||||
c.iv[i]++
|
||||
if c.iv[i] != 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *gcmCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) {
|
||||
if _, err := io.ReadFull(r, c.prefix[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
length := binary.BigEndian.Uint32(c.prefix[:])
|
||||
if length > maxPacket {
|
||||
return nil, errors.New("ssh: max packet length exceeded")
|
||||
}
|
||||
|
||||
if cap(c.buf) < int(length+gcmTagSize) {
|
||||
c.buf = make([]byte, length+gcmTagSize)
|
||||
} else {
|
||||
c.buf = c.buf[:length+gcmTagSize]
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(r, c.buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
plain, err := c.aead.Open(c.buf[:0], c.iv, c.buf, c.prefix[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.incIV()
|
||||
|
||||
if len(plain) == 0 {
|
||||
return nil, errors.New("ssh: empty packet")
|
||||
}
|
||||
|
||||
padding := plain[0]
|
||||
if padding < 4 {
|
||||
// padding is a byte, so it automatically satisfies
|
||||
// the maximum size, which is 255.
|
||||
return nil, fmt.Errorf("ssh: illegal padding %d", padding)
|
||||
}
|
||||
|
||||
if int(padding+1) >= len(plain) {
|
||||
return nil, fmt.Errorf("ssh: padding %d too large", padding)
|
||||
}
|
||||
plain = plain[1 : length-uint32(padding)]
|
||||
return plain, nil
|
||||
}
|
||||
|
||||
// cbcCipher implements aes128-cbc cipher defined in RFC 4253 section 6.1
|
||||
type cbcCipher struct {
|
||||
mac hash.Hash
|
||||
macSize uint32
|
||||
decrypter cipher.BlockMode
|
||||
encrypter cipher.BlockMode
|
||||
|
||||
// The following members are to avoid per-packet allocations.
|
||||
seqNumBytes [4]byte
|
||||
packetData []byte
|
||||
macResult []byte
|
||||
|
||||
// Amount of data we should still read to hide which
|
||||
// verification error triggered.
|
||||
oracleCamouflage uint32
|
||||
}
|
||||
|
||||
func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
|
||||
cbc := &cbcCipher{
|
||||
mac: macModes[algs.MAC].new(macKey),
|
||||
decrypter: cipher.NewCBCDecrypter(c, iv),
|
||||
encrypter: cipher.NewCBCEncrypter(c, iv),
|
||||
packetData: make([]byte, 1024),
|
||||
}
|
||||
if cbc.mac != nil {
|
||||
cbc.macSize = uint32(cbc.mac.Size())
|
||||
}
|
||||
|
||||
return cbc, nil
|
||||
}
|
||||
|
||||
func newAESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
|
||||
c, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cbc, err := newCBCCipher(c, key, iv, macKey, algs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cbc, nil
|
||||
}
|
||||
|
||||
func newTripleDESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
|
||||
c, err := des.NewTripleDESCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cbc, err := newCBCCipher(c, key, iv, macKey, algs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cbc, nil
|
||||
}
|
||||
|
||||
func maxUInt32(a, b int) uint32 {
|
||||
if a > b {
|
||||
return uint32(a)
|
||||
}
|
||||
return uint32(b)
|
||||
}
|
||||
|
||||
const (
|
||||
cbcMinPacketSizeMultiple = 8
|
||||
cbcMinPacketSize = 16
|
||||
cbcMinPaddingSize = 4
|
||||
)
|
||||
|
||||
// cbcError represents a verification error that may leak information.
|
||||
type cbcError string
|
||||
|
||||
func (e cbcError) Error() string { return string(e) }
|
||||
|
||||
func (c *cbcCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) {
|
||||
p, err := c.readCipherPacketLeaky(seqNum, r)
|
||||
if err != nil {
|
||||
if _, ok := err.(cbcError); ok {
|
||||
// Verification error: read a fixed amount of
|
||||
// data, to make distinguishing between
|
||||
// failing MAC and failing length check more
|
||||
// difficult.
|
||||
io.CopyN(io.Discard, r, int64(c.oracleCamouflage))
|
||||
}
|
||||
}
|
||||
return p, err
|
||||
}
|
||||
|
||||
func (c *cbcCipher) readCipherPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error) {
|
||||
blockSize := c.decrypter.BlockSize()
|
||||
|
||||
// Read the header, which will include some of the subsequent data in the
|
||||
// case of block ciphers - this is copied back to the payload later.
|
||||
// How many bytes of payload/padding will be read with this first read.
|
||||
firstBlockLength := uint32((prefixLen + blockSize - 1) / blockSize * blockSize)
|
||||
firstBlock := c.packetData[:firstBlockLength]
|
||||
if _, err := io.ReadFull(r, firstBlock); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.oracleCamouflage = maxPacket + 4 + c.macSize - firstBlockLength
|
||||
|
||||
c.decrypter.CryptBlocks(firstBlock, firstBlock)
|
||||
length := binary.BigEndian.Uint32(firstBlock[:4])
|
||||
if length > maxPacket {
|
||||
return nil, cbcError("ssh: packet too large")
|
||||
}
|
||||
if length+4 < maxUInt32(cbcMinPacketSize, blockSize) {
|
||||
// The minimum size of a packet is 16 (or the cipher block size, whichever
|
||||
// is larger) bytes.
|
||||
return nil, cbcError("ssh: packet too small")
|
||||
}
|
||||
// The length of the packet (including the length field but not the MAC) must
|
||||
// be a multiple of the block size or 8, whichever is larger.
|
||||
if (length+4)%maxUInt32(cbcMinPacketSizeMultiple, blockSize) != 0 {
|
||||
return nil, cbcError("ssh: invalid packet length multiple")
|
||||
}
|
||||
|
||||
paddingLength := uint32(firstBlock[4])
|
||||
if paddingLength < cbcMinPaddingSize || length <= paddingLength+1 {
|
||||
return nil, cbcError("ssh: invalid packet length")
|
||||
}
|
||||
|
||||
// Positions within the c.packetData buffer:
|
||||
macStart := 4 + length
|
||||
paddingStart := macStart - paddingLength
|
||||
|
||||
// Entire packet size, starting before length, ending at end of mac.
|
||||
entirePacketSize := macStart + c.macSize
|
||||
|
||||
// Ensure c.packetData is large enough for the entire packet data.
|
||||
if uint32(cap(c.packetData)) < entirePacketSize {
|
||||
// Still need to upsize and copy, but this should be rare at runtime, only
|
||||
// on upsizing the packetData buffer.
|
||||
c.packetData = make([]byte, entirePacketSize)
|
||||
copy(c.packetData, firstBlock)
|
||||
} else {
|
||||
c.packetData = c.packetData[:entirePacketSize]
|
||||
}
|
||||
|
||||
n, err := io.ReadFull(r, c.packetData[firstBlockLength:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.oracleCamouflage -= uint32(n)
|
||||
|
||||
remainingCrypted := c.packetData[firstBlockLength:macStart]
|
||||
c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted)
|
||||
|
||||
mac := c.packetData[macStart:]
|
||||
if c.mac != nil {
|
||||
c.mac.Reset()
|
||||
binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum)
|
||||
c.mac.Write(c.seqNumBytes[:])
|
||||
c.mac.Write(c.packetData[:macStart])
|
||||
c.macResult = c.mac.Sum(c.macResult[:0])
|
||||
if subtle.ConstantTimeCompare(c.macResult, mac) != 1 {
|
||||
return nil, cbcError("ssh: MAC failure")
|
||||
}
|
||||
}
|
||||
|
||||
return c.packetData[prefixLen:paddingStart], nil
|
||||
}
|
||||
|
||||
func (c *cbcCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error {
|
||||
effectiveBlockSize := maxUInt32(cbcMinPacketSizeMultiple, c.encrypter.BlockSize())
|
||||
|
||||
// Length of encrypted portion of the packet (header, payload, padding).
|
||||
// Enforce minimum padding and packet size.
|
||||
encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPaddingSize)
|
||||
// Enforce block size.
|
||||
encLength = (encLength + effectiveBlockSize - 1) / effectiveBlockSize * effectiveBlockSize
|
||||
|
||||
length := encLength - 4
|
||||
paddingLength := int(length) - (1 + len(packet))
|
||||
|
||||
// Overall buffer contains: header, payload, padding, mac.
|
||||
// Space for the MAC is reserved in the capacity but not the slice length.
|
||||
bufferSize := encLength + c.macSize
|
||||
if uint32(cap(c.packetData)) < bufferSize {
|
||||
c.packetData = make([]byte, encLength, bufferSize)
|
||||
} else {
|
||||
c.packetData = c.packetData[:encLength]
|
||||
}
|
||||
|
||||
p := c.packetData
|
||||
|
||||
// Packet header.
|
||||
binary.BigEndian.PutUint32(p, length)
|
||||
p = p[4:]
|
||||
p[0] = byte(paddingLength)
|
||||
|
||||
// Payload.
|
||||
p = p[1:]
|
||||
copy(p, packet)
|
||||
|
||||
// Padding.
|
||||
p = p[len(packet):]
|
||||
if _, err := io.ReadFull(rand, p); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.mac != nil {
|
||||
c.mac.Reset()
|
||||
binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum)
|
||||
c.mac.Write(c.seqNumBytes[:])
|
||||
c.mac.Write(c.packetData)
|
||||
// The MAC is now appended into the capacity reserved for it earlier.
|
||||
c.packetData = c.mac.Sum(c.packetData)
|
||||
}
|
||||
|
||||
c.encrypter.CryptBlocks(c.packetData[:encLength], c.packetData[:encLength])
|
||||
|
||||
if _, err := w.Write(c.packetData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const chacha20Poly1305ID = "chacha20-poly1305@openssh.com"
|
||||
|
||||
// chacha20Poly1305Cipher implements the chacha20-poly1305@openssh.com
|
||||
// AEAD, which is described here:
|
||||
//
|
||||
// https://tools.ietf.org/html/draft-josefsson-ssh-chacha20-poly1305-openssh-00
|
||||
//
|
||||
// the methods here also implement padding, which RFC 4253 Section 6
|
||||
// also requires of stream ciphers.
|
||||
type chacha20Poly1305Cipher struct {
|
||||
lengthKey [32]byte
|
||||
contentKey [32]byte
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func newChaCha20Cipher(key, unusedIV, unusedMACKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) {
|
||||
if len(key) != 64 {
|
||||
panic(len(key))
|
||||
}
|
||||
|
||||
c := &chacha20Poly1305Cipher{
|
||||
buf: make([]byte, 256),
|
||||
}
|
||||
|
||||
copy(c.contentKey[:], key[:32])
|
||||
copy(c.lengthKey[:], key[32:])
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *chacha20Poly1305Cipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) {
|
||||
nonce := make([]byte, 12)
|
||||
binary.BigEndian.PutUint32(nonce[8:], seqNum)
|
||||
s, err := chacha20.NewUnauthenticatedCipher(c.contentKey[:], nonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var polyKey, discardBuf [32]byte
|
||||
s.XORKeyStream(polyKey[:], polyKey[:])
|
||||
s.XORKeyStream(discardBuf[:], discardBuf[:]) // skip the next 32 bytes
|
||||
|
||||
encryptedLength := c.buf[:4]
|
||||
if _, err := io.ReadFull(r, encryptedLength); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var lenBytes [4]byte
|
||||
ls, err := chacha20.NewUnauthenticatedCipher(c.lengthKey[:], nonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ls.XORKeyStream(lenBytes[:], encryptedLength)
|
||||
|
||||
length := binary.BigEndian.Uint32(lenBytes[:])
|
||||
if length > maxPacket {
|
||||
return nil, errors.New("ssh: invalid packet length, packet too large")
|
||||
}
|
||||
|
||||
contentEnd := 4 + length
|
||||
packetEnd := contentEnd + poly1305.TagSize
|
||||
if uint32(cap(c.buf)) < packetEnd {
|
||||
c.buf = make([]byte, packetEnd)
|
||||
copy(c.buf[:], encryptedLength)
|
||||
} else {
|
||||
c.buf = c.buf[:packetEnd]
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(r, c.buf[4:packetEnd]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var mac [poly1305.TagSize]byte
|
||||
copy(mac[:], c.buf[contentEnd:packetEnd])
|
||||
if !poly1305.Verify(&mac, c.buf[:contentEnd], &polyKey) {
|
||||
return nil, errors.New("ssh: MAC failure")
|
||||
}
|
||||
|
||||
plain := c.buf[4:contentEnd]
|
||||
s.XORKeyStream(plain, plain)
|
||||
|
||||
if len(plain) == 0 {
|
||||
return nil, errors.New("ssh: empty packet")
|
||||
}
|
||||
|
||||
padding := plain[0]
|
||||
if padding < 4 {
|
||||
// padding is a byte, so it automatically satisfies
|
||||
// the maximum size, which is 255.
|
||||
return nil, fmt.Errorf("ssh: illegal padding %d", padding)
|
||||
}
|
||||
|
||||
if int(padding)+1 >= len(plain) {
|
||||
return nil, fmt.Errorf("ssh: padding %d too large", padding)
|
||||
}
|
||||
|
||||
plain = plain[1 : len(plain)-int(padding)]
|
||||
|
||||
return plain, nil
|
||||
}
|
||||
|
||||
func (c *chacha20Poly1305Cipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, payload []byte) error {
|
||||
nonce := make([]byte, 12)
|
||||
binary.BigEndian.PutUint32(nonce[8:], seqNum)
|
||||
s, err := chacha20.NewUnauthenticatedCipher(c.contentKey[:], nonce)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var polyKey, discardBuf [32]byte
|
||||
s.XORKeyStream(polyKey[:], polyKey[:])
|
||||
s.XORKeyStream(discardBuf[:], discardBuf[:]) // skip the next 32 bytes
|
||||
|
||||
// There is no blocksize, so fall back to multiple of 8 byte
|
||||
// padding, as described in RFC 4253, Sec 6.
|
||||
const packetSizeMultiple = 8
|
||||
|
||||
padding := packetSizeMultiple - (1+len(payload))%packetSizeMultiple
|
||||
if padding < 4 {
|
||||
padding += packetSizeMultiple
|
||||
}
|
||||
|
||||
// size (4 bytes), padding (1), payload, padding, tag.
|
||||
totalLength := 4 + 1 + len(payload) + padding + poly1305.TagSize
|
||||
if cap(c.buf) < totalLength {
|
||||
c.buf = make([]byte, totalLength)
|
||||
} else {
|
||||
c.buf = c.buf[:totalLength]
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint32(c.buf, uint32(1+len(payload)+padding))
|
||||
ls, err := chacha20.NewUnauthenticatedCipher(c.lengthKey[:], nonce)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ls.XORKeyStream(c.buf, c.buf[:4])
|
||||
c.buf[4] = byte(padding)
|
||||
copy(c.buf[5:], payload)
|
||||
packetEnd := 5 + len(payload) + padding
|
||||
if _, err := io.ReadFull(rand, c.buf[5+len(payload):packetEnd]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.XORKeyStream(c.buf[4:], c.buf[4:packetEnd])
|
||||
|
||||
var mac [poly1305.TagSize]byte
|
||||
poly1305.Sum(&mac, c.buf[:packetEnd], &polyKey)
|
||||
|
||||
copy(c.buf[packetEnd:], mac[:])
|
||||
|
||||
if _, err := w.Write(c.buf); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
282
vendor/github.com/tailscale/golang-x-crypto/ssh/client.go
generated
vendored
Normal file
282
vendor/github.com/tailscale/golang-x-crypto/ssh/client.go
generated
vendored
Normal file
@@ -0,0 +1,282 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client implements a traditional SSH client that supports shells,
|
||||
// subprocesses, TCP port/streamlocal forwarding and tunneled dialing.
|
||||
type Client struct {
|
||||
Conn
|
||||
|
||||
handleForwardsOnce sync.Once // guards calling (*Client).handleForwards
|
||||
|
||||
forwards forwardList // forwarded tcpip connections from the remote side
|
||||
mu sync.Mutex
|
||||
channelHandlers map[string]chan NewChannel
|
||||
}
|
||||
|
||||
// HandleChannelOpen returns a channel on which NewChannel requests
|
||||
// for the given type are sent. If the type already is being handled,
|
||||
// nil is returned. The channel is closed when the connection is closed.
|
||||
func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.channelHandlers == nil {
|
||||
// The SSH channel has been closed.
|
||||
c := make(chan NewChannel)
|
||||
close(c)
|
||||
return c
|
||||
}
|
||||
|
||||
ch := c.channelHandlers[channelType]
|
||||
if ch != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch = make(chan NewChannel, chanSize)
|
||||
c.channelHandlers[channelType] = ch
|
||||
return ch
|
||||
}
|
||||
|
||||
// NewClient creates a Client on top of the given connection.
|
||||
func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
|
||||
conn := &Client{
|
||||
Conn: c,
|
||||
channelHandlers: make(map[string]chan NewChannel, 1),
|
||||
}
|
||||
|
||||
go conn.handleGlobalRequests(reqs)
|
||||
go conn.handleChannelOpens(chans)
|
||||
go func() {
|
||||
conn.Wait()
|
||||
conn.forwards.closeAll()
|
||||
}()
|
||||
return conn
|
||||
}
|
||||
|
||||
// NewClientConn establishes an authenticated SSH connection using c
|
||||
// as the underlying transport. The Request and NewChannel channels
|
||||
// must be serviced or the connection will hang.
|
||||
func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) {
|
||||
fullConf := *config
|
||||
fullConf.SetDefaults()
|
||||
if fullConf.HostKeyCallback == nil {
|
||||
c.Close()
|
||||
return nil, nil, nil, errors.New("ssh: must specify HostKeyCallback")
|
||||
}
|
||||
|
||||
conn := &connection{
|
||||
sshConn: sshConn{conn: c, user: fullConf.User},
|
||||
}
|
||||
|
||||
if err := conn.clientHandshake(addr, &fullConf); err != nil {
|
||||
c.Close()
|
||||
return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %w", err)
|
||||
}
|
||||
conn.mux = newMux(conn.transport)
|
||||
return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil
|
||||
}
|
||||
|
||||
// clientHandshake performs the client side key exchange. See RFC 4253 Section
|
||||
// 7.
|
||||
func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error {
|
||||
if config.ClientVersion != "" {
|
||||
c.clientVersion = []byte(config.ClientVersion)
|
||||
} else {
|
||||
c.clientVersion = []byte(packageVersion)
|
||||
}
|
||||
var err error
|
||||
c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.transport = newClientTransport(
|
||||
newTransport(c.sshConn.conn, config.Rand, true /* is client */),
|
||||
c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr())
|
||||
if err := c.transport.waitSession(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.sessionID = c.transport.getSessionID()
|
||||
return c.clientAuthenticate(config)
|
||||
}
|
||||
|
||||
// verifyHostKeySignature verifies the host key obtained in the key exchange.
|
||||
// algo is the negotiated algorithm, and may be a certificate type.
|
||||
func verifyHostKeySignature(hostKey PublicKey, algo string, result *kexResult) error {
|
||||
sig, rest, ok := parseSignatureBody(result.Signature)
|
||||
if len(rest) > 0 || !ok {
|
||||
return errors.New("ssh: signature parse error")
|
||||
}
|
||||
|
||||
if a := underlyingAlgo(algo); sig.Format != a {
|
||||
return fmt.Errorf("ssh: invalid signature algorithm %q, expected %q", sig.Format, a)
|
||||
}
|
||||
|
||||
return hostKey.Verify(result.H, sig)
|
||||
}
|
||||
|
||||
// NewSession opens a new Session for this client. (A session is a remote
|
||||
// execution of a program.)
|
||||
func (c *Client) NewSession() (*Session, error) {
|
||||
ch, in, err := c.OpenChannel("session", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newSession(ch, in)
|
||||
}
|
||||
|
||||
func (c *Client) handleGlobalRequests(incoming <-chan *Request) {
|
||||
for r := range incoming {
|
||||
// This handles keepalive messages and matches
|
||||
// the behaviour of OpenSSH.
|
||||
r.Reply(false, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// handleChannelOpens channel open messages from the remote side.
|
||||
func (c *Client) handleChannelOpens(in <-chan NewChannel) {
|
||||
for ch := range in {
|
||||
c.mu.Lock()
|
||||
handler := c.channelHandlers[ch.ChannelType()]
|
||||
c.mu.Unlock()
|
||||
|
||||
if handler != nil {
|
||||
handler <- ch
|
||||
} else {
|
||||
ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type: %v", ch.ChannelType()))
|
||||
}
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
for _, ch := range c.channelHandlers {
|
||||
close(ch)
|
||||
}
|
||||
c.channelHandlers = nil
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// Dial starts a client connection to the given SSH server. It is a
|
||||
// convenience function that connects to the given network address,
|
||||
// initiates the SSH handshake, and then sets up a Client. For access
|
||||
// to incoming channels and requests, use net.Dial with NewClientConn
|
||||
// instead.
|
||||
func Dial(network, addr string, config *ClientConfig) (*Client, error) {
|
||||
conn, err := net.DialTimeout(network, addr, config.Timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, chans, reqs, err := NewClientConn(conn, addr, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewClient(c, chans, reqs), nil
|
||||
}
|
||||
|
||||
// HostKeyCallback is the function type used for verifying server
|
||||
// keys. A HostKeyCallback must return nil if the host key is OK, or
|
||||
// an error to reject it. It receives the hostname as passed to Dial
|
||||
// or NewClientConn. The remote address is the RemoteAddr of the
|
||||
// net.Conn underlying the SSH connection.
|
||||
type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
|
||||
|
||||
// BannerCallback is the function type used for treat the banner sent by
|
||||
// the server. A BannerCallback receives the message sent by the remote server.
|
||||
type BannerCallback func(message string) error
|
||||
|
||||
// A ClientConfig structure is used to configure a Client. It must not be
|
||||
// modified after having been passed to an SSH function.
|
||||
type ClientConfig struct {
|
||||
// Config contains configuration that is shared between clients and
|
||||
// servers.
|
||||
Config
|
||||
|
||||
// User contains the username to authenticate as.
|
||||
User string
|
||||
|
||||
// Auth contains possible authentication methods to use with the
|
||||
// server. Only the first instance of a particular RFC 4252 method will
|
||||
// be used during authentication.
|
||||
Auth []AuthMethod
|
||||
|
||||
// HostKeyCallback is called during the cryptographic
|
||||
// handshake to validate the server's host key. The client
|
||||
// configuration must supply this callback for the connection
|
||||
// to succeed. The functions InsecureIgnoreHostKey or
|
||||
// FixedHostKey can be used for simplistic host key checks.
|
||||
HostKeyCallback HostKeyCallback
|
||||
|
||||
// BannerCallback is called during the SSH dance to display a custom
|
||||
// server's message. The client configuration can supply this callback to
|
||||
// handle it as wished. The function BannerDisplayStderr can be used for
|
||||
// simplistic display on Stderr.
|
||||
BannerCallback BannerCallback
|
||||
|
||||
// ClientVersion contains the version identification string that will
|
||||
// be used for the connection. If empty, a reasonable default is used.
|
||||
ClientVersion string
|
||||
|
||||
// HostKeyAlgorithms lists the public key algorithms that the client will
|
||||
// accept from the server for host key authentication, in order of
|
||||
// preference. If empty, a reasonable default is used. Any
|
||||
// string returned from a PublicKey.Type method may be used, or
|
||||
// any of the CertAlgo and KeyAlgo constants.
|
||||
HostKeyAlgorithms []string
|
||||
|
||||
// Timeout is the maximum amount of time for the TCP connection to establish.
|
||||
//
|
||||
// A Timeout of zero means no timeout.
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// InsecureIgnoreHostKey returns a function that can be used for
|
||||
// ClientConfig.HostKeyCallback to accept any host key. It should
|
||||
// not be used for production code.
|
||||
func InsecureIgnoreHostKey() HostKeyCallback {
|
||||
return func(hostname string, remote net.Addr, key PublicKey) error {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type fixedHostKey struct {
|
||||
key PublicKey
|
||||
}
|
||||
|
||||
func (f *fixedHostKey) check(hostname string, remote net.Addr, key PublicKey) error {
|
||||
if f.key == nil {
|
||||
return fmt.Errorf("ssh: required host key was nil")
|
||||
}
|
||||
if !bytes.Equal(key.Marshal(), f.key.Marshal()) {
|
||||
return fmt.Errorf("ssh: host key mismatch")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FixedHostKey returns a function for use in
|
||||
// ClientConfig.HostKeyCallback to accept only a specific host key.
|
||||
func FixedHostKey(key PublicKey) HostKeyCallback {
|
||||
hk := &fixedHostKey{key}
|
||||
return hk.check
|
||||
}
|
||||
|
||||
// BannerDisplayStderr returns a function that can be used for
|
||||
// ClientConfig.BannerCallback to display banners on os.Stderr.
|
||||
func BannerDisplayStderr() BannerCallback {
|
||||
return func(banner string) error {
|
||||
_, err := os.Stderr.WriteString(banner)
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
779
vendor/github.com/tailscale/golang-x-crypto/ssh/client_auth.go
generated
vendored
Normal file
779
vendor/github.com/tailscale/golang-x-crypto/ssh/client_auth.go
generated
vendored
Normal file
@@ -0,0 +1,779 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type authResult int
|
||||
|
||||
const (
|
||||
authFailure authResult = iota
|
||||
authPartialSuccess
|
||||
authSuccess
|
||||
)
|
||||
|
||||
// clientAuthenticate authenticates with the remote server. See RFC 4252.
|
||||
func (c *connection) clientAuthenticate(config *ClientConfig) error {
|
||||
// initiate user auth session
|
||||
if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil {
|
||||
return err
|
||||
}
|
||||
packet, err := c.transport.readPacket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// The server may choose to send a SSH_MSG_EXT_INFO at this point (if we
|
||||
// advertised willingness to receive one, which we always do) or not. See
|
||||
// RFC 8308, Section 2.4.
|
||||
extensions := make(map[string][]byte)
|
||||
if len(packet) > 0 && packet[0] == msgExtInfo {
|
||||
var extInfo extInfoMsg
|
||||
if err := Unmarshal(packet, &extInfo); err != nil {
|
||||
return err
|
||||
}
|
||||
payload := extInfo.Payload
|
||||
for i := uint32(0); i < extInfo.NumExtensions; i++ {
|
||||
name, rest, ok := parseString(payload)
|
||||
if !ok {
|
||||
return parseError(msgExtInfo)
|
||||
}
|
||||
value, rest, ok := parseString(rest)
|
||||
if !ok {
|
||||
return parseError(msgExtInfo)
|
||||
}
|
||||
extensions[string(name)] = value
|
||||
payload = rest
|
||||
}
|
||||
packet, err = c.transport.readPacket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
var serviceAccept serviceAcceptMsg
|
||||
if err := Unmarshal(packet, &serviceAccept); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// during the authentication phase the client first attempts the "none" method
|
||||
// then any untried methods suggested by the server.
|
||||
var tried []string
|
||||
var lastMethods []string
|
||||
|
||||
sessionID := c.transport.getSessionID()
|
||||
for auth := AuthMethod(new(noneAuth)); auth != nil; {
|
||||
ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand, extensions)
|
||||
if err != nil {
|
||||
// We return the error later if there is no other method left to
|
||||
// try.
|
||||
ok = authFailure
|
||||
}
|
||||
if ok == authSuccess {
|
||||
// success
|
||||
return nil
|
||||
} else if ok == authFailure {
|
||||
if m := auth.method(); !contains(tried, m) {
|
||||
tried = append(tried, m)
|
||||
}
|
||||
}
|
||||
if methods == nil {
|
||||
methods = lastMethods
|
||||
}
|
||||
lastMethods = methods
|
||||
|
||||
auth = nil
|
||||
|
||||
findNext:
|
||||
for _, a := range config.Auth {
|
||||
candidateMethod := a.method()
|
||||
if contains(tried, candidateMethod) {
|
||||
continue
|
||||
}
|
||||
for _, meth := range methods {
|
||||
if meth == candidateMethod {
|
||||
auth = a
|
||||
break findNext
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if auth == nil && err != nil {
|
||||
// We have an error and there are no other authentication methods to
|
||||
// try, so we return it.
|
||||
return err
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", tried)
|
||||
}
|
||||
|
||||
func contains(list []string, e string) bool {
|
||||
for _, s := range list {
|
||||
if s == e {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// An AuthMethod represents an instance of an RFC 4252 authentication method.
|
||||
type AuthMethod interface {
|
||||
// auth authenticates user over transport t.
|
||||
// Returns true if authentication is successful.
|
||||
// If authentication is not successful, a []string of alternative
|
||||
// method names is returned. If the slice is nil, it will be ignored
|
||||
// and the previous set of possible methods will be reused.
|
||||
auth(session []byte, user string, p packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error)
|
||||
|
||||
// method returns the RFC 4252 method name.
|
||||
method() string
|
||||
}
|
||||
|
||||
// "none" authentication, RFC 4252 section 5.2.
|
||||
type noneAuth int
|
||||
|
||||
func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
|
||||
if err := c.writePacket(Marshal(&userAuthRequestMsg{
|
||||
User: user,
|
||||
Service: serviceSSH,
|
||||
Method: "none",
|
||||
})); err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
|
||||
return handleAuthResponse(c)
|
||||
}
|
||||
|
||||
func (n *noneAuth) method() string {
|
||||
return "none"
|
||||
}
|
||||
|
||||
// passwordCallback is an AuthMethod that fetches the password through
|
||||
// a function call, e.g. by prompting the user.
|
||||
type passwordCallback func() (password string, err error)
|
||||
|
||||
func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
|
||||
type passwordAuthMsg struct {
|
||||
User string `sshtype:"50"`
|
||||
Service string
|
||||
Method string
|
||||
Reply bool
|
||||
Password string
|
||||
}
|
||||
|
||||
pw, err := cb()
|
||||
// REVIEW NOTE: is there a need to support skipping a password attempt?
|
||||
// The program may only find out that the user doesn't have a password
|
||||
// when prompting.
|
||||
if err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
|
||||
if err := c.writePacket(Marshal(&passwordAuthMsg{
|
||||
User: user,
|
||||
Service: serviceSSH,
|
||||
Method: cb.method(),
|
||||
Reply: false,
|
||||
Password: pw,
|
||||
})); err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
|
||||
return handleAuthResponse(c)
|
||||
}
|
||||
|
||||
func (cb passwordCallback) method() string {
|
||||
return "password"
|
||||
}
|
||||
|
||||
// Password returns an AuthMethod using the given password.
|
||||
func Password(secret string) AuthMethod {
|
||||
return passwordCallback(func() (string, error) { return secret, nil })
|
||||
}
|
||||
|
||||
// PasswordCallback returns an AuthMethod that uses a callback for
|
||||
// fetching a password.
|
||||
func PasswordCallback(prompt func() (secret string, err error)) AuthMethod {
|
||||
return passwordCallback(prompt)
|
||||
}
|
||||
|
||||
type publickeyAuthMsg struct {
|
||||
User string `sshtype:"50"`
|
||||
Service string
|
||||
Method string
|
||||
// HasSig indicates to the receiver packet that the auth request is signed and
|
||||
// should be used for authentication of the request.
|
||||
HasSig bool
|
||||
Algoname string
|
||||
PubKey []byte
|
||||
// Sig is tagged with "rest" so Marshal will exclude it during
|
||||
// validateKey
|
||||
Sig []byte `ssh:"rest"`
|
||||
}
|
||||
|
||||
// publicKeyCallback is an AuthMethod that uses a set of key
|
||||
// pairs for authentication.
|
||||
type publicKeyCallback func() ([]Signer, error)
|
||||
|
||||
func (cb publicKeyCallback) method() string {
|
||||
return "publickey"
|
||||
}
|
||||
|
||||
func pickSignatureAlgorithm(signer Signer, extensions map[string][]byte) (MultiAlgorithmSigner, string, error) {
|
||||
var as MultiAlgorithmSigner
|
||||
keyFormat := signer.PublicKey().Type()
|
||||
|
||||
// If the signer implements MultiAlgorithmSigner we use the algorithms it
|
||||
// support, if it implements AlgorithmSigner we assume it supports all
|
||||
// algorithms, otherwise only the key format one.
|
||||
switch s := signer.(type) {
|
||||
case MultiAlgorithmSigner:
|
||||
as = s
|
||||
case AlgorithmSigner:
|
||||
as = &multiAlgorithmSigner{
|
||||
AlgorithmSigner: s,
|
||||
supportedAlgorithms: algorithmsForKeyFormat(underlyingAlgo(keyFormat)),
|
||||
}
|
||||
default:
|
||||
as = &multiAlgorithmSigner{
|
||||
AlgorithmSigner: algorithmSignerWrapper{signer},
|
||||
supportedAlgorithms: []string{underlyingAlgo(keyFormat)},
|
||||
}
|
||||
}
|
||||
|
||||
getFallbackAlgo := func() (string, error) {
|
||||
// Fallback to use if there is no "server-sig-algs" extension or a
|
||||
// common algorithm cannot be found. We use the public key format if the
|
||||
// MultiAlgorithmSigner supports it, otherwise we return an error.
|
||||
if !contains(as.Algorithms(), underlyingAlgo(keyFormat)) {
|
||||
return "", fmt.Errorf("ssh: no common public key signature algorithm, server only supports %q for key type %q, signer only supports %v",
|
||||
underlyingAlgo(keyFormat), keyFormat, as.Algorithms())
|
||||
}
|
||||
return keyFormat, nil
|
||||
}
|
||||
|
||||
extPayload, ok := extensions["server-sig-algs"]
|
||||
if !ok {
|
||||
// If there is no "server-sig-algs" extension use the fallback
|
||||
// algorithm.
|
||||
algo, err := getFallbackAlgo()
|
||||
return as, algo, err
|
||||
}
|
||||
|
||||
// The server-sig-algs extension only carries underlying signature
|
||||
// algorithm, but we are trying to select a protocol-level public key
|
||||
// algorithm, which might be a certificate type. Extend the list of server
|
||||
// supported algorithms to include the corresponding certificate algorithms.
|
||||
serverAlgos := strings.Split(string(extPayload), ",")
|
||||
for _, algo := range serverAlgos {
|
||||
if certAlgo, ok := certificateAlgo(algo); ok {
|
||||
serverAlgos = append(serverAlgos, certAlgo)
|
||||
}
|
||||
}
|
||||
|
||||
// Filter algorithms based on those supported by MultiAlgorithmSigner.
|
||||
var keyAlgos []string
|
||||
for _, algo := range algorithmsForKeyFormat(keyFormat) {
|
||||
if contains(as.Algorithms(), underlyingAlgo(algo)) {
|
||||
keyAlgos = append(keyAlgos, algo)
|
||||
}
|
||||
}
|
||||
|
||||
algo, err := findCommon("public key signature algorithm", keyAlgos, serverAlgos)
|
||||
if err != nil {
|
||||
// If there is no overlap, return the fallback algorithm to support
|
||||
// servers that fail to list all supported algorithms.
|
||||
algo, err := getFallbackAlgo()
|
||||
return as, algo, err
|
||||
}
|
||||
return as, algo, nil
|
||||
}
|
||||
|
||||
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error) {
|
||||
// Authentication is performed by sending an enquiry to test if a key is
|
||||
// acceptable to the remote. If the key is acceptable, the client will
|
||||
// attempt to authenticate with the valid key. If not the client will repeat
|
||||
// the process with the remaining keys.
|
||||
|
||||
signers, err := cb()
|
||||
if err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
var methods []string
|
||||
var errSigAlgo error
|
||||
|
||||
origSignersLen := len(signers)
|
||||
for idx := 0; idx < len(signers); idx++ {
|
||||
signer := signers[idx]
|
||||
pub := signer.PublicKey()
|
||||
as, algo, err := pickSignatureAlgorithm(signer, extensions)
|
||||
if err != nil && errSigAlgo == nil {
|
||||
// If we cannot negotiate a signature algorithm store the first
|
||||
// error so we can return it to provide a more meaningful message if
|
||||
// no other signers work.
|
||||
errSigAlgo = err
|
||||
continue
|
||||
}
|
||||
ok, err := validateKey(pub, algo, user, c)
|
||||
if err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
// OpenSSH 7.2-7.7 advertises support for rsa-sha2-256 and rsa-sha2-512
|
||||
// in the "server-sig-algs" extension but doesn't support these
|
||||
// algorithms for certificate authentication, so if the server rejects
|
||||
// the key try to use the obtained algorithm as if "server-sig-algs" had
|
||||
// not been implemented if supported from the algorithm signer.
|
||||
if !ok && idx < origSignersLen && isRSACert(algo) && algo != CertAlgoRSAv01 {
|
||||
if contains(as.Algorithms(), KeyAlgoRSA) {
|
||||
// We retry using the compat algorithm after all signers have
|
||||
// been tried normally.
|
||||
signers = append(signers, &multiAlgorithmSigner{
|
||||
AlgorithmSigner: as,
|
||||
supportedAlgorithms: []string{KeyAlgoRSA},
|
||||
})
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
pubKey := pub.Marshal()
|
||||
data := buildDataSignedForAuth(session, userAuthRequestMsg{
|
||||
User: user,
|
||||
Service: serviceSSH,
|
||||
Method: cb.method(),
|
||||
}, algo, pubKey)
|
||||
sign, err := as.SignWithAlgorithm(rand, data, underlyingAlgo(algo))
|
||||
if err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
|
||||
// manually wrap the serialized signature in a string
|
||||
s := Marshal(sign)
|
||||
sig := make([]byte, stringLength(len(s)))
|
||||
marshalString(sig, s)
|
||||
msg := publickeyAuthMsg{
|
||||
User: user,
|
||||
Service: serviceSSH,
|
||||
Method: cb.method(),
|
||||
HasSig: true,
|
||||
Algoname: algo,
|
||||
PubKey: pubKey,
|
||||
Sig: sig,
|
||||
}
|
||||
p := Marshal(&msg)
|
||||
if err := c.writePacket(p); err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
var success authResult
|
||||
success, methods, err = handleAuthResponse(c)
|
||||
if err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
|
||||
// If authentication succeeds or the list of available methods does not
|
||||
// contain the "publickey" method, do not attempt to authenticate with any
|
||||
// other keys. According to RFC 4252 Section 7, the latter can occur when
|
||||
// additional authentication methods are required.
|
||||
if success == authSuccess || !contains(methods, cb.method()) {
|
||||
return success, methods, err
|
||||
}
|
||||
}
|
||||
|
||||
return authFailure, methods, errSigAlgo
|
||||
}
|
||||
|
||||
// validateKey validates the key provided is acceptable to the server.
|
||||
func validateKey(key PublicKey, algo string, user string, c packetConn) (bool, error) {
|
||||
pubKey := key.Marshal()
|
||||
msg := publickeyAuthMsg{
|
||||
User: user,
|
||||
Service: serviceSSH,
|
||||
Method: "publickey",
|
||||
HasSig: false,
|
||||
Algoname: algo,
|
||||
PubKey: pubKey,
|
||||
}
|
||||
if err := c.writePacket(Marshal(&msg)); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return confirmKeyAck(key, algo, c)
|
||||
}
|
||||
|
||||
func confirmKeyAck(key PublicKey, algo string, c packetConn) (bool, error) {
|
||||
pubKey := key.Marshal()
|
||||
|
||||
for {
|
||||
packet, err := c.readPacket()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
switch packet[0] {
|
||||
case msgUserAuthBanner:
|
||||
if err := handleBannerResponse(c, packet); err != nil {
|
||||
return false, err
|
||||
}
|
||||
case msgUserAuthPubKeyOk:
|
||||
var msg userAuthPubKeyOkMsg
|
||||
if err := Unmarshal(packet, &msg); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if msg.Algo != algo || !bytes.Equal(msg.PubKey, pubKey) {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
case msgUserAuthFailure:
|
||||
return false, nil
|
||||
default:
|
||||
return false, unexpectedMessageError(msgUserAuthPubKeyOk, packet[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PublicKeys returns an AuthMethod that uses the given key
|
||||
// pairs.
|
||||
func PublicKeys(signers ...Signer) AuthMethod {
|
||||
return publicKeyCallback(func() ([]Signer, error) { return signers, nil })
|
||||
}
|
||||
|
||||
// PublicKeysCallback returns an AuthMethod that runs the given
|
||||
// function to obtain a list of key pairs.
|
||||
func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod {
|
||||
return publicKeyCallback(getSigners)
|
||||
}
|
||||
|
||||
// handleAuthResponse returns whether the preceding authentication request succeeded
|
||||
// along with a list of remaining authentication methods to try next and
|
||||
// an error if an unexpected response was received.
|
||||
func handleAuthResponse(c packetConn) (authResult, []string, error) {
|
||||
gotMsgExtInfo := false
|
||||
for {
|
||||
packet, err := c.readPacket()
|
||||
if err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
|
||||
switch packet[0] {
|
||||
case msgUserAuthBanner:
|
||||
if err := handleBannerResponse(c, packet); err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
case msgExtInfo:
|
||||
// Ignore post-authentication RFC 8308 extensions, once.
|
||||
if gotMsgExtInfo {
|
||||
return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
|
||||
}
|
||||
gotMsgExtInfo = true
|
||||
case msgUserAuthFailure:
|
||||
var msg userAuthFailureMsg
|
||||
if err := Unmarshal(packet, &msg); err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
if msg.PartialSuccess {
|
||||
return authPartialSuccess, msg.Methods, nil
|
||||
}
|
||||
return authFailure, msg.Methods, nil
|
||||
case msgUserAuthSuccess:
|
||||
return authSuccess, nil, nil
|
||||
default:
|
||||
return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleBannerResponse(c packetConn, packet []byte) error {
|
||||
var msg userAuthBannerMsg
|
||||
if err := Unmarshal(packet, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
transport, ok := c.(*handshakeTransport)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if transport.bannerCallback != nil {
|
||||
return transport.bannerCallback(msg.Message)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// KeyboardInteractiveChallenge should print questions, optionally
|
||||
// disabling echoing (e.g. for passwords), and return all the answers.
|
||||
// Challenge may be called multiple times in a single session. After
|
||||
// successful authentication, the server may send a challenge with no
|
||||
// questions, for which the name and instruction messages should be
|
||||
// printed. RFC 4256 section 3.3 details how the UI should behave for
|
||||
// both CLI and GUI environments.
|
||||
type KeyboardInteractiveChallenge func(name, instruction string, questions []string, echos []bool) (answers []string, err error)
|
||||
|
||||
// KeyboardInteractive returns an AuthMethod using a prompt/response
|
||||
// sequence controlled by the server.
|
||||
func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod {
|
||||
return challenge
|
||||
}
|
||||
|
||||
func (cb KeyboardInteractiveChallenge) method() string {
|
||||
return "keyboard-interactive"
|
||||
}
|
||||
|
||||
func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
|
||||
type initiateMsg struct {
|
||||
User string `sshtype:"50"`
|
||||
Service string
|
||||
Method string
|
||||
Language string
|
||||
Submethods string
|
||||
}
|
||||
|
||||
if err := c.writePacket(Marshal(&initiateMsg{
|
||||
User: user,
|
||||
Service: serviceSSH,
|
||||
Method: "keyboard-interactive",
|
||||
})); err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
|
||||
gotMsgExtInfo := false
|
||||
for {
|
||||
packet, err := c.readPacket()
|
||||
if err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
|
||||
// like handleAuthResponse, but with less options.
|
||||
switch packet[0] {
|
||||
case msgUserAuthBanner:
|
||||
if err := handleBannerResponse(c, packet); err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
continue
|
||||
case msgExtInfo:
|
||||
// Ignore post-authentication RFC 8308 extensions, once.
|
||||
if gotMsgExtInfo {
|
||||
return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
|
||||
}
|
||||
gotMsgExtInfo = true
|
||||
continue
|
||||
case msgUserAuthInfoRequest:
|
||||
// OK
|
||||
case msgUserAuthFailure:
|
||||
var msg userAuthFailureMsg
|
||||
if err := Unmarshal(packet, &msg); err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
if msg.PartialSuccess {
|
||||
return authPartialSuccess, msg.Methods, nil
|
||||
}
|
||||
return authFailure, msg.Methods, nil
|
||||
case msgUserAuthSuccess:
|
||||
return authSuccess, nil, nil
|
||||
default:
|
||||
return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
|
||||
}
|
||||
|
||||
var msg userAuthInfoRequestMsg
|
||||
if err := Unmarshal(packet, &msg); err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
|
||||
// Manually unpack the prompt/echo pairs.
|
||||
rest := msg.Prompts
|
||||
var prompts []string
|
||||
var echos []bool
|
||||
for i := 0; i < int(msg.NumPrompts); i++ {
|
||||
prompt, r, ok := parseString(rest)
|
||||
if !ok || len(r) == 0 {
|
||||
return authFailure, nil, errors.New("ssh: prompt format error")
|
||||
}
|
||||
prompts = append(prompts, string(prompt))
|
||||
echos = append(echos, r[0] != 0)
|
||||
rest = r[1:]
|
||||
}
|
||||
|
||||
if len(rest) != 0 {
|
||||
return authFailure, nil, errors.New("ssh: extra data following keyboard-interactive pairs")
|
||||
}
|
||||
|
||||
answers, err := cb(msg.Name, msg.Instruction, prompts, echos)
|
||||
if err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
|
||||
if len(answers) != len(prompts) {
|
||||
return authFailure, nil, fmt.Errorf("ssh: incorrect number of answers from keyboard-interactive callback %d (expected %d)", len(answers), len(prompts))
|
||||
}
|
||||
responseLength := 1 + 4
|
||||
for _, a := range answers {
|
||||
responseLength += stringLength(len(a))
|
||||
}
|
||||
serialized := make([]byte, responseLength)
|
||||
p := serialized
|
||||
p[0] = msgUserAuthInfoResponse
|
||||
p = p[1:]
|
||||
p = marshalUint32(p, uint32(len(answers)))
|
||||
for _, a := range answers {
|
||||
p = marshalString(p, []byte(a))
|
||||
}
|
||||
|
||||
if err := c.writePacket(serialized); err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type retryableAuthMethod struct {
|
||||
authMethod AuthMethod
|
||||
maxTries int
|
||||
}
|
||||
|
||||
func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (ok authResult, methods []string, err error) {
|
||||
for i := 0; r.maxTries <= 0 || i < r.maxTries; i++ {
|
||||
ok, methods, err = r.authMethod.auth(session, user, c, rand, extensions)
|
||||
if ok != authFailure || err != nil { // either success, partial success or error terminate
|
||||
return ok, methods, err
|
||||
}
|
||||
}
|
||||
return ok, methods, err
|
||||
}
|
||||
|
||||
func (r *retryableAuthMethod) method() string {
|
||||
return r.authMethod.method()
|
||||
}
|
||||
|
||||
// RetryableAuthMethod is a decorator for other auth methods enabling them to
|
||||
// be retried up to maxTries before considering that AuthMethod itself failed.
|
||||
// If maxTries is <= 0, will retry indefinitely
|
||||
//
|
||||
// This is useful for interactive clients using challenge/response type
|
||||
// authentication (e.g. Keyboard-Interactive, Password, etc) where the user
|
||||
// could mistype their response resulting in the server issuing a
|
||||
// SSH_MSG_USERAUTH_FAILURE (rfc4252 #8 [password] and rfc4256 #3.4
|
||||
// [keyboard-interactive]); Without this decorator, the non-retryable
|
||||
// AuthMethod would be removed from future consideration, and never tried again
|
||||
// (and so the user would never be able to retry their entry).
|
||||
func RetryableAuthMethod(auth AuthMethod, maxTries int) AuthMethod {
|
||||
return &retryableAuthMethod{authMethod: auth, maxTries: maxTries}
|
||||
}
|
||||
|
||||
// GSSAPIWithMICAuthMethod is an AuthMethod with "gssapi-with-mic" authentication.
|
||||
// See RFC 4462 section 3
|
||||
// gssAPIClient is implementation of the GSSAPIClient interface, see the definition of the interface for details.
|
||||
// target is the server host you want to log in to.
|
||||
func GSSAPIWithMICAuthMethod(gssAPIClient GSSAPIClient, target string) AuthMethod {
|
||||
if gssAPIClient == nil {
|
||||
panic("gss-api client must be not nil with enable gssapi-with-mic")
|
||||
}
|
||||
return &gssAPIWithMICCallback{gssAPIClient: gssAPIClient, target: target}
|
||||
}
|
||||
|
||||
type gssAPIWithMICCallback struct {
|
||||
gssAPIClient GSSAPIClient
|
||||
target string
|
||||
}
|
||||
|
||||
func (g *gssAPIWithMICCallback) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
|
||||
m := &userAuthRequestMsg{
|
||||
User: user,
|
||||
Service: serviceSSH,
|
||||
Method: g.method(),
|
||||
}
|
||||
// The GSS-API authentication method is initiated when the client sends an SSH_MSG_USERAUTH_REQUEST.
|
||||
// See RFC 4462 section 3.2.
|
||||
m.Payload = appendU32(m.Payload, 1)
|
||||
m.Payload = appendString(m.Payload, string(krb5OID))
|
||||
if err := c.writePacket(Marshal(m)); err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
// The server responds to the SSH_MSG_USERAUTH_REQUEST with either an
|
||||
// SSH_MSG_USERAUTH_FAILURE if none of the mechanisms are supported or
|
||||
// with an SSH_MSG_USERAUTH_GSSAPI_RESPONSE.
|
||||
// See RFC 4462 section 3.3.
|
||||
// OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication,so I don't want to check
|
||||
// selected mech if it is valid.
|
||||
packet, err := c.readPacket()
|
||||
if err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
userAuthGSSAPIResp := &userAuthGSSAPIResponse{}
|
||||
if err := Unmarshal(packet, userAuthGSSAPIResp); err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
// Start the loop into the exchange token.
|
||||
// See RFC 4462 section 3.4.
|
||||
var token []byte
|
||||
defer g.gssAPIClient.DeleteSecContext()
|
||||
for {
|
||||
// Initiates the establishment of a security context between the application and a remote peer.
|
||||
nextToken, needContinue, err := g.gssAPIClient.InitSecContext("host@"+g.target, token, false)
|
||||
if err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
if len(nextToken) > 0 {
|
||||
if err := c.writePacket(Marshal(&userAuthGSSAPIToken{
|
||||
Token: nextToken,
|
||||
})); err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
}
|
||||
if !needContinue {
|
||||
break
|
||||
}
|
||||
packet, err = c.readPacket()
|
||||
if err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
switch packet[0] {
|
||||
case msgUserAuthFailure:
|
||||
var msg userAuthFailureMsg
|
||||
if err := Unmarshal(packet, &msg); err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
if msg.PartialSuccess {
|
||||
return authPartialSuccess, msg.Methods, nil
|
||||
}
|
||||
return authFailure, msg.Methods, nil
|
||||
case msgUserAuthGSSAPIError:
|
||||
userAuthGSSAPIErrorResp := &userAuthGSSAPIError{}
|
||||
if err := Unmarshal(packet, userAuthGSSAPIErrorResp); err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
return authFailure, nil, fmt.Errorf("GSS-API Error:\n"+
|
||||
"Major Status: %d\n"+
|
||||
"Minor Status: %d\n"+
|
||||
"Error Message: %s\n", userAuthGSSAPIErrorResp.MajorStatus, userAuthGSSAPIErrorResp.MinorStatus,
|
||||
userAuthGSSAPIErrorResp.Message)
|
||||
case msgUserAuthGSSAPIToken:
|
||||
userAuthGSSAPITokenReq := &userAuthGSSAPIToken{}
|
||||
if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
token = userAuthGSSAPITokenReq.Token
|
||||
}
|
||||
}
|
||||
// Binding Encryption Keys.
|
||||
// See RFC 4462 section 3.5.
|
||||
micField := buildMIC(string(session), user, "ssh-connection", "gssapi-with-mic")
|
||||
micToken, err := g.gssAPIClient.GetMIC(micField)
|
||||
if err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
if err := c.writePacket(Marshal(&userAuthGSSAPIMIC{
|
||||
MIC: micToken,
|
||||
})); err != nil {
|
||||
return authFailure, nil, err
|
||||
}
|
||||
return handleAuthResponse(c)
|
||||
}
|
||||
|
||||
func (g *gssAPIWithMICCallback) method() string {
|
||||
return "gssapi-with-mic"
|
||||
}
|
||||
476
vendor/github.com/tailscale/golang-x-crypto/ssh/common.go
generated
vendored
Normal file
476
vendor/github.com/tailscale/golang-x-crypto/ssh/common.go
generated
vendored
Normal file
@@ -0,0 +1,476 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
_ "crypto/sha1"
|
||||
_ "crypto/sha256"
|
||||
_ "crypto/sha512"
|
||||
)
|
||||
|
||||
// These are string constants in the SSH protocol.
|
||||
const (
|
||||
compressionNone = "none"
|
||||
serviceUserAuth = "ssh-userauth"
|
||||
serviceSSH = "ssh-connection"
|
||||
)
|
||||
|
||||
// supportedCiphers lists ciphers we support but might not recommend.
|
||||
var supportedCiphers = []string{
|
||||
"aes128-ctr", "aes192-ctr", "aes256-ctr",
|
||||
"aes128-gcm@openssh.com", gcm256CipherID,
|
||||
chacha20Poly1305ID,
|
||||
"arcfour256", "arcfour128", "arcfour",
|
||||
aes128cbcID,
|
||||
tripledescbcID,
|
||||
}
|
||||
|
||||
// preferredCiphers specifies the default preference for ciphers.
|
||||
var preferredCiphers = []string{
|
||||
"aes128-gcm@openssh.com", gcm256CipherID,
|
||||
chacha20Poly1305ID,
|
||||
"aes128-ctr", "aes192-ctr", "aes256-ctr",
|
||||
}
|
||||
|
||||
// supportedKexAlgos specifies the supported key-exchange algorithms in
|
||||
// preference order.
|
||||
var supportedKexAlgos = []string{
|
||||
kexAlgoCurve25519SHA256, kexAlgoCurve25519SHA256LibSSH,
|
||||
// P384 and P521 are not constant-time yet, but since we don't
|
||||
// reuse ephemeral keys, using them for ECDH should be OK.
|
||||
kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
|
||||
kexAlgoDH14SHA256, kexAlgoDH16SHA512, kexAlgoDH14SHA1,
|
||||
kexAlgoDH1SHA1,
|
||||
}
|
||||
|
||||
// serverForbiddenKexAlgos contains key exchange algorithms, that are forbidden
|
||||
// for the server half.
|
||||
var serverForbiddenKexAlgos = map[string]struct{}{
|
||||
kexAlgoDHGEXSHA1: {}, // server half implementation is only minimal to satisfy the automated tests
|
||||
kexAlgoDHGEXSHA256: {}, // server half implementation is only minimal to satisfy the automated tests
|
||||
}
|
||||
|
||||
// preferredKexAlgos specifies the default preference for key-exchange
|
||||
// algorithms in preference order. The diffie-hellman-group16-sha512 algorithm
|
||||
// is disabled by default because it is a bit slower than the others.
|
||||
var preferredKexAlgos = []string{
|
||||
kexAlgoCurve25519SHA256, kexAlgoCurve25519SHA256LibSSH,
|
||||
kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
|
||||
kexAlgoDH14SHA256, kexAlgoDH14SHA1,
|
||||
}
|
||||
|
||||
// supportedHostKeyAlgos specifies the supported host-key algorithms (i.e. methods
|
||||
// of authenticating servers) in preference order.
|
||||
var supportedHostKeyAlgos = []string{
|
||||
CertAlgoRSASHA256v01, CertAlgoRSASHA512v01,
|
||||
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01,
|
||||
CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01,
|
||||
|
||||
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
|
||||
KeyAlgoRSASHA256, KeyAlgoRSASHA512,
|
||||
KeyAlgoRSA, KeyAlgoDSA,
|
||||
|
||||
KeyAlgoED25519,
|
||||
}
|
||||
|
||||
// supportedMACs specifies a default set of MAC algorithms in preference order.
|
||||
// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed
|
||||
// because they have reached the end of their useful life.
|
||||
var supportedMACs = []string{
|
||||
"hmac-sha2-256-etm@openssh.com", "hmac-sha2-512-etm@openssh.com", "hmac-sha2-256", "hmac-sha2-512", "hmac-sha1", "hmac-sha1-96",
|
||||
}
|
||||
|
||||
var supportedCompressions = []string{compressionNone}
|
||||
|
||||
// hashFuncs keeps the mapping of supported signature algorithms to their
|
||||
// respective hashes needed for signing and verification.
|
||||
var hashFuncs = map[string]crypto.Hash{
|
||||
KeyAlgoRSA: crypto.SHA1,
|
||||
KeyAlgoRSASHA256: crypto.SHA256,
|
||||
KeyAlgoRSASHA512: crypto.SHA512,
|
||||
KeyAlgoDSA: crypto.SHA1,
|
||||
KeyAlgoECDSA256: crypto.SHA256,
|
||||
KeyAlgoECDSA384: crypto.SHA384,
|
||||
KeyAlgoECDSA521: crypto.SHA512,
|
||||
// KeyAlgoED25519 doesn't pre-hash.
|
||||
KeyAlgoSKECDSA256: crypto.SHA256,
|
||||
KeyAlgoSKED25519: crypto.SHA256,
|
||||
}
|
||||
|
||||
// algorithmsForKeyFormat returns the supported signature algorithms for a given
|
||||
// public key format (PublicKey.Type), in order of preference. See RFC 8332,
|
||||
// Section 2. See also the note in sendKexInit on backwards compatibility.
|
||||
func algorithmsForKeyFormat(keyFormat string) []string {
|
||||
switch keyFormat {
|
||||
case KeyAlgoRSA:
|
||||
return []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA}
|
||||
case CertAlgoRSAv01:
|
||||
return []string{CertAlgoRSASHA256v01, CertAlgoRSASHA512v01, CertAlgoRSAv01}
|
||||
default:
|
||||
return []string{keyFormat}
|
||||
}
|
||||
}
|
||||
|
||||
// isRSA returns whether algo is a supported RSA algorithm, including certificate
|
||||
// algorithms.
|
||||
func isRSA(algo string) bool {
|
||||
algos := algorithmsForKeyFormat(KeyAlgoRSA)
|
||||
return contains(algos, underlyingAlgo(algo))
|
||||
}
|
||||
|
||||
func isRSACert(algo string) bool {
|
||||
_, ok := certKeyAlgoNames[algo]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return isRSA(algo)
|
||||
}
|
||||
|
||||
// supportedPubKeyAuthAlgos specifies the supported client public key
|
||||
// authentication algorithms. Note that this doesn't include certificate types
|
||||
// since those use the underlying algorithm. This list is sent to the client if
|
||||
// it supports the server-sig-algs extension. Order is irrelevant.
|
||||
var supportedPubKeyAuthAlgos = []string{
|
||||
KeyAlgoED25519,
|
||||
KeyAlgoSKED25519, KeyAlgoSKECDSA256,
|
||||
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
|
||||
KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA,
|
||||
KeyAlgoDSA,
|
||||
}
|
||||
|
||||
// unexpectedMessageError results when the SSH message that we received didn't
|
||||
// match what we wanted.
|
||||
func unexpectedMessageError(expected, got uint8) error {
|
||||
return fmt.Errorf("ssh: unexpected message type %d (expected %d)", got, expected)
|
||||
}
|
||||
|
||||
// parseError results from a malformed SSH message.
|
||||
func parseError(tag uint8) error {
|
||||
return fmt.Errorf("ssh: parse error in message type %d", tag)
|
||||
}
|
||||
|
||||
func findCommon(what string, client []string, server []string) (common string, err error) {
|
||||
for _, c := range client {
|
||||
for _, s := range server {
|
||||
if c == s {
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server)
|
||||
}
|
||||
|
||||
// directionAlgorithms records algorithm choices in one direction (either read or write)
|
||||
type directionAlgorithms struct {
|
||||
Cipher string
|
||||
MAC string
|
||||
Compression string
|
||||
}
|
||||
|
||||
// rekeyBytes returns a rekeying intervals in bytes.
|
||||
func (a *directionAlgorithms) rekeyBytes() int64 {
|
||||
// According to RFC 4344 block ciphers should rekey after
|
||||
// 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is
|
||||
// 128.
|
||||
switch a.Cipher {
|
||||
case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcm128CipherID, gcm256CipherID, aes128cbcID:
|
||||
return 16 * (1 << 32)
|
||||
|
||||
}
|
||||
|
||||
// For others, stick with RFC 4253 recommendation to rekey after 1 Gb of data.
|
||||
return 1 << 30
|
||||
}
|
||||
|
||||
var aeadCiphers = map[string]bool{
|
||||
gcm128CipherID: true,
|
||||
gcm256CipherID: true,
|
||||
chacha20Poly1305ID: true,
|
||||
}
|
||||
|
||||
type algorithms struct {
|
||||
kex string
|
||||
hostKey string
|
||||
w directionAlgorithms
|
||||
r directionAlgorithms
|
||||
}
|
||||
|
||||
func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) {
|
||||
result := &algorithms{}
|
||||
|
||||
result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
stoc, ctos := &result.w, &result.r
|
||||
if isClient {
|
||||
ctos, stoc = stoc, ctos
|
||||
}
|
||||
|
||||
ctos.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
stoc.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !aeadCiphers[ctos.Cipher] {
|
||||
ctos.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !aeadCiphers[stoc.Cipher] {
|
||||
stoc.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ctos.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
stoc.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// If rekeythreshold is too small, we can't make any progress sending
|
||||
// stuff.
|
||||
const minRekeyThreshold uint64 = 256
|
||||
|
||||
// Config contains configuration data common to both ServerConfig and
|
||||
// ClientConfig.
|
||||
type Config struct {
|
||||
// Rand provides the source of entropy for cryptographic
|
||||
// primitives. If Rand is nil, the cryptographic random reader
|
||||
// in package crypto/rand will be used.
|
||||
Rand io.Reader
|
||||
|
||||
// The maximum number of bytes sent or received after which a
|
||||
// new key is negotiated. It must be at least 256. If
|
||||
// unspecified, a size suitable for the chosen cipher is used.
|
||||
RekeyThreshold uint64
|
||||
|
||||
// The allowed key exchanges algorithms. If unspecified then a default set
|
||||
// of algorithms is used. Unsupported values are silently ignored.
|
||||
KeyExchanges []string
|
||||
|
||||
// The allowed cipher algorithms. If unspecified then a sensible default is
|
||||
// used. Unsupported values are silently ignored.
|
||||
Ciphers []string
|
||||
|
||||
// The allowed MAC algorithms. If unspecified then a sensible default is
|
||||
// used. Unsupported values are silently ignored.
|
||||
MACs []string
|
||||
}
|
||||
|
||||
// SetDefaults sets sensible values for unset fields in config. This is
|
||||
// exported for testing: Configs passed to SSH functions are copied and have
|
||||
// default values set automatically.
|
||||
func (c *Config) SetDefaults() {
|
||||
if c.Rand == nil {
|
||||
c.Rand = rand.Reader
|
||||
}
|
||||
if c.Ciphers == nil {
|
||||
c.Ciphers = preferredCiphers
|
||||
}
|
||||
var ciphers []string
|
||||
for _, c := range c.Ciphers {
|
||||
if cipherModes[c] != nil {
|
||||
// Ignore the cipher if we have no cipherModes definition.
|
||||
ciphers = append(ciphers, c)
|
||||
}
|
||||
}
|
||||
c.Ciphers = ciphers
|
||||
|
||||
if c.KeyExchanges == nil {
|
||||
c.KeyExchanges = preferredKexAlgos
|
||||
}
|
||||
var kexs []string
|
||||
for _, k := range c.KeyExchanges {
|
||||
if kexAlgoMap[k] != nil {
|
||||
// Ignore the KEX if we have no kexAlgoMap definition.
|
||||
kexs = append(kexs, k)
|
||||
}
|
||||
}
|
||||
c.KeyExchanges = kexs
|
||||
|
||||
if c.MACs == nil {
|
||||
c.MACs = supportedMACs
|
||||
}
|
||||
var macs []string
|
||||
for _, m := range c.MACs {
|
||||
if macModes[m] != nil {
|
||||
// Ignore the MAC if we have no macModes definition.
|
||||
macs = append(macs, m)
|
||||
}
|
||||
}
|
||||
c.MACs = macs
|
||||
|
||||
if c.RekeyThreshold == 0 {
|
||||
// cipher specific default
|
||||
} else if c.RekeyThreshold < minRekeyThreshold {
|
||||
c.RekeyThreshold = minRekeyThreshold
|
||||
} else if c.RekeyThreshold >= math.MaxInt64 {
|
||||
// Avoid weirdness if somebody uses -1 as a threshold.
|
||||
c.RekeyThreshold = math.MaxInt64
|
||||
}
|
||||
}
|
||||
|
||||
// buildDataSignedForAuth returns the data that is signed in order to prove
|
||||
// possession of a private key. See RFC 4252, section 7. algo is the advertised
|
||||
// algorithm, and may be a certificate type.
|
||||
func buildDataSignedForAuth(sessionID []byte, req userAuthRequestMsg, algo string, pubKey []byte) []byte {
|
||||
data := struct {
|
||||
Session []byte
|
||||
Type byte
|
||||
User string
|
||||
Service string
|
||||
Method string
|
||||
Sign bool
|
||||
Algo string
|
||||
PubKey []byte
|
||||
}{
|
||||
sessionID,
|
||||
msgUserAuthRequest,
|
||||
req.User,
|
||||
req.Service,
|
||||
req.Method,
|
||||
true,
|
||||
algo,
|
||||
pubKey,
|
||||
}
|
||||
return Marshal(data)
|
||||
}
|
||||
|
||||
func appendU16(buf []byte, n uint16) []byte {
|
||||
return append(buf, byte(n>>8), byte(n))
|
||||
}
|
||||
|
||||
func appendU32(buf []byte, n uint32) []byte {
|
||||
return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
|
||||
}
|
||||
|
||||
func appendU64(buf []byte, n uint64) []byte {
|
||||
return append(buf,
|
||||
byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32),
|
||||
byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
|
||||
}
|
||||
|
||||
func appendInt(buf []byte, n int) []byte {
|
||||
return appendU32(buf, uint32(n))
|
||||
}
|
||||
|
||||
func appendString(buf []byte, s string) []byte {
|
||||
buf = appendU32(buf, uint32(len(s)))
|
||||
buf = append(buf, s...)
|
||||
return buf
|
||||
}
|
||||
|
||||
func appendBool(buf []byte, b bool) []byte {
|
||||
if b {
|
||||
return append(buf, 1)
|
||||
}
|
||||
return append(buf, 0)
|
||||
}
|
||||
|
||||
// newCond is a helper to hide the fact that there is no usable zero
|
||||
// value for sync.Cond.
|
||||
func newCond() *sync.Cond { return sync.NewCond(new(sync.Mutex)) }
|
||||
|
||||
// window represents the buffer available to clients
|
||||
// wishing to write to a channel.
|
||||
type window struct {
|
||||
*sync.Cond
|
||||
win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1
|
||||
writeWaiters int
|
||||
closed bool
|
||||
}
|
||||
|
||||
// add adds win to the amount of window available
|
||||
// for consumers.
|
||||
func (w *window) add(win uint32) bool {
|
||||
// a zero sized window adjust is a noop.
|
||||
if win == 0 {
|
||||
return true
|
||||
}
|
||||
w.L.Lock()
|
||||
if w.win+win < win {
|
||||
w.L.Unlock()
|
||||
return false
|
||||
}
|
||||
w.win += win
|
||||
// It is unusual that multiple goroutines would be attempting to reserve
|
||||
// window space, but not guaranteed. Use broadcast to notify all waiters
|
||||
// that additional window is available.
|
||||
w.Broadcast()
|
||||
w.L.Unlock()
|
||||
return true
|
||||
}
|
||||
|
||||
// close sets the window to closed, so all reservations fail
|
||||
// immediately.
|
||||
func (w *window) close() {
|
||||
w.L.Lock()
|
||||
w.closed = true
|
||||
w.Broadcast()
|
||||
w.L.Unlock()
|
||||
}
|
||||
|
||||
// reserve reserves win from the available window capacity.
|
||||
// If no capacity remains, reserve will block. reserve may
|
||||
// return less than requested.
|
||||
func (w *window) reserve(win uint32) (uint32, error) {
|
||||
var err error
|
||||
w.L.Lock()
|
||||
w.writeWaiters++
|
||||
w.Broadcast()
|
||||
for w.win == 0 && !w.closed {
|
||||
w.Wait()
|
||||
}
|
||||
w.writeWaiters--
|
||||
if w.win < win {
|
||||
win = w.win
|
||||
}
|
||||
w.win -= win
|
||||
if w.closed {
|
||||
err = io.EOF
|
||||
}
|
||||
w.L.Unlock()
|
||||
return win, err
|
||||
}
|
||||
|
||||
// waitWriterBlocked waits until some goroutine is blocked for further
|
||||
// writes. It is used in tests only.
|
||||
func (w *window) waitWriterBlocked() {
|
||||
w.Cond.L.Lock()
|
||||
for w.writeWaiters == 0 {
|
||||
w.Cond.Wait()
|
||||
}
|
||||
w.Cond.L.Unlock()
|
||||
}
|
||||
148
vendor/github.com/tailscale/golang-x-crypto/ssh/connection.go
generated
vendored
Normal file
148
vendor/github.com/tailscale/golang-x-crypto/ssh/connection.go
generated
vendored
Normal file
@@ -0,0 +1,148 @@
|
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
// OpenChannelError is returned if the other side rejects an
|
||||
// OpenChannel request.
|
||||
type OpenChannelError struct {
|
||||
Reason RejectionReason
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *OpenChannelError) Error() string {
|
||||
return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message)
|
||||
}
|
||||
|
||||
// ConnMetadata holds metadata for the connection.
|
||||
type ConnMetadata interface {
|
||||
// User returns the user ID for this connection.
|
||||
User() string
|
||||
|
||||
// SessionID returns the session hash, also denoted by H.
|
||||
SessionID() []byte
|
||||
|
||||
// ClientVersion returns the client's version string as hashed
|
||||
// into the session ID.
|
||||
ClientVersion() []byte
|
||||
|
||||
// ServerVersion returns the server's version string as hashed
|
||||
// into the session ID.
|
||||
ServerVersion() []byte
|
||||
|
||||
// RemoteAddr returns the remote address for this connection.
|
||||
RemoteAddr() net.Addr
|
||||
|
||||
// LocalAddr returns the local address for this connection.
|
||||
LocalAddr() net.Addr
|
||||
|
||||
// SendAuthBanner sends a banner to the client. This is useful
|
||||
// for sending message during auth. It is only valid to call this
|
||||
// during the auth phase.
|
||||
SendAuthBanner(banner string) error
|
||||
}
|
||||
|
||||
// Conn represents an SSH connection for both server and client roles.
|
||||
// Conn is the basis for implementing an application layer, such
|
||||
// as ClientConn, which implements the traditional shell access for
|
||||
// clients.
|
||||
type Conn interface {
|
||||
ConnMetadata
|
||||
|
||||
// SendRequest sends a global request, and returns the
|
||||
// reply. If wantReply is true, it returns the response status
|
||||
// and payload. See also RFC 4254, section 4.
|
||||
SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error)
|
||||
|
||||
// OpenChannel tries to open an channel. If the request is
|
||||
// rejected, it returns *OpenChannelError. On success it returns
|
||||
// the SSH Channel and a Go channel for incoming, out-of-band
|
||||
// requests. The Go channel must be serviced, or the
|
||||
// connection will hang.
|
||||
OpenChannel(name string, data []byte) (Channel, <-chan *Request, error)
|
||||
|
||||
// Close closes the underlying network connection
|
||||
Close() error
|
||||
|
||||
// Wait blocks until the connection has shut down, and returns the
|
||||
// error causing the shutdown.
|
||||
Wait() error
|
||||
|
||||
// TODO(hanwen): consider exposing:
|
||||
// RequestKeyChange
|
||||
// Disconnect
|
||||
}
|
||||
|
||||
// DiscardRequests consumes and rejects all requests from the
|
||||
// passed-in channel.
|
||||
func DiscardRequests(in <-chan *Request) {
|
||||
for req := range in {
|
||||
if req.WantReply {
|
||||
req.Reply(false, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A connection represents an incoming connection.
|
||||
type connection struct {
|
||||
transport *handshakeTransport
|
||||
sshConn
|
||||
|
||||
// The connection protocol.
|
||||
*mux
|
||||
}
|
||||
|
||||
func (c *connection) Close() error {
|
||||
return c.sshConn.conn.Close()
|
||||
}
|
||||
|
||||
// sshConn provides net.Conn metadata, but disallows direct reads and
|
||||
// writes.
|
||||
type sshConn struct {
|
||||
conn net.Conn
|
||||
|
||||
user string
|
||||
sessionID []byte
|
||||
clientVersion []byte
|
||||
serverVersion []byte
|
||||
}
|
||||
|
||||
func dup(src []byte) []byte {
|
||||
dst := make([]byte, len(src))
|
||||
copy(dst, src)
|
||||
return dst
|
||||
}
|
||||
|
||||
func (c *sshConn) User() string {
|
||||
return c.user
|
||||
}
|
||||
|
||||
func (c *sshConn) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c *sshConn) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *sshConn) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *sshConn) SessionID() []byte {
|
||||
return dup(c.sessionID)
|
||||
}
|
||||
|
||||
func (c *sshConn) ClientVersion() []byte {
|
||||
return dup(c.clientVersion)
|
||||
}
|
||||
|
||||
func (c *sshConn) ServerVersion() []byte {
|
||||
return dup(c.serverVersion)
|
||||
}
|
||||
23
vendor/github.com/tailscale/golang-x-crypto/ssh/doc.go
generated
vendored
Normal file
23
vendor/github.com/tailscale/golang-x-crypto/ssh/doc.go
generated
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
Package ssh implements an SSH client and server.
|
||||
|
||||
SSH is a transport security protocol, an authentication protocol and a
|
||||
family of application protocols. The most typical application level
|
||||
protocol is a remote shell and this is specifically implemented. However,
|
||||
the multiplexed nature of SSH is exposed to users that wish to support
|
||||
others.
|
||||
|
||||
References:
|
||||
|
||||
[PROTOCOL]: https://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL?rev=HEAD
|
||||
[PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD
|
||||
[SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
|
||||
|
||||
This package does not fall under the stability promise of the Go language itself,
|
||||
so its API may be changed when pressing needs arise.
|
||||
*/
|
||||
package ssh // import "golang.org/x/crypto/ssh"
|
||||
806
vendor/github.com/tailscale/golang-x-crypto/ssh/handshake.go
generated
vendored
Normal file
806
vendor/github.com/tailscale/golang-x-crypto/ssh/handshake.go
generated
vendored
Normal file
@@ -0,0 +1,806 @@
|
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// debugHandshake, if set, prints messages sent and received. Key
|
||||
// exchange messages are printed as if DH were used, so the debug
|
||||
// messages are wrong when using ECDH.
|
||||
const debugHandshake = false
|
||||
|
||||
// chanSize sets the amount of buffering SSH connections. This is
|
||||
// primarily for testing: setting chanSize=0 uncovers deadlocks more
|
||||
// quickly.
|
||||
const chanSize = 16
|
||||
|
||||
// keyingTransport is a packet based transport that supports key
|
||||
// changes. It need not be thread-safe. It should pass through
|
||||
// msgNewKeys in both directions.
|
||||
type keyingTransport interface {
|
||||
packetConn
|
||||
|
||||
// prepareKeyChange sets up a key change. The key change for a
|
||||
// direction will be effected if a msgNewKeys message is sent
|
||||
// or received.
|
||||
prepareKeyChange(*algorithms, *kexResult) error
|
||||
|
||||
// setStrictMode sets the strict KEX mode, notably triggering
|
||||
// sequence number resets on sending or receiving msgNewKeys.
|
||||
// If the sequence number is already > 1 when setStrictMode
|
||||
// is called, an error is returned.
|
||||
setStrictMode() error
|
||||
|
||||
// setInitialKEXDone indicates to the transport that the initial key exchange
|
||||
// was completed
|
||||
setInitialKEXDone()
|
||||
}
|
||||
|
||||
// handshakeTransport implements rekeying on top of a keyingTransport
|
||||
// and offers a thread-safe writePacket() interface.
|
||||
type handshakeTransport struct {
|
||||
conn keyingTransport
|
||||
config *Config
|
||||
|
||||
serverVersion []byte
|
||||
clientVersion []byte
|
||||
|
||||
// hostKeys is non-empty if we are the server. In that case,
|
||||
// it contains all host keys that can be used to sign the
|
||||
// connection.
|
||||
hostKeys []Signer
|
||||
|
||||
// publicKeyAuthAlgorithms is non-empty if we are the server. In that case,
|
||||
// it contains the supported client public key authentication algorithms.
|
||||
publicKeyAuthAlgorithms []string
|
||||
|
||||
// hostKeyAlgorithms is non-empty if we are the client. In that case,
|
||||
// we accept these key types from the server as host key.
|
||||
hostKeyAlgorithms []string
|
||||
|
||||
// On read error, incoming is closed, and readError is set.
|
||||
incoming chan []byte
|
||||
readError error
|
||||
|
||||
mu sync.Mutex
|
||||
writeError error
|
||||
sentInitPacket []byte
|
||||
sentInitMsg *kexInitMsg
|
||||
pendingPackets [][]byte // Used when a key exchange is in progress.
|
||||
writePacketsLeft uint32
|
||||
writeBytesLeft int64
|
||||
|
||||
// If the read loop wants to schedule a kex, it pings this
|
||||
// channel, and the write loop will send out a kex
|
||||
// message.
|
||||
requestKex chan struct{}
|
||||
|
||||
// If the other side requests or confirms a kex, its kexInit
|
||||
// packet is sent here for the write loop to find it.
|
||||
startKex chan *pendingKex
|
||||
kexLoopDone chan struct{} // closed (with writeError non-nil) when kexLoop exits
|
||||
|
||||
// data for host key checking
|
||||
hostKeyCallback HostKeyCallback
|
||||
dialAddress string
|
||||
remoteAddr net.Addr
|
||||
|
||||
// bannerCallback is non-empty if we are the client and it has been set in
|
||||
// ClientConfig. In that case it is called during the user authentication
|
||||
// dance to handle a custom server's message.
|
||||
bannerCallback BannerCallback
|
||||
|
||||
// Algorithms agreed in the last key exchange.
|
||||
algorithms *algorithms
|
||||
|
||||
// Counters exclusively owned by readLoop.
|
||||
readPacketsLeft uint32
|
||||
readBytesLeft int64
|
||||
|
||||
// The session ID or nil if first kex did not complete yet.
|
||||
sessionID []byte
|
||||
|
||||
// strictMode indicates if the other side of the handshake indicated
|
||||
// that we should be following the strict KEX protocol restrictions.
|
||||
strictMode bool
|
||||
}
|
||||
|
||||
type pendingKex struct {
|
||||
otherInit []byte
|
||||
done chan error
|
||||
}
|
||||
|
||||
func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
|
||||
t := &handshakeTransport{
|
||||
conn: conn,
|
||||
serverVersion: serverVersion,
|
||||
clientVersion: clientVersion,
|
||||
incoming: make(chan []byte, chanSize),
|
||||
requestKex: make(chan struct{}, 1),
|
||||
startKex: make(chan *pendingKex),
|
||||
kexLoopDone: make(chan struct{}),
|
||||
|
||||
config: config,
|
||||
}
|
||||
t.resetReadThresholds()
|
||||
t.resetWriteThresholds()
|
||||
|
||||
// We always start with a mandatory key exchange.
|
||||
t.requestKex <- struct{}{}
|
||||
return t
|
||||
}
|
||||
|
||||
func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport {
|
||||
t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
|
||||
t.dialAddress = dialAddr
|
||||
t.remoteAddr = addr
|
||||
t.hostKeyCallback = config.HostKeyCallback
|
||||
t.bannerCallback = config.BannerCallback
|
||||
if config.HostKeyAlgorithms != nil {
|
||||
t.hostKeyAlgorithms = config.HostKeyAlgorithms
|
||||
} else {
|
||||
t.hostKeyAlgorithms = supportedHostKeyAlgos
|
||||
}
|
||||
go t.readLoop()
|
||||
go t.kexLoop()
|
||||
return t
|
||||
}
|
||||
|
||||
func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
|
||||
t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
|
||||
t.hostKeys = config.hostKeys
|
||||
t.publicKeyAuthAlgorithms = config.PublicKeyAuthAlgorithms
|
||||
go t.readLoop()
|
||||
go t.kexLoop()
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *handshakeTransport) getSessionID() []byte {
|
||||
return t.sessionID
|
||||
}
|
||||
|
||||
// waitSession waits for the session to be established. This should be
|
||||
// the first thing to call after instantiating handshakeTransport.
|
||||
func (t *handshakeTransport) waitSession() error {
|
||||
p, err := t.readPacket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if p[0] != msgNewKeys {
|
||||
return fmt.Errorf("ssh: first packet should be msgNewKeys")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *handshakeTransport) id() string {
|
||||
if len(t.hostKeys) > 0 {
|
||||
return "server"
|
||||
}
|
||||
return "client"
|
||||
}
|
||||
|
||||
func (t *handshakeTransport) printPacket(p []byte, write bool) {
|
||||
action := "got"
|
||||
if write {
|
||||
action = "sent"
|
||||
}
|
||||
|
||||
if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
|
||||
log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p))
|
||||
} else {
|
||||
msg, err := decode(p)
|
||||
log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *handshakeTransport) readPacket() ([]byte, error) {
|
||||
p, ok := <-t.incoming
|
||||
if !ok {
|
||||
return nil, t.readError
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (t *handshakeTransport) readLoop() {
|
||||
first := true
|
||||
for {
|
||||
p, err := t.readOnePacket(first)
|
||||
first = false
|
||||
if err != nil {
|
||||
t.readError = err
|
||||
close(t.incoming)
|
||||
break
|
||||
}
|
||||
// If this is the first kex, and strict KEX mode is enabled,
|
||||
// we don't ignore any messages, as they may be used to manipulate
|
||||
// the packet sequence numbers.
|
||||
if !(t.sessionID == nil && t.strictMode) && (p[0] == msgIgnore || p[0] == msgDebug) {
|
||||
continue
|
||||
}
|
||||
t.incoming <- p
|
||||
}
|
||||
|
||||
// Stop writers too.
|
||||
t.recordWriteError(t.readError)
|
||||
|
||||
// Unblock the writer should it wait for this.
|
||||
close(t.startKex)
|
||||
|
||||
// Don't close t.requestKex; it's also written to from writePacket.
|
||||
}
|
||||
|
||||
func (t *handshakeTransport) pushPacket(p []byte) error {
|
||||
if debugHandshake {
|
||||
t.printPacket(p, true)
|
||||
}
|
||||
return t.conn.writePacket(p)
|
||||
}
|
||||
|
||||
func (t *handshakeTransport) getWriteError() error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return t.writeError
|
||||
}
|
||||
|
||||
func (t *handshakeTransport) recordWriteError(err error) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.writeError == nil && err != nil {
|
||||
t.writeError = err
|
||||
}
|
||||
}
|
||||
|
||||
func (t *handshakeTransport) requestKeyExchange() {
|
||||
select {
|
||||
case t.requestKex <- struct{}{}:
|
||||
default:
|
||||
// something already requested a kex, so do nothing.
|
||||
}
|
||||
}
|
||||
|
||||
func (t *handshakeTransport) resetWriteThresholds() {
|
||||
t.writePacketsLeft = packetRekeyThreshold
|
||||
if t.config.RekeyThreshold > 0 {
|
||||
t.writeBytesLeft = int64(t.config.RekeyThreshold)
|
||||
} else if t.algorithms != nil {
|
||||
t.writeBytesLeft = t.algorithms.w.rekeyBytes()
|
||||
} else {
|
||||
t.writeBytesLeft = 1 << 30
|
||||
}
|
||||
}
|
||||
|
||||
func (t *handshakeTransport) kexLoop() {
|
||||
|
||||
write:
|
||||
for t.getWriteError() == nil {
|
||||
var request *pendingKex
|
||||
var sent bool
|
||||
|
||||
for request == nil || !sent {
|
||||
var ok bool
|
||||
select {
|
||||
case request, ok = <-t.startKex:
|
||||
if !ok {
|
||||
break write
|
||||
}
|
||||
case <-t.requestKex:
|
||||
break
|
||||
}
|
||||
|
||||
if !sent {
|
||||
if err := t.sendKexInit(); err != nil {
|
||||
t.recordWriteError(err)
|
||||
break
|
||||
}
|
||||
sent = true
|
||||
}
|
||||
}
|
||||
|
||||
if err := t.getWriteError(); err != nil {
|
||||
if request != nil {
|
||||
request.done <- err
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// We're not servicing t.requestKex, but that is OK:
|
||||
// we never block on sending to t.requestKex.
|
||||
|
||||
// We're not servicing t.startKex, but the remote end
|
||||
// has just sent us a kexInitMsg, so it can't send
|
||||
// another key change request, until we close the done
|
||||
// channel on the pendingKex request.
|
||||
|
||||
err := t.enterKeyExchange(request.otherInit)
|
||||
|
||||
t.mu.Lock()
|
||||
t.writeError = err
|
||||
t.sentInitPacket = nil
|
||||
t.sentInitMsg = nil
|
||||
|
||||
t.resetWriteThresholds()
|
||||
|
||||
// we have completed the key exchange. Since the
|
||||
// reader is still blocked, it is safe to clear out
|
||||
// the requestKex channel. This avoids the situation
|
||||
// where: 1) we consumed our own request for the
|
||||
// initial kex, and 2) the kex from the remote side
|
||||
// caused another send on the requestKex channel,
|
||||
clear:
|
||||
for {
|
||||
select {
|
||||
case <-t.requestKex:
|
||||
//
|
||||
default:
|
||||
break clear
|
||||
}
|
||||
}
|
||||
|
||||
request.done <- t.writeError
|
||||
|
||||
// kex finished. Push packets that we received while
|
||||
// the kex was in progress. Don't look at t.startKex
|
||||
// and don't increment writtenSinceKex: if we trigger
|
||||
// another kex while we are still busy with the last
|
||||
// one, things will become very confusing.
|
||||
for _, p := range t.pendingPackets {
|
||||
t.writeError = t.pushPacket(p)
|
||||
if t.writeError != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
t.pendingPackets = t.pendingPackets[:0]
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// Unblock reader.
|
||||
t.conn.Close()
|
||||
|
||||
// drain startKex channel. We don't service t.requestKex
|
||||
// because nobody does blocking sends there.
|
||||
for request := range t.startKex {
|
||||
request.done <- t.getWriteError()
|
||||
}
|
||||
|
||||
// Mark that the loop is done so that Close can return.
|
||||
close(t.kexLoopDone)
|
||||
}
|
||||
|
||||
// The protocol uses uint32 for packet counters, so we can't let them
|
||||
// reach 1<<32. We will actually read and write more packets than
|
||||
// this, though: the other side may send more packets, and after we
|
||||
// hit this limit on writing we will send a few more packets for the
|
||||
// key exchange itself.
|
||||
const packetRekeyThreshold = (1 << 31)
|
||||
|
||||
func (t *handshakeTransport) resetReadThresholds() {
|
||||
t.readPacketsLeft = packetRekeyThreshold
|
||||
if t.config.RekeyThreshold > 0 {
|
||||
t.readBytesLeft = int64(t.config.RekeyThreshold)
|
||||
} else if t.algorithms != nil {
|
||||
t.readBytesLeft = t.algorithms.r.rekeyBytes()
|
||||
} else {
|
||||
t.readBytesLeft = 1 << 30
|
||||
}
|
||||
}
|
||||
|
||||
func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
|
||||
p, err := t.conn.readPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if t.readPacketsLeft > 0 {
|
||||
t.readPacketsLeft--
|
||||
} else {
|
||||
t.requestKeyExchange()
|
||||
}
|
||||
|
||||
if t.readBytesLeft > 0 {
|
||||
t.readBytesLeft -= int64(len(p))
|
||||
} else {
|
||||
t.requestKeyExchange()
|
||||
}
|
||||
|
||||
if debugHandshake {
|
||||
t.printPacket(p, false)
|
||||
}
|
||||
|
||||
if first && p[0] != msgKexInit {
|
||||
return nil, fmt.Errorf("ssh: first packet should be msgKexInit")
|
||||
}
|
||||
|
||||
if p[0] != msgKexInit {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
firstKex := t.sessionID == nil
|
||||
|
||||
kex := pendingKex{
|
||||
done: make(chan error, 1),
|
||||
otherInit: p,
|
||||
}
|
||||
t.startKex <- &kex
|
||||
err = <-kex.done
|
||||
|
||||
if debugHandshake {
|
||||
log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.resetReadThresholds()
|
||||
|
||||
// By default, a key exchange is hidden from higher layers by
|
||||
// translating it into msgIgnore.
|
||||
successPacket := []byte{msgIgnore}
|
||||
if firstKex {
|
||||
// sendKexInit() for the first kex waits for
|
||||
// msgNewKeys so the authentication process is
|
||||
// guaranteed to happen over an encrypted transport.
|
||||
successPacket = []byte{msgNewKeys}
|
||||
}
|
||||
|
||||
return successPacket, nil
|
||||
}
|
||||
|
||||
const (
|
||||
kexStrictClient = "kex-strict-c-v00@openssh.com"
|
||||
kexStrictServer = "kex-strict-s-v00@openssh.com"
|
||||
)
|
||||
|
||||
// sendKexInit sends a key change message.
|
||||
func (t *handshakeTransport) sendKexInit() error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.sentInitMsg != nil {
|
||||
// kexInits may be sent either in response to the other side,
|
||||
// or because our side wants to initiate a key change, so we
|
||||
// may have already sent a kexInit. In that case, don't send a
|
||||
// second kexInit.
|
||||
return nil
|
||||
}
|
||||
|
||||
msg := &kexInitMsg{
|
||||
CiphersClientServer: t.config.Ciphers,
|
||||
CiphersServerClient: t.config.Ciphers,
|
||||
MACsClientServer: t.config.MACs,
|
||||
MACsServerClient: t.config.MACs,
|
||||
CompressionClientServer: supportedCompressions,
|
||||
CompressionServerClient: supportedCompressions,
|
||||
}
|
||||
io.ReadFull(rand.Reader, msg.Cookie[:])
|
||||
|
||||
// We mutate the KexAlgos slice, in order to add the kex-strict extension algorithm,
|
||||
// and possibly to add the ext-info extension algorithm. Since the slice may be the
|
||||
// user owned KeyExchanges, we create our own slice in order to avoid using user
|
||||
// owned memory by mistake.
|
||||
msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+2) // room for kex-strict and ext-info
|
||||
msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...)
|
||||
|
||||
isServer := len(t.hostKeys) > 0
|
||||
if isServer {
|
||||
for _, k := range t.hostKeys {
|
||||
// If k is a MultiAlgorithmSigner, we restrict the signature
|
||||
// algorithms. If k is a AlgorithmSigner, presume it supports all
|
||||
// signature algorithms associated with the key format. If k is not
|
||||
// an AlgorithmSigner, we can only assume it only supports the
|
||||
// algorithms that matches the key format. (This means that Sign
|
||||
// can't pick a different default).
|
||||
keyFormat := k.PublicKey().Type()
|
||||
|
||||
switch s := k.(type) {
|
||||
case MultiAlgorithmSigner:
|
||||
for _, algo := range algorithmsForKeyFormat(keyFormat) {
|
||||
if contains(s.Algorithms(), underlyingAlgo(algo)) {
|
||||
msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algo)
|
||||
}
|
||||
}
|
||||
case AlgorithmSigner:
|
||||
msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algorithmsForKeyFormat(keyFormat)...)
|
||||
default:
|
||||
msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, keyFormat)
|
||||
}
|
||||
}
|
||||
|
||||
if t.sessionID == nil {
|
||||
msg.KexAlgos = append(msg.KexAlgos, kexStrictServer)
|
||||
}
|
||||
} else {
|
||||
msg.ServerHostKeyAlgos = t.hostKeyAlgorithms
|
||||
|
||||
// As a client we opt in to receiving SSH_MSG_EXT_INFO so we know what
|
||||
// algorithms the server supports for public key authentication. See RFC
|
||||
// 8308, Section 2.1.
|
||||
//
|
||||
// We also send the strict KEX mode extension algorithm, in order to opt
|
||||
// into the strict KEX mode.
|
||||
if firstKeyExchange := t.sessionID == nil; firstKeyExchange {
|
||||
msg.KexAlgos = append(msg.KexAlgos, "ext-info-c")
|
||||
msg.KexAlgos = append(msg.KexAlgos, kexStrictClient)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
packet := Marshal(msg)
|
||||
|
||||
// writePacket destroys the contents, so save a copy.
|
||||
packetCopy := make([]byte, len(packet))
|
||||
copy(packetCopy, packet)
|
||||
|
||||
if err := t.pushPacket(packetCopy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.sentInitMsg = msg
|
||||
t.sentInitPacket = packet
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *handshakeTransport) writePacket(p []byte) error {
|
||||
switch p[0] {
|
||||
case msgKexInit:
|
||||
return errors.New("ssh: only handshakeTransport can send kexInit")
|
||||
case msgNewKeys:
|
||||
return errors.New("ssh: only handshakeTransport can send newKeys")
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.writeError != nil {
|
||||
return t.writeError
|
||||
}
|
||||
|
||||
if t.sentInitMsg != nil {
|
||||
// Copy the packet so the writer can reuse the buffer.
|
||||
cp := make([]byte, len(p))
|
||||
copy(cp, p)
|
||||
t.pendingPackets = append(t.pendingPackets, cp)
|
||||
return nil
|
||||
}
|
||||
|
||||
if t.writeBytesLeft > 0 {
|
||||
t.writeBytesLeft -= int64(len(p))
|
||||
} else {
|
||||
t.requestKeyExchange()
|
||||
}
|
||||
|
||||
if t.writePacketsLeft > 0 {
|
||||
t.writePacketsLeft--
|
||||
} else {
|
||||
t.requestKeyExchange()
|
||||
}
|
||||
|
||||
if err := t.pushPacket(p); err != nil {
|
||||
t.writeError = err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *handshakeTransport) Close() error {
|
||||
// Close the connection. This should cause the readLoop goroutine to wake up
|
||||
// and close t.startKex, which will shut down kexLoop if running.
|
||||
err := t.conn.Close()
|
||||
|
||||
// Wait for the kexLoop goroutine to complete.
|
||||
// At that point we know that the readLoop goroutine is complete too,
|
||||
// because kexLoop itself waits for readLoop to close the startKex channel.
|
||||
<-t.kexLoopDone
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
|
||||
if debugHandshake {
|
||||
log.Printf("%s entered key exchange", t.id())
|
||||
}
|
||||
|
||||
otherInit := &kexInitMsg{}
|
||||
if err := Unmarshal(otherInitPacket, otherInit); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
magics := handshakeMagics{
|
||||
clientVersion: t.clientVersion,
|
||||
serverVersion: t.serverVersion,
|
||||
clientKexInit: otherInitPacket,
|
||||
serverKexInit: t.sentInitPacket,
|
||||
}
|
||||
|
||||
clientInit := otherInit
|
||||
serverInit := t.sentInitMsg
|
||||
isClient := len(t.hostKeys) == 0
|
||||
if isClient {
|
||||
clientInit, serverInit = serverInit, clientInit
|
||||
|
||||
magics.clientKexInit = t.sentInitPacket
|
||||
magics.serverKexInit = otherInitPacket
|
||||
}
|
||||
|
||||
var err error
|
||||
t.algorithms, err = findAgreedAlgorithms(isClient, clientInit, serverInit)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if t.sessionID == nil && ((isClient && contains(serverInit.KexAlgos, kexStrictServer)) || (!isClient && contains(clientInit.KexAlgos, kexStrictClient))) {
|
||||
t.strictMode = true
|
||||
if err := t.conn.setStrictMode(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// We don't send FirstKexFollows, but we handle receiving it.
|
||||
//
|
||||
// RFC 4253 section 7 defines the kex and the agreement method for
|
||||
// first_kex_packet_follows. It states that the guessed packet
|
||||
// should be ignored if the "kex algorithm and/or the host
|
||||
// key algorithm is guessed wrong (server and client have
|
||||
// different preferred algorithm), or if any of the other
|
||||
// algorithms cannot be agreed upon". The other algorithms have
|
||||
// already been checked above so the kex algorithm and host key
|
||||
// algorithm are checked here.
|
||||
if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) {
|
||||
// other side sent a kex message for the wrong algorithm,
|
||||
// which we have to ignore.
|
||||
if _, err := t.conn.readPacket(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
kex, ok := kexAlgoMap[t.algorithms.kex]
|
||||
if !ok {
|
||||
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex)
|
||||
}
|
||||
|
||||
var result *kexResult
|
||||
if len(t.hostKeys) > 0 {
|
||||
result, err = t.server(kex, &magics)
|
||||
} else {
|
||||
result, err = t.client(kex, &magics)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
firstKeyExchange := t.sessionID == nil
|
||||
if firstKeyExchange {
|
||||
t.sessionID = result.H
|
||||
}
|
||||
result.SessionID = t.sessionID
|
||||
|
||||
if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// On the server side, after the first SSH_MSG_NEWKEYS, send a SSH_MSG_EXT_INFO
|
||||
// message with the server-sig-algs extension if the client supports it. See
|
||||
// RFC 8308, Sections 2.4 and 3.1, and [PROTOCOL], Section 1.9.
|
||||
if !isClient && firstKeyExchange && contains(clientInit.KexAlgos, "ext-info-c") {
|
||||
supportedPubKeyAuthAlgosList := strings.Join(t.publicKeyAuthAlgorithms, ",")
|
||||
extInfo := &extInfoMsg{
|
||||
NumExtensions: 2,
|
||||
Payload: make([]byte, 0, 4+15+4+len(supportedPubKeyAuthAlgosList)+4+16+4+1),
|
||||
}
|
||||
extInfo.Payload = appendInt(extInfo.Payload, len("server-sig-algs"))
|
||||
extInfo.Payload = append(extInfo.Payload, "server-sig-algs"...)
|
||||
extInfo.Payload = appendInt(extInfo.Payload, len(supportedPubKeyAuthAlgosList))
|
||||
extInfo.Payload = append(extInfo.Payload, supportedPubKeyAuthAlgosList...)
|
||||
extInfo.Payload = appendInt(extInfo.Payload, len("ping@openssh.com"))
|
||||
extInfo.Payload = append(extInfo.Payload, "ping@openssh.com"...)
|
||||
extInfo.Payload = appendInt(extInfo.Payload, 1)
|
||||
extInfo.Payload = append(extInfo.Payload, "0"...)
|
||||
if err := t.conn.writePacket(Marshal(extInfo)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if packet, err := t.conn.readPacket(); err != nil {
|
||||
return err
|
||||
} else if packet[0] != msgNewKeys {
|
||||
return unexpectedMessageError(msgNewKeys, packet[0])
|
||||
}
|
||||
|
||||
if firstKeyExchange {
|
||||
// Indicates to the transport that the first key exchange is completed
|
||||
// after receiving SSH_MSG_NEWKEYS.
|
||||
t.conn.setInitialKEXDone()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// algorithmSignerWrapper is an AlgorithmSigner that only supports the default
|
||||
// key format algorithm.
|
||||
//
|
||||
// This is technically a violation of the AlgorithmSigner interface, but it
|
||||
// should be unreachable given where we use this. Anyway, at least it returns an
|
||||
// error instead of panicing or producing an incorrect signature.
|
||||
type algorithmSignerWrapper struct {
|
||||
Signer
|
||||
}
|
||||
|
||||
func (a algorithmSignerWrapper) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
|
||||
if algorithm != underlyingAlgo(a.PublicKey().Type()) {
|
||||
return nil, errors.New("ssh: internal error: algorithmSignerWrapper invoked with non-default algorithm")
|
||||
}
|
||||
return a.Sign(rand, data)
|
||||
}
|
||||
|
||||
func pickHostKey(hostKeys []Signer, algo string) AlgorithmSigner {
|
||||
for _, k := range hostKeys {
|
||||
if s, ok := k.(MultiAlgorithmSigner); ok {
|
||||
if !contains(s.Algorithms(), underlyingAlgo(algo)) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if algo == k.PublicKey().Type() {
|
||||
return algorithmSignerWrapper{k}
|
||||
}
|
||||
|
||||
k, ok := k.(AlgorithmSigner)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, a := range algorithmsForKeyFormat(k.PublicKey().Type()) {
|
||||
if algo == a {
|
||||
return k
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *handshakeTransport) server(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) {
|
||||
hostKey := pickHostKey(t.hostKeys, t.algorithms.hostKey)
|
||||
if hostKey == nil {
|
||||
return nil, errors.New("ssh: internal error: negotiated unsupported signature type")
|
||||
}
|
||||
|
||||
r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey, t.algorithms.hostKey)
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (t *handshakeTransport) client(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) {
|
||||
result, err := kex.Client(t.conn, t.config.Rand, magics)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hostKey, err := ParsePublicKey(result.HostKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := verifyHostKeySignature(hostKey, t.algorithms.hostKey, result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
93
vendor/github.com/tailscale/golang-x-crypto/ssh/internal/bcrypt_pbkdf/bcrypt_pbkdf.go
generated
vendored
Normal file
93
vendor/github.com/tailscale/golang-x-crypto/ssh/internal/bcrypt_pbkdf/bcrypt_pbkdf.go
generated
vendored
Normal file
@@ -0,0 +1,93 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package bcrypt_pbkdf implements bcrypt_pbkdf(3) from OpenBSD.
|
||||
//
|
||||
// See https://flak.tedunangst.com/post/bcrypt-pbkdf and
|
||||
// https://cvsweb.openbsd.org/cgi-bin/cvsweb/src/lib/libutil/bcrypt_pbkdf.c.
|
||||
package bcrypt_pbkdf
|
||||
|
||||
import (
|
||||
"crypto/sha512"
|
||||
"errors"
|
||||
"golang.org/x/crypto/blowfish"
|
||||
)
|
||||
|
||||
const blockSize = 32
|
||||
|
||||
// Key derives a key from the password, salt and rounds count, returning a
|
||||
// []byte of length keyLen that can be used as cryptographic key.
|
||||
func Key(password, salt []byte, rounds, keyLen int) ([]byte, error) {
|
||||
if rounds < 1 {
|
||||
return nil, errors.New("bcrypt_pbkdf: number of rounds is too small")
|
||||
}
|
||||
if len(password) == 0 {
|
||||
return nil, errors.New("bcrypt_pbkdf: empty password")
|
||||
}
|
||||
if len(salt) == 0 || len(salt) > 1<<20 {
|
||||
return nil, errors.New("bcrypt_pbkdf: bad salt length")
|
||||
}
|
||||
if keyLen > 1024 {
|
||||
return nil, errors.New("bcrypt_pbkdf: keyLen is too large")
|
||||
}
|
||||
|
||||
numBlocks := (keyLen + blockSize - 1) / blockSize
|
||||
key := make([]byte, numBlocks*blockSize)
|
||||
|
||||
h := sha512.New()
|
||||
h.Write(password)
|
||||
shapass := h.Sum(nil)
|
||||
|
||||
shasalt := make([]byte, 0, sha512.Size)
|
||||
cnt, tmp := make([]byte, 4), make([]byte, blockSize)
|
||||
for block := 1; block <= numBlocks; block++ {
|
||||
h.Reset()
|
||||
h.Write(salt)
|
||||
cnt[0] = byte(block >> 24)
|
||||
cnt[1] = byte(block >> 16)
|
||||
cnt[2] = byte(block >> 8)
|
||||
cnt[3] = byte(block)
|
||||
h.Write(cnt)
|
||||
bcryptHash(tmp, shapass, h.Sum(shasalt))
|
||||
|
||||
out := make([]byte, blockSize)
|
||||
copy(out, tmp)
|
||||
for i := 2; i <= rounds; i++ {
|
||||
h.Reset()
|
||||
h.Write(tmp)
|
||||
bcryptHash(tmp, shapass, h.Sum(shasalt))
|
||||
for j := 0; j < len(out); j++ {
|
||||
out[j] ^= tmp[j]
|
||||
}
|
||||
}
|
||||
|
||||
for i, v := range out {
|
||||
key[i*numBlocks+(block-1)] = v
|
||||
}
|
||||
}
|
||||
return key[:keyLen], nil
|
||||
}
|
||||
|
||||
var magic = []byte("OxychromaticBlowfishSwatDynamite")
|
||||
|
||||
func bcryptHash(out, shapass, shasalt []byte) {
|
||||
c, err := blowfish.NewSaltedCipher(shapass, shasalt)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for i := 0; i < 64; i++ {
|
||||
blowfish.ExpandKey(shasalt, c)
|
||||
blowfish.ExpandKey(shapass, c)
|
||||
}
|
||||
copy(out, magic)
|
||||
for i := 0; i < 32; i += 8 {
|
||||
for j := 0; j < 64; j++ {
|
||||
c.Encrypt(out[i:i+8], out[i:i+8])
|
||||
}
|
||||
}
|
||||
// Swap bytes due to different endianness.
|
||||
for i := 0; i < 32; i += 4 {
|
||||
out[i+3], out[i+2], out[i+1], out[i] = out[i], out[i+1], out[i+2], out[i+3]
|
||||
}
|
||||
}
|
||||
786
vendor/github.com/tailscale/golang-x-crypto/ssh/kex.go
generated
vendored
Normal file
786
vendor/github.com/tailscale/golang-x-crypto/ssh/kex.go
generated
vendored
Normal file
@@ -0,0 +1,786 @@
|
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
|
||||
"golang.org/x/crypto/curve25519"
|
||||
)
|
||||
|
||||
const (
|
||||
kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1"
|
||||
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
|
||||
kexAlgoDH14SHA256 = "diffie-hellman-group14-sha256"
|
||||
kexAlgoDH16SHA512 = "diffie-hellman-group16-sha512"
|
||||
kexAlgoECDH256 = "ecdh-sha2-nistp256"
|
||||
kexAlgoECDH384 = "ecdh-sha2-nistp384"
|
||||
kexAlgoECDH521 = "ecdh-sha2-nistp521"
|
||||
kexAlgoCurve25519SHA256LibSSH = "curve25519-sha256@libssh.org"
|
||||
kexAlgoCurve25519SHA256 = "curve25519-sha256"
|
||||
|
||||
// For the following kex only the client half contains a production
|
||||
// ready implementation. The server half only consists of a minimal
|
||||
// implementation to satisfy the automated tests.
|
||||
kexAlgoDHGEXSHA1 = "diffie-hellman-group-exchange-sha1"
|
||||
kexAlgoDHGEXSHA256 = "diffie-hellman-group-exchange-sha256"
|
||||
)
|
||||
|
||||
// kexResult captures the outcome of a key exchange.
|
||||
type kexResult struct {
|
||||
// Session hash. See also RFC 4253, section 8.
|
||||
H []byte
|
||||
|
||||
// Shared secret. See also RFC 4253, section 8.
|
||||
K []byte
|
||||
|
||||
// Host key as hashed into H.
|
||||
HostKey []byte
|
||||
|
||||
// Signature of H.
|
||||
Signature []byte
|
||||
|
||||
// A cryptographic hash function that matches the security
|
||||
// level of the key exchange algorithm. It is used for
|
||||
// calculating H, and for deriving keys from H and K.
|
||||
Hash crypto.Hash
|
||||
|
||||
// The session ID, which is the first H computed. This is used
|
||||
// to derive key material inside the transport.
|
||||
SessionID []byte
|
||||
}
|
||||
|
||||
// handshakeMagics contains data that is always included in the
|
||||
// session hash.
|
||||
type handshakeMagics struct {
|
||||
clientVersion, serverVersion []byte
|
||||
clientKexInit, serverKexInit []byte
|
||||
}
|
||||
|
||||
func (m *handshakeMagics) write(w io.Writer) {
|
||||
writeString(w, m.clientVersion)
|
||||
writeString(w, m.serverVersion)
|
||||
writeString(w, m.clientKexInit)
|
||||
writeString(w, m.serverKexInit)
|
||||
}
|
||||
|
||||
// kexAlgorithm abstracts different key exchange algorithms.
|
||||
type kexAlgorithm interface {
|
||||
// Server runs server-side key agreement, signing the result
|
||||
// with a hostkey. algo is the negotiated algorithm, and may
|
||||
// be a certificate type.
|
||||
Server(p packetConn, rand io.Reader, magics *handshakeMagics, s AlgorithmSigner, algo string) (*kexResult, error)
|
||||
|
||||
// Client runs the client-side key agreement. Caller is
|
||||
// responsible for verifying the host key signature.
|
||||
Client(p packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error)
|
||||
}
|
||||
|
||||
// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
|
||||
type dhGroup struct {
|
||||
g, p, pMinus1 *big.Int
|
||||
hashFunc crypto.Hash
|
||||
}
|
||||
|
||||
func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) {
|
||||
if theirPublic.Cmp(bigOne) <= 0 || theirPublic.Cmp(group.pMinus1) >= 0 {
|
||||
return nil, errors.New("ssh: DH parameter out of bounds")
|
||||
}
|
||||
return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil
|
||||
}
|
||||
|
||||
func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) {
|
||||
var x *big.Int
|
||||
for {
|
||||
var err error
|
||||
if x, err = rand.Int(randSource, group.pMinus1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if x.Sign() > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
X := new(big.Int).Exp(group.g, x, group.p)
|
||||
kexDHInit := kexDHInitMsg{
|
||||
X: X,
|
||||
}
|
||||
if err := c.writePacket(Marshal(&kexDHInit)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
packet, err := c.readPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var kexDHReply kexDHReplyMsg
|
||||
if err = Unmarshal(packet, &kexDHReply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ki, err := group.diffieHellman(kexDHReply.Y, x)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h := group.hashFunc.New()
|
||||
magics.write(h)
|
||||
writeString(h, kexDHReply.HostKey)
|
||||
writeInt(h, X)
|
||||
writeInt(h, kexDHReply.Y)
|
||||
K := make([]byte, intLength(ki))
|
||||
marshalInt(K, ki)
|
||||
h.Write(K)
|
||||
|
||||
return &kexResult{
|
||||
H: h.Sum(nil),
|
||||
K: K,
|
||||
HostKey: kexDHReply.HostKey,
|
||||
Signature: kexDHReply.Signature,
|
||||
Hash: group.hashFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) {
|
||||
packet, err := c.readPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var kexDHInit kexDHInitMsg
|
||||
if err = Unmarshal(packet, &kexDHInit); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var y *big.Int
|
||||
for {
|
||||
if y, err = rand.Int(randSource, group.pMinus1); err != nil {
|
||||
return
|
||||
}
|
||||
if y.Sign() > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
Y := new(big.Int).Exp(group.g, y, group.p)
|
||||
ki, err := group.diffieHellman(kexDHInit.X, y)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hostKeyBytes := priv.PublicKey().Marshal()
|
||||
|
||||
h := group.hashFunc.New()
|
||||
magics.write(h)
|
||||
writeString(h, hostKeyBytes)
|
||||
writeInt(h, kexDHInit.X)
|
||||
writeInt(h, Y)
|
||||
|
||||
K := make([]byte, intLength(ki))
|
||||
marshalInt(K, ki)
|
||||
h.Write(K)
|
||||
|
||||
H := h.Sum(nil)
|
||||
|
||||
// H is already a hash, but the hostkey signing will apply its
|
||||
// own key-specific hash algorithm.
|
||||
sig, err := signAndMarshal(priv, randSource, H, algo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
kexDHReply := kexDHReplyMsg{
|
||||
HostKey: hostKeyBytes,
|
||||
Y: Y,
|
||||
Signature: sig,
|
||||
}
|
||||
packet = Marshal(&kexDHReply)
|
||||
|
||||
err = c.writePacket(packet)
|
||||
return &kexResult{
|
||||
H: H,
|
||||
K: K,
|
||||
HostKey: hostKeyBytes,
|
||||
Signature: sig,
|
||||
Hash: group.hashFunc,
|
||||
}, err
|
||||
}
|
||||
|
||||
// ecdh performs Elliptic Curve Diffie-Hellman key exchange as
|
||||
// described in RFC 5656, section 4.
|
||||
type ecdh struct {
|
||||
curve elliptic.Curve
|
||||
}
|
||||
|
||||
func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) {
|
||||
ephKey, err := ecdsa.GenerateKey(kex.curve, rand)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
kexInit := kexECDHInitMsg{
|
||||
ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y),
|
||||
}
|
||||
|
||||
serialized := Marshal(&kexInit)
|
||||
if err := c.writePacket(serialized); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
packet, err := c.readPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var reply kexECDHReplyMsg
|
||||
if err = Unmarshal(packet, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
x, y, err := unmarshalECKey(kex.curve, reply.EphemeralPubKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// generate shared secret
|
||||
secret, _ := kex.curve.ScalarMult(x, y, ephKey.D.Bytes())
|
||||
|
||||
h := ecHash(kex.curve).New()
|
||||
magics.write(h)
|
||||
writeString(h, reply.HostKey)
|
||||
writeString(h, kexInit.ClientPubKey)
|
||||
writeString(h, reply.EphemeralPubKey)
|
||||
K := make([]byte, intLength(secret))
|
||||
marshalInt(K, secret)
|
||||
h.Write(K)
|
||||
|
||||
return &kexResult{
|
||||
H: h.Sum(nil),
|
||||
K: K,
|
||||
HostKey: reply.HostKey,
|
||||
Signature: reply.Signature,
|
||||
Hash: ecHash(kex.curve),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// unmarshalECKey parses and checks an EC key.
|
||||
func unmarshalECKey(curve elliptic.Curve, pubkey []byte) (x, y *big.Int, err error) {
|
||||
x, y = elliptic.Unmarshal(curve, pubkey)
|
||||
if x == nil {
|
||||
return nil, nil, errors.New("ssh: elliptic.Unmarshal failure")
|
||||
}
|
||||
if !validateECPublicKey(curve, x, y) {
|
||||
return nil, nil, errors.New("ssh: public key not on curve")
|
||||
}
|
||||
return x, y, nil
|
||||
}
|
||||
|
||||
// validateECPublicKey checks that the point is a valid public key for
|
||||
// the given curve. See [SEC1], 3.2.2
|
||||
func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool {
|
||||
if x.Sign() == 0 && y.Sign() == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if x.Cmp(curve.Params().P) >= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if y.Cmp(curve.Params().P) >= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if !curve.IsOnCurve(x, y) {
|
||||
return false
|
||||
}
|
||||
|
||||
// We don't check if N * PubKey == 0, since
|
||||
//
|
||||
// - the NIST curves have cofactor = 1, so this is implicit.
|
||||
// (We don't foresee an implementation that supports non NIST
|
||||
// curves)
|
||||
//
|
||||
// - for ephemeral keys, we don't need to worry about small
|
||||
// subgroup attacks.
|
||||
return true
|
||||
}
|
||||
|
||||
func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) {
|
||||
packet, err := c.readPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var kexECDHInit kexECDHInitMsg
|
||||
if err = Unmarshal(packet, &kexECDHInit); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clientX, clientY, err := unmarshalECKey(kex.curve, kexECDHInit.ClientPubKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// We could cache this key across multiple users/multiple
|
||||
// connection attempts, but the benefit is small. OpenSSH
|
||||
// generates a new key for each incoming connection.
|
||||
ephKey, err := ecdsa.GenerateKey(kex.curve, rand)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hostKeyBytes := priv.PublicKey().Marshal()
|
||||
|
||||
serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y)
|
||||
|
||||
// generate shared secret
|
||||
secret, _ := kex.curve.ScalarMult(clientX, clientY, ephKey.D.Bytes())
|
||||
|
||||
h := ecHash(kex.curve).New()
|
||||
magics.write(h)
|
||||
writeString(h, hostKeyBytes)
|
||||
writeString(h, kexECDHInit.ClientPubKey)
|
||||
writeString(h, serializedEphKey)
|
||||
|
||||
K := make([]byte, intLength(secret))
|
||||
marshalInt(K, secret)
|
||||
h.Write(K)
|
||||
|
||||
H := h.Sum(nil)
|
||||
|
||||
// H is already a hash, but the hostkey signing will apply its
|
||||
// own key-specific hash algorithm.
|
||||
sig, err := signAndMarshal(priv, rand, H, algo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reply := kexECDHReplyMsg{
|
||||
EphemeralPubKey: serializedEphKey,
|
||||
HostKey: hostKeyBytes,
|
||||
Signature: sig,
|
||||
}
|
||||
|
||||
serialized := Marshal(&reply)
|
||||
if err := c.writePacket(serialized); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &kexResult{
|
||||
H: H,
|
||||
K: K,
|
||||
HostKey: reply.HostKey,
|
||||
Signature: sig,
|
||||
Hash: ecHash(kex.curve),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ecHash returns the hash to match the given elliptic curve, see RFC
|
||||
// 5656, section 6.2.1
|
||||
func ecHash(curve elliptic.Curve) crypto.Hash {
|
||||
bitSize := curve.Params().BitSize
|
||||
switch {
|
||||
case bitSize <= 256:
|
||||
return crypto.SHA256
|
||||
case bitSize <= 384:
|
||||
return crypto.SHA384
|
||||
}
|
||||
return crypto.SHA512
|
||||
}
|
||||
|
||||
var kexAlgoMap = map[string]kexAlgorithm{}
|
||||
|
||||
func init() {
|
||||
// This is the group called diffie-hellman-group1-sha1 in
|
||||
// RFC 4253 and Oakley Group 2 in RFC 2409.
|
||||
p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16)
|
||||
kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{
|
||||
g: new(big.Int).SetInt64(2),
|
||||
p: p,
|
||||
pMinus1: new(big.Int).Sub(p, bigOne),
|
||||
hashFunc: crypto.SHA1,
|
||||
}
|
||||
|
||||
// This are the groups called diffie-hellman-group14-sha1 and
|
||||
// diffie-hellman-group14-sha256 in RFC 4253 and RFC 8268,
|
||||
// and Oakley Group 14 in RFC 3526.
|
||||
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
|
||||
group14 := &dhGroup{
|
||||
g: new(big.Int).SetInt64(2),
|
||||
p: p,
|
||||
pMinus1: new(big.Int).Sub(p, bigOne),
|
||||
}
|
||||
|
||||
kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{
|
||||
g: group14.g, p: group14.p, pMinus1: group14.pMinus1,
|
||||
hashFunc: crypto.SHA1,
|
||||
}
|
||||
kexAlgoMap[kexAlgoDH14SHA256] = &dhGroup{
|
||||
g: group14.g, p: group14.p, pMinus1: group14.pMinus1,
|
||||
hashFunc: crypto.SHA256,
|
||||
}
|
||||
|
||||
// This is the group called diffie-hellman-group16-sha512 in RFC
|
||||
// 8268 and Oakley Group 16 in RFC 3526.
|
||||
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF", 16)
|
||||
|
||||
kexAlgoMap[kexAlgoDH16SHA512] = &dhGroup{
|
||||
g: new(big.Int).SetInt64(2),
|
||||
p: p,
|
||||
pMinus1: new(big.Int).Sub(p, bigOne),
|
||||
hashFunc: crypto.SHA512,
|
||||
}
|
||||
|
||||
kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()}
|
||||
kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()}
|
||||
kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()}
|
||||
kexAlgoMap[kexAlgoCurve25519SHA256] = &curve25519sha256{}
|
||||
kexAlgoMap[kexAlgoCurve25519SHA256LibSSH] = &curve25519sha256{}
|
||||
kexAlgoMap[kexAlgoDHGEXSHA1] = &dhGEXSHA{hashFunc: crypto.SHA1}
|
||||
kexAlgoMap[kexAlgoDHGEXSHA256] = &dhGEXSHA{hashFunc: crypto.SHA256}
|
||||
}
|
||||
|
||||
// curve25519sha256 implements the curve25519-sha256 (formerly known as
|
||||
// curve25519-sha256@libssh.org) key exchange method, as described in RFC 8731.
|
||||
type curve25519sha256 struct{}
|
||||
|
||||
type curve25519KeyPair struct {
|
||||
priv [32]byte
|
||||
pub [32]byte
|
||||
}
|
||||
|
||||
func (kp *curve25519KeyPair) generate(rand io.Reader) error {
|
||||
if _, err := io.ReadFull(rand, kp.priv[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
curve25519.ScalarBaseMult(&kp.pub, &kp.priv)
|
||||
return nil
|
||||
}
|
||||
|
||||
// curve25519Zeros is just an array of 32 zero bytes so that we have something
|
||||
// convenient to compare against in order to reject curve25519 points with the
|
||||
// wrong order.
|
||||
var curve25519Zeros [32]byte
|
||||
|
||||
func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) {
|
||||
var kp curve25519KeyPair
|
||||
if err := kp.generate(rand); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := c.writePacket(Marshal(&kexECDHInitMsg{kp.pub[:]})); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
packet, err := c.readPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var reply kexECDHReplyMsg
|
||||
if err = Unmarshal(packet, &reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(reply.EphemeralPubKey) != 32 {
|
||||
return nil, errors.New("ssh: peer's curve25519 public value has wrong length")
|
||||
}
|
||||
|
||||
var servPub, secret [32]byte
|
||||
copy(servPub[:], reply.EphemeralPubKey)
|
||||
curve25519.ScalarMult(&secret, &kp.priv, &servPub)
|
||||
if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 {
|
||||
return nil, errors.New("ssh: peer's curve25519 public value has wrong order")
|
||||
}
|
||||
|
||||
h := crypto.SHA256.New()
|
||||
magics.write(h)
|
||||
writeString(h, reply.HostKey)
|
||||
writeString(h, kp.pub[:])
|
||||
writeString(h, reply.EphemeralPubKey)
|
||||
|
||||
ki := new(big.Int).SetBytes(secret[:])
|
||||
K := make([]byte, intLength(ki))
|
||||
marshalInt(K, ki)
|
||||
h.Write(K)
|
||||
|
||||
return &kexResult{
|
||||
H: h.Sum(nil),
|
||||
K: K,
|
||||
HostKey: reply.HostKey,
|
||||
Signature: reply.Signature,
|
||||
Hash: crypto.SHA256,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) {
|
||||
packet, err := c.readPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var kexInit kexECDHInitMsg
|
||||
if err = Unmarshal(packet, &kexInit); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if len(kexInit.ClientPubKey) != 32 {
|
||||
return nil, errors.New("ssh: peer's curve25519 public value has wrong length")
|
||||
}
|
||||
|
||||
var kp curve25519KeyPair
|
||||
if err := kp.generate(rand); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var clientPub, secret [32]byte
|
||||
copy(clientPub[:], kexInit.ClientPubKey)
|
||||
curve25519.ScalarMult(&secret, &kp.priv, &clientPub)
|
||||
if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 {
|
||||
return nil, errors.New("ssh: peer's curve25519 public value has wrong order")
|
||||
}
|
||||
|
||||
hostKeyBytes := priv.PublicKey().Marshal()
|
||||
|
||||
h := crypto.SHA256.New()
|
||||
magics.write(h)
|
||||
writeString(h, hostKeyBytes)
|
||||
writeString(h, kexInit.ClientPubKey)
|
||||
writeString(h, kp.pub[:])
|
||||
|
||||
ki := new(big.Int).SetBytes(secret[:])
|
||||
K := make([]byte, intLength(ki))
|
||||
marshalInt(K, ki)
|
||||
h.Write(K)
|
||||
|
||||
H := h.Sum(nil)
|
||||
|
||||
sig, err := signAndMarshal(priv, rand, H, algo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reply := kexECDHReplyMsg{
|
||||
EphemeralPubKey: kp.pub[:],
|
||||
HostKey: hostKeyBytes,
|
||||
Signature: sig,
|
||||
}
|
||||
if err := c.writePacket(Marshal(&reply)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &kexResult{
|
||||
H: H,
|
||||
K: K,
|
||||
HostKey: hostKeyBytes,
|
||||
Signature: sig,
|
||||
Hash: crypto.SHA256,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// dhGEXSHA implements the diffie-hellman-group-exchange-sha1 and
|
||||
// diffie-hellman-group-exchange-sha256 key agreement protocols,
|
||||
// as described in RFC 4419
|
||||
type dhGEXSHA struct {
|
||||
hashFunc crypto.Hash
|
||||
}
|
||||
|
||||
const (
|
||||
dhGroupExchangeMinimumBits = 2048
|
||||
dhGroupExchangePreferredBits = 2048
|
||||
dhGroupExchangeMaximumBits = 8192
|
||||
)
|
||||
|
||||
func (gex *dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) {
|
||||
// Send GexRequest
|
||||
kexDHGexRequest := kexDHGexRequestMsg{
|
||||
MinBits: dhGroupExchangeMinimumBits,
|
||||
PreferedBits: dhGroupExchangePreferredBits,
|
||||
MaxBits: dhGroupExchangeMaximumBits,
|
||||
}
|
||||
if err := c.writePacket(Marshal(&kexDHGexRequest)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Receive GexGroup
|
||||
packet, err := c.readPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var msg kexDHGexGroupMsg
|
||||
if err = Unmarshal(packet, &msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// reject if p's bit length < dhGroupExchangeMinimumBits or > dhGroupExchangeMaximumBits
|
||||
if msg.P.BitLen() < dhGroupExchangeMinimumBits || msg.P.BitLen() > dhGroupExchangeMaximumBits {
|
||||
return nil, fmt.Errorf("ssh: server-generated gex p is out of range (%d bits)", msg.P.BitLen())
|
||||
}
|
||||
|
||||
// Check if g is safe by verifying that 1 < g < p-1
|
||||
pMinusOne := new(big.Int).Sub(msg.P, bigOne)
|
||||
if msg.G.Cmp(bigOne) <= 0 || msg.G.Cmp(pMinusOne) >= 0 {
|
||||
return nil, fmt.Errorf("ssh: server provided gex g is not safe")
|
||||
}
|
||||
|
||||
// Send GexInit
|
||||
pHalf := new(big.Int).Rsh(msg.P, 1)
|
||||
x, err := rand.Int(randSource, pHalf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
X := new(big.Int).Exp(msg.G, x, msg.P)
|
||||
kexDHGexInit := kexDHGexInitMsg{
|
||||
X: X,
|
||||
}
|
||||
if err := c.writePacket(Marshal(&kexDHGexInit)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Receive GexReply
|
||||
packet, err = c.readPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var kexDHGexReply kexDHGexReplyMsg
|
||||
if err = Unmarshal(packet, &kexDHGexReply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if kexDHGexReply.Y.Cmp(bigOne) <= 0 || kexDHGexReply.Y.Cmp(pMinusOne) >= 0 {
|
||||
return nil, errors.New("ssh: DH parameter out of bounds")
|
||||
}
|
||||
kInt := new(big.Int).Exp(kexDHGexReply.Y, x, msg.P)
|
||||
|
||||
// Check if k is safe by verifying that k > 1 and k < p - 1
|
||||
if kInt.Cmp(bigOne) <= 0 || kInt.Cmp(pMinusOne) >= 0 {
|
||||
return nil, fmt.Errorf("ssh: derived k is not safe")
|
||||
}
|
||||
|
||||
h := gex.hashFunc.New()
|
||||
magics.write(h)
|
||||
writeString(h, kexDHGexReply.HostKey)
|
||||
binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMinimumBits))
|
||||
binary.Write(h, binary.BigEndian, uint32(dhGroupExchangePreferredBits))
|
||||
binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMaximumBits))
|
||||
writeInt(h, msg.P)
|
||||
writeInt(h, msg.G)
|
||||
writeInt(h, X)
|
||||
writeInt(h, kexDHGexReply.Y)
|
||||
K := make([]byte, intLength(kInt))
|
||||
marshalInt(K, kInt)
|
||||
h.Write(K)
|
||||
|
||||
return &kexResult{
|
||||
H: h.Sum(nil),
|
||||
K: K,
|
||||
HostKey: kexDHGexReply.HostKey,
|
||||
Signature: kexDHGexReply.Signature,
|
||||
Hash: gex.hashFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Server half implementation of the Diffie Hellman Key Exchange with SHA1 and SHA256.
|
||||
//
|
||||
// This is a minimal implementation to satisfy the automated tests.
|
||||
func (gex dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) {
|
||||
// Receive GexRequest
|
||||
packet, err := c.readPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var kexDHGexRequest kexDHGexRequestMsg
|
||||
if err = Unmarshal(packet, &kexDHGexRequest); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Send GexGroup
|
||||
// This is the group called diffie-hellman-group14-sha1 in RFC
|
||||
// 4253 and Oakley Group 14 in RFC 3526.
|
||||
p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
|
||||
g := big.NewInt(2)
|
||||
|
||||
msg := &kexDHGexGroupMsg{
|
||||
P: p,
|
||||
G: g,
|
||||
}
|
||||
if err := c.writePacket(Marshal(msg)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Receive GexInit
|
||||
packet, err = c.readPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var kexDHGexInit kexDHGexInitMsg
|
||||
if err = Unmarshal(packet, &kexDHGexInit); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
pHalf := new(big.Int).Rsh(p, 1)
|
||||
|
||||
y, err := rand.Int(randSource, pHalf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
Y := new(big.Int).Exp(g, y, p)
|
||||
|
||||
pMinusOne := new(big.Int).Sub(p, bigOne)
|
||||
if kexDHGexInit.X.Cmp(bigOne) <= 0 || kexDHGexInit.X.Cmp(pMinusOne) >= 0 {
|
||||
return nil, errors.New("ssh: DH parameter out of bounds")
|
||||
}
|
||||
kInt := new(big.Int).Exp(kexDHGexInit.X, y, p)
|
||||
|
||||
hostKeyBytes := priv.PublicKey().Marshal()
|
||||
|
||||
h := gex.hashFunc.New()
|
||||
magics.write(h)
|
||||
writeString(h, hostKeyBytes)
|
||||
binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMinimumBits))
|
||||
binary.Write(h, binary.BigEndian, uint32(dhGroupExchangePreferredBits))
|
||||
binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMaximumBits))
|
||||
writeInt(h, p)
|
||||
writeInt(h, g)
|
||||
writeInt(h, kexDHGexInit.X)
|
||||
writeInt(h, Y)
|
||||
|
||||
K := make([]byte, intLength(kInt))
|
||||
marshalInt(K, kInt)
|
||||
h.Write(K)
|
||||
|
||||
H := h.Sum(nil)
|
||||
|
||||
// H is already a hash, but the hostkey signing will apply its
|
||||
// own key-specific hash algorithm.
|
||||
sig, err := signAndMarshal(priv, randSource, H, algo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
kexDHGexReply := kexDHGexReplyMsg{
|
||||
HostKey: hostKeyBytes,
|
||||
Y: Y,
|
||||
Signature: sig,
|
||||
}
|
||||
packet = Marshal(&kexDHGexReply)
|
||||
|
||||
err = c.writePacket(packet)
|
||||
|
||||
return &kexResult{
|
||||
H: H,
|
||||
K: K,
|
||||
HostKey: hostKeyBytes,
|
||||
Signature: sig,
|
||||
Hash: gex.hashFunc,
|
||||
}, err
|
||||
}
|
||||
1728
vendor/github.com/tailscale/golang-x-crypto/ssh/keys.go
generated
vendored
Normal file
1728
vendor/github.com/tailscale/golang-x-crypto/ssh/keys.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
68
vendor/github.com/tailscale/golang-x-crypto/ssh/mac.go
generated
vendored
Normal file
68
vendor/github.com/tailscale/golang-x-crypto/ssh/mac.go
generated
vendored
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
// Message authentication support
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"hash"
|
||||
)
|
||||
|
||||
type macMode struct {
|
||||
keySize int
|
||||
etm bool
|
||||
new func(key []byte) hash.Hash
|
||||
}
|
||||
|
||||
// truncatingMAC wraps around a hash.Hash and truncates the output digest to
|
||||
// a given size.
|
||||
type truncatingMAC struct {
|
||||
length int
|
||||
hmac hash.Hash
|
||||
}
|
||||
|
||||
func (t truncatingMAC) Write(data []byte) (int, error) {
|
||||
return t.hmac.Write(data)
|
||||
}
|
||||
|
||||
func (t truncatingMAC) Sum(in []byte) []byte {
|
||||
out := t.hmac.Sum(in)
|
||||
return out[:len(in)+t.length]
|
||||
}
|
||||
|
||||
func (t truncatingMAC) Reset() {
|
||||
t.hmac.Reset()
|
||||
}
|
||||
|
||||
func (t truncatingMAC) Size() int {
|
||||
return t.length
|
||||
}
|
||||
|
||||
func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() }
|
||||
|
||||
var macModes = map[string]*macMode{
|
||||
"hmac-sha2-512-etm@openssh.com": {64, true, func(key []byte) hash.Hash {
|
||||
return hmac.New(sha512.New, key)
|
||||
}},
|
||||
"hmac-sha2-256-etm@openssh.com": {32, true, func(key []byte) hash.Hash {
|
||||
return hmac.New(sha256.New, key)
|
||||
}},
|
||||
"hmac-sha2-512": {64, false, func(key []byte) hash.Hash {
|
||||
return hmac.New(sha512.New, key)
|
||||
}},
|
||||
"hmac-sha2-256": {32, false, func(key []byte) hash.Hash {
|
||||
return hmac.New(sha256.New, key)
|
||||
}},
|
||||
"hmac-sha1": {20, false, func(key []byte) hash.Hash {
|
||||
return hmac.New(sha1.New, key)
|
||||
}},
|
||||
"hmac-sha1-96": {20, false, func(key []byte) hash.Hash {
|
||||
return truncatingMAC{12, hmac.New(sha1.New, key)}
|
||||
}},
|
||||
}
|
||||
891
vendor/github.com/tailscale/golang-x-crypto/ssh/messages.go
generated
vendored
Normal file
891
vendor/github.com/tailscale/golang-x-crypto/ssh/messages.go
generated
vendored
Normal file
@@ -0,0 +1,891 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// These are SSH message type numbers. They are scattered around several
|
||||
// documents but many were taken from [SSH-PARAMETERS].
|
||||
const (
|
||||
msgIgnore = 2
|
||||
msgUnimplemented = 3
|
||||
msgDebug = 4
|
||||
msgNewKeys = 21
|
||||
)
|
||||
|
||||
// SSH messages:
|
||||
//
|
||||
// These structures mirror the wire format of the corresponding SSH messages.
|
||||
// They are marshaled using reflection with the marshal and unmarshal functions
|
||||
// in this file. The only wrinkle is that a final member of type []byte with a
|
||||
// ssh tag of "rest" receives the remainder of a packet when unmarshaling.
|
||||
|
||||
// See RFC 4253, section 11.1.
|
||||
const msgDisconnect = 1
|
||||
|
||||
// disconnectMsg is the message that signals a disconnect. It is also
|
||||
// the error type returned from mux.Wait()
|
||||
type disconnectMsg struct {
|
||||
Reason uint32 `sshtype:"1"`
|
||||
Message string
|
||||
Language string
|
||||
}
|
||||
|
||||
func (d *disconnectMsg) Error() string {
|
||||
return fmt.Sprintf("ssh: disconnect, reason %d: %s", d.Reason, d.Message)
|
||||
}
|
||||
|
||||
// See RFC 4253, section 7.1.
|
||||
const msgKexInit = 20
|
||||
|
||||
type kexInitMsg struct {
|
||||
Cookie [16]byte `sshtype:"20"`
|
||||
KexAlgos []string
|
||||
ServerHostKeyAlgos []string
|
||||
CiphersClientServer []string
|
||||
CiphersServerClient []string
|
||||
MACsClientServer []string
|
||||
MACsServerClient []string
|
||||
CompressionClientServer []string
|
||||
CompressionServerClient []string
|
||||
LanguagesClientServer []string
|
||||
LanguagesServerClient []string
|
||||
FirstKexFollows bool
|
||||
Reserved uint32
|
||||
}
|
||||
|
||||
// See RFC 4253, section 8.
|
||||
|
||||
// Diffie-Hellman
|
||||
const msgKexDHInit = 30
|
||||
|
||||
type kexDHInitMsg struct {
|
||||
X *big.Int `sshtype:"30"`
|
||||
}
|
||||
|
||||
const msgKexECDHInit = 30
|
||||
|
||||
type kexECDHInitMsg struct {
|
||||
ClientPubKey []byte `sshtype:"30"`
|
||||
}
|
||||
|
||||
const msgKexECDHReply = 31
|
||||
|
||||
type kexECDHReplyMsg struct {
|
||||
HostKey []byte `sshtype:"31"`
|
||||
EphemeralPubKey []byte
|
||||
Signature []byte
|
||||
}
|
||||
|
||||
const msgKexDHReply = 31
|
||||
|
||||
type kexDHReplyMsg struct {
|
||||
HostKey []byte `sshtype:"31"`
|
||||
Y *big.Int
|
||||
Signature []byte
|
||||
}
|
||||
|
||||
// See RFC 4419, section 5.
|
||||
const msgKexDHGexGroup = 31
|
||||
|
||||
type kexDHGexGroupMsg struct {
|
||||
P *big.Int `sshtype:"31"`
|
||||
G *big.Int
|
||||
}
|
||||
|
||||
const msgKexDHGexInit = 32
|
||||
|
||||
type kexDHGexInitMsg struct {
|
||||
X *big.Int `sshtype:"32"`
|
||||
}
|
||||
|
||||
const msgKexDHGexReply = 33
|
||||
|
||||
type kexDHGexReplyMsg struct {
|
||||
HostKey []byte `sshtype:"33"`
|
||||
Y *big.Int
|
||||
Signature []byte
|
||||
}
|
||||
|
||||
const msgKexDHGexRequest = 34
|
||||
|
||||
type kexDHGexRequestMsg struct {
|
||||
MinBits uint32 `sshtype:"34"`
|
||||
PreferedBits uint32
|
||||
MaxBits uint32
|
||||
}
|
||||
|
||||
// See RFC 4253, section 10.
|
||||
const msgServiceRequest = 5
|
||||
|
||||
type serviceRequestMsg struct {
|
||||
Service string `sshtype:"5"`
|
||||
}
|
||||
|
||||
// See RFC 4253, section 10.
|
||||
const msgServiceAccept = 6
|
||||
|
||||
type serviceAcceptMsg struct {
|
||||
Service string `sshtype:"6"`
|
||||
}
|
||||
|
||||
// See RFC 8308, section 2.3
|
||||
const msgExtInfo = 7
|
||||
|
||||
type extInfoMsg struct {
|
||||
NumExtensions uint32 `sshtype:"7"`
|
||||
Payload []byte `ssh:"rest"`
|
||||
}
|
||||
|
||||
// See RFC 4252, section 5.
|
||||
const msgUserAuthRequest = 50
|
||||
|
||||
type userAuthRequestMsg struct {
|
||||
User string `sshtype:"50"`
|
||||
Service string
|
||||
Method string
|
||||
Payload []byte `ssh:"rest"`
|
||||
}
|
||||
|
||||
// Used for debug printouts of packets.
|
||||
type userAuthSuccessMsg struct {
|
||||
}
|
||||
|
||||
// See RFC 4252, section 5.1
|
||||
const msgUserAuthFailure = 51
|
||||
|
||||
type userAuthFailureMsg struct {
|
||||
Methods []string `sshtype:"51"`
|
||||
PartialSuccess bool
|
||||
}
|
||||
|
||||
// See RFC 4252, section 5.1
|
||||
const msgUserAuthSuccess = 52
|
||||
|
||||
// See RFC 4252, section 5.4
|
||||
const msgUserAuthBanner = 53
|
||||
|
||||
type userAuthBannerMsg struct {
|
||||
Message string `sshtype:"53"`
|
||||
// unused, but required to allow message parsing
|
||||
Language string
|
||||
}
|
||||
|
||||
// See RFC 4256, section 3.2
|
||||
const msgUserAuthInfoRequest = 60
|
||||
const msgUserAuthInfoResponse = 61
|
||||
|
||||
type userAuthInfoRequestMsg struct {
|
||||
Name string `sshtype:"60"`
|
||||
Instruction string
|
||||
Language string
|
||||
NumPrompts uint32
|
||||
Prompts []byte `ssh:"rest"`
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.1.
|
||||
const msgChannelOpen = 90
|
||||
|
||||
type channelOpenMsg struct {
|
||||
ChanType string `sshtype:"90"`
|
||||
PeersID uint32
|
||||
PeersWindow uint32
|
||||
MaxPacketSize uint32
|
||||
TypeSpecificData []byte `ssh:"rest"`
|
||||
}
|
||||
|
||||
const msgChannelExtendedData = 95
|
||||
const msgChannelData = 94
|
||||
|
||||
// Used for debug print outs of packets.
|
||||
type channelDataMsg struct {
|
||||
PeersID uint32 `sshtype:"94"`
|
||||
Length uint32
|
||||
Rest []byte `ssh:"rest"`
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.1.
|
||||
const msgChannelOpenConfirm = 91
|
||||
|
||||
type channelOpenConfirmMsg struct {
|
||||
PeersID uint32 `sshtype:"91"`
|
||||
MyID uint32
|
||||
MyWindow uint32
|
||||
MaxPacketSize uint32
|
||||
TypeSpecificData []byte `ssh:"rest"`
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.1.
|
||||
const msgChannelOpenFailure = 92
|
||||
|
||||
type channelOpenFailureMsg struct {
|
||||
PeersID uint32 `sshtype:"92"`
|
||||
Reason RejectionReason
|
||||
Message string
|
||||
Language string
|
||||
}
|
||||
|
||||
const msgChannelRequest = 98
|
||||
|
||||
type channelRequestMsg struct {
|
||||
PeersID uint32 `sshtype:"98"`
|
||||
Request string
|
||||
WantReply bool
|
||||
RequestSpecificData []byte `ssh:"rest"`
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.4.
|
||||
const msgChannelSuccess = 99
|
||||
|
||||
type channelRequestSuccessMsg struct {
|
||||
PeersID uint32 `sshtype:"99"`
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.4.
|
||||
const msgChannelFailure = 100
|
||||
|
||||
type channelRequestFailureMsg struct {
|
||||
PeersID uint32 `sshtype:"100"`
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.3
|
||||
const msgChannelClose = 97
|
||||
|
||||
type channelCloseMsg struct {
|
||||
PeersID uint32 `sshtype:"97"`
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.3
|
||||
const msgChannelEOF = 96
|
||||
|
||||
type channelEOFMsg struct {
|
||||
PeersID uint32 `sshtype:"96"`
|
||||
}
|
||||
|
||||
// See RFC 4254, section 4
|
||||
const msgGlobalRequest = 80
|
||||
|
||||
type globalRequestMsg struct {
|
||||
Type string `sshtype:"80"`
|
||||
WantReply bool
|
||||
Data []byte `ssh:"rest"`
|
||||
}
|
||||
|
||||
// See RFC 4254, section 4
|
||||
const msgRequestSuccess = 81
|
||||
|
||||
type globalRequestSuccessMsg struct {
|
||||
Data []byte `ssh:"rest" sshtype:"81"`
|
||||
}
|
||||
|
||||
// See RFC 4254, section 4
|
||||
const msgRequestFailure = 82
|
||||
|
||||
type globalRequestFailureMsg struct {
|
||||
Data []byte `ssh:"rest" sshtype:"82"`
|
||||
}
|
||||
|
||||
// See RFC 4254, section 5.2
|
||||
const msgChannelWindowAdjust = 93
|
||||
|
||||
type windowAdjustMsg struct {
|
||||
PeersID uint32 `sshtype:"93"`
|
||||
AdditionalBytes uint32
|
||||
}
|
||||
|
||||
// See RFC 4252, section 7
|
||||
const msgUserAuthPubKeyOk = 60
|
||||
|
||||
type userAuthPubKeyOkMsg struct {
|
||||
Algo string `sshtype:"60"`
|
||||
PubKey []byte
|
||||
}
|
||||
|
||||
// See RFC 4462, section 3
|
||||
const msgUserAuthGSSAPIResponse = 60
|
||||
|
||||
type userAuthGSSAPIResponse struct {
|
||||
SupportMech []byte `sshtype:"60"`
|
||||
}
|
||||
|
||||
const msgUserAuthGSSAPIToken = 61
|
||||
|
||||
type userAuthGSSAPIToken struct {
|
||||
Token []byte `sshtype:"61"`
|
||||
}
|
||||
|
||||
const msgUserAuthGSSAPIMIC = 66
|
||||
|
||||
type userAuthGSSAPIMIC struct {
|
||||
MIC []byte `sshtype:"66"`
|
||||
}
|
||||
|
||||
// See RFC 4462, section 3.9
|
||||
const msgUserAuthGSSAPIErrTok = 64
|
||||
|
||||
type userAuthGSSAPIErrTok struct {
|
||||
ErrorToken []byte `sshtype:"64"`
|
||||
}
|
||||
|
||||
// See RFC 4462, section 3.8
|
||||
const msgUserAuthGSSAPIError = 65
|
||||
|
||||
type userAuthGSSAPIError struct {
|
||||
MajorStatus uint32 `sshtype:"65"`
|
||||
MinorStatus uint32
|
||||
Message string
|
||||
LanguageTag string
|
||||
}
|
||||
|
||||
// Transport layer OpenSSH extension. See [PROTOCOL], section 1.9
|
||||
const msgPing = 192
|
||||
|
||||
type pingMsg struct {
|
||||
Data string `sshtype:"192"`
|
||||
}
|
||||
|
||||
// Transport layer OpenSSH extension. See [PROTOCOL], section 1.9
|
||||
const msgPong = 193
|
||||
|
||||
type pongMsg struct {
|
||||
Data string `sshtype:"193"`
|
||||
}
|
||||
|
||||
// typeTags returns the possible type bytes for the given reflect.Type, which
|
||||
// should be a struct. The possible values are separated by a '|' character.
|
||||
func typeTags(structType reflect.Type) (tags []byte) {
|
||||
tagStr := structType.Field(0).Tag.Get("sshtype")
|
||||
|
||||
for _, tag := range strings.Split(tagStr, "|") {
|
||||
i, err := strconv.Atoi(tag)
|
||||
if err == nil {
|
||||
tags = append(tags, byte(i))
|
||||
}
|
||||
}
|
||||
|
||||
return tags
|
||||
}
|
||||
|
||||
func fieldError(t reflect.Type, field int, problem string) error {
|
||||
if problem != "" {
|
||||
problem = ": " + problem
|
||||
}
|
||||
return fmt.Errorf("ssh: unmarshal error for field %s of type %s%s", t.Field(field).Name, t.Name(), problem)
|
||||
}
|
||||
|
||||
var errShortRead = errors.New("ssh: short read")
|
||||
|
||||
// Unmarshal parses data in SSH wire format into a structure. The out
|
||||
// argument should be a pointer to struct. If the first member of the
|
||||
// struct has the "sshtype" tag set to a '|'-separated set of numbers
|
||||
// in decimal, the packet must start with one of those numbers. In
|
||||
// case of error, Unmarshal returns a ParseError or
|
||||
// UnexpectedMessageError.
|
||||
func Unmarshal(data []byte, out interface{}) error {
|
||||
v := reflect.ValueOf(out).Elem()
|
||||
structType := v.Type()
|
||||
expectedTypes := typeTags(structType)
|
||||
|
||||
var expectedType byte
|
||||
if len(expectedTypes) > 0 {
|
||||
expectedType = expectedTypes[0]
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
return parseError(expectedType)
|
||||
}
|
||||
|
||||
if len(expectedTypes) > 0 {
|
||||
goodType := false
|
||||
for _, e := range expectedTypes {
|
||||
if e > 0 && data[0] == e {
|
||||
goodType = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !goodType {
|
||||
return fmt.Errorf("ssh: unexpected message type %d (expected one of %v)", data[0], expectedTypes)
|
||||
}
|
||||
data = data[1:]
|
||||
}
|
||||
|
||||
var ok bool
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
t := field.Type()
|
||||
switch t.Kind() {
|
||||
case reflect.Bool:
|
||||
if len(data) < 1 {
|
||||
return errShortRead
|
||||
}
|
||||
field.SetBool(data[0] != 0)
|
||||
data = data[1:]
|
||||
case reflect.Array:
|
||||
if t.Elem().Kind() != reflect.Uint8 {
|
||||
return fieldError(structType, i, "array of unsupported type")
|
||||
}
|
||||
if len(data) < t.Len() {
|
||||
return errShortRead
|
||||
}
|
||||
for j, n := 0, t.Len(); j < n; j++ {
|
||||
field.Index(j).Set(reflect.ValueOf(data[j]))
|
||||
}
|
||||
data = data[t.Len():]
|
||||
case reflect.Uint64:
|
||||
var u64 uint64
|
||||
if u64, data, ok = parseUint64(data); !ok {
|
||||
return errShortRead
|
||||
}
|
||||
field.SetUint(u64)
|
||||
case reflect.Uint32:
|
||||
var u32 uint32
|
||||
if u32, data, ok = parseUint32(data); !ok {
|
||||
return errShortRead
|
||||
}
|
||||
field.SetUint(uint64(u32))
|
||||
case reflect.Uint8:
|
||||
if len(data) < 1 {
|
||||
return errShortRead
|
||||
}
|
||||
field.SetUint(uint64(data[0]))
|
||||
data = data[1:]
|
||||
case reflect.String:
|
||||
var s []byte
|
||||
if s, data, ok = parseString(data); !ok {
|
||||
return fieldError(structType, i, "")
|
||||
}
|
||||
field.SetString(string(s))
|
||||
case reflect.Slice:
|
||||
switch t.Elem().Kind() {
|
||||
case reflect.Uint8:
|
||||
if structType.Field(i).Tag.Get("ssh") == "rest" {
|
||||
field.Set(reflect.ValueOf(data))
|
||||
data = nil
|
||||
} else {
|
||||
var s []byte
|
||||
if s, data, ok = parseString(data); !ok {
|
||||
return errShortRead
|
||||
}
|
||||
field.Set(reflect.ValueOf(s))
|
||||
}
|
||||
case reflect.String:
|
||||
var nl []string
|
||||
if nl, data, ok = parseNameList(data); !ok {
|
||||
return errShortRead
|
||||
}
|
||||
field.Set(reflect.ValueOf(nl))
|
||||
default:
|
||||
return fieldError(structType, i, "slice of unsupported type")
|
||||
}
|
||||
case reflect.Ptr:
|
||||
if t == bigIntType {
|
||||
var n *big.Int
|
||||
if n, data, ok = parseInt(data); !ok {
|
||||
return errShortRead
|
||||
}
|
||||
field.Set(reflect.ValueOf(n))
|
||||
} else {
|
||||
return fieldError(structType, i, "pointer to unsupported type")
|
||||
}
|
||||
default:
|
||||
return fieldError(structType, i, fmt.Sprintf("unsupported type: %v", t))
|
||||
}
|
||||
}
|
||||
|
||||
if len(data) != 0 {
|
||||
return parseError(expectedType)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Marshal serializes the message in msg to SSH wire format. The msg
|
||||
// argument should be a struct or pointer to struct. If the first
|
||||
// member has the "sshtype" tag set to a number in decimal, that
|
||||
// number is prepended to the result. If the last of member has the
|
||||
// "ssh" tag set to "rest", its contents are appended to the output.
|
||||
func Marshal(msg interface{}) []byte {
|
||||
out := make([]byte, 0, 64)
|
||||
return marshalStruct(out, msg)
|
||||
}
|
||||
|
||||
func marshalStruct(out []byte, msg interface{}) []byte {
|
||||
v := reflect.Indirect(reflect.ValueOf(msg))
|
||||
msgTypes := typeTags(v.Type())
|
||||
if len(msgTypes) > 0 {
|
||||
out = append(out, msgTypes[0])
|
||||
}
|
||||
|
||||
for i, n := 0, v.NumField(); i < n; i++ {
|
||||
field := v.Field(i)
|
||||
switch t := field.Type(); t.Kind() {
|
||||
case reflect.Bool:
|
||||
var v uint8
|
||||
if field.Bool() {
|
||||
v = 1
|
||||
}
|
||||
out = append(out, v)
|
||||
case reflect.Array:
|
||||
if t.Elem().Kind() != reflect.Uint8 {
|
||||
panic(fmt.Sprintf("array of non-uint8 in field %d: %T", i, field.Interface()))
|
||||
}
|
||||
for j, l := 0, t.Len(); j < l; j++ {
|
||||
out = append(out, uint8(field.Index(j).Uint()))
|
||||
}
|
||||
case reflect.Uint32:
|
||||
out = appendU32(out, uint32(field.Uint()))
|
||||
case reflect.Uint64:
|
||||
out = appendU64(out, uint64(field.Uint()))
|
||||
case reflect.Uint8:
|
||||
out = append(out, uint8(field.Uint()))
|
||||
case reflect.String:
|
||||
s := field.String()
|
||||
out = appendInt(out, len(s))
|
||||
out = append(out, s...)
|
||||
case reflect.Slice:
|
||||
switch t.Elem().Kind() {
|
||||
case reflect.Uint8:
|
||||
if v.Type().Field(i).Tag.Get("ssh") != "rest" {
|
||||
out = appendInt(out, field.Len())
|
||||
}
|
||||
out = append(out, field.Bytes()...)
|
||||
case reflect.String:
|
||||
offset := len(out)
|
||||
out = appendU32(out, 0)
|
||||
if n := field.Len(); n > 0 {
|
||||
for j := 0; j < n; j++ {
|
||||
f := field.Index(j)
|
||||
if j != 0 {
|
||||
out = append(out, ',')
|
||||
}
|
||||
out = append(out, f.String()...)
|
||||
}
|
||||
// overwrite length value
|
||||
binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4))
|
||||
}
|
||||
default:
|
||||
panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface()))
|
||||
}
|
||||
case reflect.Ptr:
|
||||
if t == bigIntType {
|
||||
var n *big.Int
|
||||
nValue := reflect.ValueOf(&n)
|
||||
nValue.Elem().Set(field)
|
||||
needed := intLength(n)
|
||||
oldLength := len(out)
|
||||
|
||||
if cap(out)-len(out) < needed {
|
||||
newOut := make([]byte, len(out), 2*(len(out)+needed))
|
||||
copy(newOut, out)
|
||||
out = newOut
|
||||
}
|
||||
out = out[:oldLength+needed]
|
||||
marshalInt(out[oldLength:], n)
|
||||
} else {
|
||||
panic(fmt.Sprintf("pointer to unknown type in field %d: %T", i, field.Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
var bigOne = big.NewInt(1)
|
||||
|
||||
func parseString(in []byte) (out, rest []byte, ok bool) {
|
||||
if len(in) < 4 {
|
||||
return
|
||||
}
|
||||
length := binary.BigEndian.Uint32(in)
|
||||
in = in[4:]
|
||||
if uint32(len(in)) < length {
|
||||
return
|
||||
}
|
||||
out = in[:length]
|
||||
rest = in[length:]
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
comma = []byte{','}
|
||||
emptyNameList = []string{}
|
||||
)
|
||||
|
||||
func parseNameList(in []byte) (out []string, rest []byte, ok bool) {
|
||||
contents, rest, ok := parseString(in)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if len(contents) == 0 {
|
||||
out = emptyNameList
|
||||
return
|
||||
}
|
||||
parts := bytes.Split(contents, comma)
|
||||
out = make([]string, len(parts))
|
||||
for i, part := range parts {
|
||||
out[i] = string(part)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) {
|
||||
contents, rest, ok := parseString(in)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
out = new(big.Int)
|
||||
|
||||
if len(contents) > 0 && contents[0]&0x80 == 0x80 {
|
||||
// This is a negative number
|
||||
notBytes := make([]byte, len(contents))
|
||||
for i := range notBytes {
|
||||
notBytes[i] = ^contents[i]
|
||||
}
|
||||
out.SetBytes(notBytes)
|
||||
out.Add(out, bigOne)
|
||||
out.Neg(out)
|
||||
} else {
|
||||
// Positive number
|
||||
out.SetBytes(contents)
|
||||
}
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
func parseUint32(in []byte) (uint32, []byte, bool) {
|
||||
if len(in) < 4 {
|
||||
return 0, nil, false
|
||||
}
|
||||
return binary.BigEndian.Uint32(in), in[4:], true
|
||||
}
|
||||
|
||||
func parseUint64(in []byte) (uint64, []byte, bool) {
|
||||
if len(in) < 8 {
|
||||
return 0, nil, false
|
||||
}
|
||||
return binary.BigEndian.Uint64(in), in[8:], true
|
||||
}
|
||||
|
||||
func intLength(n *big.Int) int {
|
||||
length := 4 /* length bytes */
|
||||
if n.Sign() < 0 {
|
||||
nMinus1 := new(big.Int).Neg(n)
|
||||
nMinus1.Sub(nMinus1, bigOne)
|
||||
bitLen := nMinus1.BitLen()
|
||||
if bitLen%8 == 0 {
|
||||
// The number will need 0xff padding
|
||||
length++
|
||||
}
|
||||
length += (bitLen + 7) / 8
|
||||
} else if n.Sign() == 0 {
|
||||
// A zero is the zero length string
|
||||
} else {
|
||||
bitLen := n.BitLen()
|
||||
if bitLen%8 == 0 {
|
||||
// The number will need 0x00 padding
|
||||
length++
|
||||
}
|
||||
length += (bitLen + 7) / 8
|
||||
}
|
||||
|
||||
return length
|
||||
}
|
||||
|
||||
func marshalUint32(to []byte, n uint32) []byte {
|
||||
binary.BigEndian.PutUint32(to, n)
|
||||
return to[4:]
|
||||
}
|
||||
|
||||
func marshalUint64(to []byte, n uint64) []byte {
|
||||
binary.BigEndian.PutUint64(to, n)
|
||||
return to[8:]
|
||||
}
|
||||
|
||||
func marshalInt(to []byte, n *big.Int) []byte {
|
||||
lengthBytes := to
|
||||
to = to[4:]
|
||||
length := 0
|
||||
|
||||
if n.Sign() < 0 {
|
||||
// A negative number has to be converted to two's-complement
|
||||
// form. So we'll subtract 1 and invert. If the
|
||||
// most-significant-bit isn't set then we'll need to pad the
|
||||
// beginning with 0xff in order to keep the number negative.
|
||||
nMinus1 := new(big.Int).Neg(n)
|
||||
nMinus1.Sub(nMinus1, bigOne)
|
||||
bytes := nMinus1.Bytes()
|
||||
for i := range bytes {
|
||||
bytes[i] ^= 0xff
|
||||
}
|
||||
if len(bytes) == 0 || bytes[0]&0x80 == 0 {
|
||||
to[0] = 0xff
|
||||
to = to[1:]
|
||||
length++
|
||||
}
|
||||
nBytes := copy(to, bytes)
|
||||
to = to[nBytes:]
|
||||
length += nBytes
|
||||
} else if n.Sign() == 0 {
|
||||
// A zero is the zero length string
|
||||
} else {
|
||||
bytes := n.Bytes()
|
||||
if len(bytes) > 0 && bytes[0]&0x80 != 0 {
|
||||
// We'll have to pad this with a 0x00 in order to
|
||||
// stop it looking like a negative number.
|
||||
to[0] = 0
|
||||
to = to[1:]
|
||||
length++
|
||||
}
|
||||
nBytes := copy(to, bytes)
|
||||
to = to[nBytes:]
|
||||
length += nBytes
|
||||
}
|
||||
|
||||
lengthBytes[0] = byte(length >> 24)
|
||||
lengthBytes[1] = byte(length >> 16)
|
||||
lengthBytes[2] = byte(length >> 8)
|
||||
lengthBytes[3] = byte(length)
|
||||
return to
|
||||
}
|
||||
|
||||
func writeInt(w io.Writer, n *big.Int) {
|
||||
length := intLength(n)
|
||||
buf := make([]byte, length)
|
||||
marshalInt(buf, n)
|
||||
w.Write(buf)
|
||||
}
|
||||
|
||||
func writeString(w io.Writer, s []byte) {
|
||||
var lengthBytes [4]byte
|
||||
lengthBytes[0] = byte(len(s) >> 24)
|
||||
lengthBytes[1] = byte(len(s) >> 16)
|
||||
lengthBytes[2] = byte(len(s) >> 8)
|
||||
lengthBytes[3] = byte(len(s))
|
||||
w.Write(lengthBytes[:])
|
||||
w.Write(s)
|
||||
}
|
||||
|
||||
func stringLength(n int) int {
|
||||
return 4 + n
|
||||
}
|
||||
|
||||
func marshalString(to []byte, s []byte) []byte {
|
||||
to[0] = byte(len(s) >> 24)
|
||||
to[1] = byte(len(s) >> 16)
|
||||
to[2] = byte(len(s) >> 8)
|
||||
to[3] = byte(len(s))
|
||||
to = to[4:]
|
||||
copy(to, s)
|
||||
return to[len(s):]
|
||||
}
|
||||
|
||||
var bigIntType = reflect.TypeOf((*big.Int)(nil))
|
||||
|
||||
// Decode a packet into its corresponding message.
|
||||
func decode(packet []byte) (interface{}, error) {
|
||||
var msg interface{}
|
||||
switch packet[0] {
|
||||
case msgDisconnect:
|
||||
msg = new(disconnectMsg)
|
||||
case msgServiceRequest:
|
||||
msg = new(serviceRequestMsg)
|
||||
case msgServiceAccept:
|
||||
msg = new(serviceAcceptMsg)
|
||||
case msgExtInfo:
|
||||
msg = new(extInfoMsg)
|
||||
case msgKexInit:
|
||||
msg = new(kexInitMsg)
|
||||
case msgKexDHInit:
|
||||
msg = new(kexDHInitMsg)
|
||||
case msgKexDHReply:
|
||||
msg = new(kexDHReplyMsg)
|
||||
case msgUserAuthRequest:
|
||||
msg = new(userAuthRequestMsg)
|
||||
case msgUserAuthSuccess:
|
||||
return new(userAuthSuccessMsg), nil
|
||||
case msgUserAuthFailure:
|
||||
msg = new(userAuthFailureMsg)
|
||||
case msgUserAuthPubKeyOk:
|
||||
msg = new(userAuthPubKeyOkMsg)
|
||||
case msgGlobalRequest:
|
||||
msg = new(globalRequestMsg)
|
||||
case msgRequestSuccess:
|
||||
msg = new(globalRequestSuccessMsg)
|
||||
case msgRequestFailure:
|
||||
msg = new(globalRequestFailureMsg)
|
||||
case msgChannelOpen:
|
||||
msg = new(channelOpenMsg)
|
||||
case msgChannelData:
|
||||
msg = new(channelDataMsg)
|
||||
case msgChannelOpenConfirm:
|
||||
msg = new(channelOpenConfirmMsg)
|
||||
case msgChannelOpenFailure:
|
||||
msg = new(channelOpenFailureMsg)
|
||||
case msgChannelWindowAdjust:
|
||||
msg = new(windowAdjustMsg)
|
||||
case msgChannelEOF:
|
||||
msg = new(channelEOFMsg)
|
||||
case msgChannelClose:
|
||||
msg = new(channelCloseMsg)
|
||||
case msgChannelRequest:
|
||||
msg = new(channelRequestMsg)
|
||||
case msgChannelSuccess:
|
||||
msg = new(channelRequestSuccessMsg)
|
||||
case msgChannelFailure:
|
||||
msg = new(channelRequestFailureMsg)
|
||||
case msgUserAuthGSSAPIToken:
|
||||
msg = new(userAuthGSSAPIToken)
|
||||
case msgUserAuthGSSAPIMIC:
|
||||
msg = new(userAuthGSSAPIMIC)
|
||||
case msgUserAuthGSSAPIErrTok:
|
||||
msg = new(userAuthGSSAPIErrTok)
|
||||
case msgUserAuthGSSAPIError:
|
||||
msg = new(userAuthGSSAPIError)
|
||||
default:
|
||||
return nil, unexpectedMessageError(0, packet[0])
|
||||
}
|
||||
if err := Unmarshal(packet, msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
var packetTypeNames = map[byte]string{
|
||||
msgDisconnect: "disconnectMsg",
|
||||
msgServiceRequest: "serviceRequestMsg",
|
||||
msgServiceAccept: "serviceAcceptMsg",
|
||||
msgExtInfo: "extInfoMsg",
|
||||
msgKexInit: "kexInitMsg",
|
||||
msgKexDHInit: "kexDHInitMsg",
|
||||
msgKexDHReply: "kexDHReplyMsg",
|
||||
msgUserAuthRequest: "userAuthRequestMsg",
|
||||
msgUserAuthSuccess: "userAuthSuccessMsg",
|
||||
msgUserAuthFailure: "userAuthFailureMsg",
|
||||
msgUserAuthPubKeyOk: "userAuthPubKeyOkMsg",
|
||||
msgGlobalRequest: "globalRequestMsg",
|
||||
msgRequestSuccess: "globalRequestSuccessMsg",
|
||||
msgRequestFailure: "globalRequestFailureMsg",
|
||||
msgChannelOpen: "channelOpenMsg",
|
||||
msgChannelData: "channelDataMsg",
|
||||
msgChannelOpenConfirm: "channelOpenConfirmMsg",
|
||||
msgChannelOpenFailure: "channelOpenFailureMsg",
|
||||
msgChannelWindowAdjust: "windowAdjustMsg",
|
||||
msgChannelEOF: "channelEOFMsg",
|
||||
msgChannelClose: "channelCloseMsg",
|
||||
msgChannelRequest: "channelRequestMsg",
|
||||
msgChannelSuccess: "channelRequestSuccessMsg",
|
||||
msgChannelFailure: "channelRequestFailureMsg",
|
||||
}
|
||||
357
vendor/github.com/tailscale/golang-x-crypto/ssh/mux.go
generated
vendored
Normal file
357
vendor/github.com/tailscale/golang-x-crypto/ssh/mux.go
generated
vendored
Normal file
@@ -0,0 +1,357 @@
|
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// debugMux, if set, causes messages in the connection protocol to be
|
||||
// logged.
|
||||
const debugMux = false
|
||||
|
||||
// chanList is a thread safe channel list.
|
||||
type chanList struct {
|
||||
// protects concurrent access to chans
|
||||
sync.Mutex
|
||||
|
||||
// chans are indexed by the local id of the channel, which the
|
||||
// other side should send in the PeersId field.
|
||||
chans []*channel
|
||||
|
||||
// This is a debugging aid: it offsets all IDs by this
|
||||
// amount. This helps distinguish otherwise identical
|
||||
// server/client muxes
|
||||
offset uint32
|
||||
}
|
||||
|
||||
// Assigns a channel ID to the given channel.
|
||||
func (c *chanList) add(ch *channel) uint32 {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
for i := range c.chans {
|
||||
if c.chans[i] == nil {
|
||||
c.chans[i] = ch
|
||||
return uint32(i) + c.offset
|
||||
}
|
||||
}
|
||||
c.chans = append(c.chans, ch)
|
||||
return uint32(len(c.chans)-1) + c.offset
|
||||
}
|
||||
|
||||
// getChan returns the channel for the given ID.
|
||||
func (c *chanList) getChan(id uint32) *channel {
|
||||
id -= c.offset
|
||||
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if id < uint32(len(c.chans)) {
|
||||
return c.chans[id]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *chanList) remove(id uint32) {
|
||||
id -= c.offset
|
||||
c.Lock()
|
||||
if id < uint32(len(c.chans)) {
|
||||
c.chans[id] = nil
|
||||
}
|
||||
c.Unlock()
|
||||
}
|
||||
|
||||
// dropAll forgets all channels it knows, returning them in a slice.
|
||||
func (c *chanList) dropAll() []*channel {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
var r []*channel
|
||||
|
||||
for _, ch := range c.chans {
|
||||
if ch == nil {
|
||||
continue
|
||||
}
|
||||
r = append(r, ch)
|
||||
}
|
||||
c.chans = nil
|
||||
return r
|
||||
}
|
||||
|
||||
// mux represents the state for the SSH connection protocol, which
|
||||
// multiplexes many channels onto a single packet transport.
|
||||
type mux struct {
|
||||
conn packetConn
|
||||
chanList chanList
|
||||
|
||||
incomingChannels chan NewChannel
|
||||
|
||||
globalSentMu sync.Mutex
|
||||
globalResponses chan interface{}
|
||||
incomingRequests chan *Request
|
||||
|
||||
errCond *sync.Cond
|
||||
err error
|
||||
}
|
||||
|
||||
// When debugging, each new chanList instantiation has a different
|
||||
// offset.
|
||||
var globalOff uint32
|
||||
|
||||
func (m *mux) Wait() error {
|
||||
m.errCond.L.Lock()
|
||||
defer m.errCond.L.Unlock()
|
||||
for m.err == nil {
|
||||
m.errCond.Wait()
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
|
||||
// newMux returns a mux that runs over the given connection.
|
||||
func newMux(p packetConn) *mux {
|
||||
m := &mux{
|
||||
conn: p,
|
||||
incomingChannels: make(chan NewChannel, chanSize),
|
||||
globalResponses: make(chan interface{}, 1),
|
||||
incomingRequests: make(chan *Request, chanSize),
|
||||
errCond: newCond(),
|
||||
}
|
||||
if debugMux {
|
||||
m.chanList.offset = atomic.AddUint32(&globalOff, 1)
|
||||
}
|
||||
|
||||
go m.loop()
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mux) sendMessage(msg interface{}) error {
|
||||
p := Marshal(msg)
|
||||
if debugMux {
|
||||
log.Printf("send global(%d): %#v", m.chanList.offset, msg)
|
||||
}
|
||||
return m.conn.writePacket(p)
|
||||
}
|
||||
|
||||
func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
|
||||
if wantReply {
|
||||
m.globalSentMu.Lock()
|
||||
defer m.globalSentMu.Unlock()
|
||||
}
|
||||
|
||||
if err := m.sendMessage(globalRequestMsg{
|
||||
Type: name,
|
||||
WantReply: wantReply,
|
||||
Data: payload,
|
||||
}); err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
if !wantReply {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
msg, ok := <-m.globalResponses
|
||||
if !ok {
|
||||
return false, nil, io.EOF
|
||||
}
|
||||
switch msg := msg.(type) {
|
||||
case *globalRequestFailureMsg:
|
||||
return false, msg.Data, nil
|
||||
case *globalRequestSuccessMsg:
|
||||
return true, msg.Data, nil
|
||||
default:
|
||||
return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// ackRequest must be called after processing a global request that
|
||||
// has WantReply set.
|
||||
func (m *mux) ackRequest(ok bool, data []byte) error {
|
||||
if ok {
|
||||
return m.sendMessage(globalRequestSuccessMsg{Data: data})
|
||||
}
|
||||
return m.sendMessage(globalRequestFailureMsg{Data: data})
|
||||
}
|
||||
|
||||
func (m *mux) Close() error {
|
||||
return m.conn.Close()
|
||||
}
|
||||
|
||||
// loop runs the connection machine. It will process packets until an
|
||||
// error is encountered. To synchronize on loop exit, use mux.Wait.
|
||||
func (m *mux) loop() {
|
||||
var err error
|
||||
for err == nil {
|
||||
err = m.onePacket()
|
||||
}
|
||||
|
||||
for _, ch := range m.chanList.dropAll() {
|
||||
ch.close()
|
||||
}
|
||||
|
||||
close(m.incomingChannels)
|
||||
close(m.incomingRequests)
|
||||
close(m.globalResponses)
|
||||
|
||||
m.conn.Close()
|
||||
|
||||
m.errCond.L.Lock()
|
||||
m.err = err
|
||||
m.errCond.Broadcast()
|
||||
m.errCond.L.Unlock()
|
||||
|
||||
if debugMux {
|
||||
log.Println("loop exit", err)
|
||||
}
|
||||
}
|
||||
|
||||
// onePacket reads and processes one packet.
|
||||
func (m *mux) onePacket() error {
|
||||
packet, err := m.conn.readPacket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if debugMux {
|
||||
if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
|
||||
log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
|
||||
} else {
|
||||
p, _ := decode(packet)
|
||||
log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
|
||||
}
|
||||
}
|
||||
|
||||
switch packet[0] {
|
||||
case msgChannelOpen:
|
||||
return m.handleChannelOpen(packet)
|
||||
case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
|
||||
return m.handleGlobalPacket(packet)
|
||||
case msgPing:
|
||||
var msg pingMsg
|
||||
if err := Unmarshal(packet, &msg); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal ping@openssh.com message: %w", err)
|
||||
}
|
||||
return m.sendMessage(pongMsg(msg))
|
||||
}
|
||||
|
||||
// assume a channel packet.
|
||||
if len(packet) < 5 {
|
||||
return parseError(packet[0])
|
||||
}
|
||||
id := binary.BigEndian.Uint32(packet[1:])
|
||||
ch := m.chanList.getChan(id)
|
||||
if ch == nil {
|
||||
return m.handleUnknownChannelPacket(id, packet)
|
||||
}
|
||||
|
||||
return ch.handlePacket(packet)
|
||||
}
|
||||
|
||||
func (m *mux) handleGlobalPacket(packet []byte) error {
|
||||
msg, err := decode(packet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *globalRequestMsg:
|
||||
m.incomingRequests <- &Request{
|
||||
Type: msg.Type,
|
||||
WantReply: msg.WantReply,
|
||||
Payload: msg.Data,
|
||||
mux: m,
|
||||
}
|
||||
case *globalRequestSuccessMsg, *globalRequestFailureMsg:
|
||||
m.globalResponses <- msg
|
||||
default:
|
||||
panic(fmt.Sprintf("not a global message %#v", msg))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleChannelOpen schedules a channel to be Accept()ed.
|
||||
func (m *mux) handleChannelOpen(packet []byte) error {
|
||||
var msg channelOpenMsg
|
||||
if err := Unmarshal(packet, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
|
||||
failMsg := channelOpenFailureMsg{
|
||||
PeersID: msg.PeersID,
|
||||
Reason: ConnectionFailed,
|
||||
Message: "invalid request",
|
||||
Language: "en_US.UTF-8",
|
||||
}
|
||||
return m.sendMessage(failMsg)
|
||||
}
|
||||
|
||||
c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
|
||||
c.remoteId = msg.PeersID
|
||||
c.maxRemotePayload = msg.MaxPacketSize
|
||||
c.remoteWin.add(msg.PeersWindow)
|
||||
m.incomingChannels <- c
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
|
||||
ch, err := m.openChannel(chanType, extra)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return ch, ch.incomingRequests, nil
|
||||
}
|
||||
|
||||
func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
|
||||
ch := m.newChannel(chanType, channelOutbound, extra)
|
||||
|
||||
ch.maxIncomingPayload = channelMaxPacket
|
||||
|
||||
open := channelOpenMsg{
|
||||
ChanType: chanType,
|
||||
PeersWindow: ch.myWindow,
|
||||
MaxPacketSize: ch.maxIncomingPayload,
|
||||
TypeSpecificData: extra,
|
||||
PeersID: ch.localId,
|
||||
}
|
||||
if err := m.sendMessage(open); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch msg := (<-ch.msg).(type) {
|
||||
case *channelOpenConfirmMsg:
|
||||
return ch, nil
|
||||
case *channelOpenFailureMsg:
|
||||
return nil, &OpenChannelError{msg.Reason, msg.Message}
|
||||
default:
|
||||
return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mux) handleUnknownChannelPacket(id uint32, packet []byte) error {
|
||||
msg, err := decode(packet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
// RFC 4254 section 5.4 says unrecognized channel requests should
|
||||
// receive a failure response.
|
||||
case *channelRequestMsg:
|
||||
if msg.WantReply {
|
||||
return m.sendMessage(channelRequestFailureMsg{
|
||||
PeersID: msg.PeersID,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("ssh: invalid channel %d", id)
|
||||
}
|
||||
}
|
||||
846
vendor/github.com/tailscale/golang-x-crypto/ssh/server.go
generated
vendored
Normal file
846
vendor/github.com/tailscale/golang-x-crypto/ssh/server.go
generated
vendored
Normal file
@@ -0,0 +1,846 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// The Permissions type holds fine-grained permissions that are
|
||||
// specific to a user or a specific authentication method for a user.
|
||||
// The Permissions value for a successful authentication attempt is
|
||||
// available in ServerConn, so it can be used to pass information from
|
||||
// the user-authentication phase to the application layer.
|
||||
type Permissions struct {
|
||||
// CriticalOptions indicate restrictions to the default
|
||||
// permissions, and are typically used in conjunction with
|
||||
// user certificates. The standard for SSH certificates
|
||||
// defines "force-command" (only allow the given command to
|
||||
// execute) and "source-address" (only allow connections from
|
||||
// the given address). The SSH package currently only enforces
|
||||
// the "source-address" critical option. It is up to server
|
||||
// implementations to enforce other critical options, such as
|
||||
// "force-command", by checking them after the SSH handshake
|
||||
// is successful. In general, SSH servers should reject
|
||||
// connections that specify critical options that are unknown
|
||||
// or not supported.
|
||||
CriticalOptions map[string]string
|
||||
|
||||
// Extensions are extra functionality that the server may
|
||||
// offer on authenticated connections. Lack of support for an
|
||||
// extension does not preclude authenticating a user. Common
|
||||
// extensions are "permit-agent-forwarding",
|
||||
// "permit-X11-forwarding". The Go SSH library currently does
|
||||
// not act on any extension, and it is up to server
|
||||
// implementations to honor them. Extensions can be used to
|
||||
// pass data from the authentication callbacks to the server
|
||||
// application layer.
|
||||
Extensions map[string]string
|
||||
}
|
||||
|
||||
type GSSAPIWithMICConfig struct {
|
||||
// AllowLogin, must be set, is called when gssapi-with-mic
|
||||
// authentication is selected (RFC 4462 section 3). The srcName is from the
|
||||
// results of the GSS-API authentication. The format is username@DOMAIN.
|
||||
// GSSAPI just guarantees to the server who the user is, but not if they can log in, and with what permissions.
|
||||
// This callback is called after the user identity is established with GSSAPI to decide if the user can login with
|
||||
// which permissions. If the user is allowed to login, it should return a nil error.
|
||||
AllowLogin func(conn ConnMetadata, srcName string) (*Permissions, error)
|
||||
|
||||
// Server must be set. It's the implementation
|
||||
// of the GSSAPIServer interface. See GSSAPIServer interface for details.
|
||||
Server GSSAPIServer
|
||||
}
|
||||
|
||||
// ServerConfig holds server specific configuration data.
|
||||
type ServerConfig struct {
|
||||
// Config contains configuration shared between client and server.
|
||||
Config
|
||||
|
||||
// PublicKeyAuthAlgorithms specifies the supported client public key
|
||||
// authentication algorithms. Note that this should not include certificate
|
||||
// types since those use the underlying algorithm. This list is sent to the
|
||||
// client if it supports the server-sig-algs extension. Order is irrelevant.
|
||||
// If unspecified then a default set of algorithms is used.
|
||||
PublicKeyAuthAlgorithms []string
|
||||
|
||||
hostKeys []Signer
|
||||
|
||||
// NoClientAuth is true if clients are allowed to connect without
|
||||
// authenticating.
|
||||
// To determine NoClientAuth at runtime, set NoClientAuth to true
|
||||
// and the optional NoClientAuthCallback to a non-nil value.
|
||||
NoClientAuth bool
|
||||
|
||||
// NoClientAuthCallback, if non-nil, is called when a user
|
||||
// attempts to authenticate with auth method "none".
|
||||
// NoClientAuth must also be set to true for this be used, or
|
||||
// this func is unused.
|
||||
// If the function returns ErrDenied, the connection is terminated.
|
||||
NoClientAuthCallback func(ConnMetadata) (*Permissions, error)
|
||||
|
||||
// MaxAuthTries specifies the maximum number of authentication attempts
|
||||
// permitted per connection. If set to a negative number, the number of
|
||||
// attempts are unlimited. If set to zero, the number of attempts are limited
|
||||
// to 6.
|
||||
MaxAuthTries int
|
||||
|
||||
// PasswordCallback, if non-nil, is called when a user
|
||||
// attempts to authenticate using a password.
|
||||
// If the function returns ErrDenied, the connection is terminated.
|
||||
PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)
|
||||
|
||||
// PublicKeyCallback, if non-nil, is called when a client
|
||||
// offers a public key for authentication. It must return a nil error
|
||||
// if the given public key can be used to authenticate the
|
||||
// given user. For example, see CertChecker.Authenticate. A
|
||||
// call to this function does not guarantee that the key
|
||||
// offered is in fact used to authenticate. To record any data
|
||||
// depending on the public key, store it inside a
|
||||
// Permissions.Extensions entry.
|
||||
// If the function returns ErrDenied, the connection is terminated.
|
||||
PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
|
||||
|
||||
// KeyboardInteractiveCallback, if non-nil, is called when
|
||||
// keyboard-interactive authentication is selected (RFC
|
||||
// 4256). The client object's Challenge function should be
|
||||
// used to query the user. The callback may offer multiple
|
||||
// Challenge rounds. To avoid information leaks, the client
|
||||
// should be presented a challenge even if the user is
|
||||
// unknown.
|
||||
// If the function returns ErrDenied, the connection is terminated.
|
||||
KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error)
|
||||
|
||||
// AuthLogCallback, if non-nil, is called to log all authentication
|
||||
// attempts.
|
||||
AuthLogCallback func(conn ConnMetadata, method string, err error)
|
||||
|
||||
// NextAuthMethodCallback, if non-nil, is called whenever an authentication
|
||||
// method fails. It's called after AuthLogCallback, if set. The return
|
||||
// values are the SSH auth types (from
|
||||
// https://www.iana.org/assignments/ssh-parameters/ssh-parameters.xhtml#ssh-parameters-10
|
||||
// such as "password", "publickey", "keyboard-interactive", etc)
|
||||
// to suggest that the client try next. If empty, authentication fails.
|
||||
NextAuthMethodCallback func(conn ConnMetadata, prevErrors []error) []string
|
||||
|
||||
// ServerVersion is the version identification string to announce in
|
||||
// the public handshake.
|
||||
// If empty, a reasonable default is used.
|
||||
// Note that RFC 4253 section 4.2 requires that this string start with
|
||||
// "SSH-2.0-".
|
||||
ServerVersion string
|
||||
|
||||
// BannerCallback, if present, is called and the return string is sent to
|
||||
// the client after key exchange completed but before authentication.
|
||||
BannerCallback func(conn ConnMetadata) string
|
||||
|
||||
// GSSAPIWithMICConfig includes gssapi server and callback, which if both non-nil, is used
|
||||
// when gssapi-with-mic authentication is selected (RFC 4462 section 3).
|
||||
GSSAPIWithMICConfig *GSSAPIWithMICConfig
|
||||
}
|
||||
|
||||
// AddHostKey adds a private key as a host key. If an existing host
|
||||
// key exists with the same public key format, it is replaced. Each server
|
||||
// config must have at least one host key.
|
||||
func (s *ServerConfig) AddHostKey(key Signer) {
|
||||
for i, k := range s.hostKeys {
|
||||
if k.PublicKey().Type() == key.PublicKey().Type() {
|
||||
s.hostKeys[i] = key
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.hostKeys = append(s.hostKeys, key)
|
||||
}
|
||||
|
||||
// cachedPubKey contains the results of querying whether a public key is
|
||||
// acceptable for a user.
|
||||
type cachedPubKey struct {
|
||||
user string
|
||||
pubKeyData []byte
|
||||
result error
|
||||
perms *Permissions
|
||||
}
|
||||
|
||||
const maxCachedPubKeys = 16
|
||||
|
||||
// pubKeyCache caches tests for public keys. Since SSH clients
|
||||
// will query whether a public key is acceptable before attempting to
|
||||
// authenticate with it, we end up with duplicate queries for public
|
||||
// key validity. The cache only applies to a single ServerConn.
|
||||
type pubKeyCache struct {
|
||||
keys []cachedPubKey
|
||||
}
|
||||
|
||||
// get returns the result for a given user/algo/key tuple.
|
||||
func (c *pubKeyCache) get(user string, pubKeyData []byte) (cachedPubKey, bool) {
|
||||
for _, k := range c.keys {
|
||||
if k.user == user && bytes.Equal(k.pubKeyData, pubKeyData) {
|
||||
return k, true
|
||||
}
|
||||
}
|
||||
return cachedPubKey{}, false
|
||||
}
|
||||
|
||||
// add adds the given tuple to the cache.
|
||||
func (c *pubKeyCache) add(candidate cachedPubKey) {
|
||||
if len(c.keys) < maxCachedPubKeys {
|
||||
c.keys = append(c.keys, candidate)
|
||||
}
|
||||
}
|
||||
|
||||
// ServerConn is an authenticated SSH connection, as seen from the
|
||||
// server
|
||||
type ServerConn struct {
|
||||
Conn
|
||||
|
||||
// If the succeeding authentication callback returned a
|
||||
// non-nil Permissions pointer, it is stored here.
|
||||
Permissions *Permissions
|
||||
}
|
||||
|
||||
// NewServerConn starts a new SSH server with c as the underlying
|
||||
// transport. It starts with a handshake and, if the handshake is
|
||||
// unsuccessful, it closes the connection and returns an error. The
|
||||
// Request and NewChannel channels must be serviced, or the connection
|
||||
// will hang.
|
||||
//
|
||||
// The returned error may be of type *ServerAuthError for
|
||||
// authentication errors.
|
||||
func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) {
|
||||
fullConf := *config
|
||||
fullConf.SetDefaults()
|
||||
if fullConf.MaxAuthTries == 0 {
|
||||
fullConf.MaxAuthTries = 6
|
||||
}
|
||||
if len(fullConf.PublicKeyAuthAlgorithms) == 0 {
|
||||
fullConf.PublicKeyAuthAlgorithms = supportedPubKeyAuthAlgos
|
||||
} else {
|
||||
for _, algo := range fullConf.PublicKeyAuthAlgorithms {
|
||||
if !contains(supportedPubKeyAuthAlgos, algo) {
|
||||
c.Close()
|
||||
return nil, nil, nil, fmt.Errorf("ssh: unsupported public key authentication algorithm %s", algo)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Check if the config contains any unsupported key exchanges
|
||||
for _, kex := range fullConf.KeyExchanges {
|
||||
if _, ok := serverForbiddenKexAlgos[kex]; ok {
|
||||
c.Close()
|
||||
return nil, nil, nil, fmt.Errorf("ssh: unsupported key exchange %s for server", kex)
|
||||
}
|
||||
}
|
||||
|
||||
s := &connection{
|
||||
sshConn: sshConn{conn: c},
|
||||
}
|
||||
perms, err := s.serverHandshake(&fullConf)
|
||||
if err != nil {
|
||||
c.Close()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
return &ServerConn{s, perms}, s.mux.incomingChannels, s.mux.incomingRequests, nil
|
||||
}
|
||||
|
||||
// signAndMarshal signs the data with the appropriate algorithm,
|
||||
// and serializes the result in SSH wire format. algo is the negotiate
|
||||
// algorithm and may be a certificate type.
|
||||
func signAndMarshal(k AlgorithmSigner, rand io.Reader, data []byte, algo string) ([]byte, error) {
|
||||
sig, err := k.SignWithAlgorithm(rand, data, underlyingAlgo(algo))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return Marshal(sig), nil
|
||||
}
|
||||
|
||||
// handshake performs key exchange and user authentication.
|
||||
func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) {
|
||||
if len(config.hostKeys) == 0 {
|
||||
return nil, errors.New("ssh: server has no host keys")
|
||||
}
|
||||
|
||||
if !config.NoClientAuth && config.PasswordCallback == nil && config.PublicKeyCallback == nil &&
|
||||
config.KeyboardInteractiveCallback == nil && (config.GSSAPIWithMICConfig == nil ||
|
||||
config.GSSAPIWithMICConfig.AllowLogin == nil || config.GSSAPIWithMICConfig.Server == nil) {
|
||||
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
|
||||
}
|
||||
|
||||
if config.ServerVersion != "" {
|
||||
s.serverVersion = []byte(config.ServerVersion)
|
||||
} else {
|
||||
s.serverVersion = []byte(packageVersion)
|
||||
}
|
||||
var err error
|
||||
s.clientVersion, err = exchangeVersions(s.sshConn.conn, s.serverVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */)
|
||||
s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config)
|
||||
|
||||
if err := s.transport.waitSession(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// We just did the key change, so the session ID is established.
|
||||
s.sessionID = s.transport.getSessionID()
|
||||
|
||||
var packet []byte
|
||||
if packet, err = s.transport.readPacket(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var serviceRequest serviceRequestMsg
|
||||
if err = Unmarshal(packet, &serviceRequest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if serviceRequest.Service != serviceUserAuth {
|
||||
return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
|
||||
}
|
||||
serviceAccept := serviceAcceptMsg{
|
||||
Service: serviceUserAuth,
|
||||
}
|
||||
if err := s.transport.writePacket(Marshal(&serviceAccept)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
perms, err := s.serverAuthenticate(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.mux = newMux(s.transport)
|
||||
return perms, err
|
||||
}
|
||||
|
||||
func checkSourceAddress(addr net.Addr, sourceAddrs string) error {
|
||||
if addr == nil {
|
||||
return errors.New("ssh: no address known for client, but source-address match required")
|
||||
}
|
||||
|
||||
tcpAddr, ok := addr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr)
|
||||
}
|
||||
|
||||
for _, sourceAddr := range strings.Split(sourceAddrs, ",") {
|
||||
if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil {
|
||||
if allowedIP.Equal(tcpAddr.IP) {
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
_, ipNet, err := net.ParseCIDR(sourceAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err)
|
||||
}
|
||||
|
||||
if ipNet.Contains(tcpAddr.IP) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr)
|
||||
}
|
||||
|
||||
func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, token []byte, s *connection,
|
||||
sessionID []byte, userAuthReq userAuthRequestMsg) (authErr error, perms *Permissions, err error) {
|
||||
gssAPIServer := gssapiConfig.Server
|
||||
defer gssAPIServer.DeleteSecContext()
|
||||
var srcName string
|
||||
for {
|
||||
var (
|
||||
outToken []byte
|
||||
needContinue bool
|
||||
)
|
||||
outToken, srcName, needContinue, err = gssAPIServer.AcceptSecContext(token)
|
||||
if err != nil {
|
||||
return err, nil, nil
|
||||
}
|
||||
if len(outToken) != 0 {
|
||||
if err := s.transport.writePacket(Marshal(&userAuthGSSAPIToken{
|
||||
Token: outToken,
|
||||
})); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
if !needContinue {
|
||||
break
|
||||
}
|
||||
packet, err := s.transport.readPacket()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
userAuthGSSAPITokenReq := &userAuthGSSAPIToken{}
|
||||
if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
token = userAuthGSSAPITokenReq.Token
|
||||
}
|
||||
packet, err := s.transport.readPacket()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
userAuthGSSAPIMICReq := &userAuthGSSAPIMIC{}
|
||||
if err := Unmarshal(packet, userAuthGSSAPIMICReq); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
mic := buildMIC(string(sessionID), userAuthReq.User, userAuthReq.Service, userAuthReq.Method)
|
||||
if err := gssAPIServer.VerifyMIC(mic, userAuthGSSAPIMICReq.MIC); err != nil {
|
||||
return err, nil, nil
|
||||
}
|
||||
perms, authErr = gssapiConfig.AllowLogin(s, srcName)
|
||||
return authErr, perms, nil
|
||||
}
|
||||
|
||||
// isAlgoCompatible checks if the signature format is compatible with the
|
||||
// selected algorithm taking into account edge cases that occur with old
|
||||
// clients.
|
||||
func isAlgoCompatible(algo, sigFormat string) bool {
|
||||
// Compatibility for old clients.
|
||||
//
|
||||
// For certificate authentication with OpenSSH 7.2-7.7 signature format can
|
||||
// be rsa-sha2-256 or rsa-sha2-512 for the algorithm
|
||||
// ssh-rsa-cert-v01@openssh.com.
|
||||
//
|
||||
// With gpg-agent < 2.2.6 the algorithm can be rsa-sha2-256 or rsa-sha2-512
|
||||
// for signature format ssh-rsa.
|
||||
if isRSA(algo) && isRSA(sigFormat) {
|
||||
return true
|
||||
}
|
||||
// Standard case: the underlying algorithm must match the signature format.
|
||||
return underlyingAlgo(algo) == sigFormat
|
||||
}
|
||||
|
||||
// ServerAuthError represents server authentication errors and is
|
||||
// sometimes returned by NewServerConn. It appends any authentication
|
||||
// errors that may occur, and is returned if all of the authentication
|
||||
// methods provided by the user failed to authenticate.
|
||||
type ServerAuthError struct {
|
||||
// Errors contains authentication errors returned by the authentication
|
||||
// callback methods. The first entry is typically ErrNoAuth.
|
||||
Errors []error
|
||||
}
|
||||
|
||||
func (l ServerAuthError) Error() string {
|
||||
var errs []string
|
||||
for _, err := range l.Errors {
|
||||
errs = append(errs, err.Error())
|
||||
}
|
||||
return "[" + strings.Join(errs, ", ") + "]"
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrDenied can be returned from an authentication callback to inform the
|
||||
// client that access is denied and that no further attempt will be accepted
|
||||
// on the connection.
|
||||
ErrDenied = errors.New("ssh: access denied")
|
||||
|
||||
// ErrNoAuth is the error value returned if no
|
||||
// authentication method has been passed yet. This happens as a normal
|
||||
// part of the authentication loop, since the client first tries
|
||||
// 'none' authentication to discover available methods.
|
||||
// It is returned in ServerAuthError.Errors from NewServerConn.
|
||||
ErrNoAuth = errors.New("ssh: no auth passed yet")
|
||||
)
|
||||
|
||||
func (s *connection) SendAuthBanner(msg string) error {
|
||||
if s.mux != nil {
|
||||
return errors.New("ssh: SendAuthBanner called after handshake")
|
||||
}
|
||||
return s.transport.writePacket(Marshal(&userAuthBannerMsg{
|
||||
Message: msg,
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
|
||||
sessionID := s.transport.getSessionID()
|
||||
var cache pubKeyCache
|
||||
var perms *Permissions
|
||||
|
||||
authFailures := 0
|
||||
var authErrs []error
|
||||
var displayedBanner bool
|
||||
|
||||
userAuthLoop:
|
||||
for {
|
||||
if authFailures >= config.MaxAuthTries && config.MaxAuthTries > 0 {
|
||||
discMsg := &disconnectMsg{
|
||||
Reason: 2,
|
||||
Message: "too many authentication failures",
|
||||
}
|
||||
|
||||
if err := s.transport.writePacket(Marshal(discMsg)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, discMsg
|
||||
}
|
||||
|
||||
var userAuthReq userAuthRequestMsg
|
||||
if packet, err := s.transport.readPacket(); err != nil {
|
||||
if err == io.EOF {
|
||||
return nil, &ServerAuthError{Errors: authErrs}
|
||||
}
|
||||
return nil, err
|
||||
} else if err = Unmarshal(packet, &userAuthReq); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if userAuthReq.Service != serviceSSH {
|
||||
return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
|
||||
}
|
||||
|
||||
s.user = userAuthReq.User
|
||||
|
||||
if !displayedBanner && config.BannerCallback != nil {
|
||||
displayedBanner = true
|
||||
msg := config.BannerCallback(s)
|
||||
if msg != "" {
|
||||
if err := s.SendAuthBanner(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
perms = nil
|
||||
authErr := ErrNoAuth
|
||||
|
||||
switch userAuthReq.Method {
|
||||
case "none":
|
||||
if config.NoClientAuth {
|
||||
if config.NoClientAuthCallback != nil {
|
||||
perms, authErr = config.NoClientAuthCallback(s)
|
||||
} else {
|
||||
authErr = nil
|
||||
}
|
||||
}
|
||||
|
||||
// allow initial attempt of 'none' without penalty
|
||||
if authFailures == 0 {
|
||||
authFailures--
|
||||
}
|
||||
case "password":
|
||||
if config.PasswordCallback == nil {
|
||||
authErr = errors.New("ssh: password auth not configured")
|
||||
break
|
||||
}
|
||||
payload := userAuthReq.Payload
|
||||
if len(payload) < 1 || payload[0] != 0 {
|
||||
return nil, parseError(msgUserAuthRequest)
|
||||
}
|
||||
payload = payload[1:]
|
||||
password, payload, ok := parseString(payload)
|
||||
if !ok || len(payload) > 0 {
|
||||
return nil, parseError(msgUserAuthRequest)
|
||||
}
|
||||
|
||||
perms, authErr = config.PasswordCallback(s, password)
|
||||
case "keyboard-interactive":
|
||||
if config.KeyboardInteractiveCallback == nil {
|
||||
authErr = errors.New("ssh: keyboard-interactive auth not configured")
|
||||
break
|
||||
}
|
||||
|
||||
prompter := &sshClientKeyboardInteractive{s}
|
||||
perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge)
|
||||
case "publickey":
|
||||
if config.PublicKeyCallback == nil {
|
||||
authErr = errors.New("ssh: publickey auth not configured")
|
||||
break
|
||||
}
|
||||
payload := userAuthReq.Payload
|
||||
if len(payload) < 1 {
|
||||
return nil, parseError(msgUserAuthRequest)
|
||||
}
|
||||
isQuery := payload[0] == 0
|
||||
payload = payload[1:]
|
||||
algoBytes, payload, ok := parseString(payload)
|
||||
if !ok {
|
||||
return nil, parseError(msgUserAuthRequest)
|
||||
}
|
||||
algo := string(algoBytes)
|
||||
if !contains(config.PublicKeyAuthAlgorithms, underlyingAlgo(algo)) {
|
||||
authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo)
|
||||
break
|
||||
}
|
||||
|
||||
pubKeyData, payload, ok := parseString(payload)
|
||||
if !ok {
|
||||
return nil, parseError(msgUserAuthRequest)
|
||||
}
|
||||
|
||||
pubKey, err := ParsePublicKey(pubKeyData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
candidate, ok := cache.get(s.user, pubKeyData)
|
||||
if !ok {
|
||||
candidate.user = s.user
|
||||
candidate.pubKeyData = pubKeyData
|
||||
candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey)
|
||||
if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" {
|
||||
candidate.result = checkSourceAddress(
|
||||
s.RemoteAddr(),
|
||||
candidate.perms.CriticalOptions[sourceAddressCriticalOption])
|
||||
}
|
||||
cache.add(candidate)
|
||||
}
|
||||
|
||||
if isQuery {
|
||||
// The client can query if the given public key
|
||||
// would be okay.
|
||||
|
||||
if len(payload) > 0 {
|
||||
return nil, parseError(msgUserAuthRequest)
|
||||
}
|
||||
|
||||
if candidate.result == nil {
|
||||
okMsg := userAuthPubKeyOkMsg{
|
||||
Algo: algo,
|
||||
PubKey: pubKeyData,
|
||||
}
|
||||
if err = s.transport.writePacket(Marshal(&okMsg)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue userAuthLoop
|
||||
}
|
||||
authErr = candidate.result
|
||||
} else {
|
||||
sig, payload, ok := parseSignature(payload)
|
||||
if !ok || len(payload) > 0 {
|
||||
return nil, parseError(msgUserAuthRequest)
|
||||
}
|
||||
// Ensure the declared public key algo is compatible with the
|
||||
// decoded one. This check will ensure we don't accept e.g.
|
||||
// ssh-rsa-cert-v01@openssh.com algorithm with ssh-rsa public
|
||||
// key type. The algorithm and public key type must be
|
||||
// consistent: both must be certificate algorithms, or neither.
|
||||
if !contains(algorithmsForKeyFormat(pubKey.Type()), algo) {
|
||||
authErr = fmt.Errorf("ssh: public key type %q not compatible with selected algorithm %q",
|
||||
pubKey.Type(), algo)
|
||||
break
|
||||
}
|
||||
// Ensure the public key algo and signature algo
|
||||
// are supported. Compare the private key
|
||||
// algorithm name that corresponds to algo with
|
||||
// sig.Format. This is usually the same, but
|
||||
// for certs, the names differ.
|
||||
if !contains(config.PublicKeyAuthAlgorithms, sig.Format) {
|
||||
authErr = fmt.Errorf("ssh: algorithm %q not accepted", sig.Format)
|
||||
break
|
||||
}
|
||||
if !isAlgoCompatible(algo, sig.Format) {
|
||||
authErr = fmt.Errorf("ssh: signature %q not compatible with selected algorithm %q", sig.Format, algo)
|
||||
break
|
||||
}
|
||||
|
||||
signedData := buildDataSignedForAuth(sessionID, userAuthReq, algo, pubKeyData)
|
||||
|
||||
if err := pubKey.Verify(signedData, sig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authErr = candidate.result
|
||||
perms = candidate.perms
|
||||
}
|
||||
case "gssapi-with-mic":
|
||||
if config.GSSAPIWithMICConfig == nil {
|
||||
authErr = errors.New("ssh: gssapi-with-mic auth not configured")
|
||||
break
|
||||
}
|
||||
gssapiConfig := config.GSSAPIWithMICConfig
|
||||
userAuthRequestGSSAPI, err := parseGSSAPIPayload(userAuthReq.Payload)
|
||||
if err != nil {
|
||||
return nil, parseError(msgUserAuthRequest)
|
||||
}
|
||||
// OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication.
|
||||
if userAuthRequestGSSAPI.N == 0 {
|
||||
authErr = fmt.Errorf("ssh: Mechanism negotiation is not supported")
|
||||
break
|
||||
}
|
||||
var i uint32
|
||||
present := false
|
||||
for i = 0; i < userAuthRequestGSSAPI.N; i++ {
|
||||
if userAuthRequestGSSAPI.OIDS[i].Equal(krb5Mesh) {
|
||||
present = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !present {
|
||||
authErr = fmt.Errorf("ssh: GSSAPI authentication must use the Kerberos V5 mechanism")
|
||||
break
|
||||
}
|
||||
// Initial server response, see RFC 4462 section 3.3.
|
||||
if err := s.transport.writePacket(Marshal(&userAuthGSSAPIResponse{
|
||||
SupportMech: krb5OID,
|
||||
})); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Exchange token, see RFC 4462 section 3.4.
|
||||
packet, err := s.transport.readPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userAuthGSSAPITokenReq := &userAuthGSSAPIToken{}
|
||||
if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authErr, perms, err = gssExchangeToken(gssapiConfig, userAuthGSSAPITokenReq.Token, s, sessionID,
|
||||
userAuthReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method)
|
||||
}
|
||||
|
||||
authErrs = append(authErrs, authErr)
|
||||
|
||||
if config.AuthLogCallback != nil {
|
||||
config.AuthLogCallback(s, userAuthReq.Method, authErr)
|
||||
}
|
||||
|
||||
if authErr == nil {
|
||||
break userAuthLoop
|
||||
}
|
||||
if errors.Is(authErr, ErrDenied) {
|
||||
var failureMsg userAuthFailureMsg
|
||||
if config.NextAuthMethodCallback != nil {
|
||||
// We also call the NextAuthMethodCallback here, even though
|
||||
// we're hanging up by returning out of this func, as OpenSSH
|
||||
// will render this last one in parentheses in the rejection
|
||||
// error message which is a nice little place to put the product
|
||||
// name, auth type, or a hint about why it might've failed.
|
||||
failureMsg.Methods = config.NextAuthMethodCallback(s, authErrs)
|
||||
}
|
||||
if err := s.transport.writePacket(Marshal(failureMsg)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, authErr
|
||||
}
|
||||
|
||||
authFailures++
|
||||
if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries {
|
||||
// If we have hit the max attempts, don't bother sending the
|
||||
// final SSH_MSG_USERAUTH_FAILURE message, since there are
|
||||
// no more authentication methods which can be attempted,
|
||||
// and this message may cause the client to re-attempt
|
||||
// authentication while we send the disconnect message.
|
||||
// Continue, and trigger the disconnect at the start of
|
||||
// the loop.
|
||||
//
|
||||
// The SSH specification is somewhat confusing about this,
|
||||
// RFC 4252 Section 5.1 requires each authentication failure
|
||||
// be responded to with a respective SSH_MSG_USERAUTH_FAILURE
|
||||
// message, but Section 4 says the server should disconnect
|
||||
// after some number of attempts, but it isn't explicit which
|
||||
// message should take precedence (i.e. should there be a failure
|
||||
// message than a disconnect message, or if we are going to
|
||||
// disconnect, should we only send that message.)
|
||||
//
|
||||
// Either way, OpenSSH disconnects immediately after the last
|
||||
// failed authnetication attempt, and given they are typically
|
||||
// considered the golden implementation it seems reasonable
|
||||
// to match that behavior.
|
||||
continue
|
||||
}
|
||||
|
||||
var failureMsg userAuthFailureMsg
|
||||
if config.NextAuthMethodCallback != nil {
|
||||
failureMsg.Methods = config.NextAuthMethodCallback(s, authErrs)
|
||||
} else {
|
||||
if config.PasswordCallback != nil {
|
||||
failureMsg.Methods = append(failureMsg.Methods, "password")
|
||||
}
|
||||
if config.PublicKeyCallback != nil {
|
||||
failureMsg.Methods = append(failureMsg.Methods, "publickey")
|
||||
}
|
||||
if config.KeyboardInteractiveCallback != nil {
|
||||
failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive")
|
||||
}
|
||||
if config.GSSAPIWithMICConfig != nil && config.GSSAPIWithMICConfig.Server != nil &&
|
||||
config.GSSAPIWithMICConfig.AllowLogin != nil {
|
||||
failureMsg.Methods = append(failureMsg.Methods, "gssapi-with-mic")
|
||||
}
|
||||
}
|
||||
|
||||
if len(failureMsg.Methods) == 0 {
|
||||
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
|
||||
}
|
||||
|
||||
if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return perms, nil
|
||||
}
|
||||
|
||||
// sshClientKeyboardInteractive implements a ClientKeyboardInteractive by
|
||||
// asking the client on the other side of a ServerConn.
|
||||
type sshClientKeyboardInteractive struct {
|
||||
*connection
|
||||
}
|
||||
|
||||
func (c *sshClientKeyboardInteractive) Challenge(name, instruction string, questions []string, echos []bool) (answers []string, err error) {
|
||||
if len(questions) != len(echos) {
|
||||
return nil, errors.New("ssh: echos and questions must have equal length")
|
||||
}
|
||||
|
||||
var prompts []byte
|
||||
for i := range questions {
|
||||
prompts = appendString(prompts, questions[i])
|
||||
prompts = appendBool(prompts, echos[i])
|
||||
}
|
||||
|
||||
if err := c.transport.writePacket(Marshal(&userAuthInfoRequestMsg{
|
||||
Name: name,
|
||||
Instruction: instruction,
|
||||
NumPrompts: uint32(len(questions)),
|
||||
Prompts: prompts,
|
||||
})); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
packet, err := c.transport.readPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if packet[0] != msgUserAuthInfoResponse {
|
||||
return nil, unexpectedMessageError(msgUserAuthInfoResponse, packet[0])
|
||||
}
|
||||
packet = packet[1:]
|
||||
|
||||
n, packet, ok := parseUint32(packet)
|
||||
if !ok || int(n) != len(questions) {
|
||||
return nil, parseError(msgUserAuthInfoResponse)
|
||||
}
|
||||
|
||||
for i := uint32(0); i < n; i++ {
|
||||
ans, rest, ok := parseString(packet)
|
||||
if !ok {
|
||||
return nil, parseError(msgUserAuthInfoResponse)
|
||||
}
|
||||
|
||||
answers = append(answers, string(ans))
|
||||
packet = rest
|
||||
}
|
||||
if len(packet) != 0 {
|
||||
return nil, errors.New("ssh: junk at end of message")
|
||||
}
|
||||
|
||||
return answers, nil
|
||||
}
|
||||
647
vendor/github.com/tailscale/golang-x-crypto/ssh/session.go
generated
vendored
Normal file
647
vendor/github.com/tailscale/golang-x-crypto/ssh/session.go
generated
vendored
Normal file
@@ -0,0 +1,647 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
// Session implements an interactive session described in
|
||||
// "RFC 4254, section 6".
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Signal string
|
||||
|
||||
// POSIX signals as listed in RFC 4254 Section 6.10.
|
||||
const (
|
||||
SIGABRT Signal = "ABRT"
|
||||
SIGALRM Signal = "ALRM"
|
||||
SIGFPE Signal = "FPE"
|
||||
SIGHUP Signal = "HUP"
|
||||
SIGILL Signal = "ILL"
|
||||
SIGINT Signal = "INT"
|
||||
SIGKILL Signal = "KILL"
|
||||
SIGPIPE Signal = "PIPE"
|
||||
SIGQUIT Signal = "QUIT"
|
||||
SIGSEGV Signal = "SEGV"
|
||||
SIGTERM Signal = "TERM"
|
||||
SIGUSR1 Signal = "USR1"
|
||||
SIGUSR2 Signal = "USR2"
|
||||
)
|
||||
|
||||
var signals = map[Signal]int{
|
||||
SIGABRT: 6,
|
||||
SIGALRM: 14,
|
||||
SIGFPE: 8,
|
||||
SIGHUP: 1,
|
||||
SIGILL: 4,
|
||||
SIGINT: 2,
|
||||
SIGKILL: 9,
|
||||
SIGPIPE: 13,
|
||||
SIGQUIT: 3,
|
||||
SIGSEGV: 11,
|
||||
SIGTERM: 15,
|
||||
}
|
||||
|
||||
type TerminalModes map[uint8]uint32
|
||||
|
||||
// POSIX terminal mode flags as listed in RFC 4254 Section 8.
|
||||
const (
|
||||
tty_OP_END = 0
|
||||
VINTR = 1
|
||||
VQUIT = 2
|
||||
VERASE = 3
|
||||
VKILL = 4
|
||||
VEOF = 5
|
||||
VEOL = 6
|
||||
VEOL2 = 7
|
||||
VSTART = 8
|
||||
VSTOP = 9
|
||||
VSUSP = 10
|
||||
VDSUSP = 11
|
||||
VREPRINT = 12
|
||||
VWERASE = 13
|
||||
VLNEXT = 14
|
||||
VFLUSH = 15
|
||||
VSWTCH = 16
|
||||
VSTATUS = 17
|
||||
VDISCARD = 18
|
||||
IGNPAR = 30
|
||||
PARMRK = 31
|
||||
INPCK = 32
|
||||
ISTRIP = 33
|
||||
INLCR = 34
|
||||
IGNCR = 35
|
||||
ICRNL = 36
|
||||
IUCLC = 37
|
||||
IXON = 38
|
||||
IXANY = 39
|
||||
IXOFF = 40
|
||||
IMAXBEL = 41
|
||||
IUTF8 = 42 // RFC 8160
|
||||
ISIG = 50
|
||||
ICANON = 51
|
||||
XCASE = 52
|
||||
ECHO = 53
|
||||
ECHOE = 54
|
||||
ECHOK = 55
|
||||
ECHONL = 56
|
||||
NOFLSH = 57
|
||||
TOSTOP = 58
|
||||
IEXTEN = 59
|
||||
ECHOCTL = 60
|
||||
ECHOKE = 61
|
||||
PENDIN = 62
|
||||
OPOST = 70
|
||||
OLCUC = 71
|
||||
ONLCR = 72
|
||||
OCRNL = 73
|
||||
ONOCR = 74
|
||||
ONLRET = 75
|
||||
CS7 = 90
|
||||
CS8 = 91
|
||||
PARENB = 92
|
||||
PARODD = 93
|
||||
TTY_OP_ISPEED = 128
|
||||
TTY_OP_OSPEED = 129
|
||||
)
|
||||
|
||||
// A Session represents a connection to a remote command or shell.
|
||||
type Session struct {
|
||||
// Stdin specifies the remote process's standard input.
|
||||
// If Stdin is nil, the remote process reads from an empty
|
||||
// bytes.Buffer.
|
||||
Stdin io.Reader
|
||||
|
||||
// Stdout and Stderr specify the remote process's standard
|
||||
// output and error.
|
||||
//
|
||||
// If either is nil, Run connects the corresponding file
|
||||
// descriptor to an instance of io.Discard. There is a
|
||||
// fixed amount of buffering that is shared for the two streams.
|
||||
// If either blocks it may eventually cause the remote
|
||||
// command to block.
|
||||
Stdout io.Writer
|
||||
Stderr io.Writer
|
||||
|
||||
ch Channel // the channel backing this session
|
||||
started bool // true once Start, Run or Shell is invoked.
|
||||
copyFuncs []func() error
|
||||
errors chan error // one send per copyFunc
|
||||
|
||||
// true if pipe method is active
|
||||
stdinpipe, stdoutpipe, stderrpipe bool
|
||||
|
||||
// stdinPipeWriter is non-nil if StdinPipe has not been called
|
||||
// and Stdin was specified by the user; it is the write end of
|
||||
// a pipe connecting Session.Stdin to the stdin channel.
|
||||
stdinPipeWriter io.WriteCloser
|
||||
|
||||
exitStatus chan error
|
||||
}
|
||||
|
||||
// SendRequest sends an out-of-band channel request on the SSH channel
|
||||
// underlying the session.
|
||||
func (s *Session) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
|
||||
return s.ch.SendRequest(name, wantReply, payload)
|
||||
}
|
||||
|
||||
func (s *Session) Close() error {
|
||||
return s.ch.Close()
|
||||
}
|
||||
|
||||
// RFC 4254 Section 6.4.
|
||||
type setenvRequest struct {
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
// Setenv sets an environment variable that will be applied to any
|
||||
// command executed by Shell or Run.
|
||||
func (s *Session) Setenv(name, value string) error {
|
||||
msg := setenvRequest{
|
||||
Name: name,
|
||||
Value: value,
|
||||
}
|
||||
ok, err := s.ch.SendRequest("env", true, Marshal(&msg))
|
||||
if err == nil && !ok {
|
||||
err = errors.New("ssh: setenv failed")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// RFC 4254 Section 6.2.
|
||||
type ptyRequestMsg struct {
|
||||
Term string
|
||||
Columns uint32
|
||||
Rows uint32
|
||||
Width uint32
|
||||
Height uint32
|
||||
Modelist string
|
||||
}
|
||||
|
||||
// RequestPty requests the association of a pty with the session on the remote host.
|
||||
func (s *Session) RequestPty(term string, h, w int, termmodes TerminalModes) error {
|
||||
var tm []byte
|
||||
for k, v := range termmodes {
|
||||
kv := struct {
|
||||
Key byte
|
||||
Val uint32
|
||||
}{k, v}
|
||||
|
||||
tm = append(tm, Marshal(&kv)...)
|
||||
}
|
||||
tm = append(tm, tty_OP_END)
|
||||
req := ptyRequestMsg{
|
||||
Term: term,
|
||||
Columns: uint32(w),
|
||||
Rows: uint32(h),
|
||||
Width: uint32(w * 8),
|
||||
Height: uint32(h * 8),
|
||||
Modelist: string(tm),
|
||||
}
|
||||
ok, err := s.ch.SendRequest("pty-req", true, Marshal(&req))
|
||||
if err == nil && !ok {
|
||||
err = errors.New("ssh: pty-req failed")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// RFC 4254 Section 6.5.
|
||||
type subsystemRequestMsg struct {
|
||||
Subsystem string
|
||||
}
|
||||
|
||||
// RequestSubsystem requests the association of a subsystem with the session on the remote host.
|
||||
// A subsystem is a predefined command that runs in the background when the ssh session is initiated
|
||||
func (s *Session) RequestSubsystem(subsystem string) error {
|
||||
msg := subsystemRequestMsg{
|
||||
Subsystem: subsystem,
|
||||
}
|
||||
ok, err := s.ch.SendRequest("subsystem", true, Marshal(&msg))
|
||||
if err == nil && !ok {
|
||||
err = errors.New("ssh: subsystem request failed")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// RFC 4254 Section 6.7.
|
||||
type ptyWindowChangeMsg struct {
|
||||
Columns uint32
|
||||
Rows uint32
|
||||
Width uint32
|
||||
Height uint32
|
||||
}
|
||||
|
||||
// WindowChange informs the remote host about a terminal window dimension change to h rows and w columns.
|
||||
func (s *Session) WindowChange(h, w int) error {
|
||||
req := ptyWindowChangeMsg{
|
||||
Columns: uint32(w),
|
||||
Rows: uint32(h),
|
||||
Width: uint32(w * 8),
|
||||
Height: uint32(h * 8),
|
||||
}
|
||||
_, err := s.ch.SendRequest("window-change", false, Marshal(&req))
|
||||
return err
|
||||
}
|
||||
|
||||
// RFC 4254 Section 6.9.
|
||||
type signalMsg struct {
|
||||
Signal string
|
||||
}
|
||||
|
||||
// Signal sends the given signal to the remote process.
|
||||
// sig is one of the SIG* constants.
|
||||
func (s *Session) Signal(sig Signal) error {
|
||||
msg := signalMsg{
|
||||
Signal: string(sig),
|
||||
}
|
||||
|
||||
_, err := s.ch.SendRequest("signal", false, Marshal(&msg))
|
||||
return err
|
||||
}
|
||||
|
||||
// RFC 4254 Section 6.5.
|
||||
type execMsg struct {
|
||||
Command string
|
||||
}
|
||||
|
||||
// Start runs cmd on the remote host. Typically, the remote
|
||||
// server passes cmd to the shell for interpretation.
|
||||
// A Session only accepts one call to Run, Start or Shell.
|
||||
func (s *Session) Start(cmd string) error {
|
||||
if s.started {
|
||||
return errors.New("ssh: session already started")
|
||||
}
|
||||
req := execMsg{
|
||||
Command: cmd,
|
||||
}
|
||||
|
||||
ok, err := s.ch.SendRequest("exec", true, Marshal(&req))
|
||||
if err == nil && !ok {
|
||||
err = fmt.Errorf("ssh: command %v failed", cmd)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.start()
|
||||
}
|
||||
|
||||
// Run runs cmd on the remote host. Typically, the remote
|
||||
// server passes cmd to the shell for interpretation.
|
||||
// A Session only accepts one call to Run, Start, Shell, Output,
|
||||
// or CombinedOutput.
|
||||
//
|
||||
// The returned error is nil if the command runs, has no problems
|
||||
// copying stdin, stdout, and stderr, and exits with a zero exit
|
||||
// status.
|
||||
//
|
||||
// If the remote server does not send an exit status, an error of type
|
||||
// *ExitMissingError is returned. If the command completes
|
||||
// unsuccessfully or is interrupted by a signal, the error is of type
|
||||
// *ExitError. Other error types may be returned for I/O problems.
|
||||
func (s *Session) Run(cmd string) error {
|
||||
err := s.Start(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Wait()
|
||||
}
|
||||
|
||||
// Output runs cmd on the remote host and returns its standard output.
|
||||
func (s *Session) Output(cmd string) ([]byte, error) {
|
||||
if s.Stdout != nil {
|
||||
return nil, errors.New("ssh: Stdout already set")
|
||||
}
|
||||
var b bytes.Buffer
|
||||
s.Stdout = &b
|
||||
err := s.Run(cmd)
|
||||
return b.Bytes(), err
|
||||
}
|
||||
|
||||
type singleWriter struct {
|
||||
b bytes.Buffer
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (w *singleWriter) Write(p []byte) (int, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
return w.b.Write(p)
|
||||
}
|
||||
|
||||
// CombinedOutput runs cmd on the remote host and returns its combined
|
||||
// standard output and standard error.
|
||||
func (s *Session) CombinedOutput(cmd string) ([]byte, error) {
|
||||
if s.Stdout != nil {
|
||||
return nil, errors.New("ssh: Stdout already set")
|
||||
}
|
||||
if s.Stderr != nil {
|
||||
return nil, errors.New("ssh: Stderr already set")
|
||||
}
|
||||
var b singleWriter
|
||||
s.Stdout = &b
|
||||
s.Stderr = &b
|
||||
err := s.Run(cmd)
|
||||
return b.b.Bytes(), err
|
||||
}
|
||||
|
||||
// Shell starts a login shell on the remote host. A Session only
|
||||
// accepts one call to Run, Start, Shell, Output, or CombinedOutput.
|
||||
func (s *Session) Shell() error {
|
||||
if s.started {
|
||||
return errors.New("ssh: session already started")
|
||||
}
|
||||
|
||||
ok, err := s.ch.SendRequest("shell", true, nil)
|
||||
if err == nil && !ok {
|
||||
return errors.New("ssh: could not start shell")
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.start()
|
||||
}
|
||||
|
||||
func (s *Session) start() error {
|
||||
s.started = true
|
||||
|
||||
type F func(*Session)
|
||||
for _, setupFd := range []F{(*Session).stdin, (*Session).stdout, (*Session).stderr} {
|
||||
setupFd(s)
|
||||
}
|
||||
|
||||
s.errors = make(chan error, len(s.copyFuncs))
|
||||
for _, fn := range s.copyFuncs {
|
||||
go func(fn func() error) {
|
||||
s.errors <- fn()
|
||||
}(fn)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Wait waits for the remote command to exit.
|
||||
//
|
||||
// The returned error is nil if the command runs, has no problems
|
||||
// copying stdin, stdout, and stderr, and exits with a zero exit
|
||||
// status.
|
||||
//
|
||||
// If the remote server does not send an exit status, an error of type
|
||||
// *ExitMissingError is returned. If the command completes
|
||||
// unsuccessfully or is interrupted by a signal, the error is of type
|
||||
// *ExitError. Other error types may be returned for I/O problems.
|
||||
func (s *Session) Wait() error {
|
||||
if !s.started {
|
||||
return errors.New("ssh: session not started")
|
||||
}
|
||||
waitErr := <-s.exitStatus
|
||||
|
||||
if s.stdinPipeWriter != nil {
|
||||
s.stdinPipeWriter.Close()
|
||||
}
|
||||
var copyError error
|
||||
for range s.copyFuncs {
|
||||
if err := <-s.errors; err != nil && copyError == nil {
|
||||
copyError = err
|
||||
}
|
||||
}
|
||||
if waitErr != nil {
|
||||
return waitErr
|
||||
}
|
||||
return copyError
|
||||
}
|
||||
|
||||
func (s *Session) wait(reqs <-chan *Request) error {
|
||||
wm := Waitmsg{status: -1}
|
||||
// Wait for msg channel to be closed before returning.
|
||||
for msg := range reqs {
|
||||
switch msg.Type {
|
||||
case "exit-status":
|
||||
wm.status = int(binary.BigEndian.Uint32(msg.Payload))
|
||||
case "exit-signal":
|
||||
var sigval struct {
|
||||
Signal string
|
||||
CoreDumped bool
|
||||
Error string
|
||||
Lang string
|
||||
}
|
||||
if err := Unmarshal(msg.Payload, &sigval); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Must sanitize strings?
|
||||
wm.signal = sigval.Signal
|
||||
wm.msg = sigval.Error
|
||||
wm.lang = sigval.Lang
|
||||
default:
|
||||
// This handles keepalives and matches
|
||||
// OpenSSH's behaviour.
|
||||
if msg.WantReply {
|
||||
msg.Reply(false, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
if wm.status == 0 {
|
||||
return nil
|
||||
}
|
||||
if wm.status == -1 {
|
||||
// exit-status was never sent from server
|
||||
if wm.signal == "" {
|
||||
// signal was not sent either. RFC 4254
|
||||
// section 6.10 recommends against this
|
||||
// behavior, but it is allowed, so we let
|
||||
// clients handle it.
|
||||
return &ExitMissingError{}
|
||||
}
|
||||
wm.status = 128
|
||||
if _, ok := signals[Signal(wm.signal)]; ok {
|
||||
wm.status += signals[Signal(wm.signal)]
|
||||
}
|
||||
}
|
||||
|
||||
return &ExitError{wm}
|
||||
}
|
||||
|
||||
// ExitMissingError is returned if a session is torn down cleanly, but
|
||||
// the server sends no confirmation of the exit status.
|
||||
type ExitMissingError struct{}
|
||||
|
||||
func (e *ExitMissingError) Error() string {
|
||||
return "wait: remote command exited without exit status or exit signal"
|
||||
}
|
||||
|
||||
func (s *Session) stdin() {
|
||||
if s.stdinpipe {
|
||||
return
|
||||
}
|
||||
var stdin io.Reader
|
||||
if s.Stdin == nil {
|
||||
stdin = new(bytes.Buffer)
|
||||
} else {
|
||||
r, w := io.Pipe()
|
||||
go func() {
|
||||
_, err := io.Copy(w, s.Stdin)
|
||||
w.CloseWithError(err)
|
||||
}()
|
||||
stdin, s.stdinPipeWriter = r, w
|
||||
}
|
||||
s.copyFuncs = append(s.copyFuncs, func() error {
|
||||
_, err := io.Copy(s.ch, stdin)
|
||||
if err1 := s.ch.CloseWrite(); err == nil && err1 != io.EOF {
|
||||
err = err1
|
||||
}
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) stdout() {
|
||||
if s.stdoutpipe {
|
||||
return
|
||||
}
|
||||
if s.Stdout == nil {
|
||||
s.Stdout = io.Discard
|
||||
}
|
||||
s.copyFuncs = append(s.copyFuncs, func() error {
|
||||
_, err := io.Copy(s.Stdout, s.ch)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) stderr() {
|
||||
if s.stderrpipe {
|
||||
return
|
||||
}
|
||||
if s.Stderr == nil {
|
||||
s.Stderr = io.Discard
|
||||
}
|
||||
s.copyFuncs = append(s.copyFuncs, func() error {
|
||||
_, err := io.Copy(s.Stderr, s.ch.Stderr())
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// sessionStdin reroutes Close to CloseWrite.
|
||||
type sessionStdin struct {
|
||||
io.Writer
|
||||
ch Channel
|
||||
}
|
||||
|
||||
func (s *sessionStdin) Close() error {
|
||||
return s.ch.CloseWrite()
|
||||
}
|
||||
|
||||
// StdinPipe returns a pipe that will be connected to the
|
||||
// remote command's standard input when the command starts.
|
||||
func (s *Session) StdinPipe() (io.WriteCloser, error) {
|
||||
if s.Stdin != nil {
|
||||
return nil, errors.New("ssh: Stdin already set")
|
||||
}
|
||||
if s.started {
|
||||
return nil, errors.New("ssh: StdinPipe after process started")
|
||||
}
|
||||
s.stdinpipe = true
|
||||
return &sessionStdin{s.ch, s.ch}, nil
|
||||
}
|
||||
|
||||
// StdoutPipe returns a pipe that will be connected to the
|
||||
// remote command's standard output when the command starts.
|
||||
// There is a fixed amount of buffering that is shared between
|
||||
// stdout and stderr streams. If the StdoutPipe reader is
|
||||
// not serviced fast enough it may eventually cause the
|
||||
// remote command to block.
|
||||
func (s *Session) StdoutPipe() (io.Reader, error) {
|
||||
if s.Stdout != nil {
|
||||
return nil, errors.New("ssh: Stdout already set")
|
||||
}
|
||||
if s.started {
|
||||
return nil, errors.New("ssh: StdoutPipe after process started")
|
||||
}
|
||||
s.stdoutpipe = true
|
||||
return s.ch, nil
|
||||
}
|
||||
|
||||
// StderrPipe returns a pipe that will be connected to the
|
||||
// remote command's standard error when the command starts.
|
||||
// There is a fixed amount of buffering that is shared between
|
||||
// stdout and stderr streams. If the StderrPipe reader is
|
||||
// not serviced fast enough it may eventually cause the
|
||||
// remote command to block.
|
||||
func (s *Session) StderrPipe() (io.Reader, error) {
|
||||
if s.Stderr != nil {
|
||||
return nil, errors.New("ssh: Stderr already set")
|
||||
}
|
||||
if s.started {
|
||||
return nil, errors.New("ssh: StderrPipe after process started")
|
||||
}
|
||||
s.stderrpipe = true
|
||||
return s.ch.Stderr(), nil
|
||||
}
|
||||
|
||||
// newSession returns a new interactive session on the remote host.
|
||||
func newSession(ch Channel, reqs <-chan *Request) (*Session, error) {
|
||||
s := &Session{
|
||||
ch: ch,
|
||||
}
|
||||
s.exitStatus = make(chan error, 1)
|
||||
go func() {
|
||||
s.exitStatus <- s.wait(reqs)
|
||||
}()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// An ExitError reports unsuccessful completion of a remote command.
|
||||
type ExitError struct {
|
||||
Waitmsg
|
||||
}
|
||||
|
||||
func (e *ExitError) Error() string {
|
||||
return e.Waitmsg.String()
|
||||
}
|
||||
|
||||
// Waitmsg stores the information about an exited remote command
|
||||
// as reported by Wait.
|
||||
type Waitmsg struct {
|
||||
status int
|
||||
signal string
|
||||
msg string
|
||||
lang string
|
||||
}
|
||||
|
||||
// ExitStatus returns the exit status of the remote command.
|
||||
func (w Waitmsg) ExitStatus() int {
|
||||
return w.status
|
||||
}
|
||||
|
||||
// Signal returns the exit signal of the remote command if
|
||||
// it was terminated violently.
|
||||
func (w Waitmsg) Signal() string {
|
||||
return w.signal
|
||||
}
|
||||
|
||||
// Msg returns the exit message given by the remote command
|
||||
func (w Waitmsg) Msg() string {
|
||||
return w.msg
|
||||
}
|
||||
|
||||
// Lang returns the language tag. See RFC 3066
|
||||
func (w Waitmsg) Lang() string {
|
||||
return w.lang
|
||||
}
|
||||
|
||||
func (w Waitmsg) String() string {
|
||||
str := fmt.Sprintf("Process exited with status %v", w.status)
|
||||
if w.signal != "" {
|
||||
str += fmt.Sprintf(" from signal %v", w.signal)
|
||||
}
|
||||
if w.msg != "" {
|
||||
str += fmt.Sprintf(". Reason was: %v", w.msg)
|
||||
}
|
||||
return str
|
||||
}
|
||||
139
vendor/github.com/tailscale/golang-x-crypto/ssh/ssh_gss.go
generated
vendored
Normal file
139
vendor/github.com/tailscale/golang-x-crypto/ssh/ssh_gss.go
generated
vendored
Normal file
@@ -0,0 +1,139 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"encoding/asn1"
|
||||
"errors"
|
||||
)
|
||||
|
||||
var krb5OID []byte
|
||||
|
||||
func init() {
|
||||
krb5OID, _ = asn1.Marshal(krb5Mesh)
|
||||
}
|
||||
|
||||
// GSSAPIClient provides the API to plug-in GSSAPI authentication for client logins.
|
||||
type GSSAPIClient interface {
|
||||
// InitSecContext initiates the establishment of a security context for GSS-API between the
|
||||
// ssh client and ssh server. Initially the token parameter should be specified as nil.
|
||||
// The routine may return a outputToken which should be transferred to
|
||||
// the ssh server, where the ssh server will present it to
|
||||
// AcceptSecContext. If no token need be sent, InitSecContext will indicate this by setting
|
||||
// needContinue to false. To complete the context
|
||||
// establishment, one or more reply tokens may be required from the ssh
|
||||
// server;if so, InitSecContext will return a needContinue which is true.
|
||||
// In this case, InitSecContext should be called again when the
|
||||
// reply token is received from the ssh server, passing the reply
|
||||
// token to InitSecContext via the token parameters.
|
||||
// See RFC 2743 section 2.2.1 and RFC 4462 section 3.4.
|
||||
InitSecContext(target string, token []byte, isGSSDelegCreds bool) (outputToken []byte, needContinue bool, err error)
|
||||
// GetMIC generates a cryptographic MIC for the SSH2 message, and places
|
||||
// the MIC in a token for transfer to the ssh server.
|
||||
// The contents of the MIC field are obtained by calling GSS_GetMIC()
|
||||
// over the following, using the GSS-API context that was just
|
||||
// established:
|
||||
// string session identifier
|
||||
// byte SSH_MSG_USERAUTH_REQUEST
|
||||
// string user name
|
||||
// string service
|
||||
// string "gssapi-with-mic"
|
||||
// See RFC 2743 section 2.3.1 and RFC 4462 3.5.
|
||||
GetMIC(micFiled []byte) ([]byte, error)
|
||||
// Whenever possible, it should be possible for
|
||||
// DeleteSecContext() calls to be successfully processed even
|
||||
// if other calls cannot succeed, thereby enabling context-related
|
||||
// resources to be released.
|
||||
// In addition to deleting established security contexts,
|
||||
// gss_delete_sec_context must also be able to delete "half-built"
|
||||
// security contexts resulting from an incomplete sequence of
|
||||
// InitSecContext()/AcceptSecContext() calls.
|
||||
// See RFC 2743 section 2.2.3.
|
||||
DeleteSecContext() error
|
||||
}
|
||||
|
||||
// GSSAPIServer provides the API to plug in GSSAPI authentication for server logins.
|
||||
type GSSAPIServer interface {
|
||||
// AcceptSecContext allows a remotely initiated security context between the application
|
||||
// and a remote peer to be established by the ssh client. The routine may return a
|
||||
// outputToken which should be transferred to the ssh client,
|
||||
// where the ssh client will present it to InitSecContext.
|
||||
// If no token need be sent, AcceptSecContext will indicate this
|
||||
// by setting the needContinue to false. To
|
||||
// complete the context establishment, one or more reply tokens may be
|
||||
// required from the ssh client. if so, AcceptSecContext
|
||||
// will return a needContinue which is true, in which case it
|
||||
// should be called again when the reply token is received from the ssh
|
||||
// client, passing the token to AcceptSecContext via the
|
||||
// token parameters.
|
||||
// The srcName return value is the authenticated username.
|
||||
// See RFC 2743 section 2.2.2 and RFC 4462 section 3.4.
|
||||
AcceptSecContext(token []byte) (outputToken []byte, srcName string, needContinue bool, err error)
|
||||
// VerifyMIC verifies that a cryptographic MIC, contained in the token parameter,
|
||||
// fits the supplied message is received from the ssh client.
|
||||
// See RFC 2743 section 2.3.2.
|
||||
VerifyMIC(micField []byte, micToken []byte) error
|
||||
// Whenever possible, it should be possible for
|
||||
// DeleteSecContext() calls to be successfully processed even
|
||||
// if other calls cannot succeed, thereby enabling context-related
|
||||
// resources to be released.
|
||||
// In addition to deleting established security contexts,
|
||||
// gss_delete_sec_context must also be able to delete "half-built"
|
||||
// security contexts resulting from an incomplete sequence of
|
||||
// InitSecContext()/AcceptSecContext() calls.
|
||||
// See RFC 2743 section 2.2.3.
|
||||
DeleteSecContext() error
|
||||
}
|
||||
|
||||
var (
|
||||
// OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication,
|
||||
// so we also support the krb5 mechanism only.
|
||||
// See RFC 1964 section 1.
|
||||
krb5Mesh = asn1.ObjectIdentifier{1, 2, 840, 113554, 1, 2, 2}
|
||||
)
|
||||
|
||||
// The GSS-API authentication method is initiated when the client sends an SSH_MSG_USERAUTH_REQUEST
|
||||
// See RFC 4462 section 3.2.
|
||||
type userAuthRequestGSSAPI struct {
|
||||
N uint32
|
||||
OIDS []asn1.ObjectIdentifier
|
||||
}
|
||||
|
||||
func parseGSSAPIPayload(payload []byte) (*userAuthRequestGSSAPI, error) {
|
||||
n, rest, ok := parseUint32(payload)
|
||||
if !ok {
|
||||
return nil, errors.New("parse uint32 failed")
|
||||
}
|
||||
s := &userAuthRequestGSSAPI{
|
||||
N: n,
|
||||
OIDS: make([]asn1.ObjectIdentifier, n),
|
||||
}
|
||||
for i := 0; i < int(n); i++ {
|
||||
var (
|
||||
desiredMech []byte
|
||||
err error
|
||||
)
|
||||
desiredMech, rest, ok = parseString(rest)
|
||||
if !ok {
|
||||
return nil, errors.New("parse string failed")
|
||||
}
|
||||
if rest, err = asn1.Unmarshal(desiredMech, &s.OIDS[i]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// See RFC 4462 section 3.6.
|
||||
func buildMIC(sessionID string, username string, service string, authMethod string) []byte {
|
||||
out := make([]byte, 0, 0)
|
||||
out = appendString(out, sessionID)
|
||||
out = append(out, msgUserAuthRequest)
|
||||
out = appendString(out, username)
|
||||
out = appendString(out, service)
|
||||
out = appendString(out, authMethod)
|
||||
return out
|
||||
}
|
||||
116
vendor/github.com/tailscale/golang-x-crypto/ssh/streamlocal.go
generated
vendored
Normal file
116
vendor/github.com/tailscale/golang-x-crypto/ssh/streamlocal.go
generated
vendored
Normal file
@@ -0,0 +1,116 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
// streamLocalChannelOpenDirectMsg is a struct used for SSH_MSG_CHANNEL_OPEN message
|
||||
// with "direct-streamlocal@openssh.com" string.
|
||||
//
|
||||
// See openssh-portable/PROTOCOL, section 2.4. connection: Unix domain socket forwarding
|
||||
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL#L235
|
||||
type streamLocalChannelOpenDirectMsg struct {
|
||||
socketPath string
|
||||
reserved0 string
|
||||
reserved1 uint32
|
||||
}
|
||||
|
||||
// forwardedStreamLocalPayload is a struct used for SSH_MSG_CHANNEL_OPEN message
|
||||
// with "forwarded-streamlocal@openssh.com" string.
|
||||
type forwardedStreamLocalPayload struct {
|
||||
SocketPath string
|
||||
Reserved0 string
|
||||
}
|
||||
|
||||
// streamLocalChannelForwardMsg is a struct used for SSH2_MSG_GLOBAL_REQUEST message
|
||||
// with "streamlocal-forward@openssh.com"/"cancel-streamlocal-forward@openssh.com" string.
|
||||
type streamLocalChannelForwardMsg struct {
|
||||
socketPath string
|
||||
}
|
||||
|
||||
// ListenUnix is similar to ListenTCP but uses a Unix domain socket.
|
||||
func (c *Client) ListenUnix(socketPath string) (net.Listener, error) {
|
||||
c.handleForwardsOnce.Do(c.handleForwards)
|
||||
m := streamLocalChannelForwardMsg{
|
||||
socketPath,
|
||||
}
|
||||
// send message
|
||||
ok, _, err := c.SendRequest("streamlocal-forward@openssh.com", true, Marshal(&m))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
return nil, errors.New("ssh: streamlocal-forward@openssh.com request denied by peer")
|
||||
}
|
||||
ch := c.forwards.add(&net.UnixAddr{Name: socketPath, Net: "unix"})
|
||||
|
||||
return &unixListener{socketPath, c, ch}, nil
|
||||
}
|
||||
|
||||
func (c *Client) dialStreamLocal(socketPath string) (Channel, error) {
|
||||
msg := streamLocalChannelOpenDirectMsg{
|
||||
socketPath: socketPath,
|
||||
}
|
||||
ch, in, err := c.OpenChannel("direct-streamlocal@openssh.com", Marshal(&msg))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go DiscardRequests(in)
|
||||
return ch, err
|
||||
}
|
||||
|
||||
type unixListener struct {
|
||||
socketPath string
|
||||
|
||||
conn *Client
|
||||
in <-chan forward
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
func (l *unixListener) Accept() (net.Conn, error) {
|
||||
s, ok := <-l.in
|
||||
if !ok {
|
||||
return nil, io.EOF
|
||||
}
|
||||
ch, incoming, err := s.newCh.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go DiscardRequests(incoming)
|
||||
|
||||
return &chanConn{
|
||||
Channel: ch,
|
||||
laddr: &net.UnixAddr{
|
||||
Name: l.socketPath,
|
||||
Net: "unix",
|
||||
},
|
||||
raddr: &net.UnixAddr{
|
||||
Name: "@",
|
||||
Net: "unix",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the listener.
|
||||
func (l *unixListener) Close() error {
|
||||
// this also closes the listener.
|
||||
l.conn.forwards.remove(&net.UnixAddr{Name: l.socketPath, Net: "unix"})
|
||||
m := streamLocalChannelForwardMsg{
|
||||
l.socketPath,
|
||||
}
|
||||
ok, _, err := l.conn.SendRequest("cancel-streamlocal-forward@openssh.com", true, Marshal(&m))
|
||||
if err == nil && !ok {
|
||||
err = errors.New("ssh: cancel-streamlocal-forward@openssh.com failed")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Addr returns the listener's network address.
|
||||
func (l *unixListener) Addr() net.Addr {
|
||||
return &net.UnixAddr{
|
||||
Name: l.socketPath,
|
||||
Net: "unix",
|
||||
}
|
||||
}
|
||||
509
vendor/github.com/tailscale/golang-x-crypto/ssh/tcpip.go
generated
vendored
Normal file
509
vendor/github.com/tailscale/golang-x-crypto/ssh/tcpip.go
generated
vendored
Normal file
@@ -0,0 +1,509 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Listen requests the remote peer open a listening socket on
|
||||
// addr. Incoming connections will be available by calling Accept on
|
||||
// the returned net.Listener. The listener must be serviced, or the
|
||||
// SSH connection may hang.
|
||||
// N must be "tcp", "tcp4", "tcp6", or "unix".
|
||||
func (c *Client) Listen(n, addr string) (net.Listener, error) {
|
||||
switch n {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
laddr, err := net.ResolveTCPAddr(n, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.ListenTCP(laddr)
|
||||
case "unix":
|
||||
return c.ListenUnix(addr)
|
||||
default:
|
||||
return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
|
||||
}
|
||||
}
|
||||
|
||||
// Automatic port allocation is broken with OpenSSH before 6.0. See
|
||||
// also https://bugzilla.mindrot.org/show_bug.cgi?id=2017. In
|
||||
// particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0,
|
||||
// rather than the actual port number. This means you can never open
|
||||
// two different listeners with auto allocated ports. We work around
|
||||
// this by trying explicit ports until we succeed.
|
||||
|
||||
const openSSHPrefix = "OpenSSH_"
|
||||
|
||||
var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
// isBrokenOpenSSHVersion returns true if the given version string
|
||||
// specifies a version of OpenSSH that is known to have a bug in port
|
||||
// forwarding.
|
||||
func isBrokenOpenSSHVersion(versionStr string) bool {
|
||||
i := strings.Index(versionStr, openSSHPrefix)
|
||||
if i < 0 {
|
||||
return false
|
||||
}
|
||||
i += len(openSSHPrefix)
|
||||
j := i
|
||||
for ; j < len(versionStr); j++ {
|
||||
if versionStr[j] < '0' || versionStr[j] > '9' {
|
||||
break
|
||||
}
|
||||
}
|
||||
version, _ := strconv.Atoi(versionStr[i:j])
|
||||
return version < 6
|
||||
}
|
||||
|
||||
// autoPortListenWorkaround simulates automatic port allocation by
|
||||
// trying random ports repeatedly.
|
||||
func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) {
|
||||
var sshListener net.Listener
|
||||
var err error
|
||||
const tries = 10
|
||||
for i := 0; i < tries; i++ {
|
||||
addr := *laddr
|
||||
addr.Port = 1024 + portRandomizer.Intn(60000)
|
||||
sshListener, err = c.ListenTCP(&addr)
|
||||
if err == nil {
|
||||
laddr.Port = addr.Port
|
||||
return sshListener, err
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err)
|
||||
}
|
||||
|
||||
// RFC 4254 7.1
|
||||
type channelForwardMsg struct {
|
||||
addr string
|
||||
rport uint32
|
||||
}
|
||||
|
||||
// handleForwards starts goroutines handling forwarded connections.
|
||||
// It's called on first use by (*Client).ListenTCP to not launch
|
||||
// goroutines until needed.
|
||||
func (c *Client) handleForwards() {
|
||||
go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-tcpip"))
|
||||
go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
|
||||
}
|
||||
|
||||
// ListenTCP requests the remote peer open a listening socket
|
||||
// on laddr. Incoming connections will be available by calling
|
||||
// Accept on the returned net.Listener.
|
||||
func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
|
||||
c.handleForwardsOnce.Do(c.handleForwards)
|
||||
if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
|
||||
return c.autoPortListenWorkaround(laddr)
|
||||
}
|
||||
|
||||
m := channelForwardMsg{
|
||||
laddr.IP.String(),
|
||||
uint32(laddr.Port),
|
||||
}
|
||||
// send message
|
||||
ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
return nil, errors.New("ssh: tcpip-forward request denied by peer")
|
||||
}
|
||||
|
||||
// If the original port was 0, then the remote side will
|
||||
// supply a real port number in the response.
|
||||
if laddr.Port == 0 {
|
||||
var p struct {
|
||||
Port uint32
|
||||
}
|
||||
if err := Unmarshal(resp, &p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
laddr.Port = int(p.Port)
|
||||
}
|
||||
|
||||
// Register this forward, using the port number we obtained.
|
||||
ch := c.forwards.add(laddr)
|
||||
|
||||
return &tcpListener{laddr, c, ch}, nil
|
||||
}
|
||||
|
||||
// forwardList stores a mapping between remote
|
||||
// forward requests and the tcpListeners.
|
||||
type forwardList struct {
|
||||
sync.Mutex
|
||||
entries []forwardEntry
|
||||
}
|
||||
|
||||
// forwardEntry represents an established mapping of a laddr on a
|
||||
// remote ssh server to a channel connected to a tcpListener.
|
||||
type forwardEntry struct {
|
||||
laddr net.Addr
|
||||
c chan forward
|
||||
}
|
||||
|
||||
// forward represents an incoming forwarded tcpip connection. The
|
||||
// arguments to add/remove/lookup should be address as specified in
|
||||
// the original forward-request.
|
||||
type forward struct {
|
||||
newCh NewChannel // the ssh client channel underlying this forward
|
||||
raddr net.Addr // the raddr of the incoming connection
|
||||
}
|
||||
|
||||
func (l *forwardList) add(addr net.Addr) chan forward {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
f := forwardEntry{
|
||||
laddr: addr,
|
||||
c: make(chan forward, 1),
|
||||
}
|
||||
l.entries = append(l.entries, f)
|
||||
return f.c
|
||||
}
|
||||
|
||||
// See RFC 4254, section 7.2
|
||||
type forwardedTCPPayload struct {
|
||||
Addr string
|
||||
Port uint32
|
||||
OriginAddr string
|
||||
OriginPort uint32
|
||||
}
|
||||
|
||||
// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.
|
||||
func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) {
|
||||
if port == 0 || port > 65535 {
|
||||
return nil, fmt.Errorf("ssh: port number out of range: %d", port)
|
||||
}
|
||||
ip := net.ParseIP(string(addr))
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr)
|
||||
}
|
||||
return &net.TCPAddr{IP: ip, Port: int(port)}, nil
|
||||
}
|
||||
|
||||
func (l *forwardList) handleChannels(in <-chan NewChannel) {
|
||||
for ch := range in {
|
||||
var (
|
||||
laddr net.Addr
|
||||
raddr net.Addr
|
||||
err error
|
||||
)
|
||||
switch channelType := ch.ChannelType(); channelType {
|
||||
case "forwarded-tcpip":
|
||||
var payload forwardedTCPPayload
|
||||
if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
|
||||
ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
// RFC 4254 section 7.2 specifies that incoming
|
||||
// addresses should list the address, in string
|
||||
// format. It is implied that this should be an IP
|
||||
// address, as it would be impossible to connect to it
|
||||
// otherwise.
|
||||
laddr, err = parseTCPAddr(payload.Addr, payload.Port)
|
||||
if err != nil {
|
||||
ch.Reject(ConnectionFailed, err.Error())
|
||||
continue
|
||||
}
|
||||
raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort)
|
||||
if err != nil {
|
||||
ch.Reject(ConnectionFailed, err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
case "forwarded-streamlocal@openssh.com":
|
||||
var payload forwardedStreamLocalPayload
|
||||
if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
|
||||
ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error())
|
||||
continue
|
||||
}
|
||||
laddr = &net.UnixAddr{
|
||||
Name: payload.SocketPath,
|
||||
Net: "unix",
|
||||
}
|
||||
raddr = &net.UnixAddr{
|
||||
Name: "@",
|
||||
Net: "unix",
|
||||
}
|
||||
default:
|
||||
panic(fmt.Errorf("ssh: unknown channel type %s", channelType))
|
||||
}
|
||||
if ok := l.forward(laddr, raddr, ch); !ok {
|
||||
// Section 7.2, implementations MUST reject spurious incoming
|
||||
// connections.
|
||||
ch.Reject(Prohibited, "no forward for address")
|
||||
continue
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// remove removes the forward entry, and the channel feeding its
|
||||
// listener.
|
||||
func (l *forwardList) remove(addr net.Addr) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
for i, f := range l.entries {
|
||||
if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() {
|
||||
l.entries = append(l.entries[:i], l.entries[i+1:]...)
|
||||
close(f.c)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// closeAll closes and clears all forwards.
|
||||
func (l *forwardList) closeAll() {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
for _, f := range l.entries {
|
||||
close(f.c)
|
||||
}
|
||||
l.entries = nil
|
||||
}
|
||||
|
||||
func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
for _, f := range l.entries {
|
||||
if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() {
|
||||
f.c <- forward{newCh: ch, raddr: raddr}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type tcpListener struct {
|
||||
laddr *net.TCPAddr
|
||||
|
||||
conn *Client
|
||||
in <-chan forward
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
func (l *tcpListener) Accept() (net.Conn, error) {
|
||||
s, ok := <-l.in
|
||||
if !ok {
|
||||
return nil, io.EOF
|
||||
}
|
||||
ch, incoming, err := s.newCh.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go DiscardRequests(incoming)
|
||||
|
||||
return &chanConn{
|
||||
Channel: ch,
|
||||
laddr: l.laddr,
|
||||
raddr: s.raddr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the listener.
|
||||
func (l *tcpListener) Close() error {
|
||||
m := channelForwardMsg{
|
||||
l.laddr.IP.String(),
|
||||
uint32(l.laddr.Port),
|
||||
}
|
||||
|
||||
// this also closes the listener.
|
||||
l.conn.forwards.remove(l.laddr)
|
||||
ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
|
||||
if err == nil && !ok {
|
||||
err = errors.New("ssh: cancel-tcpip-forward failed")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Addr returns the listener's network address.
|
||||
func (l *tcpListener) Addr() net.Addr {
|
||||
return l.laddr
|
||||
}
|
||||
|
||||
// DialContext initiates a connection to the addr from the remote host.
|
||||
//
|
||||
// The provided Context must be non-nil. If the context expires before the
|
||||
// connection is complete, an error is returned. Once successfully connected,
|
||||
// any expiration of the context will not affect the connection.
|
||||
//
|
||||
// See func Dial for additional information.
|
||||
func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
type connErr struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
ch := make(chan connErr)
|
||||
go func() {
|
||||
conn, err := c.Dial(n, addr)
|
||||
select {
|
||||
case ch <- connErr{conn, err}:
|
||||
case <-ctx.Done():
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case res := <-ch:
|
||||
return res.conn, res.err
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Dial initiates a connection to the addr from the remote host.
|
||||
// The resulting connection has a zero LocalAddr() and RemoteAddr().
|
||||
func (c *Client) Dial(n, addr string) (net.Conn, error) {
|
||||
var ch Channel
|
||||
switch n {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
// Parse the address into host and numeric port.
|
||||
host, portString, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
port, err := strconv.ParseUint(portString, 10, 16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Use a zero address for local and remote address.
|
||||
zeroAddr := &net.TCPAddr{
|
||||
IP: net.IPv4zero,
|
||||
Port: 0,
|
||||
}
|
||||
return &chanConn{
|
||||
Channel: ch,
|
||||
laddr: zeroAddr,
|
||||
raddr: zeroAddr,
|
||||
}, nil
|
||||
case "unix":
|
||||
var err error
|
||||
ch, err = c.dialStreamLocal(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &chanConn{
|
||||
Channel: ch,
|
||||
laddr: &net.UnixAddr{
|
||||
Name: "@",
|
||||
Net: "unix",
|
||||
},
|
||||
raddr: &net.UnixAddr{
|
||||
Name: addr,
|
||||
Net: "unix",
|
||||
},
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
|
||||
}
|
||||
}
|
||||
|
||||
// DialTCP connects to the remote address raddr on the network net,
|
||||
// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used
|
||||
// as the local address for the connection.
|
||||
func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
|
||||
if laddr == nil {
|
||||
laddr = &net.TCPAddr{
|
||||
IP: net.IPv4zero,
|
||||
Port: 0,
|
||||
}
|
||||
}
|
||||
ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &chanConn{
|
||||
Channel: ch,
|
||||
laddr: laddr,
|
||||
raddr: raddr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RFC 4254 7.2
|
||||
type channelOpenDirectMsg struct {
|
||||
raddr string
|
||||
rport uint32
|
||||
laddr string
|
||||
lport uint32
|
||||
}
|
||||
|
||||
func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) {
|
||||
msg := channelOpenDirectMsg{
|
||||
raddr: raddr,
|
||||
rport: uint32(rport),
|
||||
laddr: laddr,
|
||||
lport: uint32(lport),
|
||||
}
|
||||
ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go DiscardRequests(in)
|
||||
return ch, err
|
||||
}
|
||||
|
||||
type tcpChan struct {
|
||||
Channel // the backing channel
|
||||
}
|
||||
|
||||
// chanConn fulfills the net.Conn interface without
|
||||
// the tcpChan having to hold laddr or raddr directly.
|
||||
type chanConn struct {
|
||||
Channel
|
||||
laddr, raddr net.Addr
|
||||
}
|
||||
|
||||
// LocalAddr returns the local network address.
|
||||
func (t *chanConn) LocalAddr() net.Addr {
|
||||
return t.laddr
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote network address.
|
||||
func (t *chanConn) RemoteAddr() net.Addr {
|
||||
return t.raddr
|
||||
}
|
||||
|
||||
// SetDeadline sets the read and write deadlines associated
|
||||
// with the connection.
|
||||
func (t *chanConn) SetDeadline(deadline time.Time) error {
|
||||
if err := t.SetReadDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
return t.SetWriteDeadline(deadline)
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the read deadline.
|
||||
// A zero value for t means Read will not time out.
|
||||
// After the deadline, the error from Read will implement net.Error
|
||||
// with Timeout() == true.
|
||||
func (t *chanConn) SetReadDeadline(deadline time.Time) error {
|
||||
// for compatibility with previous version,
|
||||
// the error message contains "tcpChan"
|
||||
return errors.New("ssh: tcpChan: deadline not supported")
|
||||
}
|
||||
|
||||
// SetWriteDeadline exists to satisfy the net.Conn interface
|
||||
// but is not implemented by this type. It always returns an error.
|
||||
func (t *chanConn) SetWriteDeadline(deadline time.Time) error {
|
||||
return errors.New("ssh: tcpChan: deadline not supported")
|
||||
}
|
||||
380
vendor/github.com/tailscale/golang-x-crypto/ssh/transport.go
generated
vendored
Normal file
380
vendor/github.com/tailscale/golang-x-crypto/ssh/transport.go
generated
vendored
Normal file
@@ -0,0 +1,380 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
)
|
||||
|
||||
// debugTransport if set, will print packet types as they go over the
|
||||
// wire. No message decoding is done, to minimize the impact on timing.
|
||||
const debugTransport = false
|
||||
|
||||
const (
|
||||
gcm128CipherID = "aes128-gcm@openssh.com"
|
||||
gcm256CipherID = "aes256-gcm@openssh.com"
|
||||
aes128cbcID = "aes128-cbc"
|
||||
tripledescbcID = "3des-cbc"
|
||||
)
|
||||
|
||||
// packetConn represents a transport that implements packet based
|
||||
// operations.
|
||||
type packetConn interface {
|
||||
// Encrypt and send a packet of data to the remote peer.
|
||||
writePacket(packet []byte) error
|
||||
|
||||
// Read a packet from the connection. The read is blocking,
|
||||
// i.e. if error is nil, then the returned byte slice is
|
||||
// always non-empty.
|
||||
readPacket() ([]byte, error)
|
||||
|
||||
// Close closes the write-side of the connection.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// transport is the keyingTransport that implements the SSH packet
|
||||
// protocol.
|
||||
type transport struct {
|
||||
reader connectionState
|
||||
writer connectionState
|
||||
|
||||
bufReader *bufio.Reader
|
||||
bufWriter *bufio.Writer
|
||||
rand io.Reader
|
||||
isClient bool
|
||||
io.Closer
|
||||
|
||||
strictMode bool
|
||||
initialKEXDone bool
|
||||
}
|
||||
|
||||
// packetCipher represents a combination of SSH encryption/MAC
|
||||
// protocol. A single instance should be used for one direction only.
|
||||
type packetCipher interface {
|
||||
// writeCipherPacket encrypts the packet and writes it to w. The
|
||||
// contents of the packet are generally scrambled.
|
||||
writeCipherPacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error
|
||||
|
||||
// readCipherPacket reads and decrypts a packet of data. The
|
||||
// returned packet may be overwritten by future calls of
|
||||
// readPacket.
|
||||
readCipherPacket(seqnum uint32, r io.Reader) ([]byte, error)
|
||||
}
|
||||
|
||||
// connectionState represents one side (read or write) of the
|
||||
// connection. This is necessary because each direction has its own
|
||||
// keys, and can even have its own algorithms
|
||||
type connectionState struct {
|
||||
packetCipher
|
||||
seqNum uint32
|
||||
dir direction
|
||||
pendingKeyChange chan packetCipher
|
||||
}
|
||||
|
||||
func (t *transport) setStrictMode() error {
|
||||
if t.reader.seqNum != 1 {
|
||||
return errors.New("ssh: sequence number != 1 when strict KEX mode requested")
|
||||
}
|
||||
t.strictMode = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *transport) setInitialKEXDone() {
|
||||
t.initialKEXDone = true
|
||||
}
|
||||
|
||||
// prepareKeyChange sets up key material for a keychange. The key changes in
|
||||
// both directions are triggered by reading and writing a msgNewKey packet
|
||||
// respectively.
|
||||
func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
|
||||
ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.reader.pendingKeyChange <- ciph
|
||||
|
||||
ciph, err = newPacketCipher(t.writer.dir, algs.w, kexResult)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.writer.pendingKeyChange <- ciph
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *transport) printPacket(p []byte, write bool) {
|
||||
if len(p) == 0 {
|
||||
return
|
||||
}
|
||||
who := "server"
|
||||
if t.isClient {
|
||||
who = "client"
|
||||
}
|
||||
what := "read"
|
||||
if write {
|
||||
what = "write"
|
||||
}
|
||||
|
||||
log.Println(what, who, p[0])
|
||||
}
|
||||
|
||||
// Read and decrypt next packet.
|
||||
func (t *transport) readPacket() (p []byte, err error) {
|
||||
for {
|
||||
p, err = t.reader.readPacket(t.bufReader, t.strictMode)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
// in strict mode we pass through DEBUG and IGNORE packets only during the initial KEX
|
||||
if len(p) == 0 || (t.strictMode && !t.initialKEXDone) || (p[0] != msgIgnore && p[0] != msgDebug) {
|
||||
break
|
||||
}
|
||||
}
|
||||
if debugTransport {
|
||||
t.printPacket(p, false)
|
||||
}
|
||||
|
||||
return p, err
|
||||
}
|
||||
|
||||
func (s *connectionState) readPacket(r *bufio.Reader, strictMode bool) ([]byte, error) {
|
||||
packet, err := s.packetCipher.readCipherPacket(s.seqNum, r)
|
||||
s.seqNum++
|
||||
if err == nil && len(packet) == 0 {
|
||||
err = errors.New("ssh: zero length packet")
|
||||
}
|
||||
|
||||
if len(packet) > 0 {
|
||||
switch packet[0] {
|
||||
case msgNewKeys:
|
||||
select {
|
||||
case cipher := <-s.pendingKeyChange:
|
||||
s.packetCipher = cipher
|
||||
if strictMode {
|
||||
s.seqNum = 0
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("ssh: got bogus newkeys message")
|
||||
}
|
||||
|
||||
case msgDisconnect:
|
||||
// Transform a disconnect message into an
|
||||
// error. Since this is lowest level at which
|
||||
// we interpret message types, doing it here
|
||||
// ensures that we don't have to handle it
|
||||
// elsewhere.
|
||||
var msg disconnectMsg
|
||||
if err := Unmarshal(packet, &msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, &msg
|
||||
}
|
||||
}
|
||||
|
||||
// The packet may point to an internal buffer, so copy the
|
||||
// packet out here.
|
||||
fresh := make([]byte, len(packet))
|
||||
copy(fresh, packet)
|
||||
|
||||
return fresh, err
|
||||
}
|
||||
|
||||
func (t *transport) writePacket(packet []byte) error {
|
||||
if debugTransport {
|
||||
t.printPacket(packet, true)
|
||||
}
|
||||
return t.writer.writePacket(t.bufWriter, t.rand, packet, t.strictMode)
|
||||
}
|
||||
|
||||
func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte, strictMode bool) error {
|
||||
changeKeys := len(packet) > 0 && packet[0] == msgNewKeys
|
||||
|
||||
err := s.packetCipher.writeCipherPacket(s.seqNum, w, rand, packet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = w.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
s.seqNum++
|
||||
if changeKeys {
|
||||
select {
|
||||
case cipher := <-s.pendingKeyChange:
|
||||
s.packetCipher = cipher
|
||||
if strictMode {
|
||||
s.seqNum = 0
|
||||
}
|
||||
default:
|
||||
panic("ssh: no key material for msgNewKeys")
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport {
|
||||
t := &transport{
|
||||
bufReader: bufio.NewReader(rwc),
|
||||
bufWriter: bufio.NewWriter(rwc),
|
||||
rand: rand,
|
||||
reader: connectionState{
|
||||
packetCipher: &streamPacketCipher{cipher: noneCipher{}},
|
||||
pendingKeyChange: make(chan packetCipher, 1),
|
||||
},
|
||||
writer: connectionState{
|
||||
packetCipher: &streamPacketCipher{cipher: noneCipher{}},
|
||||
pendingKeyChange: make(chan packetCipher, 1),
|
||||
},
|
||||
Closer: rwc,
|
||||
}
|
||||
t.isClient = isClient
|
||||
|
||||
if isClient {
|
||||
t.reader.dir = serverKeys
|
||||
t.writer.dir = clientKeys
|
||||
} else {
|
||||
t.reader.dir = clientKeys
|
||||
t.writer.dir = serverKeys
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
type direction struct {
|
||||
ivTag []byte
|
||||
keyTag []byte
|
||||
macKeyTag []byte
|
||||
}
|
||||
|
||||
var (
|
||||
serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}}
|
||||
clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
|
||||
)
|
||||
|
||||
// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
|
||||
// described in RFC 4253, section 6.4. direction should either be serverKeys
|
||||
// (to setup server->client keys) or clientKeys (for client->server keys).
|
||||
func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) {
|
||||
cipherMode := cipherModes[algs.Cipher]
|
||||
|
||||
iv := make([]byte, cipherMode.ivSize)
|
||||
key := make([]byte, cipherMode.keySize)
|
||||
|
||||
generateKeyMaterial(iv, d.ivTag, kex)
|
||||
generateKeyMaterial(key, d.keyTag, kex)
|
||||
|
||||
var macKey []byte
|
||||
if !aeadCiphers[algs.Cipher] {
|
||||
macMode := macModes[algs.MAC]
|
||||
macKey = make([]byte, macMode.keySize)
|
||||
generateKeyMaterial(macKey, d.macKeyTag, kex)
|
||||
}
|
||||
|
||||
return cipherModes[algs.Cipher].create(key, iv, macKey, algs)
|
||||
}
|
||||
|
||||
// generateKeyMaterial fills out with key material generated from tag, K, H
|
||||
// and sessionId, as specified in RFC 4253, section 7.2.
|
||||
func generateKeyMaterial(out, tag []byte, r *kexResult) {
|
||||
var digestsSoFar []byte
|
||||
|
||||
h := r.Hash.New()
|
||||
for len(out) > 0 {
|
||||
h.Reset()
|
||||
h.Write(r.K)
|
||||
h.Write(r.H)
|
||||
|
||||
if len(digestsSoFar) == 0 {
|
||||
h.Write(tag)
|
||||
h.Write(r.SessionID)
|
||||
} else {
|
||||
h.Write(digestsSoFar)
|
||||
}
|
||||
|
||||
digest := h.Sum(nil)
|
||||
n := copy(out, digest)
|
||||
out = out[n:]
|
||||
if len(out) > 0 {
|
||||
digestsSoFar = append(digestsSoFar, digest...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const packageVersion = "SSH-2.0-Go"
|
||||
|
||||
// Sends and receives a version line. The versionLine string should
|
||||
// be US ASCII, start with "SSH-2.0-", and should not include a
|
||||
// newline. exchangeVersions returns the other side's version line.
|
||||
func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) {
|
||||
// Contrary to the RFC, we do not ignore lines that don't
|
||||
// start with "SSH-2.0-" to make the library usable with
|
||||
// nonconforming servers.
|
||||
for _, c := range versionLine {
|
||||
// The spec disallows non US-ASCII chars, and
|
||||
// specifically forbids null chars.
|
||||
if c < 32 {
|
||||
return nil, errors.New("ssh: junk character in version line")
|
||||
}
|
||||
}
|
||||
if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
them, err = readVersion(rw)
|
||||
return them, err
|
||||
}
|
||||
|
||||
// maxVersionStringBytes is the maximum number of bytes that we'll
|
||||
// accept as a version string. RFC 4253 section 4.2 limits this at 255
|
||||
// chars
|
||||
const maxVersionStringBytes = 255
|
||||
|
||||
// Read version string as specified by RFC 4253, section 4.2.
|
||||
func readVersion(r io.Reader) ([]byte, error) {
|
||||
versionString := make([]byte, 0, 64)
|
||||
var ok bool
|
||||
var buf [1]byte
|
||||
|
||||
for length := 0; length < maxVersionStringBytes; length++ {
|
||||
_, err := io.ReadFull(r, buf[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// The RFC says that the version should be terminated with \r\n
|
||||
// but several SSH servers actually only send a \n.
|
||||
if buf[0] == '\n' {
|
||||
if !bytes.HasPrefix(versionString, []byte("SSH-")) {
|
||||
// RFC 4253 says we need to ignore all version string lines
|
||||
// except the one containing the SSH version (provided that
|
||||
// all the lines do not exceed 255 bytes in total).
|
||||
versionString = versionString[:0]
|
||||
continue
|
||||
}
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
|
||||
// non ASCII chars are disallowed, but we are lenient,
|
||||
// since Go doesn't use null-terminated strings.
|
||||
|
||||
// The RFC allows a comment after a space, however,
|
||||
// all of it (version and comments) goes into the
|
||||
// session hash.
|
||||
versionString = append(versionString, buf[0])
|
||||
}
|
||||
|
||||
if !ok {
|
||||
return nil, errors.New("ssh: overflow reading version string")
|
||||
}
|
||||
|
||||
// There might be a '\r' on the end which we should remove.
|
||||
if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' {
|
||||
versionString = versionString[:len(versionString)-1]
|
||||
}
|
||||
return versionString, nil
|
||||
}
|
||||
2
vendor/github.com/tailscale/goupnp/.gitignore
generated
vendored
Normal file
2
vendor/github.com/tailscale/goupnp/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
*.zip
|
||||
*.sublime-workspace
|
||||
133
vendor/github.com/tailscale/goupnp/GUIDE.md
generated
vendored
Normal file
133
vendor/github.com/tailscale/goupnp/GUIDE.md
generated
vendored
Normal file
@@ -0,0 +1,133 @@
|
||||
# Guide
|
||||
|
||||
This is a quick guide with example code for common use cases that might be
|
||||
helpful for people wanting a quick guide to common ways that people would
|
||||
want to use this library.
|
||||
|
||||
## Internet Gateways
|
||||
|
||||
`goupnp/dcps/internetgateway1` and `goupnp/dcps/internetgateway2` implement
|
||||
different version standards that allow clients to interact with devices like
|
||||
home consumer routers, but you can probably get by with just
|
||||
`internetgateway2`. Some very common use cases to talk to such devices are:
|
||||
|
||||
- Requesting the external Internet-facing IP address.
|
||||
- Requesting a port be forwarded from the external (Internet-facing) interface
|
||||
to a port on the LAN.
|
||||
|
||||
Different routers implement different standards, so you may have to request
|
||||
multiple clients to find the one that your router needs. The most useful ones
|
||||
for the purpose above can be requested with the following functions:
|
||||
|
||||
- `internetgateway2.NewWANIPConnection1Clients()`
|
||||
- `internetgateway2.NewWANIPConnection2Clients()`
|
||||
- `internetgateway2.NewWANPPPConnection1Clients()`
|
||||
|
||||
Fortunately, each of the clients returned by these functions provide the same
|
||||
method signatures for the purposes listed above. So you could request multiple
|
||||
clients, and whichever one you find, and return it from a function in a variable
|
||||
of the common interface, e.g:
|
||||
|
||||
```go
|
||||
type RouterClient interface {
|
||||
AddPortMapping(
|
||||
NewRemoteHost string,
|
||||
NewExternalPort uint16,
|
||||
NewProtocol string,
|
||||
NewInternalPort uint16,
|
||||
NewInternalClient string,
|
||||
NewEnabled bool,
|
||||
NewPortMappingDescription string,
|
||||
NewLeaseDuration uint32,
|
||||
) (err error)
|
||||
|
||||
GetExternalIPAddress() (
|
||||
NewExternalIPAddress string,
|
||||
err error,
|
||||
)
|
||||
}
|
||||
|
||||
func PickRouterClient(ctx context.Context) (RouterClient, error) {
|
||||
tasks, _ := errgroup.WithContext(ctx)
|
||||
// Request each type of client in parallel, and return what is found.
|
||||
var ip1Clients []*internetgateway2.WANIPConnection1
|
||||
tasks.Go(func() error {
|
||||
var err error
|
||||
ip1Clients, _, err = internetgateway2.NewWANIPConnection1Clients()
|
||||
return err
|
||||
})
|
||||
var ip2Clients []*internetgateway2.WANIPConnection2
|
||||
tasks.Go(func() error {
|
||||
var err error
|
||||
ip2Clients, _, err = internetgateway2.NewWANIPConnection2Clients()
|
||||
return err
|
||||
})
|
||||
var ppp1Clients []*internetgateway2.WANPPPConnection1
|
||||
tasks.Go(func() error {
|
||||
var err error
|
||||
ppp1Clients, _, err = internetgateway2.NewWANPPPConnection1Clients()
|
||||
return err
|
||||
})
|
||||
|
||||
if err := tasks.Wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Trivial handling for where we find exactly one device to talk to, you
|
||||
// might want to provide more flexible handling than this if multiple
|
||||
// devices are found.
|
||||
switch {
|
||||
case len(ip2Clients) == 1:
|
||||
return ip2Clients[0], nil
|
||||
case len(ip1Clients) == 1:
|
||||
return ip1Clients[0], nil
|
||||
case len(ppp1Clients) == 1:
|
||||
return ppp1Clients[0], nil
|
||||
default:
|
||||
return nil, errors.New("multiple or no services found")
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
You could then use this function to create a client, and both request the
|
||||
external IP address and forward it to a port on your local network, e.g:
|
||||
|
||||
```go
|
||||
func GetIPAndForwardPort(ctx context.Context) error {
|
||||
client, err := PickRouterClient(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
externalIP, err := client.GetExternalIPAddress()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Our external IP address is: ", externalIP)
|
||||
|
||||
return client.AddPortMapping(
|
||||
"",
|
||||
// External port number to expose to Internet:
|
||||
1234,
|
||||
// Forward TCP (this could be "UDP" if we wanted that instead).
|
||||
"TCP",
|
||||
// Internal port number on the LAN to forward to.
|
||||
// Some routers might not support this being different to the external
|
||||
// port number.
|
||||
1234,
|
||||
// Internal address on the LAN we want to forward to.
|
||||
"192.168.1.6",
|
||||
// Enabled:
|
||||
true,
|
||||
// Informational description for the client requesting the port forwarding.
|
||||
"MyProgramName",
|
||||
// How long should the port forward last for in seconds.
|
||||
// If you want to keep it open for longer and potentially across router
|
||||
// resets, you might want to periodically request before this elapses.
|
||||
3600,
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
The code above is of course just a relatively trivial example that you can
|
||||
tailor to your own use case.
|
||||
23
vendor/github.com/tailscale/goupnp/LICENSE
generated
vendored
Normal file
23
vendor/github.com/tailscale/goupnp/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
Copyright (c) 2013, John Beisley <johnbeisleyuk@gmail.com>
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification,
|
||||
are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this
|
||||
list of conditions and the following disclaimer in the documentation and/or
|
||||
other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
||||
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
||||
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
68
vendor/github.com/tailscale/goupnp/README.md
generated
vendored
Normal file
68
vendor/github.com/tailscale/goupnp/README.md
generated
vendored
Normal file
@@ -0,0 +1,68 @@
|
||||
# tailscale/goupnp
|
||||
|
||||
This repo is a temporary fork of [huin/goupnp](https://github.com/huin/goupnp)(A UPnP client
|
||||
library for Go), customized for UPnP support in Tailscale.
|
||||
|
||||
## Installation
|
||||
|
||||
Run `go get -u github.com/tailscale/goupnp`.
|
||||
|
||||
## Documentation
|
||||
|
||||
See [GUIDE.md](GUIDE.md) for a quick start on the most common use case for this
|
||||
library.
|
||||
|
||||
Supported DCPs (you probably want to start with one of these):
|
||||
|
||||
- [ av1](https://godoc.org/github.com/tailscale/goupnp/dcps/av1) - Client for UPnP Device Control Protocol MediaServer v1 and MediaRenderer v1.
|
||||
- [ internetgateway1](https://godoc.org/github.com/tailscale/goupnp/dcps/internetgateway1) - Client for UPnP Device Control Protocol Internet Gateway Device v1.
|
||||
- [ internetgateway2](https://godoc.org/github.com/tailscale/goupnp/dcps/internetgateway2) - Client for UPnP Device Control Protocol Internet Gateway Device v2.
|
||||
|
||||
Core components:
|
||||
|
||||
- [ (goupnp)](https://godoc.org/github.com/tailscale/goupnp) core library - contains datastructures and utilities typically used by the implemented DCPs.
|
||||
- [ httpu](https://godoc.org/github.com/tailscale/goupnp/httpu) HTTPU implementation, underlies SSDP.
|
||||
- [ ssdp](https://godoc.org/github.com/tailscale/goupnp/ssdp) SSDP client implementation (simple service discovery protocol) - used to discover UPnP services on a network.
|
||||
- [ soap](https://godoc.org/github.com/tailscale/goupnp/soap) SOAP client implementation (simple object access protocol) - used to communicate with discovered services.
|
||||
|
||||
## Regenerating dcps generated source code:
|
||||
|
||||
1. Build code generator:
|
||||
|
||||
`go get -u github.com/tailscale/goupnp/cmd/goupnpdcpgen`
|
||||
|
||||
2. Regenerate the code:
|
||||
|
||||
`go generate ./...`
|
||||
|
||||
## Supporting additional UPnP devices and services:
|
||||
|
||||
Supporting additional services is, in the trivial case, simply a matter of
|
||||
adding the service to the `dcpMetadata` whitelist in `cmd/goupnpdcpgen/metadata.go`,
|
||||
regenerating the source code (see above), and committing that source code.
|
||||
|
||||
However, it would be helpful if anyone needing such a service could test the
|
||||
service against the service they have, and then reporting any trouble
|
||||
encountered as an [issue on this
|
||||
project](https://github.com/tailscale/goupnp/issues/new). If it just works, then
|
||||
please report at least minimal working functionality as an issue, and
|
||||
optionally contribute the metadata upstream.
|
||||
|
||||
## Migrating due to Breaking Changes
|
||||
|
||||
- \#40 introduced a breaking change to handling non-utf8 encodings, but removes a heavy
|
||||
dependency on `golang.org/x/net` with charset encodings. If this breaks your usage of this
|
||||
library, you can return to the old behavior by modifying the exported variable and importing
|
||||
the package yourself:
|
||||
|
||||
```go
|
||||
import (
|
||||
"golang.org/x/net/html/charset"
|
||||
"github.com/tailscale/goupnp"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// should be modified before goupnp libraries are in use.
|
||||
goupnp.CharsetReaderFault = charset.NewReaderLabel
|
||||
}
|
||||
```
|
||||
2
vendor/github.com/tailscale/goupnp/dcps/internetgateway2/gen.go
generated
vendored
Normal file
2
vendor/github.com/tailscale/goupnp/dcps/internetgateway2/gen.go
generated
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
//go:generate goupnpdcpgen -dcp_name internetgateway2
|
||||
package internetgateway2
|
||||
2351
vendor/github.com/tailscale/goupnp/dcps/internetgateway2/internetgateway2.go
generated
vendored
Normal file
2351
vendor/github.com/tailscale/goupnp/dcps/internetgateway2/internetgateway2.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
192
vendor/github.com/tailscale/goupnp/device.go
generated
vendored
Normal file
192
vendor/github.com/tailscale/goupnp/device.go
generated
vendored
Normal file
@@ -0,0 +1,192 @@
|
||||
// This file contains XML structures for communicating with UPnP devices.
|
||||
|
||||
package goupnp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/tailscale/goupnp/scpd"
|
||||
"github.com/tailscale/goupnp/soap"
|
||||
)
|
||||
|
||||
const (
|
||||
DeviceXMLNamespace = "urn:schemas-upnp-org:device-1-0"
|
||||
)
|
||||
|
||||
// RootDevice is the device description as described by section 2.3 "Device
|
||||
// description" in
|
||||
// http://upnp.org/specs/arch/UPnP-arch-DeviceArchitecture-v1.1.pdf
|
||||
type RootDevice struct {
|
||||
XMLName xml.Name `xml:"root"`
|
||||
SpecVersion SpecVersion `xml:"specVersion"`
|
||||
URLBase url.URL `xml:"-"`
|
||||
URLBaseStr string `xml:"URLBase"`
|
||||
Device Device `xml:"device"`
|
||||
}
|
||||
|
||||
// SetURLBase sets the URLBase for the RootDevice and its underlying components.
|
||||
func (root *RootDevice) SetURLBase(urlBase *url.URL) {
|
||||
root.URLBase = *urlBase
|
||||
root.URLBaseStr = urlBase.String()
|
||||
root.Device.SetURLBase(urlBase)
|
||||
}
|
||||
|
||||
// SpecVersion is part of a RootDevice, describes the version of the
|
||||
// specification that the data adheres to.
|
||||
type SpecVersion struct {
|
||||
Major int32 `xml:"major"`
|
||||
Minor int32 `xml:"minor"`
|
||||
}
|
||||
|
||||
// Device is a UPnP device. It can have child devices.
|
||||
type Device struct {
|
||||
DeviceType string `xml:"deviceType"`
|
||||
FriendlyName string `xml:"friendlyName"`
|
||||
Manufacturer string `xml:"manufacturer"`
|
||||
ManufacturerURL URLField `xml:"manufacturerURL"`
|
||||
ModelDescription string `xml:"modelDescription"`
|
||||
ModelName string `xml:"modelName"`
|
||||
ModelNumber string `xml:"modelNumber"`
|
||||
ModelURL URLField `xml:"modelURL"`
|
||||
SerialNumber string `xml:"serialNumber"`
|
||||
UDN string `xml:"UDN"`
|
||||
UPC string `xml:"UPC,omitempty"`
|
||||
Icons []Icon `xml:"iconList>icon,omitempty"`
|
||||
Services []Service `xml:"serviceList>service,omitempty"`
|
||||
Devices []Device `xml:"deviceList>device,omitempty"`
|
||||
|
||||
// Extra observed elements:
|
||||
PresentationURL URLField `xml:"presentationURL"`
|
||||
}
|
||||
|
||||
// VisitDevices calls visitor for the device, and all its descendent devices.
|
||||
func (device *Device) VisitDevices(visitor func(*Device)) {
|
||||
visitor(device)
|
||||
for i := range device.Devices {
|
||||
device.Devices[i].VisitDevices(visitor)
|
||||
}
|
||||
}
|
||||
|
||||
// VisitServices calls visitor for all Services under the device and all its
|
||||
// descendent devices.
|
||||
func (device *Device) VisitServices(visitor func(*Service)) {
|
||||
device.VisitDevices(func(d *Device) {
|
||||
for i := range d.Services {
|
||||
visitor(&d.Services[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// FindService finds all (if any) Services under the device and its descendents
|
||||
// that have the given ServiceType.
|
||||
func (device *Device) FindService(ctx context.Context, serviceType string) []*Service {
|
||||
var services []*Service
|
||||
device.VisitServices(func(s *Service) {
|
||||
if s.ServiceType == serviceType {
|
||||
services = append(services, s)
|
||||
}
|
||||
})
|
||||
return services
|
||||
}
|
||||
|
||||
// SetURLBase sets the URLBase for the Device and its underlying components.
|
||||
func (device *Device) SetURLBase(urlBase *url.URL) {
|
||||
device.ManufacturerURL.SetURLBase(urlBase)
|
||||
device.ModelURL.SetURLBase(urlBase)
|
||||
device.PresentationURL.SetURLBase(urlBase)
|
||||
for i := range device.Icons {
|
||||
device.Icons[i].SetURLBase(urlBase)
|
||||
}
|
||||
for i := range device.Services {
|
||||
device.Services[i].SetURLBase(urlBase)
|
||||
}
|
||||
for i := range device.Devices {
|
||||
device.Devices[i].SetURLBase(urlBase)
|
||||
}
|
||||
}
|
||||
|
||||
func (device *Device) String() string {
|
||||
return fmt.Sprintf("Device ID %s : %s (%s)", device.UDN, device.DeviceType, device.FriendlyName)
|
||||
}
|
||||
|
||||
// Icon is a representative image that a device might include in its
|
||||
// description.
|
||||
type Icon struct {
|
||||
Mimetype string `xml:"mimetype"`
|
||||
Width int32 `xml:"width"`
|
||||
Height int32 `xml:"height"`
|
||||
Depth int32 `xml:"depth"`
|
||||
URL URLField `xml:"url"`
|
||||
}
|
||||
|
||||
// SetURLBase sets the URLBase for the Icon.
|
||||
func (icon *Icon) SetURLBase(url *url.URL) {
|
||||
icon.URL.SetURLBase(url)
|
||||
}
|
||||
|
||||
// Service is a service provided by a UPnP Device.
|
||||
type Service struct {
|
||||
ServiceType string `xml:"serviceType"`
|
||||
ServiceId string `xml:"serviceId"`
|
||||
SCPDURL URLField `xml:"SCPDURL"`
|
||||
ControlURL URLField `xml:"controlURL"`
|
||||
EventSubURL URLField `xml:"eventSubURL"`
|
||||
}
|
||||
|
||||
// SetURLBase sets the URLBase for the Service.
|
||||
func (srv *Service) SetURLBase(urlBase *url.URL) {
|
||||
srv.SCPDURL.SetURLBase(urlBase)
|
||||
srv.ControlURL.SetURLBase(urlBase)
|
||||
srv.EventSubURL.SetURLBase(urlBase)
|
||||
}
|
||||
|
||||
func (srv *Service) String() string {
|
||||
return fmt.Sprintf("Service ID %s : %s", srv.ServiceId, srv.ServiceType)
|
||||
}
|
||||
|
||||
// RequestSCPD requests the SCPD (soap actions and state variables description)
|
||||
// for the service.
|
||||
func (srv *Service) RequestSCPD(ctx context.Context) (*scpd.SCPD, error) {
|
||||
if !srv.SCPDURL.Ok {
|
||||
return nil, errors.New("bad/missing SCPD URL, or no URLBase has been set")
|
||||
}
|
||||
s := new(scpd.SCPD)
|
||||
if err := requestXml(ctx, srv.SCPDURL.URL.String(), scpd.SCPDXMLNamespace, s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// NewSOAPClient returns a new SOAP client to the service's control
|
||||
// URL, using the provided http.Client.
|
||||
// If httpc is nil, http.DefaultClient is used.
|
||||
func (srv *Service) NewSOAPClient(httpc *http.Client) *soap.SOAPClient {
|
||||
if httpc == nil {
|
||||
httpc = http.DefaultClient
|
||||
}
|
||||
return soap.NewSOAPClient(srv.ControlURL.URL, httpc)
|
||||
}
|
||||
|
||||
// URLField is a URL that is part of a device description.
|
||||
type URLField struct {
|
||||
URL url.URL `xml:"-"`
|
||||
Ok bool `xml:"-"`
|
||||
Str string `xml:",chardata"`
|
||||
}
|
||||
|
||||
func (uf *URLField) SetURLBase(urlBase *url.URL) {
|
||||
refUrl, err := url.Parse(uf.Str)
|
||||
if err != nil {
|
||||
uf.URL = url.URL{}
|
||||
uf.Ok = false
|
||||
return
|
||||
}
|
||||
|
||||
uf.URL = *urlBase.ResolveReference(refUrl)
|
||||
uf.Ok = true
|
||||
}
|
||||
164
vendor/github.com/tailscale/goupnp/goupnp.go
generated
vendored
Normal file
164
vendor/github.com/tailscale/goupnp/goupnp.go
generated
vendored
Normal file
@@ -0,0 +1,164 @@
|
||||
// goupnp is an implementation of a client for various UPnP services.
|
||||
//
|
||||
// For most uses, it is recommended to use the code-generated packages under
|
||||
// github.com/tailscale/goupnp/dcps. Example use is shown at
|
||||
// http://godoc.org/github.com/tailscale/goupnp/example
|
||||
//
|
||||
// A commonly used client is internetgateway1.WANPPPConnection1:
|
||||
// http://godoc.org/github.com/tailscale/goupnp/dcps/internetgateway1#WANPPPConnection1
|
||||
//
|
||||
// Currently only a couple of schemas have code generated for them from the
|
||||
// UPnP example XML specifications. Not all methods will work on these clients,
|
||||
// because the generated stubs contain the full set of specified methods from
|
||||
// the XML specifications, and the discovered services will likely support a
|
||||
// subset of those methods.
|
||||
package goupnp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/tailscale/goupnp/ssdp"
|
||||
)
|
||||
|
||||
// ContextError is an error that wraps an error with some context information.
|
||||
type ContextError struct {
|
||||
Context string
|
||||
Err error
|
||||
}
|
||||
|
||||
func ctxError(err error, msg string) ContextError {
|
||||
return ContextError{Context: msg, Err: err}
|
||||
}
|
||||
|
||||
func ctxErrorf(err error, msg string, args ...interface{}) ContextError {
|
||||
return ContextError{
|
||||
Context: fmt.Sprintf(msg, args...),
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (err ContextError) Error() string {
|
||||
return fmt.Sprintf("%s: %v", err.Context, err.Err)
|
||||
}
|
||||
|
||||
// MaybeRootDevice contains either a RootDevice or an error.
|
||||
type MaybeRootDevice struct {
|
||||
// Identifier of the device.
|
||||
USN string
|
||||
|
||||
// Set iff Err == nil.
|
||||
Root *RootDevice
|
||||
|
||||
// The location the device was discovered at. This can be used with
|
||||
// DeviceByURL, assuming the device is still present. A location represents
|
||||
// the discovery of a device, regardless of if there was an error probing it.
|
||||
Location *url.URL
|
||||
|
||||
// Any error encountered probing a discovered device.
|
||||
Err error
|
||||
}
|
||||
|
||||
// DiscoverDevices attempts to find targets of the given type. This is
|
||||
// typically the entry-point for this package. searchTarget is typically a URN
|
||||
// in the form "urn:schemas-upnp-org:device:..." or
|
||||
// "urn:schemas-upnp-org:service:...". A single error is returned for errors
|
||||
// while attempting to send the query. An error or RootDevice is returned for
|
||||
// each discovered RootDevice.
|
||||
func DiscoverDevices(ctx context.Context, searchTarget string) ([]MaybeRootDevice, error) {
|
||||
hc, err := httpuClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer hc.Close()
|
||||
responses, err := ssdp.SSDPRawSearch(ctx, hc, string(searchTarget), 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
results := make([]MaybeRootDevice, len(responses))
|
||||
for i, response := range responses {
|
||||
maybe := &results[i]
|
||||
maybe.USN = response.Header.Get("USN")
|
||||
loc, err := response.Location()
|
||||
if err != nil {
|
||||
maybe.Err = ContextError{"unexpected bad location from search", err}
|
||||
continue
|
||||
}
|
||||
maybe.Location = loc
|
||||
if root, err := DeviceByURL(ctx, loc); err != nil {
|
||||
maybe.Err = err
|
||||
} else {
|
||||
maybe.Root = root
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func DeviceByURL(ctx context.Context, loc *url.URL) (*RootDevice, error) {
|
||||
locStr := loc.String()
|
||||
root := new(RootDevice)
|
||||
if err := requestXml(ctx, locStr, DeviceXMLNamespace, root); err != nil {
|
||||
return nil, ContextError{fmt.Sprintf("error requesting root device details from %q", locStr), err}
|
||||
}
|
||||
var urlBaseStr string
|
||||
if root.URLBaseStr != "" {
|
||||
urlBaseStr = root.URLBaseStr
|
||||
} else {
|
||||
urlBaseStr = locStr
|
||||
}
|
||||
urlBase, err := url.Parse(urlBaseStr)
|
||||
if err != nil {
|
||||
return nil, ContextError{fmt.Sprintf("error parsing location URL %q", locStr), err}
|
||||
}
|
||||
root.SetURLBase(urlBase)
|
||||
return root, nil
|
||||
}
|
||||
|
||||
// CharsetReaderDefault specifies the charset reader used while decoding the output
|
||||
// from a UPnP server. It can be modified in an init function to allow for non-utf8 encodings,
|
||||
// but should not be changed after requesting clients.
|
||||
var CharsetReaderDefault func(charset string, input io.Reader) (io.Reader, error)
|
||||
|
||||
// contextKey is an unexported type which prevents construction of other contextKeys.
|
||||
type httpContextKey struct{}
|
||||
|
||||
// WithHTTPClient returns a context wrapping ctx with the HTTP client set to c.
|
||||
// If c is nil, http.DefaultClient is used.
|
||||
func WithHTTPClient(ctx context.Context, c *http.Client) context.Context {
|
||||
return context.WithValue(ctx, httpContextKey{}, c)
|
||||
}
|
||||
|
||||
func httpClient(ctx context.Context) *http.Client {
|
||||
if c, _ := ctx.Value(httpContextKey{}).(*http.Client); c != nil {
|
||||
return c
|
||||
}
|
||||
return http.DefaultClient
|
||||
}
|
||||
|
||||
func requestXml(ctx context.Context, url string, defaultSpace string, into interface{}) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := httpClient(ctx).Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return fmt.Errorf("goupnp: got response status %s from %q", resp.Status, url)
|
||||
}
|
||||
|
||||
decoder := xml.NewDecoder(resp.Body)
|
||||
decoder.DefaultSpace = defaultSpace
|
||||
decoder.CharsetReader = CharsetReaderDefault
|
||||
|
||||
return decoder.Decode(into)
|
||||
}
|
||||
164
vendor/github.com/tailscale/goupnp/httpu/httpu.go
generated
vendored
Normal file
164
vendor/github.com/tailscale/goupnp/httpu/httpu.go
generated
vendored
Normal file
@@ -0,0 +1,164 @@
|
||||
package httpu
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ClientInterface is the general interface provided to perform HTTP-over-UDP
|
||||
// requests.
|
||||
type ClientInterface interface {
|
||||
io.Closer
|
||||
// Do performs a request. The timeout is how long to wait for before returning
|
||||
// the responses that were received. An error is only returned for failing to
|
||||
// send the request. Failures in receipt simply do not add to the resulting
|
||||
// responses.
|
||||
Do(req *http.Request, numSends int) ([]*http.Response, error)
|
||||
}
|
||||
|
||||
// HTTPUClient is a client for dealing with HTTPU (HTTP over UDP). Its typical
|
||||
// function is for HTTPMU, and particularly SSDP.
|
||||
type HTTPUClient struct {
|
||||
connLock sync.Mutex // Protects use of conn.
|
||||
conn net.PacketConn
|
||||
}
|
||||
|
||||
var _ ClientInterface = &HTTPUClient{}
|
||||
|
||||
// NewHTTPUClient creates a new HTTPUClient, opening up a new UDP socket for the
|
||||
// purpose.
|
||||
func NewHTTPUClient() (*HTTPUClient, error) {
|
||||
conn, err := net.ListenPacket("udp", ":0")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &HTTPUClient{conn: conn}, nil
|
||||
}
|
||||
|
||||
// NewHTTPUClientAddr creates a new HTTPUClient which will broadcast packets
|
||||
// from the specified address, opening up a new UDP socket for the purpose
|
||||
func NewHTTPUClientAddr(addr string) (*HTTPUClient, error) {
|
||||
ip := net.ParseIP(addr)
|
||||
if ip == nil {
|
||||
return nil, errors.New("Invalid listening address")
|
||||
}
|
||||
conn, err := net.ListenPacket("udp", ip.String()+":0")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &HTTPUClient{conn: conn}, nil
|
||||
}
|
||||
|
||||
// Close shuts down the client. The client will no longer be useful following
|
||||
// this.
|
||||
func (httpu *HTTPUClient) Close() error {
|
||||
httpu.connLock.Lock()
|
||||
defer httpu.connLock.Unlock()
|
||||
return httpu.conn.Close()
|
||||
}
|
||||
|
||||
// Do implements ClientInterface.Do.
|
||||
//
|
||||
// Note that at present only one concurrent connection will happen per
|
||||
// HTTPUClient.
|
||||
func (httpu *HTTPUClient) Do(
|
||||
req *http.Request,
|
||||
numSends int,
|
||||
) ([]*http.Response, error) {
|
||||
httpu.connLock.Lock()
|
||||
defer httpu.connLock.Unlock()
|
||||
|
||||
// Create the request. This is a subset of what http.Request.Write does
|
||||
// deliberately to avoid creating extra fields which may confuse some
|
||||
// devices.
|
||||
var requestBuf bytes.Buffer
|
||||
method := req.Method
|
||||
if method == "" {
|
||||
method = "GET"
|
||||
}
|
||||
if _, err := fmt.Fprintf(&requestBuf, "%s %s HTTP/1.1\r\n", method, req.URL.RequestURI()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Header.Write(&requestBuf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := requestBuf.Write([]byte{'\r', '\n'}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
destAddr, err := net.ResolveUDPAddr("udp", req.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx := req.Context()
|
||||
deadline, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
deadline = time.Now().Add(2 * time.Second)
|
||||
}
|
||||
returned := make(chan struct{})
|
||||
defer close(returned)
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// if context is cancelled, stop any connections by setting time in the past.
|
||||
httpu.conn.SetDeadline(time.Now().Add(-time.Second))
|
||||
case <-returned:
|
||||
}
|
||||
}()
|
||||
if err = httpu.conn.SetDeadline(deadline); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send request.
|
||||
for i := 0; i < numSends; i++ {
|
||||
if n, err := httpu.conn.WriteTo(requestBuf.Bytes(), destAddr); err != nil {
|
||||
return nil, err
|
||||
} else if n < len(requestBuf.Bytes()) {
|
||||
return nil, fmt.Errorf("httpu: wrote %d bytes rather than full %d in request",
|
||||
n, len(requestBuf.Bytes()))
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Await responses until timeout.
|
||||
var responses []*http.Response
|
||||
responseBytes := make([]byte, 2048)
|
||||
for {
|
||||
// 2048 bytes should be sufficient for most networks.
|
||||
n, _, err := httpu.conn.ReadFrom(responseBytes)
|
||||
if err != nil {
|
||||
if err, ok := err.(net.Error); ok {
|
||||
if err.Timeout() {
|
||||
break
|
||||
}
|
||||
if err.Temporary() {
|
||||
// Sleep in case this is a persistent error to avoid pegging CPU until deadline.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse response.
|
||||
response, err := http.ReadResponse(bufio.NewReader(bytes.NewBuffer(responseBytes[:n])), req)
|
||||
if err != nil {
|
||||
log.Printf("httpu: error while parsing response: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
responses = append(responses, response)
|
||||
}
|
||||
|
||||
// Timeout reached - return discovered responses.
|
||||
return responses, nil
|
||||
}
|
||||
72
vendor/github.com/tailscale/goupnp/httpu/multiclient.go
generated
vendored
Normal file
72
vendor/github.com/tailscale/goupnp/httpu/multiclient.go
generated
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
package httpu
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// MultiClient dispatches requests out to all the delegated clients.
|
||||
type MultiClient struct {
|
||||
// The HTTPU clients to delegate to.
|
||||
delegates []ClientInterface
|
||||
}
|
||||
|
||||
var _ ClientInterface = &MultiClient{}
|
||||
|
||||
// NewMultiClient creates a new MultiClient that delegates to all the given
|
||||
// clients.
|
||||
func NewMultiClient(delegates []ClientInterface) *MultiClient {
|
||||
return &MultiClient{delegates: delegates}
|
||||
}
|
||||
|
||||
func (mc *MultiClient) Close() error {
|
||||
for _, d := range mc.delegates {
|
||||
d.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Do implements ClientInterface.Do.
|
||||
func (mc *MultiClient) Do(
|
||||
req *http.Request,
|
||||
numSends int,
|
||||
) ([]*http.Response, error) {
|
||||
tasks := &errgroup.Group{}
|
||||
|
||||
results := make(chan []*http.Response)
|
||||
tasks.Go(func() error {
|
||||
defer close(results)
|
||||
return mc.sendRequests(results, req, numSends)
|
||||
})
|
||||
|
||||
var responses []*http.Response
|
||||
tasks.Go(func() error {
|
||||
for rs := range results {
|
||||
responses = append(responses, rs...)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return responses, tasks.Wait()
|
||||
}
|
||||
|
||||
func (mc *MultiClient) sendRequests(
|
||||
results chan<- []*http.Response,
|
||||
req *http.Request,
|
||||
numSends int,
|
||||
) error {
|
||||
tasks := &errgroup.Group{}
|
||||
for _, d := range mc.delegates {
|
||||
d := d // copy for closure
|
||||
tasks.Go(func() error {
|
||||
responses, err := d.Do(req, numSends)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
results <- responses
|
||||
return nil
|
||||
})
|
||||
}
|
||||
return tasks.Wait()
|
||||
}
|
||||
108
vendor/github.com/tailscale/goupnp/httpu/serve.go
generated
vendored
Normal file
108
vendor/github.com/tailscale/goupnp/httpu/serve.go
generated
vendored
Normal file
@@ -0,0 +1,108 @@
|
||||
package httpu
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMaxMessageBytes = 2048
|
||||
)
|
||||
|
||||
var (
|
||||
trailingWhitespaceRx = regexp.MustCompile(" +\r\n")
|
||||
crlf = []byte("\r\n")
|
||||
)
|
||||
|
||||
// Handler is the interface by which received HTTPU messages are passed to
|
||||
// handling code.
|
||||
type Handler interface {
|
||||
// ServeMessage is called for each HTTPU message received. peerAddr contains
|
||||
// the address that the message was received from.
|
||||
ServeMessage(r *http.Request)
|
||||
}
|
||||
|
||||
// HandlerFunc is a function-to-Handler adapter.
|
||||
type HandlerFunc func(r *http.Request)
|
||||
|
||||
func (f HandlerFunc) ServeMessage(r *http.Request) {
|
||||
f(r)
|
||||
}
|
||||
|
||||
// A Server defines parameters for running an HTTPU server.
|
||||
type Server struct {
|
||||
Addr string // UDP address to listen on
|
||||
Multicast bool // Should listen for multicast?
|
||||
Interface *net.Interface // Network interface to listen on for multicast, nil for default multicast interface
|
||||
Handler Handler // handler to invoke
|
||||
MaxMessageBytes int // maximum number of bytes to read from a packet, DefaultMaxMessageBytes if 0
|
||||
}
|
||||
|
||||
// ListenAndServe listens on the UDP network address srv.Addr. If srv.Multicast
|
||||
// is true, then a multicast UDP listener will be used on srv.Interface (or
|
||||
// default interface if nil).
|
||||
func (srv *Server) ListenAndServe() error {
|
||||
var err error
|
||||
|
||||
var addr *net.UDPAddr
|
||||
if addr, err = net.ResolveUDPAddr("udp", srv.Addr); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
var conn net.PacketConn
|
||||
if srv.Multicast {
|
||||
if conn, err = net.ListenMulticastUDP("udp", srv.Interface, addr); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if conn, err = net.ListenUDP("udp", addr); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return srv.Serve(conn)
|
||||
}
|
||||
|
||||
// Serve messages received on the given packet listener to the srv.Handler.
|
||||
func (srv *Server) Serve(l net.PacketConn) error {
|
||||
maxMessageBytes := DefaultMaxMessageBytes
|
||||
if srv.MaxMessageBytes != 0 {
|
||||
maxMessageBytes = srv.MaxMessageBytes
|
||||
}
|
||||
for {
|
||||
buf := make([]byte, maxMessageBytes)
|
||||
n, peerAddr, err := l.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buf = buf[:n]
|
||||
|
||||
go func(buf []byte, peerAddr net.Addr) {
|
||||
// At least one router's UPnP implementation has added a trailing space
|
||||
// after "HTTP/1.1" - trim it.
|
||||
buf = trailingWhitespaceRx.ReplaceAllLiteral(buf, crlf)
|
||||
|
||||
req, err := http.ReadRequest(bufio.NewReader(bytes.NewBuffer(buf)))
|
||||
if err != nil {
|
||||
log.Printf("httpu: Failed to parse request: %v", err)
|
||||
return
|
||||
}
|
||||
req.RemoteAddr = peerAddr.String()
|
||||
srv.Handler.ServeMessage(req)
|
||||
// No need to call req.Body.Close - underlying reader is bytes.Buffer.
|
||||
}(buf, peerAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// Serve messages received on the given packet listener to the given handler.
|
||||
func Serve(l net.PacketConn, handler Handler) error {
|
||||
srv := Server{
|
||||
Handler: handler,
|
||||
MaxMessageBytes: DefaultMaxMessageBytes,
|
||||
}
|
||||
return srv.Serve(l)
|
||||
}
|
||||
65
vendor/github.com/tailscale/goupnp/network.go
generated
vendored
Normal file
65
vendor/github.com/tailscale/goupnp/network.go
generated
vendored
Normal file
@@ -0,0 +1,65 @@
|
||||
package goupnp
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/tailscale/goupnp/httpu"
|
||||
)
|
||||
|
||||
// httpuClient creates a HTTPU client that multiplexes to all multicast-capable
|
||||
// IPv4 addresses on the host. Returns a function to clean up once the client is
|
||||
// no longer required.
|
||||
func httpuClient() (*httpu.MultiClient, error) {
|
||||
addrs, err := localIPv4MCastAddrs()
|
||||
if err != nil {
|
||||
return nil, ctxError(err, "requesting host IPv4 addresses")
|
||||
}
|
||||
|
||||
delegates := make([]httpu.ClientInterface, 0, len(addrs))
|
||||
for _, addr := range addrs {
|
||||
c, err := httpu.NewHTTPUClientAddr(addr)
|
||||
if err != nil {
|
||||
return nil, ctxErrorf(err, "creating HTTPU client for address %s", addr)
|
||||
}
|
||||
delegates = append(delegates, c)
|
||||
}
|
||||
|
||||
return httpu.NewMultiClient(delegates), nil
|
||||
}
|
||||
|
||||
// localIPv2MCastAddrs returns the set of IPv4 addresses on multicast-able
|
||||
// network interfaces.
|
||||
func localIPv4MCastAddrs() ([]string, error) {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, ctxError(err, "requesting host interfaces")
|
||||
}
|
||||
|
||||
// Find the set of addresses to listen on.
|
||||
var addrs []string
|
||||
for _, iface := range ifaces {
|
||||
if iface.Flags&net.FlagMulticast == 0 || iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
|
||||
// Does not support multicast or is a loopback address.
|
||||
continue
|
||||
}
|
||||
ifaceAddrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
return nil, ctxErrorf(err,
|
||||
"finding addresses on interface %s", iface.Name)
|
||||
}
|
||||
for _, netAddr := range ifaceAddrs {
|
||||
addr, ok := netAddr.(*net.IPNet)
|
||||
if !ok {
|
||||
// Not an IPNet address.
|
||||
continue
|
||||
}
|
||||
if addr.IP.To4() == nil {
|
||||
// Not IPv4.
|
||||
continue
|
||||
}
|
||||
addrs = append(addrs, addr.IP.String())
|
||||
}
|
||||
}
|
||||
|
||||
return addrs, nil
|
||||
}
|
||||
167
vendor/github.com/tailscale/goupnp/scpd/scpd.go
generated
vendored
Normal file
167
vendor/github.com/tailscale/goupnp/scpd/scpd.go
generated
vendored
Normal file
@@ -0,0 +1,167 @@
|
||||
package scpd
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
SCPDXMLNamespace = "urn:schemas-upnp-org:service-1-0"
|
||||
)
|
||||
|
||||
func cleanWhitespace(s *string) {
|
||||
*s = strings.TrimSpace(*s)
|
||||
}
|
||||
|
||||
// SCPD is the service description as described by section 2.5 "Service
|
||||
// description" in
|
||||
// http://upnp.org/specs/arch/UPnP-arch-DeviceArchitecture-v1.1.pdf
|
||||
type SCPD struct {
|
||||
XMLName xml.Name `xml:"scpd"`
|
||||
ConfigId string `xml:"configId,attr"`
|
||||
SpecVersion SpecVersion `xml:"specVersion"`
|
||||
Actions []Action `xml:"actionList>action"`
|
||||
StateVariables []StateVariable `xml:"serviceStateTable>stateVariable"`
|
||||
}
|
||||
|
||||
// Clean attempts to remove stray whitespace etc. in the structure. It seems
|
||||
// unfortunately common for stray whitespace to be present in SCPD documents,
|
||||
// this method attempts to make it easy to clean them out.
|
||||
func (scpd *SCPD) Clean() {
|
||||
cleanWhitespace(&scpd.ConfigId)
|
||||
for i := range scpd.Actions {
|
||||
scpd.Actions[i].clean()
|
||||
}
|
||||
for i := range scpd.StateVariables {
|
||||
scpd.StateVariables[i].clean()
|
||||
}
|
||||
}
|
||||
|
||||
func (scpd *SCPD) GetStateVariable(variable string) *StateVariable {
|
||||
for i := range scpd.StateVariables {
|
||||
v := &scpd.StateVariables[i]
|
||||
if v.Name == variable {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (scpd *SCPD) GetAction(action string) *Action {
|
||||
for i := range scpd.Actions {
|
||||
a := &scpd.Actions[i]
|
||||
if a.Name == action {
|
||||
return a
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SpecVersion is part of a SCPD document, describes the version of the
|
||||
// specification that the data adheres to.
|
||||
type SpecVersion struct {
|
||||
Major int32 `xml:"major"`
|
||||
Minor int32 `xml:"minor"`
|
||||
}
|
||||
|
||||
type Action struct {
|
||||
Name string `xml:"name"`
|
||||
Arguments []Argument `xml:"argumentList>argument"`
|
||||
}
|
||||
|
||||
func (action *Action) clean() {
|
||||
cleanWhitespace(&action.Name)
|
||||
for i := range action.Arguments {
|
||||
action.Arguments[i].clean()
|
||||
}
|
||||
}
|
||||
|
||||
func (action *Action) InputArguments() []*Argument {
|
||||
var result []*Argument
|
||||
for i := range action.Arguments {
|
||||
arg := &action.Arguments[i]
|
||||
if arg.IsInput() {
|
||||
result = append(result, arg)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (action *Action) OutputArguments() []*Argument {
|
||||
var result []*Argument
|
||||
for i := range action.Arguments {
|
||||
arg := &action.Arguments[i]
|
||||
if arg.IsOutput() {
|
||||
result = append(result, arg)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type Argument struct {
|
||||
Name string `xml:"name"`
|
||||
Direction string `xml:"direction"` // in|out
|
||||
RelatedStateVariable string `xml:"relatedStateVariable"` // ?
|
||||
Retval string `xml:"retval"` // ?
|
||||
}
|
||||
|
||||
func (arg *Argument) clean() {
|
||||
cleanWhitespace(&arg.Name)
|
||||
cleanWhitespace(&arg.Direction)
|
||||
cleanWhitespace(&arg.RelatedStateVariable)
|
||||
cleanWhitespace(&arg.Retval)
|
||||
}
|
||||
|
||||
func (arg *Argument) IsInput() bool {
|
||||
return arg.Direction == "in"
|
||||
}
|
||||
|
||||
func (arg *Argument) IsOutput() bool {
|
||||
return arg.Direction == "out"
|
||||
}
|
||||
|
||||
type StateVariable struct {
|
||||
Name string `xml:"name"`
|
||||
SendEvents string `xml:"sendEvents,attr"` // yes|no
|
||||
Multicast string `xml:"multicast,attr"` // yes|no
|
||||
DataType DataType `xml:"dataType"`
|
||||
DefaultValue string `xml:"defaultValue"`
|
||||
AllowedValueRange *AllowedValueRange `xml:"allowedValueRange"`
|
||||
AllowedValues []string `xml:"allowedValueList>allowedValue"`
|
||||
}
|
||||
|
||||
func (v *StateVariable) clean() {
|
||||
cleanWhitespace(&v.Name)
|
||||
cleanWhitespace(&v.SendEvents)
|
||||
cleanWhitespace(&v.Multicast)
|
||||
v.DataType.clean()
|
||||
cleanWhitespace(&v.DefaultValue)
|
||||
if v.AllowedValueRange != nil {
|
||||
v.AllowedValueRange.clean()
|
||||
}
|
||||
for i := range v.AllowedValues {
|
||||
cleanWhitespace(&v.AllowedValues[i])
|
||||
}
|
||||
}
|
||||
|
||||
type AllowedValueRange struct {
|
||||
Minimum string `xml:"minimum"`
|
||||
Maximum string `xml:"maximum"`
|
||||
Step string `xml:"step"`
|
||||
}
|
||||
|
||||
func (r *AllowedValueRange) clean() {
|
||||
cleanWhitespace(&r.Minimum)
|
||||
cleanWhitespace(&r.Maximum)
|
||||
cleanWhitespace(&r.Step)
|
||||
}
|
||||
|
||||
type DataType struct {
|
||||
Name string `xml:",chardata"`
|
||||
Type string `xml:"type,attr"`
|
||||
}
|
||||
|
||||
func (dt *DataType) clean() {
|
||||
cleanWhitespace(&dt.Name)
|
||||
cleanWhitespace(&dt.Type)
|
||||
}
|
||||
88
vendor/github.com/tailscale/goupnp/service_client.go
generated
vendored
Normal file
88
vendor/github.com/tailscale/goupnp/service_client.go
generated
vendored
Normal file
@@ -0,0 +1,88 @@
|
||||
package goupnp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/tailscale/goupnp/soap"
|
||||
)
|
||||
|
||||
// ServiceClient is a SOAP client, root device and the service for the SOAP
|
||||
// client rolled into one value. The root device, location, and service are
|
||||
// intended to be informational. Location can be used to later recreate a
|
||||
// ServiceClient with NewServiceClientByURL if the service is still present;
|
||||
// bypassing the discovery process.
|
||||
type ServiceClient struct {
|
||||
SOAPClient *soap.SOAPClient
|
||||
RootDevice *RootDevice
|
||||
Location *url.URL
|
||||
Service *Service
|
||||
}
|
||||
|
||||
// NewServiceClients discovers services, and returns clients for them. err will
|
||||
// report any error with the discovery process (blocking any device/service
|
||||
// discovery), errors reports errors on a per-root-device basis.
|
||||
func NewServiceClients(ctx context.Context, searchTarget string) (clients []ServiceClient, errors []error, err error) {
|
||||
var maybeRootDevices []MaybeRootDevice
|
||||
if maybeRootDevices, err = DiscoverDevices(ctx, searchTarget); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
clients = make([]ServiceClient, 0, len(maybeRootDevices))
|
||||
|
||||
for _, maybeRootDevice := range maybeRootDevices {
|
||||
if maybeRootDevice.Err != nil {
|
||||
errors = append(errors, maybeRootDevice.Err)
|
||||
continue
|
||||
}
|
||||
|
||||
deviceClients, err := NewServiceClientsFromRootDevice(ctx, maybeRootDevice.Root, maybeRootDevice.Location, searchTarget)
|
||||
if err != nil {
|
||||
errors = append(errors, err)
|
||||
continue
|
||||
}
|
||||
clients = append(clients, deviceClients...)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// NewServiceClientsByURL creates client(s) for the given service URN, for a
|
||||
// root device at the given URL.
|
||||
func NewServiceClientsByURL(ctx context.Context, loc *url.URL, searchTarget string) ([]ServiceClient, error) {
|
||||
rootDevice, err := DeviceByURL(ctx, loc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewServiceClientsFromRootDevice(ctx, rootDevice, loc, searchTarget)
|
||||
}
|
||||
|
||||
// NewServiceClientsFromDevice creates client(s) for the given service URN, in
|
||||
// a given root device. The loc parameter is simply assigned to the
|
||||
// Location attribute of the returned ServiceClient(s).
|
||||
func NewServiceClientsFromRootDevice(ctx context.Context, rootDevice *RootDevice, loc *url.URL, searchTarget string) ([]ServiceClient, error) {
|
||||
device := &rootDevice.Device
|
||||
srvs := device.FindService(ctx, searchTarget)
|
||||
if len(srvs) == 0 {
|
||||
return nil, fmt.Errorf("goupnp: service %q not found within device %q (UDN=%q)",
|
||||
searchTarget, device.FriendlyName, device.UDN)
|
||||
}
|
||||
clients := make([]ServiceClient, 0, len(srvs))
|
||||
for _, srv := range srvs {
|
||||
clients = append(clients, ServiceClient{
|
||||
SOAPClient: srv.NewSOAPClient(httpClient(ctx)),
|
||||
RootDevice: rootDevice,
|
||||
Location: loc,
|
||||
Service: srv,
|
||||
})
|
||||
}
|
||||
return clients, nil
|
||||
}
|
||||
|
||||
// GetServiceClient returns the ServiceClient itself. This is provided so that the
|
||||
// service client attributes can be accessed via an interface method on a
|
||||
// wrapping type.
|
||||
func (client *ServiceClient) GetServiceClient() *ServiceClient {
|
||||
return client
|
||||
}
|
||||
192
vendor/github.com/tailscale/goupnp/soap/soap.go
generated
vendored
Normal file
192
vendor/github.com/tailscale/goupnp/soap/soap.go
generated
vendored
Normal file
@@ -0,0 +1,192 @@
|
||||
// Definition for the SOAP structure required for UPnP's SOAP usage.
|
||||
|
||||
package soap
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
soapEncodingStyle = "http://schemas.xmlsoap.org/soap/encoding/"
|
||||
soapPrefix = xml.Header + `<s:Envelope xmlns:s="http://schemas.xmlsoap.org/soap/envelope/" s:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/"><s:Body>`
|
||||
soapSuffix = `</s:Body></s:Envelope>`
|
||||
)
|
||||
|
||||
type SOAPClient struct {
|
||||
EndpointURL url.URL
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
// NewSoapClient creates a new soap client to a specific URL using a given http.Client.
|
||||
// The http.Client must not be nil.
|
||||
func NewSOAPClient(endpointURL url.URL, httpClient *http.Client) *SOAPClient {
|
||||
return &SOAPClient{
|
||||
EndpointURL: endpointURL,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
}
|
||||
|
||||
// PerformSOAPAction makes a SOAP request, with the given action.
|
||||
// inAction and outAction must both be pointers to structs with string fields
|
||||
// only.
|
||||
func (client *SOAPClient) PerformAction(
|
||||
ctx context.Context, actionNamespace, actionName string,
|
||||
inAction interface{}, outAction interface{},
|
||||
) error {
|
||||
requestBytes, err := encodeRequestAction(actionNamespace, actionName, inAction)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := &http.Request{
|
||||
Method: "POST",
|
||||
URL: &client.EndpointURL,
|
||||
Header: http.Header{
|
||||
"SOAPACTION": []string{`"` + actionNamespace + "#" + actionName + `"`},
|
||||
"CONTENT-TYPE": []string{"text/xml; charset=\"utf-8\""},
|
||||
},
|
||||
Body: ioutil.NopCloser(bytes.NewBuffer(requestBytes)),
|
||||
// Set ContentLength to avoid chunked encoding - some servers might not support it.
|
||||
ContentLength: int64(len(requestBytes)),
|
||||
}
|
||||
|
||||
response, err := client.HTTPClient.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return fmt.Errorf("goupnp: error performing SOAP HTTP request: %v", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode != 200 && response.ContentLength == 0 {
|
||||
return fmt.Errorf("goupnp: SOAP request got HTTP %s", response.Status)
|
||||
}
|
||||
|
||||
responseEnv := newSOAPEnvelope()
|
||||
decoder := xml.NewDecoder(response.Body)
|
||||
if err := decoder.Decode(responseEnv); err != nil {
|
||||
return fmt.Errorf("goupnp: error decoding response body: %v", err)
|
||||
}
|
||||
|
||||
if responseEnv.Body.Fault != nil {
|
||||
return responseEnv.Body.Fault
|
||||
} else if response.StatusCode != 200 {
|
||||
return fmt.Errorf("goupnp: SOAP request got HTTP %s", response.Status)
|
||||
}
|
||||
|
||||
if outAction == nil {
|
||||
return nil
|
||||
}
|
||||
if err := xml.Unmarshal(responseEnv.Body.RawAction, outAction); err != nil {
|
||||
return fmt.Errorf("goupnp: error unmarshalling out action: %v, %v", err, responseEnv.Body.RawAction)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// newSOAPAction creates a soapEnvelope with the given action and arguments.
|
||||
func newSOAPEnvelope() *soapEnvelope {
|
||||
return &soapEnvelope{EncodingStyle: soapEncodingStyle}
|
||||
}
|
||||
|
||||
// encodeRequestAction is a hacky way to create an encoded SOAP envelope
|
||||
// containing the given action. Experiments with one router have shown that it
|
||||
// 500s for requests where the outer default xmlns is set to the SOAP
|
||||
// namespace, and then reassigning the default namespace within that to the
|
||||
// service namespace. Hand-coding the outer XML to work-around this.
|
||||
func encodeRequestAction(actionNamespace, actionName string, inAction interface{}) ([]byte, error) {
|
||||
requestBuf := new(bytes.Buffer)
|
||||
requestBuf.WriteString(soapPrefix)
|
||||
requestBuf.WriteString(`<u:`)
|
||||
xml.EscapeText(requestBuf, []byte(actionName))
|
||||
requestBuf.WriteString(` xmlns:u="`)
|
||||
xml.EscapeText(requestBuf, []byte(actionNamespace))
|
||||
requestBuf.WriteString(`">`)
|
||||
if inAction != nil {
|
||||
if err := encodeRequestArgs(requestBuf, inAction); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
requestBuf.WriteString(`</u:`)
|
||||
xml.EscapeText(requestBuf, []byte(actionName))
|
||||
requestBuf.WriteString(`>`)
|
||||
requestBuf.WriteString(soapSuffix)
|
||||
return requestBuf.Bytes(), nil
|
||||
}
|
||||
|
||||
func encodeRequestArgs(w *bytes.Buffer, inAction interface{}) error {
|
||||
in := reflect.Indirect(reflect.ValueOf(inAction))
|
||||
if in.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("goupnp: SOAP inAction is not a struct but of type %v", in.Type())
|
||||
}
|
||||
enc := xml.NewEncoder(w)
|
||||
nFields := in.NumField()
|
||||
inType := in.Type()
|
||||
for i := 0; i < nFields; i++ {
|
||||
field := inType.Field(i)
|
||||
argName := field.Name
|
||||
if nameOverride := field.Tag.Get("soap"); nameOverride != "" {
|
||||
argName = nameOverride
|
||||
}
|
||||
value := in.Field(i)
|
||||
if value.Kind() != reflect.String {
|
||||
return fmt.Errorf("goupnp: SOAP arg %q is not of type string, but of type %v", argName, value.Type())
|
||||
}
|
||||
elem := xml.StartElement{Name: xml.Name{Local: argName}}
|
||||
if err := enc.EncodeToken(elem); err != nil {
|
||||
return fmt.Errorf("goupnp: error encoding start element for SOAP arg %q: %v", argName, err)
|
||||
}
|
||||
if err := enc.Flush(); err != nil {
|
||||
return fmt.Errorf("goupnp: error flushing start element for SOAP arg %q: %v", argName, err)
|
||||
}
|
||||
if _, err := w.Write([]byte(escapeXMLText(value.Interface().(string)))); err != nil {
|
||||
return fmt.Errorf("goupnp: error writing value for SOAP arg %q: %v", argName, err)
|
||||
}
|
||||
if err := enc.EncodeToken(elem.End()); err != nil {
|
||||
return fmt.Errorf("goupnp: error encoding end element for SOAP arg %q: %v", argName, err)
|
||||
}
|
||||
}
|
||||
enc.Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
var replacer = strings.NewReplacer("<", "<", ">", ">", "&", "&")
|
||||
|
||||
// escapeXMLText is used by generated code to escape text in XML, but only
|
||||
// escaping the characters `<`, `>`, and `&`.
|
||||
//
|
||||
// This is provided in order to work around SOAP server implementations that
|
||||
// fail to decode XML correctly, specifically failing to decode `"`, `'`. Note
|
||||
// that this can only be safely used for injecting into XML text, but not into
|
||||
// attributes or other contexts.
|
||||
func escapeXMLText(s string) string {
|
||||
return replacer.Replace(s)
|
||||
}
|
||||
|
||||
type soapEnvelope struct {
|
||||
XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Envelope"`
|
||||
EncodingStyle string `xml:"http://schemas.xmlsoap.org/soap/envelope/ encodingStyle,attr"`
|
||||
Body soapBody `xml:"http://schemas.xmlsoap.org/soap/envelope/ Body"`
|
||||
}
|
||||
|
||||
type soapBody struct {
|
||||
Fault *SOAPFaultError `xml:"Fault"`
|
||||
RawAction []byte `xml:",innerxml"`
|
||||
}
|
||||
|
||||
// SOAPFaultError implements error, and contains SOAP fault information.
|
||||
type SOAPFaultError struct {
|
||||
FaultCode string `xml:"faultCode"`
|
||||
FaultString string `xml:"faultString"`
|
||||
Detail struct {
|
||||
Raw []byte `xml:",innerxml"`
|
||||
} `xml:"detail"`
|
||||
}
|
||||
|
||||
func (err *SOAPFaultError) Error() string {
|
||||
return fmt.Sprintf("SOAP fault: %s", err.FaultString)
|
||||
}
|
||||
490
vendor/github.com/tailscale/goupnp/soap/types.go
generated
vendored
Normal file
490
vendor/github.com/tailscale/goupnp/soap/types.go
generated
vendored
Normal file
@@ -0,0 +1,490 @@
|
||||
package soap
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
var (
|
||||
// localLoc acts like time.Local for this package, but is faked out by the
|
||||
// unit tests to ensure that things stay constant (especially when running
|
||||
// this test in a place where local time is UTC which might mask bugs).
|
||||
localLoc = time.Local
|
||||
)
|
||||
|
||||
func MarshalUi1(v uint8) (string, error) {
|
||||
return strconv.FormatUint(uint64(v), 10), nil
|
||||
}
|
||||
|
||||
func UnmarshalUi1(s string) (uint8, error) {
|
||||
v, err := strconv.ParseUint(s, 10, 8)
|
||||
return uint8(v), err
|
||||
}
|
||||
|
||||
func MarshalUi2(v uint16) (string, error) {
|
||||
return strconv.FormatUint(uint64(v), 10), nil
|
||||
}
|
||||
|
||||
func UnmarshalUi2(s string) (uint16, error) {
|
||||
v, err := strconv.ParseUint(s, 10, 16)
|
||||
return uint16(v), err
|
||||
}
|
||||
|
||||
func MarshalUi4(v uint32) (string, error) {
|
||||
return strconv.FormatUint(uint64(v), 10), nil
|
||||
}
|
||||
|
||||
func UnmarshalUi4(s string) (uint32, error) {
|
||||
v, err := strconv.ParseUint(s, 10, 32)
|
||||
return uint32(v), err
|
||||
}
|
||||
|
||||
func MarshalUi8(v uint64) (string, error) {
|
||||
return strconv.FormatUint(v, 10), nil
|
||||
}
|
||||
|
||||
func UnmarshalUi8(s string) (uint64, error) {
|
||||
v, err := strconv.ParseUint(s, 10, 64)
|
||||
return uint64(v), err
|
||||
}
|
||||
|
||||
func MarshalI1(v int8) (string, error) {
|
||||
return strconv.FormatInt(int64(v), 10), nil
|
||||
}
|
||||
|
||||
func UnmarshalI1(s string) (int8, error) {
|
||||
v, err := strconv.ParseInt(s, 10, 8)
|
||||
return int8(v), err
|
||||
}
|
||||
|
||||
func MarshalI2(v int16) (string, error) {
|
||||
return strconv.FormatInt(int64(v), 10), nil
|
||||
}
|
||||
|
||||
func UnmarshalI2(s string) (int16, error) {
|
||||
v, err := strconv.ParseInt(s, 10, 16)
|
||||
return int16(v), err
|
||||
}
|
||||
|
||||
func MarshalI4(v int32) (string, error) {
|
||||
return strconv.FormatInt(int64(v), 10), nil
|
||||
}
|
||||
|
||||
func UnmarshalI4(s string) (int32, error) {
|
||||
v, err := strconv.ParseInt(s, 10, 32)
|
||||
return int32(v), err
|
||||
}
|
||||
|
||||
func MarshalInt(v int64) (string, error) {
|
||||
return strconv.FormatInt(v, 10), nil
|
||||
}
|
||||
|
||||
func UnmarshalInt(s string) (int64, error) {
|
||||
return strconv.ParseInt(s, 10, 64)
|
||||
}
|
||||
|
||||
func MarshalR4(v float32) (string, error) {
|
||||
return strconv.FormatFloat(float64(v), 'G', -1, 32), nil
|
||||
}
|
||||
|
||||
func UnmarshalR4(s string) (float32, error) {
|
||||
v, err := strconv.ParseFloat(s, 32)
|
||||
return float32(v), err
|
||||
}
|
||||
|
||||
func MarshalR8(v float64) (string, error) {
|
||||
return strconv.FormatFloat(v, 'G', -1, 64), nil
|
||||
}
|
||||
|
||||
func UnmarshalR8(s string) (float64, error) {
|
||||
v, err := strconv.ParseFloat(s, 64)
|
||||
return float64(v), err
|
||||
}
|
||||
|
||||
// MarshalFixed14_4 marshals float64 to SOAP "fixed.14.4" type.
|
||||
func MarshalFixed14_4(v float64) (string, error) {
|
||||
if v >= 1e14 || v <= -1e14 {
|
||||
return "", fmt.Errorf("soap fixed14.4: value %v out of bounds", v)
|
||||
}
|
||||
return strconv.FormatFloat(v, 'f', 4, 64), nil
|
||||
}
|
||||
|
||||
// UnmarshalFixed14_4 unmarshals float64 from SOAP "fixed.14.4" type.
|
||||
func UnmarshalFixed14_4(s string) (float64, error) {
|
||||
v, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if v >= 1e14 || v <= -1e14 {
|
||||
return 0, fmt.Errorf("soap fixed14.4: value %q out of bounds", s)
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// MarshalChar marshals rune to SOAP "char" type.
|
||||
func MarshalChar(v rune) (string, error) {
|
||||
if v == 0 {
|
||||
return "", errors.New("soap char: rune 0 is not allowed")
|
||||
}
|
||||
return string(v), nil
|
||||
}
|
||||
|
||||
// UnmarshalChar unmarshals rune from SOAP "char" type.
|
||||
func UnmarshalChar(s string) (rune, error) {
|
||||
if len(s) == 0 {
|
||||
return 0, errors.New("soap char: got empty string")
|
||||
}
|
||||
r, n := utf8.DecodeRune([]byte(s))
|
||||
if n != len(s) {
|
||||
return 0, fmt.Errorf("soap char: value %q is not a single rune", s)
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func MarshalString(v string) (string, error) {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func UnmarshalString(v string) (string, error) {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func parseInt(s string, err *error) int {
|
||||
v, parseErr := strconv.ParseInt(s, 10, 64)
|
||||
if parseErr != nil {
|
||||
*err = parseErr
|
||||
}
|
||||
return int(v)
|
||||
}
|
||||
|
||||
var dateLayouts = []string{
|
||||
"2006-01-02",
|
||||
"20060102",
|
||||
|
||||
"2006-01",
|
||||
"200601",
|
||||
|
||||
"2006",
|
||||
}
|
||||
|
||||
func parseDateParts(s string) (year, month, day int, err error) {
|
||||
var parsed time.Time
|
||||
for _, layout := range dateLayouts {
|
||||
parsed, err = time.Parse(layout, s)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
err = fmt.Errorf("soap date: value %q is not in a recognized ISO8601 date format", s)
|
||||
return
|
||||
}
|
||||
year = int(parsed.Year())
|
||||
month = int(parsed.Month())
|
||||
day = int(parsed.Day())
|
||||
return
|
||||
}
|
||||
|
||||
var timeLayouts = []string{
|
||||
"15:04:05",
|
||||
"150405",
|
||||
|
||||
"15:04",
|
||||
"1504",
|
||||
|
||||
"15",
|
||||
}
|
||||
|
||||
func parseTimeParts(s string) (hour, minute, second int, err error) {
|
||||
// special case 24 hour time limit
|
||||
if s == "24:00:00" {
|
||||
return 24, 0, 0, nil
|
||||
}
|
||||
var parsed time.Time
|
||||
for _, layout := range timeLayouts {
|
||||
parsed, err = time.Parse(layout, s)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
err = fmt.Errorf("soap time: value %q is not in ISO8601 time format", s)
|
||||
return
|
||||
}
|
||||
|
||||
hour = parsed.Hour()
|
||||
minute = parsed.Minute()
|
||||
second = parsed.Second()
|
||||
return
|
||||
}
|
||||
|
||||
// (+|-)hh[[:]mm]
|
||||
var timezoneLayouts = []string{
|
||||
"+07:00",
|
||||
"-07:00",
|
||||
|
||||
"+0700",
|
||||
"-0700",
|
||||
|
||||
"+07",
|
||||
"-07",
|
||||
}
|
||||
|
||||
func parseTimezone(s string) (loc *time.Location, err error) {
|
||||
if s == "Z" {
|
||||
return time.UTC, nil
|
||||
}
|
||||
var parsed time.Time
|
||||
for _, layout := range timezoneLayouts {
|
||||
parsed, err = time.Parse(layout, s)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
err = fmt.Errorf("soap timezone: value %q is not in ISO8601 timezone format", s)
|
||||
return
|
||||
}
|
||||
return parsed.Location(), nil
|
||||
}
|
||||
|
||||
// splitCompleteDateTimeZone splits date, time and timezone apart from an
|
||||
// ISO8601 string. It does not ensure that the contents of each part are
|
||||
// correct, it merely splits on certain delimiters.
|
||||
// e.g "2010-09-08T12:15:10+0700" => "2010-09-08", "12:15:10", "+0700".
|
||||
// Timezone can only be present if time is also present.
|
||||
func splitCompleteDateTimeZone(s string) (dateStr, timeStr, zoneStr string) {
|
||||
dateIndex := strings.Index(s, "T")
|
||||
if dateIndex == -1 {
|
||||
dateStr = s
|
||||
return
|
||||
}
|
||||
dateStr = s[:dateIndex]
|
||||
zoneIndex := strings.IndexAny(s[dateIndex:], "Z+-")
|
||||
if zoneIndex == -1 {
|
||||
timeStr = s[dateIndex+1:]
|
||||
return
|
||||
}
|
||||
timeStr = s[dateIndex+1 : dateIndex+zoneIndex]
|
||||
zoneStr = s[dateIndex+zoneIndex:]
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalDate marshals time.Time to SOAP "date" type. Note that this converts
|
||||
// to local time, and discards the time-of-day components.
|
||||
func MarshalDate(v time.Time) (string, error) {
|
||||
return v.In(localLoc).Format("2006-01-02"), nil
|
||||
}
|
||||
|
||||
var dateFmts = []string{"2006-01-02", "20060102"}
|
||||
|
||||
// UnmarshalDate unmarshals time.Time from SOAP "date" type. This outputs the
|
||||
// date as midnight in the local time zone.
|
||||
func UnmarshalDate(s string) (time.Time, error) {
|
||||
year, month, day, err := parseDateParts(s)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
return time.Date(year, time.Month(month), day, 0, 0, 0, 0, localLoc), nil
|
||||
}
|
||||
|
||||
// TimeOfDay is used in cases where SOAP "time" or "time.tz" is used.
|
||||
type TimeOfDay struct {
|
||||
// Duration of time since midnight.
|
||||
FromMidnight time.Duration
|
||||
|
||||
// Offset is non-zero only if time.tz is used. It is otherwise ignored. If
|
||||
// non-zero, then it is regarded as a UTC offset in seconds. Note that the
|
||||
// sub-minutes is ignored by the marshal function.
|
||||
Location *time.Location
|
||||
}
|
||||
|
||||
// MarshalTimeOfDay marshals TimeOfDay to the "time" type.
|
||||
func MarshalTimeOfDay(v TimeOfDay) (string, error) {
|
||||
d := int64(v.FromMidnight / time.Second)
|
||||
hour := d / 3600
|
||||
d = d % 3600
|
||||
minute := d / 60
|
||||
second := d % 60
|
||||
|
||||
return fmt.Sprintf("%02d:%02d:%02d", hour, minute, second), nil
|
||||
}
|
||||
|
||||
// UnmarshalTimeOfDay unmarshals TimeOfDay from the "time" type.
|
||||
func UnmarshalTimeOfDay(s string) (TimeOfDay, error) {
|
||||
t, err := UnmarshalTimeOfDayTz(s)
|
||||
if err != nil {
|
||||
return t, err
|
||||
} else if t.Location != time.UTC {
|
||||
return t, fmt.Errorf("soap time: value %q contains unexpected timezone", s)
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// MarshalTimeOfDayTz marshals TimeOfDay to the "time.tz" type.
|
||||
func MarshalTimeOfDayTz(v TimeOfDay) (string, error) {
|
||||
d := int64(v.FromMidnight / time.Second)
|
||||
hour := d / 3600
|
||||
d = d % 3600
|
||||
minute := d / 60
|
||||
second := d % 60
|
||||
if v.Location == nil {
|
||||
t := time.Date(0, 0, 0, int(hour), int(minute), int(second), 0, time.UTC)
|
||||
return t.Format("15:04:05"), nil
|
||||
}
|
||||
t := time.Date(0, 0, 0, int(hour), int(minute), int(second), 0, v.Location)
|
||||
return t.Format("15:04:05Z07:00"), nil
|
||||
}
|
||||
|
||||
// UnmarshalTimeOfDayTz unmarshals TimeOfDay from the "time.tz" type.
|
||||
func UnmarshalTimeOfDayTz(s string) (tod TimeOfDay, err error) {
|
||||
zoneIndex := strings.IndexAny(s, "Z+-")
|
||||
var timePart string
|
||||
loc := time.UTC
|
||||
if zoneIndex == -1 {
|
||||
timePart = s
|
||||
} else {
|
||||
timePart = s[:zoneIndex]
|
||||
if loc, err = parseTimezone(s[zoneIndex:]); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
hour, minute, second, err := parseTimeParts(timePart)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
fromMidnight := time.Duration(hour*3600+minute*60+second) * time.Second
|
||||
|
||||
// ISO8601 special case - values up to 24:00:00 are allowed, so using
|
||||
// strictly greater-than for the maximum value.
|
||||
if fromMidnight > 24*time.Hour || minute >= 60 || second >= 60 {
|
||||
return TimeOfDay{Location: localLoc}, fmt.Errorf("soap time.tz: value %q has value(s) out of range", s)
|
||||
}
|
||||
|
||||
return TimeOfDay{
|
||||
FromMidnight: time.Duration(hour*3600+minute*60+second) * time.Second,
|
||||
Location: loc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MarshalDateTime marshals time.Time to SOAP "dateTime" type. Note that this
|
||||
// converts to local time.
|
||||
func MarshalDateTime(v time.Time) (string, error) {
|
||||
return v.In(localLoc).Format("2006-01-02T15:04:05"), nil
|
||||
}
|
||||
|
||||
// UnmarshalDateTime unmarshals time.Time from the SOAP "dateTime" type. This
|
||||
// returns a value in the local timezone.
|
||||
func UnmarshalDateTime(s string) (result time.Time, err error) {
|
||||
dateStr, timeStr, zoneStr := splitCompleteDateTimeZone(s)
|
||||
|
||||
if len(zoneStr) != 0 {
|
||||
err = fmt.Errorf("soap datetime: unexpected timezone in %q", s)
|
||||
return
|
||||
}
|
||||
|
||||
year, month, day, err := parseDateParts(dateStr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var hour, minute, second int
|
||||
if len(timeStr) != 0 {
|
||||
hour, minute, second, err = parseTimeParts(timeStr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
result = time.Date(year, time.Month(month), day, hour, minute, second, 0, localLoc)
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalDateTimeTz marshals time.Time to SOAP "dateTime.tz" type.
|
||||
func MarshalDateTimeTz(v time.Time) (string, error) {
|
||||
return v.Format("2006-01-02T15:04:05-07:00"), nil
|
||||
}
|
||||
|
||||
// UnmarshalDateTimeTz unmarshals time.Time from the SOAP "dateTime.tz" type.
|
||||
// This returns a value in the local timezone when the timezone is unspecified.
|
||||
func UnmarshalDateTimeTz(s string) (result time.Time, err error) {
|
||||
dateStr, timeStr, zoneStr := splitCompleteDateTimeZone(s)
|
||||
|
||||
year, month, day, err := parseDateParts(dateStr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var hour, minute, second int
|
||||
location := localLoc
|
||||
if len(timeStr) != 0 {
|
||||
hour, minute, second, err = parseTimeParts(timeStr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(zoneStr) != 0 {
|
||||
location, err = parseTimezone(zoneStr)
|
||||
}
|
||||
}
|
||||
|
||||
result = time.Date(year, time.Month(month), day, hour, minute, second, 0, location)
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalBoolean marshals bool to SOAP "boolean" type.
|
||||
func MarshalBoolean(v bool) (string, error) {
|
||||
if v {
|
||||
return "1", nil
|
||||
}
|
||||
return "0", nil
|
||||
}
|
||||
|
||||
// UnmarshalBoolean unmarshals bool from the SOAP "boolean" type.
|
||||
func UnmarshalBoolean(s string) (bool, error) {
|
||||
switch s {
|
||||
case "0", "false", "no":
|
||||
return false, nil
|
||||
case "1", "true", "yes":
|
||||
return true, nil
|
||||
}
|
||||
return false, fmt.Errorf("soap boolean: %q is not a valid boolean value", s)
|
||||
}
|
||||
|
||||
// MarshalBinBase64 marshals []byte to SOAP "bin.base64" type.
|
||||
func MarshalBinBase64(v []byte) (string, error) {
|
||||
return base64.StdEncoding.EncodeToString(v), nil
|
||||
}
|
||||
|
||||
// UnmarshalBinBase64 unmarshals []byte from the SOAP "bin.base64" type.
|
||||
func UnmarshalBinBase64(s string) ([]byte, error) {
|
||||
return base64.StdEncoding.DecodeString(s)
|
||||
}
|
||||
|
||||
// MarshalBinHex marshals []byte to SOAP "bin.hex" type.
|
||||
func MarshalBinHex(v []byte) (string, error) {
|
||||
return hex.EncodeToString(v), nil
|
||||
}
|
||||
|
||||
// UnmarshalBinHex unmarshals []byte from the SOAP "bin.hex" type.
|
||||
func UnmarshalBinHex(s string) ([]byte, error) {
|
||||
return hex.DecodeString(s)
|
||||
}
|
||||
|
||||
// MarshalURI marshals *url.URL to SOAP "uri" type.
|
||||
func MarshalURI(v *url.URL) (string, error) {
|
||||
return v.String(), nil
|
||||
}
|
||||
|
||||
// UnmarshalURI unmarshals *url.URL from the SOAP "uri" type.
|
||||
func UnmarshalURI(s string) (*url.URL, error) {
|
||||
return url.Parse(s)
|
||||
}
|
||||
312
vendor/github.com/tailscale/goupnp/ssdp/registry.go
generated
vendored
Normal file
312
vendor/github.com/tailscale/goupnp/ssdp/registry.go
generated
vendored
Normal file
@@ -0,0 +1,312 @@
|
||||
package ssdp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/goupnp/httpu"
|
||||
)
|
||||
|
||||
const (
|
||||
maxExpiryTimeSeconds = 24 * 60 * 60
|
||||
)
|
||||
|
||||
var (
|
||||
maxAgeRx = regexp.MustCompile("max-age= *([0-9]+)")
|
||||
)
|
||||
|
||||
const (
|
||||
EventAlive = EventType(iota)
|
||||
EventUpdate
|
||||
EventByeBye
|
||||
)
|
||||
|
||||
type EventType int8
|
||||
|
||||
func (et EventType) String() string {
|
||||
switch et {
|
||||
case EventAlive:
|
||||
return "EventAlive"
|
||||
case EventUpdate:
|
||||
return "EventUpdate"
|
||||
case EventByeBye:
|
||||
return "EventByeBye"
|
||||
default:
|
||||
return fmt.Sprintf("EventUnknown(%d)", int8(et))
|
||||
}
|
||||
}
|
||||
|
||||
type Update struct {
|
||||
// The USN of the service.
|
||||
USN string
|
||||
// What happened.
|
||||
EventType EventType
|
||||
// The entry, which is nil if the service was not known and
|
||||
// EventType==EventByeBye. The contents of this must not be modified as it is
|
||||
// shared with the registry and other listeners. Once created, the Registry
|
||||
// does not modify the Entry value - any updates are replaced with a new
|
||||
// Entry value.
|
||||
Entry *Entry
|
||||
}
|
||||
|
||||
type Entry struct {
|
||||
// The address that the entry data was actually received from.
|
||||
RemoteAddr string
|
||||
// Unique Service Name. Identifies a unique instance of a device or service.
|
||||
USN string
|
||||
// Notfication Type. The type of device or service being announced.
|
||||
NT string
|
||||
// Server's self-identifying string.
|
||||
Server string
|
||||
Host string
|
||||
// Location of the UPnP root device description.
|
||||
Location url.URL
|
||||
|
||||
// Despite BOOTID,CONFIGID being required fields, apparently they are not
|
||||
// always set by devices. Set to -1 if not present.
|
||||
|
||||
BootID int32
|
||||
ConfigID int32
|
||||
|
||||
SearchPort uint16
|
||||
|
||||
// When the last update was received for this entry identified by this USN.
|
||||
LastUpdate time.Time
|
||||
// When the last update's cached values are advised to expire.
|
||||
CacheExpiry time.Time
|
||||
}
|
||||
|
||||
func newEntryFromRequest(r *http.Request) (*Entry, error) {
|
||||
now := time.Now()
|
||||
expiryDuration, err := parseCacheControlMaxAge(r.Header.Get("CACHE-CONTROL"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ssdp: error parsing CACHE-CONTROL max age: %v", err)
|
||||
}
|
||||
|
||||
loc, err := url.Parse(r.Header.Get("LOCATION"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ssdp: error parsing entry Location URL: %v", err)
|
||||
}
|
||||
|
||||
bootID, err := parseUpnpIntHeader(r.Header, "BOOTID.UPNP.ORG", -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
configID, err := parseUpnpIntHeader(r.Header, "CONFIGID.UPNP.ORG", -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
searchPort, err := parseUpnpIntHeader(r.Header, "SEARCHPORT.UPNP.ORG", ssdpSearchPort)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if searchPort < 1 || searchPort > 65535 {
|
||||
return nil, fmt.Errorf("ssdp: search port %d is out of range", searchPort)
|
||||
}
|
||||
|
||||
return &Entry{
|
||||
RemoteAddr: r.RemoteAddr,
|
||||
USN: r.Header.Get("USN"),
|
||||
NT: r.Header.Get("NT"),
|
||||
Server: r.Header.Get("SERVER"),
|
||||
Host: r.Header.Get("HOST"),
|
||||
Location: *loc,
|
||||
BootID: bootID,
|
||||
ConfigID: configID,
|
||||
SearchPort: uint16(searchPort),
|
||||
LastUpdate: now,
|
||||
CacheExpiry: now.Add(expiryDuration),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseCacheControlMaxAge(cc string) (time.Duration, error) {
|
||||
matches := maxAgeRx.FindStringSubmatch(cc)
|
||||
if len(matches) != 2 {
|
||||
return 0, fmt.Errorf("did not find exactly one max-age in cache control header: %q", cc)
|
||||
}
|
||||
expirySeconds, err := strconv.ParseInt(matches[1], 10, 16)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if expirySeconds < 1 || expirySeconds > maxExpiryTimeSeconds {
|
||||
return 0, fmt.Errorf("rejecting bad expiry time of %d seconds", expirySeconds)
|
||||
}
|
||||
return time.Duration(expirySeconds) * time.Second, nil
|
||||
}
|
||||
|
||||
// parseUpnpIntHeader is intended to parse the
|
||||
// {BOOT,CONFIGID,SEARCHPORT}.UPNP.ORG header fields. It returns the def if
|
||||
// the head is empty or missing.
|
||||
func parseUpnpIntHeader(headers http.Header, headerName string, def int32) (int32, error) {
|
||||
s := headers.Get(headerName)
|
||||
if s == "" {
|
||||
return def, nil
|
||||
}
|
||||
v, err := strconv.ParseInt(s, 10, 32)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("ssdp: could not parse header %s: %v", headerName, err)
|
||||
}
|
||||
return int32(v), nil
|
||||
}
|
||||
|
||||
var _ httpu.Handler = new(Registry)
|
||||
|
||||
// Registry maintains knowledge of discovered devices and services.
|
||||
//
|
||||
// NOTE: the interface for this is experimental and may change, or go away
|
||||
// entirely.
|
||||
type Registry struct {
|
||||
lock sync.Mutex
|
||||
byUSN map[string]*Entry
|
||||
|
||||
listenersLock sync.RWMutex
|
||||
listeners map[chan<- Update]struct{}
|
||||
}
|
||||
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{
|
||||
byUSN: make(map[string]*Entry),
|
||||
listeners: make(map[chan<- Update]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// NewServerAndRegistry is a convenience function to create a registry, and an
|
||||
// httpu server to pass it messages. Call ListenAndServe on the server for
|
||||
// messages to be processed.
|
||||
func NewServerAndRegistry() (*httpu.Server, *Registry) {
|
||||
reg := NewRegistry()
|
||||
srv := &httpu.Server{
|
||||
Addr: ssdpUDP4Addr,
|
||||
Multicast: true,
|
||||
Handler: reg,
|
||||
}
|
||||
return srv, reg
|
||||
}
|
||||
|
||||
func (reg *Registry) AddListener(c chan<- Update) {
|
||||
reg.listenersLock.Lock()
|
||||
defer reg.listenersLock.Unlock()
|
||||
reg.listeners[c] = struct{}{}
|
||||
}
|
||||
|
||||
func (reg *Registry) RemoveListener(c chan<- Update) {
|
||||
reg.listenersLock.Lock()
|
||||
defer reg.listenersLock.Unlock()
|
||||
delete(reg.listeners, c)
|
||||
}
|
||||
|
||||
func (reg *Registry) sendUpdate(u Update) {
|
||||
reg.listenersLock.RLock()
|
||||
defer reg.listenersLock.RUnlock()
|
||||
for c := range reg.listeners {
|
||||
c <- u
|
||||
}
|
||||
}
|
||||
|
||||
// GetService returns known service (or device) entries for the given service
|
||||
// URN.
|
||||
func (reg *Registry) GetService(serviceURN string) []*Entry {
|
||||
// Currently assumes that the map is small, so we do a linear search rather
|
||||
// than indexed to avoid maintaining two maps.
|
||||
var results []*Entry
|
||||
reg.lock.Lock()
|
||||
defer reg.lock.Unlock()
|
||||
for _, entry := range reg.byUSN {
|
||||
if entry.NT == serviceURN {
|
||||
results = append(results, entry)
|
||||
}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
// ServeMessage implements httpu.Handler, and uses SSDP NOTIFY requests to
|
||||
// maintain the registry of devices and services.
|
||||
func (reg *Registry) ServeMessage(r *http.Request) {
|
||||
if r.Method != methodNotify {
|
||||
return
|
||||
}
|
||||
|
||||
nts := r.Header.Get("nts")
|
||||
|
||||
var err error
|
||||
switch nts {
|
||||
case ntsAlive:
|
||||
err = reg.handleNTSAlive(r)
|
||||
case ntsUpdate:
|
||||
err = reg.handleNTSUpdate(r)
|
||||
case ntsByebye:
|
||||
err = reg.handleNTSByebye(r)
|
||||
default:
|
||||
err = fmt.Errorf("unknown NTS value: %q", nts)
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("goupnp/ssdp: failed to handle %s message from %s: %v", nts, r.RemoteAddr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (reg *Registry) handleNTSAlive(r *http.Request) error {
|
||||
entry, err := newEntryFromRequest(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reg.lock.Lock()
|
||||
reg.byUSN[entry.USN] = entry
|
||||
reg.lock.Unlock()
|
||||
|
||||
reg.sendUpdate(Update{
|
||||
USN: entry.USN,
|
||||
EventType: EventAlive,
|
||||
Entry: entry,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (reg *Registry) handleNTSUpdate(r *http.Request) error {
|
||||
entry, err := newEntryFromRequest(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nextBootID, err := parseUpnpIntHeader(r.Header, "NEXTBOOTID.UPNP.ORG", -1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
entry.BootID = nextBootID
|
||||
|
||||
reg.lock.Lock()
|
||||
reg.byUSN[entry.USN] = entry
|
||||
reg.lock.Unlock()
|
||||
|
||||
reg.sendUpdate(Update{
|
||||
USN: entry.USN,
|
||||
EventType: EventUpdate,
|
||||
Entry: entry,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (reg *Registry) handleNTSByebye(r *http.Request) error {
|
||||
usn := r.Header.Get("USN")
|
||||
|
||||
reg.lock.Lock()
|
||||
entry := reg.byUSN[usn]
|
||||
delete(reg.byUSN, usn)
|
||||
reg.lock.Unlock()
|
||||
|
||||
reg.sendUpdate(Update{
|
||||
USN: usn,
|
||||
EventType: EventByeBye,
|
||||
Entry: entry,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
113
vendor/github.com/tailscale/goupnp/ssdp/ssdp.go
generated
vendored
Normal file
113
vendor/github.com/tailscale/goupnp/ssdp/ssdp.go
generated
vendored
Normal file
@@ -0,0 +1,113 @@
|
||||
package ssdp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
ssdpDiscover = `"ssdp:discover"`
|
||||
ntsAlive = `ssdp:alive`
|
||||
ntsByebye = `ssdp:byebye`
|
||||
ntsUpdate = `ssdp:update`
|
||||
ssdpUDP4Addr = "239.255.255.250:1900"
|
||||
ssdpSearchPort = 1900
|
||||
methodSearch = "M-SEARCH"
|
||||
methodNotify = "NOTIFY"
|
||||
|
||||
// SSDPAll is a value for searchTarget that searches for all devices and services.
|
||||
SSDPAll = "ssdp:all"
|
||||
// UPNPRootDevice is a value for searchTarget that searches for all root devices.
|
||||
UPNPRootDevice = "upnp:rootdevice"
|
||||
)
|
||||
|
||||
// HTTPUClient is the interface required to perform HTTP-over-UDP requests.
|
||||
type HTTPUClient interface {
|
||||
Do(req *http.Request, numSends int) ([]*http.Response, error)
|
||||
}
|
||||
|
||||
func max(a, b int64) int64 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// SSDPRawSearch performs a fairly raw SSDP search request, and returns the
|
||||
// unique response(s) that it receives. Each response has the requested
|
||||
// searchTarget, a USN, and a valid location. maxWaitSeconds states how long to
|
||||
// wait for responses in seconds, and must be a minimum of 1 (the
|
||||
// implementation waits an additional 100ms for responses to arrive), 2 is a
|
||||
// reasonable value for this. numSends is the number of requests to send - 3 is
|
||||
// a reasonable value for this.
|
||||
func SSDPRawSearch(
|
||||
ctx context.Context,
|
||||
httpu HTTPUClient,
|
||||
searchTarget string,
|
||||
numSends int,
|
||||
) ([]*http.Response, error) {
|
||||
// Must specify at least 1 second according to the spec.
|
||||
var wait int64 = 1
|
||||
// https://openconnectivity.org/upnp-specs/UPnP-arch-DeviceArchitecture-v2.0-20200417.pdf
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
wait = max(wait, int64(time.Until(deadline).Seconds()))
|
||||
}
|
||||
|
||||
header := http.Header{
|
||||
// Putting headers in here avoids them being title-cased.
|
||||
// (The UPnP discovery protocol uses case-sensitive headers)
|
||||
"HOST": []string{ssdpUDP4Addr},
|
||||
"MAN": []string{ssdpDiscover},
|
||||
"MX": []string{strconv.FormatInt(wait, 10)},
|
||||
"ST": []string{searchTarget},
|
||||
}
|
||||
|
||||
req := &http.Request{
|
||||
Method: methodSearch,
|
||||
// TODO: Support both IPv4 and IPv6.
|
||||
Host: ssdpUDP4Addr,
|
||||
URL: &url.URL{Opaque: "*"},
|
||||
Header: header,
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
allResponses, err := httpu.Do(req, numSends)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
isExactSearch := searchTarget != SSDPAll && searchTarget != UPNPRootDevice
|
||||
|
||||
seenUSNs := make(map[string]bool)
|
||||
var responses []*http.Response
|
||||
for _, response := range allResponses {
|
||||
if response.StatusCode != 200 {
|
||||
log.Printf("ssdp: got response status code %q in search response", response.Status)
|
||||
continue
|
||||
}
|
||||
if st := response.Header.Get("ST"); isExactSearch && st != searchTarget {
|
||||
continue
|
||||
}
|
||||
usn := response.Header.Get("USN")
|
||||
if usn == "" {
|
||||
// Empty/missing USN in search response - using location instead.
|
||||
location, err := response.Location()
|
||||
if err != nil {
|
||||
// No usable location in search response - discard.
|
||||
continue
|
||||
}
|
||||
usn = location.String()
|
||||
}
|
||||
if _, alreadySeen := seenUSNs[usn]; !alreadySeen {
|
||||
seenUSNs[usn] = true
|
||||
responses = append(responses, response)
|
||||
}
|
||||
}
|
||||
|
||||
return responses, nil
|
||||
}
|
||||
27
vendor/github.com/tailscale/hujson/LICENSE
generated
vendored
Normal file
27
vendor/github.com/tailscale/hujson/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
Copyright (c) 2019 Tailscale Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
34
vendor/github.com/tailscale/hujson/README.md
generated
vendored
Normal file
34
vendor/github.com/tailscale/hujson/README.md
generated
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
# HuJSON - "Human JSON" ([JWCC](https://nigeltao.github.io/blog/2021/json-with-commas-comments.html))
|
||||
|
||||
The `github.com/tailscale/hujson` package implements
|
||||
the [JWCC](https://nigeltao.github.io/blog/2021/json-with-commas-comments.html) extension
|
||||
of [standard JSON](https://datatracker.ietf.org/doc/html/rfc8259).
|
||||
|
||||
The `JWCC` format permits two things over standard JSON:
|
||||
|
||||
1. C-style line comments and block comments intermixed with whitespace,
|
||||
2. allows trailing commas after the last member/element in an object/array.
|
||||
|
||||
All JSON is valid JWCC.
|
||||
|
||||
For details, see the JWCC docs at:
|
||||
|
||||
https://nigeltao.github.io/blog/2021/json-with-commas-comments.html
|
||||
|
||||
## Visual Studio Code association
|
||||
|
||||
Visual Studio Code supports a similar `jsonc` (JSON with comments) format. To
|
||||
treat all `*.hujson` files as `jsonc` with trailing commas allowed, you can add
|
||||
the following snippet to your Visual Studio Code configuration:
|
||||
|
||||
```json
|
||||
"files.associations": {
|
||||
"*.hujson": "jsonc"
|
||||
},
|
||||
"json.schemas": [{
|
||||
"fileMatch": ["*.hujson"],
|
||||
"schema": {
|
||||
"allowTrailingCommas": true
|
||||
}
|
||||
}]
|
||||
```
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user