Update dependencies

This commit is contained in:
bluepython508
2024-11-01 17:33:34 +00:00
parent 033ac0b400
commit 5cdfab398d
3596 changed files with 1033483 additions and 259 deletions

View File

@@ -0,0 +1,519 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package authenticode contains Windows Authenticode signature verification code.
package authenticode
import (
"encoding/hex"
"errors"
"fmt"
"path/filepath"
"strings"
"unsafe"
"github.com/dblohm7/wingoes"
"github.com/dblohm7/wingoes/pe"
"golang.org/x/sys/windows"
)
var (
// ErrSigNotFound is returned if no authenticode signature could be found.
ErrSigNotFound = errors.New("authenticode signature not found")
// ErrUnexpectedCertSubject is wrapped with the actual cert subject and
// returned when the binary is signed by a different subject than expected.
ErrUnexpectedCertSubject = errors.New("unexpected cert subject")
errCertSubjectNotFound = errors.New("cert subject not found")
errCertSubjectDecodeLenMismatch = errors.New("length mismatch while decoding cert subject")
)
const (
_CERT_STRONG_SIGN_OID_INFO_CHOICE = 2
_CMSG_SIGNER_CERT_INFO_PARAM = 7
_MSI_INVALID_HASH_IS_FATAL = 1
_TRUST_E_NOSIGNATURE = wingoes.HRESULT(-((0x800B0100 ^ 0xFFFFFFFF) + 1))
)
// Verify performs authenticode verification on the file at path, and also
// ensures that expectedCertSubject matches the actual cert subject. path may
// point to either a PE binary or an MSI package. ErrSigNotFound is returned if
// no signature is found.
func Verify(path string, expectedCertSubject string) error {
path16, err := windows.UTF16PtrFromString(path)
if err != nil {
return err
}
var subject string
if strings.EqualFold(filepath.Ext(path), ".msi") {
subject, err = verifyMSI(path16)
} else {
subject, _, err = queryPE(path16, true)
}
if err != nil {
return err
}
if subject != expectedCertSubject {
return fmt.Errorf("%w %q", ErrUnexpectedCertSubject, subject)
}
return nil
}
// SigProvenance indicates whether an authenticode signature was embedded within
// the file itself, or the signature applies to an associated catalog file.
type SigProvenance int
const (
SigProvUnknown = SigProvenance(iota)
SigProvEmbedded
SigProvCatalog
)
// QueryCertSubject obtains the subject associated with the certificate used to
// sign the PE binary located at path. When err == nil, it also returns the
// provenance of that signature. ErrSigNotFound is returned if no signature
// is found. Note that this function does *not* validate the chain of trust; use
// Verify for that purpose!
func QueryCertSubject(path string) (certSubject string, provenance SigProvenance, err error) {
path16, err := windows.UTF16PtrFromString(path)
if err != nil {
return "", SigProvUnknown, err
}
return queryPE(path16, false)
}
func queryPE(utf16Path *uint16, verify bool) (string, SigProvenance, error) {
certSubject, err := queryEmbeddedCertSubject(utf16Path, verify)
switch {
case err == ErrSigNotFound:
// Try looking for the signature in a catalog file.
default:
return certSubject, SigProvEmbedded, err
}
certSubject, err = queryCatalogCertSubject(utf16Path, verify)
switch {
case err == ErrSigNotFound:
return "", SigProvUnknown, err
default:
return certSubject, SigProvCatalog, err
}
}
// CertSubjectError is returned if a cert subject was successfully resolved but
// there was a problem encountered during its extraction. The Subject is
// provided for informational purposes but is not presumed to be accurate.
type CertSubjectError struct {
Err error // The error that occurred while extracting the cert subject.
Subject string // The (possibly invalid) cert subject that was extracted.
}
func (e *CertSubjectError) Error() string {
if e == nil {
return "<nil>"
}
if e.Subject == "" {
return e.Err.Error()
}
return fmt.Sprintf("cert subject %q: %v", e.Subject, e.Err)
}
func (e *CertSubjectError) Unwrap() error {
return e.Err
}
func verifyMSI(path *uint16) (string, error) {
var certCtx *windows.CertContext
hr := msiGetFileSignatureInformation(path, _MSI_INVALID_HASH_IS_FATAL, &certCtx, nil, nil)
if e := wingoes.ErrorFromHRESULT(hr); e.Failed() {
if e == wingoes.ErrorFromHRESULT(_TRUST_E_NOSIGNATURE) {
return "", ErrSigNotFound
}
return "", e
}
defer windows.CertFreeCertificateContext(certCtx)
return certSubjectFromCertContext(certCtx)
}
func certSubjectFromCertContext(certCtx *windows.CertContext) (string, error) {
desiredLen := windows.CertGetNameString(
certCtx,
windows.CERT_NAME_SIMPLE_DISPLAY_TYPE,
0,
nil,
nil,
0,
)
if desiredLen <= 1 {
return "", errCertSubjectNotFound
}
buf := make([]uint16, desiredLen)
actualLen := windows.CertGetNameString(
certCtx,
windows.CERT_NAME_SIMPLE_DISPLAY_TYPE,
0,
nil,
&buf[0],
desiredLen,
)
if actualLen != desiredLen {
return "", errCertSubjectDecodeLenMismatch
}
return windows.UTF16ToString(buf), nil
}
type objectQuery struct {
certStore windows.Handle
cryptMsg windows.Handle
encodingType uint32
}
func newObjectQuery(utf16Path *uint16) (*objectQuery, error) {
var oq objectQuery
if err := windows.CryptQueryObject(
windows.CERT_QUERY_OBJECT_FILE,
unsafe.Pointer(utf16Path),
windows.CERT_QUERY_CONTENT_FLAG_PKCS7_SIGNED_EMBED,
windows.CERT_QUERY_FORMAT_FLAG_BINARY,
0,
&oq.encodingType,
nil,
nil,
&oq.certStore,
&oq.cryptMsg,
nil,
); err != nil {
return nil, err
}
return &oq, nil
}
func (oq *objectQuery) Close() error {
if oq.certStore != 0 {
if err := windows.CertCloseStore(oq.certStore, 0); err != nil {
return err
}
oq.certStore = 0
}
if oq.cryptMsg != 0 {
if err := cryptMsgClose(oq.cryptMsg); err != nil {
return err
}
oq.cryptMsg = 0
}
return nil
}
func (oq *objectQuery) certSubject() (string, error) {
var certInfoLen uint32
if err := cryptMsgGetParam(
oq.cryptMsg,
_CMSG_SIGNER_CERT_INFO_PARAM,
0,
unsafe.Pointer(nil),
&certInfoLen,
); err != nil {
return "", err
}
buf := make([]byte, certInfoLen)
if err := cryptMsgGetParam(
oq.cryptMsg,
_CMSG_SIGNER_CERT_INFO_PARAM,
0,
unsafe.Pointer(&buf[0]),
&certInfoLen,
); err != nil {
return "", err
}
certInfo := (*windows.CertInfo)(unsafe.Pointer(&buf[0]))
certCtx, err := windows.CertFindCertificateInStore(
oq.certStore,
oq.encodingType,
0,
windows.CERT_FIND_SUBJECT_CERT,
unsafe.Pointer(certInfo),
nil,
)
if err != nil {
return "", err
}
defer windows.CertFreeCertificateContext(certCtx)
return certSubjectFromCertContext(certCtx)
}
func extractCertBlob(hfile windows.Handle) ([]byte, error) {
pef, err := pe.NewPEFromFileHandle(hfile)
if err != nil {
return nil, err
}
defer pef.Close()
certsAny, err := pef.DataDirectoryEntry(pe.IMAGE_DIRECTORY_ENTRY_SECURITY)
if err != nil {
if errors.Is(err, pe.ErrNotPresent) {
err = ErrSigNotFound
}
return nil, err
}
certs, ok := certsAny.([]pe.AuthenticodeCert)
if !ok || len(certs) == 0 {
return nil, ErrSigNotFound
}
for _, cert := range certs {
if cert.Revision() != pe.WIN_CERT_REVISION_2_0 || cert.Type() != pe.WIN_CERT_TYPE_PKCS_SIGNED_DATA {
continue
}
return cert.Data(), nil
}
return nil, ErrSigNotFound
}
type _HCRYPTPROV windows.Handle
type _CRYPT_VERIFY_MESSAGE_PARA struct {
CBSize uint32
MsgAndCertEncodingType uint32
HCryptProv _HCRYPTPROV
FNGetSignerCertificate uintptr
GetArg uintptr
StrongSignPara *windows.CertStrongSignPara
}
func querySubjectFromBlob(blob []byte) (string, error) {
para := _CRYPT_VERIFY_MESSAGE_PARA{
CBSize: uint32(unsafe.Sizeof(_CRYPT_VERIFY_MESSAGE_PARA{})),
MsgAndCertEncodingType: windows.X509_ASN_ENCODING | windows.PKCS_7_ASN_ENCODING,
}
var certCtx *windows.CertContext
if err := cryptVerifyMessageSignature(&para, 0, &blob[0], uint32(len(blob)), nil, nil, &certCtx); err != nil {
return "", err
}
defer windows.CertFreeCertificateContext(certCtx)
return certSubjectFromCertContext(certCtx)
}
func queryEmbeddedCertSubject(utf16Path *uint16, verify bool) (string, error) {
peBinary, err := windows.CreateFile(
utf16Path,
windows.GENERIC_READ,
windows.FILE_SHARE_READ,
nil,
windows.OPEN_EXISTING,
0,
0,
)
if err != nil {
return "", err
}
defer windows.CloseHandle(peBinary)
blob, err := extractCertBlob(peBinary)
if err != nil {
return "", err
}
certSubj, err := querySubjectFromBlob(blob)
if err != nil {
return "", err
}
if !verify {
return certSubj, nil
}
wintrustArg := unsafe.Pointer(&windows.WinTrustFileInfo{
Size: uint32(unsafe.Sizeof(windows.WinTrustFileInfo{})),
FilePath: utf16Path,
File: peBinary,
})
if err := verifyTrust(windows.WTD_CHOICE_FILE, wintrustArg); err != nil {
// We might still want to know who the cert subject claims to be
// even if the validation has failed (eg for troubleshooting purposes),
// so we return a CertSubjectError.
return "", &CertSubjectError{Err: err, Subject: certSubj}
}
return certSubj, nil
}
var (
_BCRYPT_SHA256_ALGORITHM = &([]uint16{'S', 'H', 'A', '2', '5', '6', 0})[0]
_OID_CERT_STRONG_SIGN_OS_1 = &([]byte("1.3.6.1.4.1.311.72.1.1\x00"))[0]
)
type _HCATADMIN windows.Handle
type _HCATINFO windows.Handle
type _CATALOG_INFO struct {
size uint32
catalogFile [windows.MAX_PATH]uint16
}
type _WINTRUST_CATALOG_INFO struct {
size uint32
catalogVersion uint32
catalogFilePath *uint16
memberTag *uint16
memberFilePath *uint16
memberFile windows.Handle
pCalculatedFileHash *byte
cbCalculatedFileHash uint32
catalogContext uintptr
catAdmin _HCATADMIN
}
func queryCatalogCertSubject(utf16Path *uint16, verify bool) (string, error) {
var catAdmin _HCATADMIN
policy := windows.CertStrongSignPara{
Size: uint32(unsafe.Sizeof(windows.CertStrongSignPara{})),
InfoChoice: _CERT_STRONG_SIGN_OID_INFO_CHOICE,
InfoOrSerializedInfoOrOID: unsafe.Pointer(_OID_CERT_STRONG_SIGN_OS_1),
}
if err := cryptCATAdminAcquireContext2(
&catAdmin,
nil,
_BCRYPT_SHA256_ALGORITHM,
&policy,
0,
); err != nil {
return "", err
}
defer cryptCATAdminReleaseContext(catAdmin, 0)
// We use windows.CreateFile instead of standard library facilities because:
// 1. Subsequent API calls directly utilize the file's Win32 HANDLE;
// 2. We're going to be hashing the contents of this file, so we want to
// provide a sequential-scan hint to the kernel.
memberFile, err := windows.CreateFile(
utf16Path,
windows.GENERIC_READ,
windows.FILE_SHARE_READ,
nil,
windows.OPEN_EXISTING,
windows.FILE_FLAG_SEQUENTIAL_SCAN,
0,
)
if err != nil {
return "", err
}
defer windows.CloseHandle(memberFile)
var hashLen uint32
if err := cryptCATAdminCalcHashFromFileHandle2(
catAdmin,
memberFile,
&hashLen,
nil,
0,
); err != nil {
return "", err
}
hashBuf := make([]byte, hashLen)
if err := cryptCATAdminCalcHashFromFileHandle2(
catAdmin,
memberFile,
&hashLen,
&hashBuf[0],
0,
); err != nil {
return "", err
}
catInfoCtx, err := cryptCATAdminEnumCatalogFromHash(
catAdmin,
&hashBuf[0],
hashLen,
0,
nil,
)
if err != nil {
if err == windows.ERROR_NOT_FOUND {
err = ErrSigNotFound
}
return "", err
}
defer cryptCATAdminReleaseCatalogContext(catAdmin, catInfoCtx, 0)
catInfo := _CATALOG_INFO{
size: uint32(unsafe.Sizeof(_CATALOG_INFO{})),
}
if err := cryptCATAdminCatalogInfoFromContext(catInfoCtx, &catInfo, 0); err != nil {
return "", err
}
oq, err := newObjectQuery(&catInfo.catalogFile[0])
if err != nil {
return "", err
}
defer oq.Close()
certSubj, err := oq.certSubject()
if err != nil {
return "", err
}
if !verify {
return certSubj, nil
}
// memberTag is required to be formatted this way.
hbh := strings.ToUpper(hex.EncodeToString(hashBuf))
memberTag, err := windows.UTF16PtrFromString(hbh)
if err != nil {
return "", err
}
wintrustArg := unsafe.Pointer(&_WINTRUST_CATALOG_INFO{
size: uint32(unsafe.Sizeof(_WINTRUST_CATALOG_INFO{})),
catalogFilePath: &catInfo.catalogFile[0],
memberTag: memberTag,
memberFilePath: utf16Path,
memberFile: memberFile,
catAdmin: catAdmin,
})
if err := verifyTrust(windows.WTD_CHOICE_CATALOG, wintrustArg); err != nil {
// We might still want to know who the cert subject claims to be
// even if the validation has failed (eg for troubleshooting purposes),
// so we return a CertSubjectError.
return "", &CertSubjectError{Err: err, Subject: certSubj}
}
return certSubj, nil
}
func verifyTrust(infoType uint32, info unsafe.Pointer) error {
data := &windows.WinTrustData{
Size: uint32(unsafe.Sizeof(windows.WinTrustData{})),
UIChoice: windows.WTD_UI_NONE,
RevocationChecks: windows.WTD_REVOKE_WHOLECHAIN, // Full revocation checking, as this is called with network connectivity.
UnionChoice: infoType,
StateAction: windows.WTD_STATEACTION_VERIFY,
FileOrCatalogOrBlobOrSgnrOrCert: info,
}
err := windows.WinVerifyTrustEx(windows.InvalidHWND, &windows.WINTRUST_ACTION_GENERIC_VERIFY_V2, data)
data.StateAction = windows.WTD_STATEACTION_CLOSE
windows.WinVerifyTrustEx(windows.InvalidHWND, &windows.WINTRUST_ACTION_GENERIC_VERIFY_V2, data)
return err
}

View File

@@ -0,0 +1,18 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package authenticode
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go
//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go
//sys cryptCATAdminAcquireContext2(hCatAdmin *_HCATADMIN, pgSubsystem *windows.GUID, hashAlgorithm *uint16, strongHashPolicy *windows.CertStrongSignPara, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminAcquireContext2
//sys cryptCATAdminCalcHashFromFileHandle2(hCatAdmin _HCATADMIN, file windows.Handle, pcbHash *uint32, pbHash *byte, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminCalcHashFromFileHandle2
//sys cryptCATAdminCatalogInfoFromContext(hCatInfo _HCATINFO, catInfo *_CATALOG_INFO, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATCatalogInfoFromContext
//sys cryptCATAdminEnumCatalogFromHash(hCatAdmin _HCATADMIN, pbHash *byte, cbHash uint32, flags uint32, prevCatInfo *_HCATINFO) (ret _HCATINFO, err error) [ret==0] = wintrust.CryptCATAdminEnumCatalogFromHash
//sys cryptCATAdminReleaseCatalogContext(hCatAdmin _HCATADMIN, hCatInfo _HCATINFO, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminReleaseCatalogContext
//sys cryptCATAdminReleaseContext(hCatAdmin _HCATADMIN, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminReleaseContext
//sys cryptMsgClose(cryptMsg windows.Handle) (err error) [int32(failretval)==0] = crypt32.CryptMsgClose
//sys cryptMsgGetParam(cryptMsg windows.Handle, paramType uint32, index uint32, data unsafe.Pointer, dataLen *uint32) (err error) [int32(failretval)==0] = crypt32.CryptMsgGetParam
//sys cryptVerifyMessageSignature(pVerifyPara *_CRYPT_VERIFY_MESSAGE_PARA, signerIndex uint32, pbSignedBlob *byte, cbSignedBlob uint32, pbDecoded *byte, pdbDecoded *uint32, ppSignerCert **windows.CertContext) (err error) [int32(failretval)==0] = crypt32.CryptVerifyMessageSignature
//sys msiGetFileSignatureInformation(signedObjectPath *uint16, flags uint32, certCtx **windows.CertContext, pbHashData *byte, cbHashData *uint32) (ret wingoes.HRESULT) = msi.MsiGetFileSignatureInformationW

View File

@@ -0,0 +1,135 @@
// Code generated by 'go generate'; DO NOT EDIT.
package authenticode
import (
"syscall"
"unsafe"
"github.com/dblohm7/wingoes"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modcrypt32 = windows.NewLazySystemDLL("crypt32.dll")
modmsi = windows.NewLazySystemDLL("msi.dll")
modwintrust = windows.NewLazySystemDLL("wintrust.dll")
procCryptMsgClose = modcrypt32.NewProc("CryptMsgClose")
procCryptMsgGetParam = modcrypt32.NewProc("CryptMsgGetParam")
procCryptVerifyMessageSignature = modcrypt32.NewProc("CryptVerifyMessageSignature")
procMsiGetFileSignatureInformationW = modmsi.NewProc("MsiGetFileSignatureInformationW")
procCryptCATAdminAcquireContext2 = modwintrust.NewProc("CryptCATAdminAcquireContext2")
procCryptCATAdminCalcHashFromFileHandle2 = modwintrust.NewProc("CryptCATAdminCalcHashFromFileHandle2")
procCryptCATAdminEnumCatalogFromHash = modwintrust.NewProc("CryptCATAdminEnumCatalogFromHash")
procCryptCATAdminReleaseCatalogContext = modwintrust.NewProc("CryptCATAdminReleaseCatalogContext")
procCryptCATAdminReleaseContext = modwintrust.NewProc("CryptCATAdminReleaseContext")
procCryptCATCatalogInfoFromContext = modwintrust.NewProc("CryptCATCatalogInfoFromContext")
)
func cryptMsgClose(cryptMsg windows.Handle) (err error) {
r1, _, e1 := syscall.Syscall(procCryptMsgClose.Addr(), 1, uintptr(cryptMsg), 0, 0)
if int32(r1) == 0 {
err = errnoErr(e1)
}
return
}
func cryptMsgGetParam(cryptMsg windows.Handle, paramType uint32, index uint32, data unsafe.Pointer, dataLen *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procCryptMsgGetParam.Addr(), 5, uintptr(cryptMsg), uintptr(paramType), uintptr(index), uintptr(data), uintptr(unsafe.Pointer(dataLen)), 0)
if int32(r1) == 0 {
err = errnoErr(e1)
}
return
}
func cryptVerifyMessageSignature(pVerifyPara *_CRYPT_VERIFY_MESSAGE_PARA, signerIndex uint32, pbSignedBlob *byte, cbSignedBlob uint32, pbDecoded *byte, pdbDecoded *uint32, ppSignerCert **windows.CertContext) (err error) {
r1, _, e1 := syscall.Syscall9(procCryptVerifyMessageSignature.Addr(), 7, uintptr(unsafe.Pointer(pVerifyPara)), uintptr(signerIndex), uintptr(unsafe.Pointer(pbSignedBlob)), uintptr(cbSignedBlob), uintptr(unsafe.Pointer(pbDecoded)), uintptr(unsafe.Pointer(pdbDecoded)), uintptr(unsafe.Pointer(ppSignerCert)), 0, 0)
if int32(r1) == 0 {
err = errnoErr(e1)
}
return
}
func msiGetFileSignatureInformation(signedObjectPath *uint16, flags uint32, certCtx **windows.CertContext, pbHashData *byte, cbHashData *uint32) (ret wingoes.HRESULT) {
r0, _, _ := syscall.Syscall6(procMsiGetFileSignatureInformationW.Addr(), 5, uintptr(unsafe.Pointer(signedObjectPath)), uintptr(flags), uintptr(unsafe.Pointer(certCtx)), uintptr(unsafe.Pointer(pbHashData)), uintptr(unsafe.Pointer(cbHashData)), 0)
ret = wingoes.HRESULT(r0)
return
}
func cryptCATAdminAcquireContext2(hCatAdmin *_HCATADMIN, pgSubsystem *windows.GUID, hashAlgorithm *uint16, strongHashPolicy *windows.CertStrongSignPara, flags uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procCryptCATAdminAcquireContext2.Addr(), 5, uintptr(unsafe.Pointer(hCatAdmin)), uintptr(unsafe.Pointer(pgSubsystem)), uintptr(unsafe.Pointer(hashAlgorithm)), uintptr(unsafe.Pointer(strongHashPolicy)), uintptr(flags), 0)
if int32(r1) == 0 {
err = errnoErr(e1)
}
return
}
func cryptCATAdminCalcHashFromFileHandle2(hCatAdmin _HCATADMIN, file windows.Handle, pcbHash *uint32, pbHash *byte, flags uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procCryptCATAdminCalcHashFromFileHandle2.Addr(), 5, uintptr(hCatAdmin), uintptr(file), uintptr(unsafe.Pointer(pcbHash)), uintptr(unsafe.Pointer(pbHash)), uintptr(flags), 0)
if int32(r1) == 0 {
err = errnoErr(e1)
}
return
}
func cryptCATAdminEnumCatalogFromHash(hCatAdmin _HCATADMIN, pbHash *byte, cbHash uint32, flags uint32, prevCatInfo *_HCATINFO) (ret _HCATINFO, err error) {
r0, _, e1 := syscall.Syscall6(procCryptCATAdminEnumCatalogFromHash.Addr(), 5, uintptr(hCatAdmin), uintptr(unsafe.Pointer(pbHash)), uintptr(cbHash), uintptr(flags), uintptr(unsafe.Pointer(prevCatInfo)), 0)
ret = _HCATINFO(r0)
if ret == 0 {
err = errnoErr(e1)
}
return
}
func cryptCATAdminReleaseCatalogContext(hCatAdmin _HCATADMIN, hCatInfo _HCATINFO, flags uint32) (err error) {
r1, _, e1 := syscall.Syscall(procCryptCATAdminReleaseCatalogContext.Addr(), 3, uintptr(hCatAdmin), uintptr(hCatInfo), uintptr(flags))
if int32(r1) == 0 {
err = errnoErr(e1)
}
return
}
func cryptCATAdminReleaseContext(hCatAdmin _HCATADMIN, flags uint32) (err error) {
r1, _, e1 := syscall.Syscall(procCryptCATAdminReleaseContext.Addr(), 2, uintptr(hCatAdmin), uintptr(flags), 0)
if int32(r1) == 0 {
err = errnoErr(e1)
}
return
}
func cryptCATAdminCatalogInfoFromContext(hCatInfo _HCATINFO, catInfo *_CATALOG_INFO, flags uint32) (err error) {
r1, _, e1 := syscall.Syscall(procCryptCATCatalogInfoFromContext.Addr(), 3, uintptr(hCatInfo), uintptr(unsafe.Pointer(catInfo)), uintptr(flags))
if int32(r1) == 0 {
err = errnoErr(e1)
}
return
}

79
vendor/tailscale.com/util/winutil/gp/gp_windows.go generated vendored Normal file
View File

@@ -0,0 +1,79 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package gp contains [Group Policy]-related functions and types.
//
// [Group Policy]: https://web.archive.org/web/20240630210707/https://learn.microsoft.com/en-us/previous-versions/windows/desktop/policy/group-policy-start-page
package gp
import (
"fmt"
"runtime"
"golang.org/x/sys/windows"
)
// Scope is a user or machine policy scope.
type Scope int
const (
// MachinePolicy indicates a machine policy.
// Registry-based machine policies reside in HKEY_LOCAL_MACHINE.
MachinePolicy Scope = iota
// UserPolicy indicates a user policy.
// Registry-based user policies reside in HKEY_CURRENT_USER of the corresponding user.
UserPolicy
)
// _RP_FORCE causes RefreshPolicyEx to reapply policy even if no policy change was detected.
// See [RP_FORCE] for details.
//
// [RP_FORCE]: https://web.archive.org/save/https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-refreshpolicyex
const _RP_FORCE = 0x1
// RefreshUserPolicy triggers a machine policy refresh, but does not wait for it to complete.
// When the force parameter is true, it causes the Group Policy to reapply policy even
// if no policy change was detected.
func RefreshMachinePolicy(force bool) error {
return refreshPolicyEx(true, toRefreshPolicyFlags(force))
}
// RefreshUserPolicy triggers a user policy refresh, but does not wait for it to complete.
// When the force parameter is true, it causes the Group Policy to reapply policy even
// if no policy change was detected.
//
// The token indicates user whose policy should be refreshed.
// If specified, the token must be either a primary token with TOKEN_QUERY and TOKEN_DUPLICATE
// access, or an impersonation token with TOKEN_QUERY and TOKEN_IMPERSONATE access,
// and the specified user must be logged in interactively.
//
// Otherwise, a zero token value indicates the current user. It should not
// be used by services or other applications running under system identities.
//
// The function fails with windows.ERROR_ACCESS_DENIED if the user represented by the token
// is not logged in interactively at the time of the call.
func RefreshUserPolicy(token windows.Token, force bool) error {
if token != 0 {
// Impersonate the user whose policy we need to refresh.
runtime.LockOSThread()
defer runtime.UnlockOSThread()
if err := impersonateLoggedOnUser(token); err != nil {
return err
}
defer func() {
if err := windows.RevertToSelf(); err != nil {
// RevertToSelf errors are non-recoverable.
panic(fmt.Errorf("could not revert impersonation: %w", err))
}
}()
}
return refreshPolicyEx(true, toRefreshPolicyFlags(force))
}
func toRefreshPolicyFlags(force bool) uint32 {
if force {
return _RP_FORCE
}
return 0
}

13
vendor/tailscale.com/util/winutil/gp/mksyscall.go generated vendored Normal file
View File

@@ -0,0 +1,13 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package gp
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go
//sys enterCriticalPolicySection(machine bool) (handle policyLockHandle, err error) [int32(failretval)==0] = userenv.EnterCriticalPolicySection
//sys impersonateLoggedOnUser(token windows.Token) (err error) [int32(failretval)==0] = advapi32.ImpersonateLoggedOnUser
//sys leaveCriticalPolicySection(handle policyLockHandle) (err error) [int32(failretval)==0] = userenv.LeaveCriticalPolicySection
//sys registerGPNotification(event windows.Handle, machine bool) (err error) [int32(failretval)==0] = userenv.RegisterGPNotification
//sys refreshPolicyEx(machine bool, flags uint32) (err error) [int32(failretval)==0] = userenv.RefreshPolicyEx
//sys unregisterGPNotification(event windows.Handle) (err error) [int32(failretval)==0] = userenv.UnregisterGPNotification

View File

@@ -0,0 +1,292 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package gp
import (
"errors"
"fmt"
"runtime"
"sync"
"sync/atomic"
"golang.org/x/sys/windows"
)
// PolicyLock allows pausing the application of policy to safely read Group Policy
// settings. A PolicyLock is an R-lock that can be held by multiple readers simultaneously,
// preventing the Group Policy Client service (which maintains its W-counterpart) from
// modifying policies while they are being read.
//
// It is not possible to pause group policy processing for longer than 10 minutes.
// If the system needs to apply policies and the lock is being held for more than that,
// the Group Policy Client service will release the lock and continue policy processing.
//
// To avoid deadlocks when acquiring both machine and user locks, acquire the
// user lock before the machine lock.
type PolicyLock struct {
scope Scope
token windows.Token
// hooks for testing
enterFn func(bool) (policyLockHandle, error)
leaveFn func(policyLockHandle) error
closing chan struct{} // closing is closed when the Close method is called.
mu sync.Mutex
handle policyLockHandle
lockCnt atomic.Int32 // A non-zero LSB indicates that the lock can be acquired.
}
// policyLockHandle is the underlying lock handle returned by enterCriticalPolicySection.
type policyLockHandle uintptr
type policyLockResult struct {
handle policyLockHandle
err error
}
var (
// ErrInvalidLockState is returned by (*PolicyLock).Lock if the lock has a zero value or has already been closed.
ErrInvalidLockState = errors.New("the lock has not been created or has already been closed")
)
// NewMachinePolicyLock creates a PolicyLock that facilitates pausing the
// application of computer policy. To avoid deadlocks when acquiring both
// machine and user locks, acquire the user lock before the machine lock.
func NewMachinePolicyLock() *PolicyLock {
lock := &PolicyLock{
scope: MachinePolicy,
closing: make(chan struct{}),
enterFn: enterCriticalPolicySection,
leaveFn: leaveCriticalPolicySection,
}
lock.lockCnt.Store(1) // mark as initialized
return lock
}
// NewUserPolicyLock creates a PolicyLock that facilitates pausing the
// application of the user policy for the specified user. To avoid deadlocks
// when acquiring both machine and user locks, acquire the user lock before the
// machine lock.
//
// The token indicates which user's policy should be locked for reading.
// If specified, the token must have TOKEN_DUPLICATE access,
// the specified user must be logged in interactively.
// and the caller retains ownership of the token.
//
// Otherwise, a zero token value indicates the current user. It should not
// be used by services or other applications running under system identities.
func NewUserPolicyLock(token windows.Token) (*PolicyLock, error) {
lock := &PolicyLock{
scope: UserPolicy,
closing: make(chan struct{}),
enterFn: enterCriticalPolicySection,
leaveFn: leaveCriticalPolicySection,
}
if token != 0 {
err := windows.DuplicateHandle(
windows.CurrentProcess(),
windows.Handle(token),
windows.CurrentProcess(),
(*windows.Handle)(&lock.token),
windows.TOKEN_QUERY|windows.TOKEN_DUPLICATE|windows.TOKEN_IMPERSONATE,
false,
0)
if err != nil {
return nil, err
}
}
lock.lockCnt.Store(1) // mark as initialized
return lock, nil
}
// Lock locks l.
// It returns ErrNotInitialized if l has a zero value or has already been closed,
// or an Errno if the underlying Group Policy lock cannot be acquired.
//
// As a special case, it fails with windows.ERROR_ACCESS_DENIED
// if l is a user policy lock, and the corresponding user is not logged in
// interactively at the time of the call.
func (l *PolicyLock) Lock() error {
l.mu.Lock()
defer l.mu.Unlock()
if l.lockCnt.Add(2)&1 == 0 {
// The lock cannot be acquired because it has either never been properly
// created or its Close method has already been called. However, we need
// to call Unlock to both decrement lockCnt and leave the underlying
// CriticalPolicySection if we won the race with another goroutine and
// now own the lock.
l.Unlock()
return ErrInvalidLockState
}
if l.handle != 0 {
// The underlying CriticalPolicySection is already acquired.
// It is an R-Lock (with the W-counterpart owned by the Group Policy service),
// meaning that it can be acquired by multiple readers simultaneously.
// So we can just return.
return nil
}
return l.lockSlow()
}
// lockSlow calls enterCriticalPolicySection to acquire the underlying GP read lock.
// It waits for either the lock to be acquired, or for the Close method to be called.
//
// l.mu must be held.
func (l *PolicyLock) lockSlow() (err error) {
defer func() {
if err != nil {
// Decrement the counter if the lock cannot be acquired,
// and complete the pending close request if we're the last owner.
if l.lockCnt.Add(-2) == 0 {
l.closeInternal()
}
}
}()
// In some cases in production environments, the Group Policy service may
// hold the corresponding W-Lock for extended periods of time (minutes
// rather than seconds or milliseconds). We need to make our wait operation
// cancellable. So, if one goroutine invokes (*PolicyLock).Close while another
// initiates (*PolicyLock).Lock and waits for the underlying R-lock to be
// acquired by enterCriticalPolicySection, the Close method should cancel
// the wait.
initCh := make(chan error)
resultCh := make(chan policyLockResult)
go func() {
closing := l.closing
if l.scope == UserPolicy && l.token != 0 {
// Impersonate the user whose critical policy section we want to acquire.
runtime.LockOSThread()
defer runtime.UnlockOSThread()
if err := impersonateLoggedOnUser(l.token); err != nil {
initCh <- err
return
}
defer func() {
if err := windows.RevertToSelf(); err != nil {
// RevertToSelf errors are non-recoverable.
panic(fmt.Errorf("could not revert impersonation: %w", err))
}
}()
}
close(initCh)
var machine bool
if l.scope == MachinePolicy {
machine = true
}
handle, err := l.enterFn(machine)
send_result:
for {
select {
case resultCh <- policyLockResult{handle, err}:
// lockSlow has received the result.
break send_result
default:
select {
case <-closing:
// The lock is being closed, and we lost the race to l.closing
// it the calling goroutine.
if err == nil {
l.leaveFn(handle)
}
break send_result
default:
// The calling goroutine did not enter the select block yet.
runtime.Gosched() // allow other routines to run
continue send_result
}
}
}
}()
// lockSlow should not return until the goroutine above has been fully initialized,
// even if the lock is being closed.
if err = <-initCh; err != nil {
return err
}
select {
case result := <-resultCh:
if result.err == nil {
l.handle = result.handle
}
return result.err
case <-l.closing:
return ErrInvalidLockState
}
}
// Unlock unlocks l.
// It panics if l is not locked on entry to Unlock.
func (l *PolicyLock) Unlock() {
l.mu.Lock()
defer l.mu.Unlock()
lockCnt := l.lockCnt.Add(-2)
if lockCnt < 0 {
panic("negative lockCnt")
}
if lockCnt > 1 {
// The lock is still being used by other readers.
// We compare against 1 rather than 0 because the least significant bit
// of lockCnt indicates that l has been initialized and a close
// has not been requested yet.
return
}
if l.handle != 0 {
// Impersonation is not required to unlock a critical policy section.
// The handle we pass determines which mutex will be unlocked.
leaveCriticalPolicySection(l.handle)
l.handle = 0
}
if lockCnt == 0 {
// Complete the pending close request if there's no more readers.
l.closeInternal()
}
}
// Close releases resources associated with l.
// It is a no-op for the machine policy lock.
func (l *PolicyLock) Close() error {
lockCnt := l.lockCnt.Load()
if lockCnt&1 == 0 {
// The lock has never been initialized, or close has already been called.
return nil
}
close(l.closing)
// Unset the LSB to indicate a pending close request.
for !l.lockCnt.CompareAndSwap(lockCnt, lockCnt&^int32(1)) {
lockCnt = l.lockCnt.Load()
}
if lockCnt != 0 {
// The lock is still being used and will be closed upon the final Unlock call.
return nil
}
return l.closeInternal()
}
func (l *PolicyLock) closeInternal() error {
if l.token != 0 {
if err := l.token.Close(); err != nil {
return err
}
l.token = 0
}
l.closing = nil
return nil
}

107
vendor/tailscale.com/util/winutil/gp/watcher_windows.go generated vendored Normal file
View File

@@ -0,0 +1,107 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package gp
import (
"golang.org/x/sys/windows"
)
// ChangeWatcher calls the handler whenever a policy in the specified scope changes.
type ChangeWatcher struct {
gpWaitEvents [2]windows.Handle
handler func()
done chan struct{}
}
// NewChangeWatcher creates an instance of ChangeWatcher that invokes handler
// every time Windows notifies it of a group policy change in the specified scope.
func NewChangeWatcher(scope Scope, handler func()) (*ChangeWatcher, error) {
var err error
// evtDone is signaled by (*gpNotificationWatcher).Close() to indicate that
// the doWatch goroutine should exit.
evtDone, err := windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
windows.CloseHandle(evtDone)
}
}()
// evtChanged is registered with the Windows policy engine to become
// signalled any time group policy has been refreshed.
evtChanged, err := windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
windows.CloseHandle(evtChanged)
}
}()
// Tell Windows to signal evtChanged whenever group policies are refreshed.
if err := registerGPNotification(evtChanged, scope == MachinePolicy); err != nil {
return nil, err
}
result := &ChangeWatcher{
// Ordering of the event handles in gpWaitEvents is important:
// When calling windows.WaitForMultipleObjects and multiple objects are
// signalled simultaneously, it always returns the wait code for the
// lowest-indexed handle in its input array. evtDone is higher priority for
// us than evtChanged, so the former must be placed into the array ahead of
// the latter.
gpWaitEvents: [2]windows.Handle{
evtDone,
evtChanged,
},
handler: handler,
done: make(chan struct{}),
}
go result.doWatch()
return result, nil
}
func (w *ChangeWatcher) doWatch() {
// The wait code corresponding to the event that is signalled when a group
// policy change occurs. That is, w.gpWaitEvents[1] aka evtChanged.
const expectedWaitCode = windows.WAIT_OBJECT_0 + 1
for {
if waitCode, _ := windows.WaitForMultipleObjects(w.gpWaitEvents[:], false, windows.INFINITE); waitCode != expectedWaitCode {
break
}
w.handler()
}
close(w.done)
}
// Close unsubscribes from further Group Policy notifications,
// waits for any running handlers to complete, and releases any remaining resources
// associated with w.
func (w *ChangeWatcher) Close() error {
// Notify doWatch that we're done and it should exit.
if err := windows.SetEvent(w.gpWaitEvents[0]); err != nil {
return err
}
unregisterGPNotification(w.gpWaitEvents[1])
// Wait for doWatch to complete.
<-w.done
// Now we may safely clean up all the things.
for i, evt := range w.gpWaitEvents {
windows.CloseHandle(evt)
w.gpWaitEvents[i] = 0
}
w.handler = nil
return nil
}

View File

@@ -0,0 +1,111 @@
// Code generated by 'go generate'; DO NOT EDIT.
package gp
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
moduserenv = windows.NewLazySystemDLL("userenv.dll")
procImpersonateLoggedOnUser = modadvapi32.NewProc("ImpersonateLoggedOnUser")
procEnterCriticalPolicySection = moduserenv.NewProc("EnterCriticalPolicySection")
procLeaveCriticalPolicySection = moduserenv.NewProc("LeaveCriticalPolicySection")
procRefreshPolicyEx = moduserenv.NewProc("RefreshPolicyEx")
procRegisterGPNotification = moduserenv.NewProc("RegisterGPNotification")
procUnregisterGPNotification = moduserenv.NewProc("UnregisterGPNotification")
)
func impersonateLoggedOnUser(token windows.Token) (err error) {
r1, _, e1 := syscall.Syscall(procImpersonateLoggedOnUser.Addr(), 1, uintptr(token), 0, 0)
if int32(r1) == 0 {
err = errnoErr(e1)
}
return
}
func enterCriticalPolicySection(machine bool) (handle policyLockHandle, err error) {
var _p0 uint32
if machine {
_p0 = 1
}
r0, _, e1 := syscall.Syscall(procEnterCriticalPolicySection.Addr(), 1, uintptr(_p0), 0, 0)
handle = policyLockHandle(r0)
if int32(handle) == 0 {
err = errnoErr(e1)
}
return
}
func leaveCriticalPolicySection(handle policyLockHandle) (err error) {
r1, _, e1 := syscall.Syscall(procLeaveCriticalPolicySection.Addr(), 1, uintptr(handle), 0, 0)
if int32(r1) == 0 {
err = errnoErr(e1)
}
return
}
func refreshPolicyEx(machine bool, flags uint32) (err error) {
var _p0 uint32
if machine {
_p0 = 1
}
r1, _, e1 := syscall.Syscall(procRefreshPolicyEx.Addr(), 2, uintptr(_p0), uintptr(flags), 0)
if int32(r1) == 0 {
err = errnoErr(e1)
}
return
}
func registerGPNotification(event windows.Handle, machine bool) (err error) {
var _p0 uint32
if machine {
_p0 = 1
}
r1, _, e1 := syscall.Syscall(procRegisterGPNotification.Addr(), 2, uintptr(event), uintptr(_p0), 0)
if int32(r1) == 0 {
err = errnoErr(e1)
}
return
}
func unregisterGPNotification(event windows.Handle) (err error) {
r1, _, e1 := syscall.Syscall(procUnregisterGPNotification.Addr(), 1, uintptr(event), 0, 0)
if int32(r1) == 0 {
err = errnoErr(e1)
}
return
}

21
vendor/tailscale.com/util/winutil/mksyscall.go generated vendored Normal file
View File

@@ -0,0 +1,21 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package winutil
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go
//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go
//sys dsGetDcName(computerName *uint16, domainName *uint16, domainGuid *windows.GUID, siteName *uint16, flags dsGetDcNameFlag, dcInfo **_DOMAIN_CONTROLLER_INFO) (ret error) = netapi32.DsGetDcNameW
//sys expandEnvironmentStringsForUser(token windows.Token, src *uint16, dst *uint16, dstLen uint32) (err error) [int32(failretval)==0] = userenv.ExpandEnvironmentStringsForUserW
//sys getApplicationRestartSettings(process windows.Handle, commandLine *uint16, commandLineLen *uint32, flags *uint32) (ret wingoes.HRESULT) = kernel32.GetApplicationRestartSettings
//sys loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) [int32(failretval)==0] = userenv.LoadUserProfileW
//sys netValidateName(server *uint16, name *uint16, account *uint16, password *uint16, nameType _NETSETUP_NAME_TYPE) (ret error) = netapi32.NetValidateName
//sys queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) [failretval==0] = advapi32.QueryServiceConfig2W
//sys registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret wingoes.HRESULT) = kernel32.RegisterApplicationRestart
//sys rmEndSession(session _RMHANDLE) (ret error) = rstrtmgr.RmEndSession
//sys rmGetList(session _RMHANDLE, nProcInfoNeeded *uint32, nProcInfo *uint32, rgAffectedApps *_RM_PROCESS_INFO, pRebootReasons *uint32) (ret error) = rstrtmgr.RmGetList
//sys rmJoinSession(pSession *_RMHANDLE, sessionKey *uint16) (ret error) = rstrtmgr.RmJoinSession
//sys rmRegisterResources(session _RMHANDLE, nFiles uint32, rgsFileNames **uint16, nApplications uint32, rgApplications *_RM_UNIQUE_PROCESS, nServices uint32, rgsServiceNames **uint16) (ret error) = rstrtmgr.RmRegisterResources
//sys rmStartSession(pSession *_RMHANDLE, flags uint32, sessionKey *uint16) (ret error) = rstrtmgr.RmStartSession
//sys unloadUserProfile(token windows.Token, profile registry.Key) (err error) [int32(failretval)==0] = userenv.UnloadUserProfile

View File

@@ -0,0 +1,155 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package policy contains higher-level abstractions for accessing Windows enterprise policies.
package policy
import (
"time"
"tailscale.com/util/winutil"
)
// PreferenceOptionPolicy is a policy that governs whether a boolean variable
// is forcibly assigned an administrator-defined value, or allowed to receive
// a user-defined value.
type PreferenceOptionPolicy int
const (
showChoiceByPolicy PreferenceOptionPolicy = iota
neverByPolicy
alwaysByPolicy
)
// Show returns if the UI option that controls the choice administered by this
// policy should be shown. Currently this is true if and only if the policy is
// showChoiceByPolicy.
func (p PreferenceOptionPolicy) Show() bool {
return p == showChoiceByPolicy
}
// ShouldEnable checks if the choice administered by this policy should be
// enabled. If the administrator has chosen a setting, the administrator's
// setting is returned, otherwise userChoice is returned.
func (p PreferenceOptionPolicy) ShouldEnable(userChoice bool) bool {
switch p {
case neverByPolicy:
return false
case alwaysByPolicy:
return true
default:
return userChoice
}
}
// GetPreferenceOptionPolicy loads a policy from the registry that can be
// managed by an enterprise policy management system and allows administrative
// overrides of users' choices in a way that we do not want tailcontrol to have
// the authority to set. It describes user-decides/always/never options, where
// "always" and "never" remove the user's ability to make a selection. If not
// present or set to a different value, "user-decides" is the default.
func GetPreferenceOptionPolicy(name string) PreferenceOptionPolicy {
opt, err := winutil.GetPolicyString(name)
if opt == "" || err != nil {
return showChoiceByPolicy
}
switch opt {
case "always":
return alwaysByPolicy
case "never":
return neverByPolicy
default:
return showChoiceByPolicy
}
}
// VisibilityPolicy is a policy that controls whether or not a particular
// component of a user interface is to be shown.
type VisibilityPolicy byte
const (
visibleByPolicy VisibilityPolicy = 'v'
hiddenByPolicy VisibilityPolicy = 'h'
)
// Show reports whether the UI option administered by this policy should be shown.
// Currently this is true if and only if the policy is visibleByPolicy.
func (p VisibilityPolicy) Show() bool {
return p == visibleByPolicy
}
// GetVisibilityPolicy loads a policy from the registry that can be managed
// by an enterprise policy management system and describes show/hide decisions
// for UI elements. The registry value should be a string set to "show" (return
// true) or "hide" (return true). If not present or set to a different value,
// "show" (return false) is the default.
func GetVisibilityPolicy(name string) VisibilityPolicy {
opt, err := winutil.GetPolicyString(name)
if opt == "" || err != nil {
return visibleByPolicy
}
switch opt {
case "hide":
return hiddenByPolicy
default:
return visibleByPolicy
}
}
// GetDurationPolicy loads a policy from the registry that can be managed
// by an enterprise policy management system and describes a duration for some
// action. The registry value should be a string that time.ParseDuration
// understands. If the registry value is "" or can not be processed,
// defaultValue is returned instead.
func GetDurationPolicy(name string, defaultValue time.Duration) time.Duration {
opt, err := winutil.GetPolicyString(name)
if opt == "" || err != nil {
return defaultValue
}
v, err := time.ParseDuration(opt)
if err != nil || v < 0 {
return defaultValue
}
return v
}
// SelectControlURL returns the ControlURL to use based on a value in
// the registry (LoginURL) and the one on disk (in the GUI's
// prefs.conf). If both are empty, it returns a default value. (It
// always return a non-empty value)
//
// See https://github.com/tailscale/tailscale/issues/2798 for some background.
func SelectControlURL(reg, disk string) string {
const def = "https://controlplane.tailscale.com"
// Prior to Dec 2020's commit 739b02e6, the installer
// wrote a LoginURL value of https://login.tailscale.com to the registry.
const oldRegDef = "https://login.tailscale.com"
// If they have an explicit value in the registry, use it,
// unless it's an old default value from an old installer.
// Then we have to see which is better.
if reg != "" {
if reg != oldRegDef {
// Something explicit in the registry that we didn't
// set ourselves by the installer.
return reg
}
if disk == "" {
// Something in the registry is better than nothing on disk.
return reg
}
if disk != def && disk != oldRegDef {
// The value in the registry is the old
// default (login.tailscale.com) but the value
// on disk is neither our old nor new default
// value, so it must be some custom thing that
// the user cares about. Prefer the disk value.
return disk
}
}
if disk != "" {
return disk
}
return def
}

845
vendor/tailscale.com/util/winutil/restartmgr_windows.go generated vendored Normal file
View File

@@ -0,0 +1,845 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package winutil
import (
"bytes"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"runtime"
"strings"
"time"
"unicode/utf16"
"unsafe"
"github.com/dblohm7/wingoes"
"golang.org/x/sys/windows"
"tailscale.com/types/logger"
"tailscale.com/util/multierr"
)
var (
// ErrDefunctProcess is returned when the process no longer exists.
ErrDefunctProcess = errors.New("process is defunct")
// ErrProcessNotRestartable is returned by (*UniqueProcess).AsRestartableProcess
// when the process has previously indicated that it must not be restarted
// during a patch/upgrade.
ErrProcessNotRestartable = errors.New("process is not restartable")
)
// Implementation note: the code in this file will be invoked from within
// MSI custom actions, so please try to return windows.Errno error codes
// whenever possible; this makes the action return more accurate errors to
// the installer engine.
const (
_RESTART_NO_CRASH = 1
_RESTART_NO_HANG = 2
_RESTART_NO_PATCH = 4
_RESTART_NO_REBOOT = 8
)
func registerForRestart(opts RegisterForRestartOpts) error {
var flags uint32
if !opts.RestartOnCrash {
flags |= _RESTART_NO_CRASH
}
if !opts.RestartOnHang {
flags |= _RESTART_NO_HANG
}
if !opts.RestartOnUpgrade {
flags |= _RESTART_NO_PATCH
}
if !opts.RestartOnReboot {
flags |= _RESTART_NO_REBOOT
}
var cmdLine *uint16
if opts.UseCmdLineArgs {
if len(opts.CmdLineArgs) == 0 {
// re-use our current args, excluding the exe name itself
opts.CmdLineArgs = os.Args[1:]
}
var b strings.Builder
for _, arg := range opts.CmdLineArgs {
if b.Len() > 0 {
b.WriteByte(' ')
}
b.WriteString(windows.EscapeArg(arg))
}
if b.Len() > 0 {
var err error
cmdLine, err = windows.UTF16PtrFromString(b.String())
if err != nil {
return err
}
}
}
hr := registerApplicationRestart(cmdLine, flags)
if e := wingoes.ErrorFromHRESULT(hr); e.Failed() {
return e
}
return nil
}
type _RMHANDLE uint32
// See https://web.archive.org/web/20231128212837/https://learn.microsoft.com/en-us/windows/win32/rstmgr/using-restart-manager-with-a-secondary-installer
const _INVALID_RMHANDLE = ^_RMHANDLE(0)
type _RM_UNIQUE_PROCESS struct {
PID uint32
ProcessStartTime windows.Filetime
}
type _RM_APP_TYPE int32
const (
_RmUnknownApp _RM_APP_TYPE = 0
_RmMainWindow _RM_APP_TYPE = 1
_RmOtherWindow _RM_APP_TYPE = 2
_RmService _RM_APP_TYPE = 3
_RmExplorer _RM_APP_TYPE = 4
_RmConsole _RM_APP_TYPE = 5
_RmCritical _RM_APP_TYPE = 1000
)
type _RM_APP_STATUS uint32
const (
//lint:ignore U1000 maps to a win32 API
_RmStatusUnknown _RM_APP_STATUS = 0x0
_RmStatusRunning _RM_APP_STATUS = 0x1
_RmStatusStopped _RM_APP_STATUS = 0x2
_RmStatusStoppedOther _RM_APP_STATUS = 0x4
_RmStatusRestarted _RM_APP_STATUS = 0x8
_RmStatusErrorOnStop _RM_APP_STATUS = 0x10
_RmStatusErrorOnRestart _RM_APP_STATUS = 0x20
_RmStatusShutdownMasked _RM_APP_STATUS = 0x40
_RmStatusRestartMasked _RM_APP_STATUS = 0x80
)
type _RM_PROCESS_INFO struct {
Process _RM_UNIQUE_PROCESS
AppName [256]uint16
ServiceShortName [64]uint16
AppType _RM_APP_TYPE
AppStatus _RM_APP_STATUS
TSSessionID uint32
Restartable int32 // Win32 BOOL
}
// RestartManagerSession represents an open Restart Manager session.
type RestartManagerSession interface {
io.Closer
// AddPaths adds the fully-qualified paths in fqPaths to the set of binaries
// that will be monitored by this restart manager session. NOTE: This
// method is expensive to call, so it is better to make a single call with
// a larger slice than to make multiple calls with smaller slices.
AddPaths(fqPaths []string) error
// AffectedProcesses returns the UniqueProcess information for all running
// processes that utilize the binaries previously specified by calls to
// AddPaths.
AffectedProcesses() ([]UniqueProcess, error)
// Key returns the session key associated with this instance.
Key() string
}
// rmSession encapsulates the necessary information to represent an open
// restart manager session.
//
// Implementation note: rmSession methods that return errors should use
// windows.Errno codes whenever possible, as we call them from the custom
// action DLL. MSI custom actions are expected to return windows.Errno values;
// to ensure our compliance with this expectation, we should also use those
// values. Failure to do so will result in a generic windows.Errno being
// returned to the Windows Installer, which obviously is less than ideal.
type rmSession struct {
session _RMHANDLE
key string
logf logger.Logf
}
const _CCH_RM_SESSION_KEY = 32 // (excludes NUL terminator)
// NewRestartManagerSession creates a new RestartManagerSession that utilizes
// logf for logging.
func NewRestartManagerSession(logf logger.Logf) (RestartManagerSession, error) {
var sessionKeyBuf [_CCH_RM_SESSION_KEY + 1]uint16
result := rmSession{
logf: logf,
}
if err := rmStartSession(&result.session, 0, &sessionKeyBuf[0]); err != nil {
return nil, err
}
result.key = windows.UTF16ToString(sessionKeyBuf[:_CCH_RM_SESSION_KEY])
return &result, nil
}
// AttachRestartManagerSession opens a connection to an existing session
// specified by sessionKey, using logf for logging.
func AttachRestartManagerSession(logf logger.Logf, sessionKey string) (RestartManagerSession, error) {
sessionKey16, err := windows.UTF16PtrFromString(sessionKey)
if err != nil {
return nil, err
}
result := rmSession{
key: sessionKey,
logf: logf,
}
if err := rmJoinSession(&result.session, sessionKey16); err != nil {
return nil, err
}
return &result, nil
}
func (rms *rmSession) Close() error {
if rms == nil || rms.session == _INVALID_RMHANDLE {
return nil
}
if err := rmEndSession(rms.session); err != nil {
return err
}
rms.session = _INVALID_RMHANDLE
return nil
}
func (rms *rmSession) Key() string {
return rms.key
}
func (rms *rmSession) AffectedProcesses() ([]UniqueProcess, error) {
infos, err := rms.processList()
if err != nil {
return nil, err
}
result := make([]UniqueProcess, 0, len(infos))
for _, info := range infos {
result = append(result, UniqueProcess{
_RM_UNIQUE_PROCESS: info.Process,
CanReceiveGUIMsgs: info.AppType == _RmMainWindow || info.AppType == _RmOtherWindow,
})
}
return result, nil
}
func (rms *rmSession) processList() ([]_RM_PROCESS_INFO, error) {
const maxAttempts = 5
var avail, rebootReasons uint32
needed := uint32(1)
var buf []_RM_PROCESS_INFO
err := error(windows.ERROR_MORE_DATA)
numAttempts := 0
for err == windows.ERROR_MORE_DATA && numAttempts < maxAttempts {
numAttempts++
buf = make([]_RM_PROCESS_INFO, needed)
avail = needed
err = rmGetList(rms.session, &needed, &avail, unsafe.SliceData(buf), &rebootReasons)
}
if err != nil {
if err == windows.ERROR_SESSION_CREDENTIAL_CONFLICT {
// Add some more context about the meaning of this error.
err = fmt.Errorf("%w (the Restart Manager does not permit calling RmGetList from a process that did not originally create the session)", err)
}
return nil, err
}
return buf[:avail], nil
}
func (rms *rmSession) AddPaths(fqPaths []string) error {
if len(fqPaths) == 0 {
return nil
}
fqPaths16 := make([]*uint16, 0, len(fqPaths))
for _, fqPath := range fqPaths {
if !filepath.IsAbs(fqPath) {
return fmt.Errorf("%w: paths must be fully-qualified", windows.ERROR_BAD_PATHNAME)
}
fqPath16, err := windows.UTF16PtrFromString(fqPath)
if err != nil {
return err
}
fqPaths16 = append(fqPaths16, fqPath16)
}
return rmRegisterResources(rms.session, uint32(len(fqPaths16)), unsafe.SliceData(fqPaths16), 0, nil, 0, nil)
}
// UniqueProcess contains the necessary information to uniquely identify a
// process in the face of potential PID reuse.
type UniqueProcess struct {
_RM_UNIQUE_PROCESS
// CanReceiveGUIMsgs is true when the process has open top-level windows.
CanReceiveGUIMsgs bool
}
// AsRestartableProcess obtains a RestartableProcess populated using the
// information obtained from up.
func (up *UniqueProcess) AsRestartableProcess() (*RestartableProcess, error) {
// We need PROCESS_QUERY_INFORMATION instead of PROCESS_QUERY_LIMITED_INFORMATION
// in order for ProcessImageName to be able to work from within a privileged
// Windows Installer process.
// We need PROCESS_VM_READ for GetApplicationRestartSettings.
// We need PROCESS_TERMINATE and SYNCHRONIZE to terminate the process and
// to be able to wait for the terminated process's handle to signal.
access := uint32(windows.PROCESS_QUERY_INFORMATION | windows.PROCESS_TERMINATE | windows.PROCESS_VM_READ | windows.SYNCHRONIZE)
h, err := windows.OpenProcess(access, false, up.PID)
if err != nil {
return nil, fmt.Errorf("OpenProcess(%d[%#X]): %w", up.PID, up.PID, err)
}
defer func() {
if h == 0 {
return
}
windows.CloseHandle(h)
}()
var creationTime, exitTime, kernelTime, userTime windows.Filetime
if err := windows.GetProcessTimes(h, &creationTime, &exitTime, &kernelTime, &userTime); err != nil {
return nil, fmt.Errorf("GetProcessTimes: %w", err)
}
if creationTime != up.ProcessStartTime {
// The PID has been reused and does not actually reference the original process.
return nil, ErrDefunctProcess
}
var tok windows.Token
if err := windows.OpenProcessToken(h, windows.TOKEN_QUERY, &tok); err != nil {
return nil, fmt.Errorf("OpenProcessToken: %w", err)
}
defer tok.Close()
tsSessionID, err := TSSessionID(tok)
if err != nil {
return nil, fmt.Errorf("TSSessionID: %w", err)
}
logonSessionID, err := LogonSessionID(tok)
if err != nil {
return nil, fmt.Errorf("LogonSessionID: %w", err)
}
img, err := ProcessImageName(h)
if err != nil {
return nil, fmt.Errorf("ProcessImageName: %w", err)
}
const _RESTART_MAX_CMD_LINE = 1024
var cmdLine [_RESTART_MAX_CMD_LINE]uint16
cmdLineLen := uint32(len(cmdLine))
var rmFlags uint32
hr := getApplicationRestartSettings(h, &cmdLine[0], &cmdLineLen, &rmFlags)
// Not found is not an error; it just means that the app never set any restart settings.
if e := wingoes.ErrorFromHRESULT(hr); e.Failed() && e != wingoes.ErrorFromErrno(windows.ERROR_NOT_FOUND) {
return nil, fmt.Errorf("GetApplicationRestartSettings: %w", error(e))
}
if (rmFlags & _RESTART_NO_PATCH) != 0 {
// The application explicitly stated that it cannot be restarted during
// an upgrade.
return nil, ErrProcessNotRestartable
}
var logonSID string
// Non-fatal, so we'll proceed with best-effort.
if tokenGroups, err := tok.GetTokenGroups(); err == nil {
for _, group := range tokenGroups.AllGroups() {
if (group.Attributes & windows.SE_GROUP_LOGON_ID) != 0 {
logonSID = group.Sid.String()
break
}
}
}
var userSID string
// Non-fatal, so we'll proceed with best-effort.
if tokenUser, err := tok.GetTokenUser(); err == nil {
// Save the user's SID so that we can later check it against the currently
// logged-in Tailscale profile.
userSID = tokenUser.User.Sid.String()
}
result := &RestartableProcess{
Process: *up,
SessionInfo: SessionID{
LogonSession: logonSessionID,
TSSession: tsSessionID,
},
CommandLineInfo: CommandLineInfo{
ExePath: img,
Args: windows.UTF16ToString(cmdLine[:cmdLineLen]),
},
LogonSID: logonSID,
UserSID: userSID,
handle: h,
}
runtime.SetFinalizer(result, func(rp *RestartableProcess) { rp.Close() })
h = 0
return result, nil
}
// RestartableProcess contains the necessary information to uniquely identify
// an existing process, as well as the necessary information to be able to
// terminate it and later start a new instance in the identical logon session
// to the previous instance.
type RestartableProcess struct {
// Process uniquely identifies the existing process.
Process UniqueProcess
// SessionInfo uniquely identifies the Terminal Services (RDP) and logon
// sessions the existing process is running under.
SessionInfo SessionID
// CommandLineInfo contains the command line information necessary for restarting.
CommandLineInfo CommandLineInfo
// LogonSID contains the stringified SID of the existing process's token's logon session.
LogonSID string
// UserSID contains the stringified SID of the existing process's token's user.
UserSID string
// handle specifies the Win32 HANDLE associated with the existing process.
// When non-zero, it includes access rights for querying, terminating, and synchronizing.
handle windows.Handle
// hasExitCode is true when the exitCode field is valid.
hasExitCode bool
// exitCode contains exit code returned by this RestartableProcess once
// its termination has been recorded by (RestartableProcesses).Terminate.
// It is only valid when hasExitCode == true.
exitCode uint32
}
func (rp *RestartableProcess) Close() error {
if rp.handle == 0 {
return nil
}
windows.CloseHandle(rp.handle)
runtime.SetFinalizer(rp, nil)
rp.handle = 0
return nil
}
// RestartableProcesses is a map of PID to *RestartableProcess instance.
type RestartableProcesses map[uint32]*RestartableProcess
// NewRestartableProcesses instantiates a new RestartableProcesses.
func NewRestartableProcesses() RestartableProcesses {
return make(RestartableProcesses)
}
// Add inserts rp into rps.
func (rps RestartableProcesses) Add(rp *RestartableProcess) {
if rp != nil {
rps[rp.Process.PID] = rp
}
}
// Delete removes rp from rps.
func (rps RestartableProcesses) Delete(rp *RestartableProcess) {
if rp != nil {
delete(rps, rp.Process.PID)
}
}
// Close invokes (*RestartableProcess).Close on every value in rps, and then
// clears rps.
func (rps RestartableProcesses) Close() error {
for _, v := range rps {
v.Close()
}
clear(rps)
return nil
}
// _MAXIMUM_WAIT_OBJECTS is the Win32 constant for the maximum number of
// handles that a call to WaitForMultipleObjects may receive at once.
const _MAXIMUM_WAIT_OBJECTS = 64
// Terminate forcibly terminates all processes in rps using exitCode, and then
// waits for their process handles to signal, up to timeout.
func (rps RestartableProcesses) Terminate(logf logger.Logf, exitCode uint32, timeout time.Duration) error {
if len(rps) == 0 {
return nil
}
millis, err := wingoes.DurationToTimeoutMilliseconds(timeout)
if err != nil {
return err
}
errs := make([]error, 0, len(rps))
procs := make([]*RestartableProcess, 0, len(rps))
handles := make([]windows.Handle, 0, len(rps))
for _, v := range rps {
if err := windows.TerminateProcess(v.handle, exitCode); err != nil {
if err == windows.ERROR_ACCESS_DENIED {
// If v terminated before we attempted to terminate, we'll receive
// ERROR_ACCESS_DENIED, which is not really an error worth reporting in
// our use case. Just obtain the exit code and then close the process.
if err := windows.GetExitCodeProcess(v.handle, &v.exitCode); err != nil {
logf("GetExitCodeProcess failed: %v", err)
} else {
v.hasExitCode = true
}
v.Close()
} else {
errs = append(errs, &terminationError{rp: v, err: err})
}
continue
}
procs = append(procs, v)
handles = append(handles, v.handle)
}
for len(handles) > 0 {
// WaitForMultipleObjects can only wait on _MAXIMUM_WAIT_OBJECTS handles per
// call, so we batch them as necessary.
count := uint32(min(len(handles), _MAXIMUM_WAIT_OBJECTS))
waitCode, err := windows.WaitForMultipleObjects(handles[:count], true, millis)
if err != nil {
errs = append(errs, fmt.Errorf("waiting on terminated process handles: %w", err))
break
}
if e := windows.Errno(waitCode); e == windows.WAIT_TIMEOUT {
errs = append(errs, fmt.Errorf("waiting on terminated process handles: %w", error(e)))
break
}
if waitCode >= windows.WAIT_OBJECT_0 && waitCode < (windows.WAIT_OBJECT_0+count) {
// The first count process handles have all been signaled. Close them out.
for _, proc := range procs[:count] {
if err := windows.GetExitCodeProcess(proc.handle, &proc.exitCode); err != nil {
logf("GetExitCodeProcess failed: %v", err)
} else {
proc.hasExitCode = true
}
proc.Close()
}
procs = procs[count:]
handles = handles[count:]
continue
}
// We really shouldn't be reaching this point
panic(fmt.Sprintf("unexpected state from WaitForMultipleObjects: %d", waitCode))
}
if len(errs) != 0 {
return multierr.New(errs...)
}
return nil
}
type terminationError struct {
rp *RestartableProcess
err error
}
func (te *terminationError) Error() string {
pid := te.rp.Process.PID
return fmt.Sprintf("terminating process %d (%#X): %v", pid, pid, te.err)
}
func (te *terminationError) Unwrap() error {
return te.err
}
// SessionID encapsulates the necessary information for uniquely identifying
// sessions. In particular, SessionID contains enough information to detect
// reuse of Terminal Service session IDs.
type SessionID struct {
// LogonSession is the NT logon session ID.
LogonSession windows.LUID
// TSSession is the terminal services session ID.
TSSession uint32
}
// OpenToken obtains the security token associated with sessID.
func (sessID *SessionID) OpenToken() (windows.Token, error) {
var token windows.Token
if err := windows.WTSQueryUserToken(sessID.TSSession, &token); err != nil {
return 0, err
}
var err error
defer func() {
if err != nil {
token.Close()
}
}()
tokenLogonSession, err := LogonSessionID(token)
if err != nil {
return 0, err
}
if tokenLogonSession != sessID.LogonSession {
err = windows.ERROR_NO_SUCH_LOGON_SESSION
return 0, err
}
return token, nil
}
// ContainsToken determines whether token is contained within sessID.
func (sessID *SessionID) ContainsToken(token windows.Token) (bool, error) {
tokenTSSessionID, err := TSSessionID(token)
if err != nil {
return false, err
}
if tokenTSSessionID != sessID.TSSession {
return false, nil
}
tokenLogonSession, err := LogonSessionID(token)
if err != nil {
return false, err
}
return tokenLogonSession == sessID.LogonSession, nil
}
// This is the Window Station and Desktop within a particular session that must
// be specified for interactive processes: "Winsta0\\default\x00"
var defaultDesktop = unsafe.SliceData([]uint16{'W', 'i', 'n', 's', 't', 'a', '0', '\\', 'd', 'e', 'f', 'a', 'u', 'l', 't', 0})
// CommandLineInfo manages the necessary information for creating a Win32
// process using a specific command line.
type CommandLineInfo struct {
// ExePath must be a fully-qualified path to a Windows executable binary.
ExePath string
// Args must be any arguments supplied to the process, excluding the
// path to the binary itself. Args must be properly quoted according to
// Windows path rules. To create a properly quoted Args from scratch, call the
// SetArgs method instead.
Args string `json:",omitempty"`
}
// SetArgs converts args to a string quoted as necessary to satisfy the rules
// for Win32 command lines, and sets cli.Args to that string.
func (cli *CommandLineInfo) SetArgs(args []string) {
var buf strings.Builder
for _, arg := range args {
if buf.Len() > 0 {
buf.WriteByte(' ')
}
buf.WriteString(windows.EscapeArg(arg))
}
cli.Args = buf.String()
}
// Validate ensures that cli.ExePath contains an absolute path.
func (cli *CommandLineInfo) Validate() error {
if cli == nil {
return windows.ERROR_INVALID_PARAMETER
}
if !filepath.IsAbs(cli.ExePath) {
return fmt.Errorf("%w: CommandLineInfo requires absolute ExePath", windows.ERROR_BAD_PATHNAME)
}
return nil
}
// Resolve converts the information in cli to a format compatible with the Win32
// CreateProcess* family of APIs, as pointers to C-style UTF-16 strings. It also
// returns the full command line as a Go string for logging purposes.
func (cli *CommandLineInfo) Resolve() (exePath *uint16, cmdLine *uint16, cmdLineStr string, err error) {
// Resolve cmdLine first since that also does a Validate.
cmdLineStr, cmdLine, err = cli.resolveArgsAsUTF16Ptr()
if err != nil {
return nil, nil, "", err
}
exePath, err = windows.UTF16PtrFromString(cli.ExePath)
if err != nil {
return nil, nil, "", err
}
return exePath, cmdLine, cmdLineStr, nil
}
// resolveArgs quotes cli.ExePath as necessary, appends Args, and returns the result.
func (cli *CommandLineInfo) resolveArgs() (string, error) {
if err := cli.Validate(); err != nil {
return "", err
}
var cmdLineBuf strings.Builder
cmdLineBuf.WriteString(windows.EscapeArg(cli.ExePath))
if args := cli.Args; args != "" {
cmdLineBuf.WriteByte(' ')
cmdLineBuf.WriteString(args)
}
return cmdLineBuf.String(), nil
}
func (cli *CommandLineInfo) resolveArgsAsUTF16Ptr() (string, *uint16, error) {
s, err := cli.resolveArgs()
if err != nil {
return "", nil, err
}
s16, err := windows.UTF16PtrFromString(s)
if err != nil {
return "", nil, err
}
return s, s16, nil
}
// StartProcessInSession creates a new process using cmdLineInfo that will
// reside inside the session identified by sessID, with the security token whose
// logon is associated with sessID. The child process's environment will be
// inherited from the session token's environment.
func StartProcessInSession(sessID SessionID, cmdLineInfo CommandLineInfo) error {
return StartProcessInSessionWithHandler(sessID, cmdLineInfo, nil)
}
// PostCreateProcessHandler is a function that is invoked by
// StartProcessInSessionWithHandler when the child process has been successfully
// created. It is the responsibility of the handler to close the pi.Thread and
// pi.Process handles.
type PostCreateProcessHandler func(pi *windows.ProcessInformation)
// StartProcessInSessionWithHandler creates a new process using cmdLineInfo that
// will reside inside the session identified by sessID, with the security token
// whose logon is associated with sessID. The child process's environment will be
// inherited from the session token's environment. When the child process has
// been successfully created, handler is invoked with the windows.ProcessInformation
// that was returned by the OS.
func StartProcessInSessionWithHandler(sessID SessionID, cmdLineInfo CommandLineInfo, handler PostCreateProcessHandler) error {
pi, err := startProcessInSessionInternal(sessID, cmdLineInfo, 0)
if err != nil {
return err
}
if handler != nil {
handler(pi)
return nil
}
windows.CloseHandle(pi.Process)
windows.CloseHandle(pi.Thread)
return nil
}
// RunProcessInSession creates a new process and waits up to timeout for that
// child process to complete its execution. The process is created using
// cmdLineInfo and will reside inside the session identified by sessID, with the
// security token whose logon is associated with sessID. The child process's
// environment will be inherited from the session token's environment.
func RunProcessInSession(sessID SessionID, cmdLineInfo CommandLineInfo, timeout time.Duration) (uint32, error) {
timeoutMillis, err := wingoes.DurationToTimeoutMilliseconds(timeout)
if err != nil {
return 1, err
}
pi, err := startProcessInSessionInternal(sessID, cmdLineInfo, 0)
if err != nil {
return 1, err
}
windows.CloseHandle(pi.Thread)
defer windows.CloseHandle(pi.Process)
waitCode, err := windows.WaitForSingleObject(pi.Process, timeoutMillis)
if err != nil {
return 1, fmt.Errorf("WaitForSingleObject: %w", err)
}
if e := windows.Errno(waitCode); e == windows.WAIT_TIMEOUT {
return 1, e
}
if waitCode != windows.WAIT_OBJECT_0 {
// This should not be possible; log
return 1, fmt.Errorf("unexpected state from WaitForSingleObject: %d", waitCode)
}
var exitCode uint32
if err := windows.GetExitCodeProcess(pi.Process, &exitCode); err != nil {
return 1, err
}
return exitCode, nil
}
func startProcessInSessionInternal(sessID SessionID, cmdLineInfo CommandLineInfo, extraFlags uint32) (*windows.ProcessInformation, error) {
if err := cmdLineInfo.Validate(); err != nil {
return nil, err
}
token, err := sessID.OpenToken()
if err != nil {
return nil, fmt.Errorf("(*SessionID).OpenToken: %w", err)
}
defer token.Close()
exePath16, commandLine16, _, err := cmdLineInfo.Resolve()
if err != nil {
return nil, fmt.Errorf("(*CommandLineInfo).Resolve(): %w", err)
}
wd16, err := windows.UTF16PtrFromString(filepath.Dir(cmdLineInfo.ExePath))
if err != nil {
return nil, fmt.Errorf("UTF16PtrFromString(wd): %w", err)
}
env, err := token.Environ(false)
if err != nil {
return nil, fmt.Errorf("token environment: %w", err)
}
env16 := NewEnvBlock(env)
// The privileges in privNames are required for CreateProcessAsUser to be
// able to start processes as other users in other logon sessions.
privNames := []string{
"SeAssignPrimaryTokenPrivilege",
"SeIncreaseQuotaPrivilege",
}
dropPrivs, err := EnableCurrentThreadPrivileges(privNames)
if err != nil {
return nil, fmt.Errorf("EnableCurrentThreadPrivileges(%#v): %w", privNames, err)
}
defer dropPrivs()
createFlags := extraFlags | windows.CREATE_UNICODE_ENVIRONMENT | windows.DETACHED_PROCESS
si := windows.StartupInfo{
Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})),
Desktop: defaultDesktop,
}
var pi windows.ProcessInformation
if err := windows.CreateProcessAsUser(token, exePath16, commandLine16, nil, nil,
false, createFlags, env16, wd16, &si, &pi); err != nil {
return nil, fmt.Errorf("CreateProcessAsUser: %w", err)
}
return &pi, nil
}
// NewEnvBlock processes a slice of strings containing "NAME=value" pairs
// representing a process envionment into the environment block format used by
// Windows APIs such as CreateProcess. env must be sorted case-insensitively
// by variable name.
func NewEnvBlock(env []string) *uint16 {
// Intentionally using bytes.Buffer here because we're writing nul bytes (the standard library does this too).
var buf bytes.Buffer
for _, v := range env {
buf.WriteString(v)
buf.WriteByte(0)
}
if buf.Len() == 0 {
// So that we end with a double-null in the empty env case
buf.WriteByte(0)
}
buf.WriteByte(0)
return unsafe.SliceData(utf16.Encode([]rune(string(buf.Bytes()))))
}

View File

@@ -0,0 +1,319 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package winutil
import (
"errors"
"fmt"
"os"
"slices"
"unsafe"
"github.com/dblohm7/wingoes"
"golang.org/x/sys/windows"
)
var (
// ErrAlreadyResolved is returned by (*StartupInfoBuilder).Resolve when the
// StartupInfoBuilder has already been resolved.
ErrAlreadyResolved = errors.New("StartupInfo already resolved")
// ErrAlreadySet is returned by StartupInfoBuilder setters if the value
// has already been set.
ErrAlreadySet = errors.New("StartupInfoBuilder value already set")
// ErrTooManyMitigationPolicyArguments is returned by
// (*StartupInfoBuilder).AddMitigationPolicyFlags if more arguments are
// passed than are supported by the current version of Windows. This error
// may be wrapped with additional information, so use [errors.Is] to check for it.
ErrTooManyMitigationPolicyArguments = errors.New("too many mitigation policy arguments for current Windows version")
)
// Attribute IDs not yet present in x/sys/windows
const (
_PROC_THREAD_ATTRIBUTE_JOB_LIST = 0x0002000D
)
// Mitigation flags from the Win32 SDK
const (
PROCESS_CREATION_MITIGATION_POLICY_EXTENSION_POINT_DISABLE_ALWAYS_ON = (1 << 32)
PROCESS_CREATION_MITIGATION_POLICY_BLOCK_NON_MICROSOFT_BINARIES_ALWAYS_ON = (1 << 44)
PROCESS_CREATION_MITIGATION_POLICY_IMAGE_LOAD_NO_REMOTE_ALWAYS_ON = (1 << 52)
PROCESS_CREATION_MITIGATION_POLICY_IMAGE_LOAD_NO_LOW_LABEL_ALWAYS_ON = (1 << 56)
PROCESS_CREATION_MITIGATION_POLICY_IMAGE_LOAD_PREFER_SYSTEM32_ALWAYS_ON = (1 << 60)
)
// StartupInfoBuilder constructs a Windows STARTUPINFOEX and optional
// process/thread attribute list for use with the CreateProcess family of APIs.
type StartupInfoBuilder struct {
siex windows.StartupInfoEx
attrs map[uintptr]any // attr -> value
attrContainer *windows.ProcThreadAttributeListContainer
}
func (sib *StartupInfoBuilder) Close() error {
si := &sib.siex.StartupInfo
if (si.Flags & windows.STARTF_USESTDHANDLES) != 0 {
for _, h := range []windows.Handle{si.StdInput, si.StdOutput, si.StdErr} {
if canBeInherited(h) {
windows.CloseHandle(h)
}
}
}
sib.siex = windows.StartupInfoEx{}
if sib.attrContainer != nil {
sib.attrContainer.Delete()
sib.attrContainer = nil
}
sib.attrs = nil
return nil
}
// Resolve causes all settings and attributes stored within sib to be processed
// and formatted into valid arguments for use by CreateProcess* APIs.
// The returned values will not be altered any further by sib, so the caller
// is free to make additional customizations to the returned values prior to
// passing them into CreateProcess.
func (sib *StartupInfoBuilder) Resolve() (startupInfo *windows.StartupInfo, inheritHandles bool, createProcessFlags uint32, err error) {
if sib.siex.StartupInfo.Cb != 0 {
return nil, false, 0, ErrAlreadyResolved
}
// Always create a Unicode environment.
createProcessFlags = windows.CREATE_UNICODE_ENVIRONMENT
if l := uint32(len(sib.attrs)); l > 0 {
attrCont, err := windows.NewProcThreadAttributeList(l)
if err != nil {
return nil, false, 0, err
}
defer func() {
if err != nil {
attrCont.Delete()
}
}()
for attr, val := range sib.attrs {
var pval unsafe.Pointer
var sval uintptr
switch v := val.(type) {
case windows.Handle:
// An individual handle is pointer-width and is thus passed by value.
pval = unsafe.Pointer(v)
sval = unsafe.Sizeof(v)
case []uint64:
pval = unsafe.Pointer(unsafe.SliceData(v))
sval = unsafe.Sizeof(v[0]) * uintptr(len(v))
case []windows.Handle:
pval = unsafe.Pointer(unsafe.SliceData(v))
sval = unsafe.Sizeof(v[0]) * uintptr(len(v))
default:
panic("unsupported data type")
}
// Note that pointer keepalives are managed by attrCont.
if err := attrCont.Update(attr, pval, sval); err != nil {
return nil, false, 0, err
}
if attr == windows.PROC_THREAD_ATTRIBUTE_HANDLE_LIST {
inheritHandles = true
}
}
sib.attrContainer = attrCont
sib.siex.ProcThreadAttributeList = attrCont.List()
sib.siex.StartupInfo.Cb = uint32(unsafe.Sizeof(sib.siex))
createProcessFlags |= windows.EXTENDED_STARTUPINFO_PRESENT
} else {
sib.siex.StartupInfo.Cb = uint32(unsafe.Sizeof(sib.siex.StartupInfo))
}
return &sib.siex.StartupInfo, inheritHandles, createProcessFlags, nil
}
func canBeInherited(h windows.Handle) bool {
if h == 0 || h == windows.InvalidHandle {
return false
}
ft, _ := windows.GetFileType(h)
switch ft {
case windows.FILE_TYPE_DISK, windows.FILE_TYPE_PIPE:
return true
case windows.FILE_TYPE_CHAR:
// Console handles are treated differently from other character devices.
// In particular, they should not be set up to be inherited like other
// kernel handles. We determine whether h is a console handle by attempting
// to retrieve its console mode. If this call fails then h is not a console.
var mode uint32
return windows.GetConsoleMode(h, &mode) != nil
default:
return false
}
}
// SetStdHandles sets the StdInput, StdOutput, and StdErr handles and configures
// their inheritability as needed. When the handles are valid, non-console
// kernel objects, sib takes ownership of of them. All three handles may be set
// to zero to indicate that the parent's std handles should not be implicitly
// inherited.
//
// It returns ErrAlreadySet if the handles have already been set by a previous call.
func (sib *StartupInfoBuilder) SetStdHandles(stdin, stdout, stderr windows.Handle) error {
if (sib.siex.StartupInfo.Flags & windows.STARTF_USESTDHANDLES) != 0 {
return ErrAlreadySet
}
toInherit := make([]windows.Handle, 0, 3)
for _, h := range []windows.Handle{stdin, stdout, stderr} {
if !canBeInherited(h) {
continue
}
toInherit = append(toInherit, h)
}
if err := sib.InheritHandles(toInherit...); err != nil {
return err
}
sib.siex.StartupInfo.Flags |= windows.STARTF_USESTDHANDLES
sib.siex.StartupInfo.StdInput = stdin
sib.siex.StartupInfo.StdOutput = stdout
sib.siex.StartupInfo.StdErr = stderr
return nil
}
func (sib *StartupInfoBuilder) makeAttrs() {
if sib.attrs == nil {
// The size of this map should correspond to the number of distinct
// attribute values supported by the StartupInfoBuilder API. Currently
// we support four:
// * Inheritable handle list;
// * Pseudoconsole;
// * Mitigation policy;
// * Job list
sib.attrs = make(map[uintptr]any, 4)
}
}
func (sib *StartupInfoBuilder) getAttr(attr uintptr) any {
sib.makeAttrs()
return sib.attrs[attr]
}
// InheritHandles configures each handle in handles to be inheritable and adds
// it to the inheritable handle list proc/thread attribute. handles must consist
// entirely of kernel objects (handles that are closed via windows.CloseHandle).
// InheritHandles may be called multiple times; each successive call accumulates
// handles into an internal list maintained by sib.
func (sib *StartupInfoBuilder) InheritHandles(handles ...windows.Handle) error {
if len(handles) == 0 {
return nil
}
newHandles := make([]windows.Handle, 0, len(handles))
for _, h := range handles {
if h == 0 || h == windows.InvalidHandle || slices.Contains(newHandles, h) {
continue
}
if err := windows.SetHandleInformation(h, windows.HANDLE_FLAG_INHERIT, windows.HANDLE_FLAG_INHERIT); err != nil {
return err
}
newHandles = append(newHandles, h)
}
if len(newHandles) == 0 {
return nil
}
var handleList []windows.Handle
if attrv := sib.getAttr(windows.PROC_THREAD_ATTRIBUTE_HANDLE_LIST); attrv != nil {
handleList = attrv.([]windows.Handle)
}
sib.attrs[windows.PROC_THREAD_ATTRIBUTE_HANDLE_LIST] = append(handleList, newHandles...)
return nil
}
// AddMitigationPolicyFlags sets the process mitigation policy flags in newFlags
// on the mitigation policy proc/thread attribute. It accepts a different
// number of arguments depending on the current Windows version. If the
// current Windows version is Windows 10 build 1703 or newer, it accepts up to
// two arguments. It only accepts one argument on older versions of Windows 10.
// If too many arguments are supplied, AddMitigationPolicyFlags returns
// ErrTooManyMitigationPolicyArguments wrapped with additional information;
// use errors.Is to check for this error.
// AddMitigationPolicyFlags may be called multiple times; each successive call
// accumulates additional flags into the mitigation policy.
func (sib *StartupInfoBuilder) AddMitigationPolicyFlags(newFlags ...uint64) error {
if len(newFlags) == 0 {
return nil
}
supportedLen := 1
if wingoes.IsWin10BuildOrGreater(wingoes.Win10Build1703) {
supportedLen++
}
if len(newFlags) > supportedLen {
return fmt.Errorf("%w: no more than %d allowed", ErrTooManyMitigationPolicyArguments, supportedLen)
}
attrv := sib.getAttr(windows.PROC_THREAD_ATTRIBUTE_MITIGATION_POLICY)
switch v := attrv.(type) {
case nil:
sib.attrs[windows.PROC_THREAD_ATTRIBUTE_MITIGATION_POLICY] = newFlags
case []uint64:
if newElems := len(newFlags) - len(v); newElems > 0 {
v = append(v, make([]uint64, newElems)...)
sib.attrs[windows.PROC_THREAD_ATTRIBUTE_MITIGATION_POLICY] = v
}
for i := range v {
v[i] |= newFlags[i]
}
default:
panic("unexpected attribute type")
}
return nil
}
// SetPseudoConsole sets pty as the pseudoconsole proc/thread attribute.
// pty must be a conpty handle. It returns ErrAlreadySet if the pty has already
// been successfully set by a previous call.
func (sib *StartupInfoBuilder) SetPseudoConsole(pty windows.Handle) error {
if pty == 0 {
return os.ErrInvalid
}
if attrv := sib.getAttr(windows.PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE); attrv != nil {
return ErrAlreadySet
}
sib.attrs[windows.PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE] = pty
return nil
}
// AssignToJob assigns the process created by sib to job. AssignToJob may be
// called multiple times to assign the process to multiple jobs.
func (sib *StartupInfoBuilder) AssignToJob(job windows.Handle) error {
if job == 0 {
return os.ErrInvalid
}
var jobList []windows.Handle
if attrv := sib.getAttr(_PROC_THREAD_ATTRIBUTE_JOB_LIST); attrv != nil {
jobList = attrv.([]windows.Handle)
}
if slices.Contains(jobList, job) {
return nil
}
sib.attrs[_PROC_THREAD_ATTRIBUTE_JOB_LIST] = append(jobList, job)
return nil
}

303
vendor/tailscale.com/util/winutil/svcdiag_windows.go generated vendored Normal file
View File

@@ -0,0 +1,303 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package winutil
import (
"encoding/hex"
"encoding/json"
"fmt"
"strings"
"unsafe"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/mgr"
"tailscale.com/types/logger"
"tailscale.com/util/set"
)
// LogSvcState obtains the state of the Windows service named rootSvcName and
// all of its dependencies, and then emits that state to logf.
func LogSvcState(logf logger.Logf, rootSvcName string) {
logEntries := []svcStateLogEntry{}
walkFn := func(svc *mgr.Service, config mgr.Config) {
status, err := svc.Query()
if err != nil {
logf("Failed retrieving Status for service %q: %v", svc.Name, err)
}
logEntries = append(logEntries, makeLogEntry(svc, status, config))
}
err := walkServices(rootSvcName, walkFn)
if err != nil {
logf("LogSvcState error: %v", err)
return
}
json, err := json.MarshalIndent(logEntries, "", " ")
if err != nil {
logf("Error marshaling service log entries: %v", err)
return
}
var builder strings.Builder
builder.WriteString("State of service ")
fmt.Fprintf(&builder, "%q", rootSvcName)
builder.WriteString(" and its dependencies:")
builder.WriteString("\n")
builder.Write(json)
builder.WriteString("\n")
logf(builder.String())
}
// walkSvcFunc is type of the callback function invoked by WalkServices.
type walkSvcFunc func(*mgr.Service, mgr.Config)
// walkServices opens the service named rootSvcName and walks its dependency
// graph, invoking callback for each service (including the root itself).
func walkServices(rootSvcName string, callback walkSvcFunc) error {
scm, err := ConnectToLocalSCMForRead()
if err != nil {
return fmt.Errorf("connecting to Service Control Manager: %w", err)
}
defer scm.Disconnect()
rootSvc, err := OpenServiceForRead(scm, rootSvcName)
if err != nil {
return fmt.Errorf("opening service %q: %w", rootSvcName, err)
}
deps := []*mgr.Service{rootSvc}
defer func() {
// Any service still in deps when we return is open and must be closed.
for _, dep := range deps {
dep.Close()
}
}()
seen := set.Set[string]{}
for err == nil && len(deps) > 0 {
err = func() error {
curSvc := deps[len(deps)-1]
defer curSvc.Close()
deps = deps[:len(deps)-1]
seen.Add(curSvc.Name)
curCfg, err := curSvc.Config()
if err != nil {
return fmt.Errorf("retrieving Config for service %q: %w", curSvc.Name, err)
}
callback(curSvc, curCfg)
for _, depName := range curCfg.Dependencies {
if seen.Contains(depName) {
continue
}
depSvc, err := OpenServiceForRead(scm, depName)
if err != nil {
return fmt.Errorf("opening service %q: %w", depName, err)
}
deps = append(deps, depSvc)
}
return nil
}()
}
return err
}
type svcStateLogEntry struct {
ServiceName string `json:"serviceName"`
ServiceType string `json:"serviceType"`
State string `json:"state"`
StartupType string `json:"startupType"`
Triggers *_SERVICE_TRIGGER_INFO `json:"triggers,omitempty"`
TriggersError error `json:"triggersError,omitempty"`
}
type _SERVICE_TRIGGER_SPECIFIC_DATA_ITEM struct {
dataType uint32
cbData uint32
data *byte
}
type serviceTriggerSpecificDataItemJSONMarshal struct {
DataType uint32 `json:"dataType"`
Data string `json:"data,omitempty"`
}
func (tsdi *_SERVICE_TRIGGER_SPECIFIC_DATA_ITEM) MarshalJSON() ([]byte, error) {
m := serviceTriggerSpecificDataItemJSONMarshal{DataType: tsdi.dataType}
const maxDataLen = 128
data := unsafe.Slice(tsdi.data, tsdi.cbData)
if len(data) > maxDataLen {
// Only output the first maxDataLen bytes.
m.Data = fmt.Sprintf("%s... (truncated %d bytes)", hex.EncodeToString(data[:maxDataLen]), len(data)-maxDataLen)
} else {
m.Data = hex.EncodeToString(data)
}
return json.Marshal(m)
}
type _SERVICE_TRIGGER struct {
triggerType uint32
action uint32
triggerSubtype *windows.GUID
cDataItems uint32
pDataItems *_SERVICE_TRIGGER_SPECIFIC_DATA_ITEM
}
type serviceTriggerJSONMarshal struct {
TriggerType uint32 `json:"triggerType"`
Action uint32 `json:"action"`
TriggerSubtype string `json:"triggerSubtype,omitempty"`
DataItems []_SERVICE_TRIGGER_SPECIFIC_DATA_ITEM `json:"dataItems"`
}
func (ti *_SERVICE_TRIGGER) MarshalJSON() ([]byte, error) {
m := serviceTriggerJSONMarshal{
TriggerType: ti.triggerType,
Action: ti.action,
DataItems: unsafe.Slice(ti.pDataItems, ti.cDataItems),
}
if ti.triggerSubtype != nil {
m.TriggerSubtype = ti.triggerSubtype.String()
}
return json.Marshal(m)
}
type _SERVICE_TRIGGER_INFO struct {
cTriggers uint32
pTriggers *_SERVICE_TRIGGER
_ *byte // pReserved
}
func (sti *_SERVICE_TRIGGER_INFO) MarshalJSON() ([]byte, error) {
triggers := unsafe.Slice(sti.pTriggers, sti.cTriggers)
return json.Marshal(triggers)
}
// getSvcTriggerInfo obtains information about any system events that may be
// used to start svc. Only relevant for demand-start (aka manual) services.
func getSvcTriggerInfo(svc *mgr.Service) (*_SERVICE_TRIGGER_INFO, error) {
var desiredLen uint32
err := queryServiceConfig2(svc.Handle, windows.SERVICE_CONFIG_TRIGGER_INFO,
nil, 0, &desiredLen)
if err != windows.ERROR_INSUFFICIENT_BUFFER {
return nil, err
}
buf := make([]byte, desiredLen)
err = queryServiceConfig2(svc.Handle, windows.SERVICE_CONFIG_TRIGGER_INFO,
&buf[0], desiredLen, &desiredLen)
if err != nil {
return nil, err
}
return (*_SERVICE_TRIGGER_INFO)(unsafe.Pointer(&buf[0])), nil
}
// makeLogEntry consolidates relevant service information into a svcStateLogEntry.
// We record the values of various service configuration constants as strings
// so the the log entries are easy to interpret at a glance by humans.
func makeLogEntry(svc *mgr.Service, status svc.Status, cfg mgr.Config) (entry svcStateLogEntry) {
entry.ServiceName = svc.Name
switch status.State {
case windows.SERVICE_STOPPED:
entry.State = "STOPPED"
case windows.SERVICE_START_PENDING:
entry.State = "START_PENDING"
case windows.SERVICE_STOP_PENDING:
entry.State = "STOP_PENDING"
case windows.SERVICE_RUNNING:
entry.State = "RUNNING"
case windows.SERVICE_CONTINUE_PENDING:
entry.State = "CONTINUE_PENDING"
case windows.SERVICE_PAUSE_PENDING:
entry.State = "PAUSE_PENDING"
case windows.SERVICE_PAUSED:
entry.State = "PAUSED"
case windows.SERVICE_NO_CHANGE:
entry.State = "NO_CHANGE"
default:
entry.State = fmt.Sprintf("Unknown constant %d", status.State)
}
switch cfg.ServiceType {
case windows.SERVICE_FILE_SYSTEM_DRIVER:
entry.ServiceType = "FILE_SYSTEM_DRIVER"
case windows.SERVICE_KERNEL_DRIVER:
entry.ServiceType = "KERNEL_DRIVER"
case windows.SERVICE_WIN32_OWN_PROCESS, windows.SERVICE_WIN32_SHARE_PROCESS:
entry.ServiceType = "WIN32"
default:
entry.ServiceType = fmt.Sprintf("Unknown constant %d", cfg.ServiceType)
}
switch cfg.StartType {
case windows.SERVICE_BOOT_START:
entry.StartupType = "BOOT_START"
case windows.SERVICE_SYSTEM_START:
entry.StartupType = "SYSTEM_START"
case windows.SERVICE_AUTO_START:
if cfg.DelayedAutoStart {
entry.StartupType = "DELAYED_AUTO_START"
} else {
entry.StartupType = "AUTO_START"
}
case windows.SERVICE_DEMAND_START:
entry.StartupType = "DEMAND_START"
triggerInfo, err := getSvcTriggerInfo(svc)
if err == nil {
entry.Triggers = triggerInfo
} else {
entry.TriggersError = err
}
case windows.SERVICE_DISABLED:
entry.StartupType = "DISABLED"
default:
entry.StartupType = fmt.Sprintf("Unknown constant %d", cfg.StartType)
}
return entry
}
// ConnectToLocalSCMForRead connects to the Windows Service Control Manager with
// read-only access. x/sys/windows/svc/mgr/Connect requests read+write access,
// which requires Administrative access rights.
func ConnectToLocalSCMForRead() (*mgr.Mgr, error) {
h, err := windows.OpenSCManager(nil, nil, windows.GENERIC_READ)
if err != nil {
return nil, err
}
return &mgr.Mgr{Handle: h}, nil
}
// OpenServiceForRead opens a service with read-only access.
// x/sys/windows/svc/mgr/(*Mgr).OpenService requests read+write access,
// which requires Administrative access rights.
func OpenServiceForRead(scm *mgr.Mgr, svcName string) (*mgr.Service, error) {
svcNamePtr, err := windows.UTF16PtrFromString(svcName)
if err != nil {
return nil, err
}
h, err := windows.OpenService(scm.Handle, svcNamePtr, windows.GENERIC_READ)
if err != nil {
return nil, err
}
return &mgr.Service{Name: svcName, Handle: h}, nil
}

View File

@@ -0,0 +1,237 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package winutil
import (
"os/user"
"strings"
"unsafe"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
"tailscale.com/types/logger"
"tailscale.com/util/winutil/winenv"
)
type _PROFILEINFO struct {
Size uint32
Flags uint32
UserName *uint16
ProfilePath *uint16
DefaultPath *uint16
ServerName *uint16
PolicyPath *uint16
Profile registry.Key
}
// _PROFILEINFO flags
const (
_PI_NOUI = 0x00000001
)
type _USER_INFO_4 struct {
Name *uint16
Password *uint16
PasswordAge uint32
Priv uint32
HomeDir *uint16
Comment *uint16
Flags uint32
ScriptPath *uint16
AuthFlags uint32
FullName *uint16
UsrComment *uint16
Parms *uint16
Workstations *uint16
LastLogon uint32
LastLogoff uint32
AcctExpires uint32
MaxStorage uint32
UnitsPerWeek uint32
LogonHours *byte
BadPwCount uint32
NumLogons uint32
LogonServer *uint16
CountryCode uint32
CodePage uint32
UserSID *windows.SID
PrimaryGroupID uint32
Profile *uint16
HomeDirDrive *uint16
PasswordExpired uint32
}
// UserProfile encapsulates a loaded Windows user profile.
type UserProfile struct {
token windows.Token
profileKey registry.Key
}
// LoadUserProfile loads the Windows user profile associated with token and u.
// u serves simply as a hint for speeding up resolution of the username and thus
// must reference the same user as token. u may also be nil, in which case token
// is queried for the username.
func LoadUserProfile(token windows.Token, u *user.User) (up *UserProfile, err error) {
computerName, userName, err := getComputerAndUserName(token, u)
if err != nil {
return nil, err
}
var roamingProfilePath *uint16
if winenv.IsDomainJoined() {
roamingProfilePath, err = getRoamingProfilePath(nil, token, computerName, userName)
if err != nil {
return nil, err
}
}
pi := _PROFILEINFO{
Size: uint32(unsafe.Sizeof(_PROFILEINFO{})),
Flags: _PI_NOUI,
UserName: userName,
ProfilePath: roamingProfilePath,
ServerName: computerName,
}
if err := loadUserProfile(token, &pi); err != nil {
return nil, err
}
// Duplicate the token so that we have a copy to use during cleanup without
// consuming the token passed into this function.
var dupToken windows.Handle
cp := windows.CurrentProcess()
if err := windows.DuplicateHandle(cp, windows.Handle(token), cp, &dupToken, 0,
false, windows.DUPLICATE_SAME_ACCESS); err != nil {
return nil, err
}
return &UserProfile{
token: windows.Token(dupToken),
profileKey: pi.Profile,
}, nil
}
// RegKey returns the registry key associated with the user profile.
// The caller must not close the returned key.
func (up *UserProfile) RegKey() registry.Key {
return up.profileKey
}
// Close unloads the user profile and cleans up any other resources held by up.
func (up *UserProfile) Close() error {
if up.profileKey != 0 {
if err := unloadUserProfile(up.token, up.profileKey); err != nil {
return err
}
up.profileKey = 0
}
if up.token != 0 {
up.token.Close()
up.token = 0
}
return nil
}
func getRoamingProfilePath(logf logger.Logf, token windows.Token, computerName, userName *uint16) (path *uint16, err error) {
// logf is for debugging/testing. While we would normally replace a nil logf
// with logger.Discard, we're using explicit checks within this func so that
// we don't waste time allocating and converting UTF-16 strings unnecessarily.
var comp string
if logf != nil {
comp = windows.UTF16PtrToString(computerName)
user := windows.UTF16PtrToString(userName)
logf("BEGIN getRoamingProfilePath(%q, %q)", comp, user)
defer logf("END getRoamingProfilePath(%q, %q)", comp, user)
}
isDomainName, err := isDomainName(computerName)
if err != nil {
return nil, err
}
if isDomainName {
if logf != nil {
logf("computerName %q is a domain, resolving...", comp)
}
dcInfo, err := resolveDomainController(computerName, nil)
if err != nil {
return nil, err
}
defer dcInfo.Close()
computerName = dcInfo.DomainControllerName
if logf != nil {
dom := windows.UTF16PtrToString(computerName)
logf("%q resolved to %q", comp, dom)
}
}
var pbuf *byte
if err := windows.NetUserGetInfo(computerName, userName, 4, &pbuf); err != nil {
return nil, err
}
defer windows.NetApiBufferFree(pbuf)
ui4 := (*_USER_INFO_4)(unsafe.Pointer(pbuf))
if logf != nil {
logf("getRoamingProfilePath: got %#v", *ui4)
}
profilePath := ui4.Profile
if profilePath == nil {
return nil, nil
}
if *profilePath == 0 {
// Empty string
return nil, nil
}
var expanded [windows.MAX_PATH + 1]uint16
if err := expandEnvironmentStringsForUser(token, profilePath, &expanded[0], uint32(len(expanded))); err != nil {
return nil, err
}
if logf != nil {
logf("returning %q", windows.UTF16ToString(expanded[:]))
}
// This buffer is only used briefly, so we don't bother copying it into a shorter slice.
return &expanded[0], nil
}
func getComputerAndUserName(token windows.Token, u *user.User) (computerName *uint16, userName *uint16, err error) {
if u == nil {
tokenUser, err := token.GetTokenUser()
if err != nil {
return nil, nil, err
}
u, err = user.LookupId(tokenUser.User.Sid.String())
if err != nil {
return nil, nil, err
}
}
var strComputer, strUser string
before, after, hasBackslash := strings.Cut(u.Username, `\`)
if hasBackslash {
strComputer = before
strUser = after
} else {
strUser = before
}
if strComputer != "" {
computerName, err = windows.UTF16PtrFromString(strComputer)
if err != nil {
return nil, nil, err
}
}
userName, err = windows.UTF16PtrFromString(strUser)
if err != nil {
return nil, nil, err
}
return computerName, userName, nil
}

15
vendor/tailscale.com/util/winutil/winenv/mksyscall.go generated vendored Normal file
View File

@@ -0,0 +1,15 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package winenv
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go
// https://web.archive.org/web/20240407040123/https://learn.microsoft.com/en-us/windows/win32/api/mdmregistration/nf-mdmregistration-isdeviceregisteredwithmanagement
//sys isDeviceRegisteredWithManagement(isMDMRegistered *bool, upnBufLen uint32, upnBuf *uint16) (hr int32, err error) = MDMRegistration.IsDeviceRegisteredWithManagement?
// https://web.archive.org/web/20240407035921/https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-verifyversioninfow
//sys verifyVersionInfo(verInfo *osVersionInfoEx, typ verTypeMask, cond verCondMask) (res bool) = kernel32.VerifyVersionInfoW
// https://web.archive.org/web/20240407035706/https://learn.microsoft.com/en-us/windows/win32/api/winnt/nf-winnt-versetconditionmask
//sys verSetConditionMask(condMask verCondMask, typ verTypeMask, cond verCond) (res verCondMask) = kernel32.VerSetConditionMask

View File

@@ -0,0 +1,109 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package winenv provides information about the current Windows environment.
// This includes details such as whether the device is a server or workstation,
// if it is AD domain-joined, MDM-registered, or neither, and other characteristics.
package winenv
import (
"runtime"
"unsafe"
"golang.org/x/sys/windows"
)
// osVersionInfoEx contains operating system version information.
// See [OSVERSIONINFOEXW] for details.
//
// [OSVERSIONINFOEXW]: https://web.archive.org/web/20240407035213/https://learn.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-osversioninfoexw
type osVersionInfoEx struct {
cbSize uint32
majorVersion uint32
minorVersion uint32
buildNumber uint32
platformId uint32
csdVersion [128]uint16
servicePackMajor uint16
servicePackMinor uint16
suiteMask uint16
productType verProductType
reserved uint8
}
type (
verTypeMask uint32
verCondMask uint64
verCond uint8
verProductType uint8
)
// See [VER_SET_CONDITION] and [VerSetConditionMask] for details.
//
// [VER_SET_CONDITION]: https://web.archive.org/web/20240407035400/https://learn.microsoft.com/en-us/windows/win32/api/winnt/nf-winnt-ver_set_condition
// [VerSetConditionMask]: https://web.archive.org/web/20240407035706/https://learn.microsoft.com/en-us/windows/win32/api/winnt/nf-winnt-versetconditionmask
const (
_VER_MINORVERSION = verTypeMask(0x0000001)
_VER_MAJORVERSION = verTypeMask(0x0000002)
_VER_BUILDNUMBER = verTypeMask(0x0000004)
_VER_PLATFORMID = verTypeMask(0x0000008)
_VER_SERVICEPACKMINOR = verTypeMask(0x0000010)
_VER_SERVICEPACKMAJOR = verTypeMask(0x0000020)
_VER_SUITENAME = verTypeMask(0x0000040)
_VER_PRODUCT_TYPE = verTypeMask(0x0000080)
_VER_NT_WORKSTATION = verProductType(1)
_VER_NT_DOMAIN_CONTROLLER = verProductType(2)
_VER_NT_SERVER = verProductType(3)
_VER_EQUAL = verCond(1)
_VER_GREATER = verCond(2)
_VER_GREATER_EQUAL = verCond(3)
_VER_LESS = verCond(4)
_VER_LESS_EQUAL = verCond(5)
_VER_AND = verCond(6)
_VER_OR = verCond(7)
)
// IsDomainJoined reports whether the device is domain-joined.
func IsDomainJoined() bool {
var domain *uint16
var status uint32
if err := windows.NetGetJoinInformation(nil, &domain, &status); err != nil {
return false
}
windows.NetApiBufferFree((*byte)(unsafe.Pointer(domain)))
return status == windows.NetSetupDomainName
}
// IsMDMRegistered reports whether the device is MDM-registered.
func IsMDMRegistered() bool {
const S_OK int32 = 0
var isMDMRegistered bool
if hr, err := isDeviceRegisteredWithManagement(&isMDMRegistered, 0, nil); err != nil || hr != S_OK {
return false
}
return isMDMRegistered
}
// IsManaged reports whether the device is managed through AD or MDM.
func IsManaged() bool {
return IsDomainJoined() || IsMDMRegistered()
}
// IsWindowsServer reports whether the device is running a Windows Server operating system.
func IsWindowsServer() bool {
if runtime.GOARCH != "amd64" && runtime.GOARCH != "arm64" {
// TODO(nickkhyl): the Windows Server versions we support do not have 32-bit editions.
// But we should remove this check once we adopt mkwinsyscallx, as it can handle 64-bit
// long arguments such as verCondMask.
return false
}
osvi := &osVersionInfoEx{
cbSize: uint32(unsafe.Sizeof(osVersionInfoEx{})),
productType: _VER_NT_WORKSTATION,
}
condMask := verSetConditionMask(0, _VER_PRODUCT_TYPE, _VER_EQUAL)
return !verifyVersionInfo(osvi, _VER_PRODUCT_TYPE, condMask)
}

View File

@@ -0,0 +1,77 @@
// Code generated by 'go generate'; DO NOT EDIT.
package winenv
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modMDMRegistration = windows.NewLazySystemDLL("MDMRegistration.dll")
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
procIsDeviceRegisteredWithManagement = modMDMRegistration.NewProc("IsDeviceRegisteredWithManagement")
procVerSetConditionMask = modkernel32.NewProc("VerSetConditionMask")
procVerifyVersionInfoW = modkernel32.NewProc("VerifyVersionInfoW")
)
func isDeviceRegisteredWithManagement(isMDMRegistered *bool, upnBufLen uint32, upnBuf *uint16) (hr int32, err error) {
err = procIsDeviceRegisteredWithManagement.Find()
if err != nil {
return
}
var _p0 uint32
if *isMDMRegistered {
_p0 = 1
}
r0, _, e1 := syscall.Syscall(procIsDeviceRegisteredWithManagement.Addr(), 3, uintptr(unsafe.Pointer(&_p0)), uintptr(upnBufLen), uintptr(unsafe.Pointer(upnBuf)))
*isMDMRegistered = _p0 != 0
hr = int32(r0)
if hr == 0 {
err = errnoErr(e1)
}
return
}
func verSetConditionMask(condMask verCondMask, typ verTypeMask, cond verCond) (res verCondMask) {
r0, _, _ := syscall.Syscall(procVerSetConditionMask.Addr(), 3, uintptr(condMask), uintptr(typ), uintptr(cond))
res = verCondMask(r0)
return
}
func verifyVersionInfo(verInfo *osVersionInfoEx, typ verTypeMask, cond verCondMask) (res bool) {
r0, _, _ := syscall.Syscall(procVerifyVersionInfoW.Addr(), 3, uintptr(unsafe.Pointer(verInfo)), uintptr(typ), uintptr(cond))
res = r0 != 0
return
}

116
vendor/tailscale.com/util/winutil/winutil.go generated vendored Normal file
View File

@@ -0,0 +1,116 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package winutil contains misc Windows/Win32 helper functions.
package winutil
import (
"os/user"
)
const (
// RegBase is the registry path inside HKEY_LOCAL_MACHINE where registry settings
// are stored. This constant is a non-empty string only when GOOS=windows.
RegBase = regBase
// RegPolicyBase is the registry path inside HKEY_LOCAL_MACHINE where registry
// policies are stored. This constant is a non-empty string only when
// GOOS=windows.
RegPolicyBase = regPolicyBase
)
// GetPolicyString looks up a registry value in the local machine's path for
// system policies, or returns empty string and the error.
// Use this function to read values that may be set by sysadmins via the MSI
// installer or via GPO. For registry settings that you do *not* want to be
// visible to sysadmin tools, use GetRegString instead.
//
// This function will only work on GOOS=windows. Trying to run it on any other
// OS will always return an empty string and ErrNoValue.
// If value does not exist or another error happens, returns empty string and error.
func GetPolicyString(name string) (string, error) {
return getPolicyString(name)
}
// GetPolicyInteger looks up a registry value in the local machine's path for
// system policies, or returns 0 and the associated error.
// Use this function to read values that may be set by sysadmins via the MSI
// installer or via GPO. For registry settings that you do *not* want to be
// visible to sysadmin tools, use GetRegInteger instead.
//
// This function will only work on GOOS=windows. Trying to run it on any other
// OS will always return 0 and ErrNoValue.
// If value does not exist or another error happens, returns 0 and error.
func GetPolicyInteger(name string) (uint64, error) {
return getPolicyInteger(name)
}
func GetPolicyStringArray(name string) ([]string, error) {
return getPolicyStringArray(name)
}
// GetRegString looks up a registry path in the local machine path, or returns
// an empty string and error.
//
// This function will only work on GOOS=windows. Trying to run it on any other
// OS will always return an empty string and ErrNoValue.
// If value does not exist or another error happens, returns empty string and error.
func GetRegString(name string) (string, error) {
return getRegString(name)
}
// GetRegInteger looks up a registry path in the local machine path, or returns
// 0 and the error.
//
// This function will only work on GOOS=windows. Trying to run it on any other
// OS will always return 0 and ErrNoValue.
// If value does not exist or another error happens, returns 0 and error.
func GetRegInteger(name string) (uint64, error) {
return getRegInteger(name)
}
// IsSIDValidPrincipal determines whether the SID contained in uid represents a
// type that is a valid security principal under Windows. This check helps us
// work around a bug in the standard library's Windows implementation of
// LookupId in os/user.
// See https://github.com/tailscale/tailscale/issues/869
//
// This function will only work on GOOS=windows. Trying to run it on any other
// OS will always return false.
func IsSIDValidPrincipal(uid string) bool {
return isSIDValidPrincipal(uid)
}
// LookupPseudoUser attempts to resolve the user specified by uid by checking
// against well-known pseudo-users on Windows. This is a temporary workaround
// until https://github.com/golang/go/issues/49509 is resolved and shipped.
//
// This function will only work on GOOS=windows. Trying to run it on any other
// OS will always return an error.
func LookupPseudoUser(uid string) (*user.User, error) {
return lookupPseudoUser(uid)
}
// RegisterForRestartOpts supplies options to RegisterForRestart.
type RegisterForRestartOpts struct {
RestartOnCrash bool // When true, this program will be restarted after a crash.
RestartOnHang bool // When true, this program will be restarted after a hang.
RestartOnUpgrade bool // When true, this program will be restarted after an upgrade.
RestartOnReboot bool // When true, this program will be restarted after a reboot.
UseCmdLineArgs bool // When true, CmdLineArgs will be used as the program's arguments upon restart. Otherwise no arguments will be provided.
CmdLineArgs []string // When UseCmdLineArgs == true, contains the command line arguments, excluding the executable name itself. If nil or empty, the arguments from the current process will be re-used.
}
// RegisterForRestart registers the current process' restart preferences with
// the Windows Restart Manager. This enables the OS to intelligently restart
// the calling executable as requested via opts. This should be called by any
// programs which need to be restarted by the installer post-update.
//
// This function may be called multiple times; the opts from the most recent
// call will override those from any previous invocations.
//
// This function will only work on GOOS=windows. Trying to run it on any other
// OS will always return nil.
func RegisterForRestart(opts RegisterForRestartOpts) error {
return registerForRestart(opts)
}

View File

@@ -0,0 +1,38 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !windows
package winutil
import (
"errors"
"fmt"
"os/user"
"runtime"
)
const regBase = ``
const regPolicyBase = ``
var ErrNoValue = errors.New("no value because registry is unavailable on this OS")
func getPolicyString(name string) (string, error) { return "", ErrNoValue }
func getPolicyInteger(name string) (uint64, error) { return 0, ErrNoValue }
func getPolicyStringArray(name string) ([]string, error) { return nil, ErrNoValue }
func getRegString(name string) (string, error) { return "", ErrNoValue }
func getRegInteger(name string) (uint64, error) { return 0, ErrNoValue }
func isSIDValidPrincipal(uid string) bool { return false }
func lookupPseudoUser(uid string) (*user.User, error) {
return nil, fmt.Errorf("unimplemented on %v", runtime.GOOS)
}
func IsCurrentProcessElevated() bool { return false }
func registerForRestart(opts RegisterForRestartOpts) error { return nil }

949
vendor/tailscale.com/util/winutil/winutil_windows.go generated vendored Normal file
View File

@@ -0,0 +1,949 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package winutil
import (
"errors"
"fmt"
"log"
"math"
"os/exec"
"os/user"
"reflect"
"runtime"
"strings"
"syscall"
"time"
"unsafe"
"golang.org/x/exp/constraints"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
)
const (
regBase = `SOFTWARE\Tailscale IPN`
regPolicyBase = `SOFTWARE\Policies\Tailscale`
)
// ErrNoShell is returned when the shell process is not found.
var ErrNoShell = errors.New("no Shell process is present")
// ErrNoValue is returned when the value doesn't exist in the registry.
var ErrNoValue = registry.ErrNotExist
// GetDesktopPID searches the PID of the process that's running the
// currently active desktop. Returns ErrNoShell if the shell is not present.
// Usually the PID will be for explorer.exe.
func GetDesktopPID() (uint32, error) {
hwnd := windows.GetShellWindow()
if hwnd == 0 {
return 0, ErrNoShell
}
var pid uint32
windows.GetWindowThreadProcessId(hwnd, &pid)
if pid == 0 {
return 0, fmt.Errorf("invalid PID for HWND %v", hwnd)
}
return pid, nil
}
func getPolicyString(name string) (string, error) {
s, err := getRegStringInternal(registry.LOCAL_MACHINE, regPolicyBase, name)
if err != nil {
// Fall back to the legacy path
return getRegString(name)
}
return s, err
}
func getPolicyStringArray(name string) ([]string, error) {
return getRegStringsInternal(regPolicyBase, name)
}
func getRegString(name string) (string, error) {
s, err := getRegStringInternal(registry.LOCAL_MACHINE, regBase, name)
if err != nil {
return "", err
}
return s, err
}
func getPolicyInteger(name string) (uint64, error) {
i, err := getRegIntegerInternal(regPolicyBase, name)
if err != nil {
// Fall back to the legacy path
return getRegInteger(name)
}
return i, err
}
func getRegInteger(name string) (uint64, error) {
i, err := getRegIntegerInternal(regBase, name)
if err != nil {
return 0, err
}
return i, err
}
func getRegStringInternal(key registry.Key, subKey, name string) (string, error) {
key, err := registry.OpenKey(key, subKey, registry.READ)
if err != nil {
if err != ErrNoValue {
log.Printf("registry.OpenKey(%v): %v", subKey, err)
}
return "", err
}
defer key.Close()
val, _, err := key.GetStringValue(name)
if err != nil {
if err != ErrNoValue {
log.Printf("registry.GetStringValue(%v): %v", name, err)
}
return "", err
}
return val, nil
}
// GetRegUserString looks up a registry path in the current user key, or returns
// an empty string and error.
func GetRegUserString(name string) (string, error) {
return getRegStringInternal(registry.CURRENT_USER, regBase, name)
}
// SetRegUserString sets a SZ value identified by name in the current user key
// to the string specified by value.
func SetRegUserString(name, value string) error {
key, _, err := registry.CreateKey(registry.CURRENT_USER, regBase, registry.SET_VALUE)
if err != nil {
log.Printf("registry.CreateKey(%v): %v", regBase, err)
}
defer key.Close()
return key.SetStringValue(name, value)
}
// GetRegStrings looks up a registry value in the local machine path, or returns
// the given default if it can't.
func GetRegStrings(name string, defval []string) []string {
s, err := getRegStringsInternal(regBase, name)
if err != nil {
return defval
}
return s
}
func getRegStringsInternal(subKey, name string) ([]string, error) {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, subKey, registry.READ)
if err != nil {
if err != ErrNoValue {
log.Printf("registry.OpenKey(%v): %v", subKey, err)
}
return nil, err
}
defer key.Close()
val, _, err := key.GetStringsValue(name)
if err != nil {
if err != ErrNoValue {
log.Printf("registry.GetStringValue(%v): %v", name, err)
}
return nil, err
}
return val, nil
}
// SetRegStrings sets a MULTI_SZ value in the in the local machine path
// to the strings specified by values.
func SetRegStrings(name string, values []string) error {
return setRegStringsInternal(regBase, name, values)
}
func setRegStringsInternal(subKey, name string, values []string) error {
key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, subKey, registry.SET_VALUE)
if err != nil {
log.Printf("registry.CreateKey(%v): %v", subKey, err)
}
defer key.Close()
return key.SetStringsValue(name, values)
}
// DeleteRegValue removes a registry value in the local machine path.
func DeleteRegValue(name string) error {
return deleteRegValueInternal(regBase, name)
}
func deleteRegValueInternal(subKey, name string) error {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, subKey, registry.SET_VALUE)
if err == ErrNoValue {
return nil
}
if err != nil {
log.Printf("registry.OpenKey(%v): %v", subKey, err)
return err
}
defer key.Close()
err = key.DeleteValue(name)
if err == ErrNoValue {
err = nil
}
return err
}
func getRegIntegerInternal(subKey, name string) (uint64, error) {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, subKey, registry.READ)
if err != nil {
if err != ErrNoValue {
log.Printf("registry.OpenKey(%v): %v", subKey, err)
}
return 0, err
}
defer key.Close()
val, _, err := key.GetIntegerValue(name)
if err != nil {
if err != ErrNoValue {
log.Printf("registry.GetIntegerValue(%v): %v", name, err)
}
return 0, err
}
return val, nil
}
var (
kernel32 = syscall.NewLazyDLL("kernel32.dll")
procWTSGetActiveConsoleSessionId = kernel32.NewProc("WTSGetActiveConsoleSessionId")
)
// TODO(crawshaw): replace with x/sys/windows... one day.
// https://go-review.googlesource.com/c/sys/+/331909
func WTSGetActiveConsoleSessionId() uint32 {
r1, _, _ := procWTSGetActiveConsoleSessionId.Call()
return uint32(r1)
}
func isSIDValidPrincipal(uid string) bool {
usid, err := syscall.StringToSid(uid)
if err != nil {
return false
}
_, _, accType, err := usid.LookupAccount("")
if err != nil {
return false
}
switch accType {
case syscall.SidTypeUser, syscall.SidTypeGroup, syscall.SidTypeDomain, syscall.SidTypeAlias, syscall.SidTypeWellKnownGroup, syscall.SidTypeComputer:
return true
default:
// Reject deleted users, invalid SIDs, unknown SIDs, mandatory label SIDs, etc.
return false
}
}
// EnableCurrentThreadPrivilege enables the named privilege
// in the current thread's access token. The current goroutine is also locked to
// the OS thread (runtime.LockOSThread). Callers must call the returned disable
// function when done with the privileged task.
func EnableCurrentThreadPrivilege(name string) (disable func(), err error) {
return EnableCurrentThreadPrivileges([]string{name})
}
// EnableCurrentThreadPrivileges enables the named privileges
// in the current thread's access token. The current goroutine is also locked to
// the OS thread (runtime.LockOSThread). Callers must call the returned disable
// function when done with the privileged task.
func EnableCurrentThreadPrivileges(names []string) (disable func(), err error) {
runtime.LockOSThread()
if len(names) == 0 {
// Nothing to enable; no-op isn't really an error...
return runtime.UnlockOSThread, nil
}
if err := windows.ImpersonateSelf(windows.SecurityImpersonation); err != nil {
runtime.UnlockOSThread()
return nil, err
}
disable = func() {
defer runtime.UnlockOSThread()
// If RevertToSelf fails, it's not really recoverable and we should panic.
// Failure to do so would leak the privileges we're enabling, which is a
// security issue.
if err := windows.RevertToSelf(); err != nil {
panic(fmt.Sprintf("RevertToSelf failed: %v", err))
}
}
defer func() {
if err != nil {
disable()
}
}()
var t windows.Token
err = windows.OpenThreadToken(windows.CurrentThread(),
windows.TOKEN_QUERY|windows.TOKEN_ADJUST_PRIVILEGES, false, &t)
if err != nil {
return nil, err
}
defer t.Close()
tp := newTokenPrivileges(len(names))
privs := tp.AllPrivileges()
for i := range privs {
var privStr *uint16
privStr, err = windows.UTF16PtrFromString(names[i])
if err != nil {
return nil, err
}
err = windows.LookupPrivilegeValue(nil, privStr, &privs[i].Luid)
if err != nil {
return nil, err
}
privs[i].Attributes = windows.SE_PRIVILEGE_ENABLED
}
err = windows.AdjustTokenPrivileges(t, false, tp, 0, nil, nil)
if err != nil {
return nil, err
}
return disable, nil
}
func newTokenPrivileges(numPrivs int) *windows.Tokenprivileges {
if numPrivs <= 0 {
panic("numPrivs must be > 0")
}
numBytes := unsafe.Sizeof(windows.Tokenprivileges{}) + (uintptr(numPrivs-1) * unsafe.Sizeof(windows.LUIDAndAttributes{}))
buf := make([]byte, numBytes)
result := (*windows.Tokenprivileges)(unsafe.Pointer(unsafe.SliceData(buf)))
result.PrivilegeCount = uint32(numPrivs)
return result
}
// StartProcessAsChild starts exePath process as a child of parentPID.
// StartProcessAsChild copies parentPID's environment variables into
// the new process, along with any optional environment variables in extraEnv.
func StartProcessAsChild(parentPID uint32, exePath string, extraEnv []string) error {
// The rest of this function requires SeDebugPrivilege to be held.
//
// According to https://docs.microsoft.com/en-us/windows/win32/procthread/process-security-and-access-rights
//
// ... To open a handle to another process and obtain full access rights,
// you must enable the SeDebugPrivilege privilege. ...
//
// But we only need PROCESS_CREATE_PROCESS. So perhaps SeDebugPrivilege is too much.
//
// https://devblogs.microsoft.com/oldnewthing/20080314-00/?p=23113
//
// TODO: try look for something less than SeDebugPrivilege
disableSeDebug, err := EnableCurrentThreadPrivilege("SeDebugPrivilege")
if err != nil {
return err
}
defer disableSeDebug()
ph, err := windows.OpenProcess(
windows.PROCESS_CREATE_PROCESS|windows.PROCESS_QUERY_INFORMATION|windows.PROCESS_DUP_HANDLE,
false, parentPID)
if err != nil {
return err
}
defer windows.CloseHandle(ph)
var pt windows.Token
err = windows.OpenProcessToken(ph, windows.TOKEN_QUERY, &pt)
if err != nil {
return err
}
defer pt.Close()
env, err := pt.Environ(false)
if err != nil {
return err
}
env = append(env, extraEnv...)
sys := &syscall.SysProcAttr{ParentProcess: syscall.Handle(ph)}
cmd := exec.Command(exePath)
cmd.Env = env
cmd.SysProcAttr = sys
return cmd.Start()
}
// StartProcessAsCurrentGUIUser is like StartProcessAsChild, but if finds
// current logged in user desktop process (normally explorer.exe),
// and passes found PID to StartProcessAsChild.
func StartProcessAsCurrentGUIUser(exePath string, extraEnv []string) error {
// as described in https://devblogs.microsoft.com/oldnewthing/20190425-00/?p=102443
desktop, err := GetDesktopPID()
if err != nil {
return fmt.Errorf("failed to find desktop: %v", err)
}
err = StartProcessAsChild(desktop, exePath, extraEnv)
if err != nil {
return fmt.Errorf("failed to start executable: %v", err)
}
return nil
}
// CreateAppMutex creates a named Windows mutex, returning nil if the mutex
// is created successfully or an error if the mutex already exists or could not
// be created for some other reason.
func CreateAppMutex(name string) (windows.Handle, error) {
return windows.CreateMutex(nil, false, windows.StringToUTF16Ptr(name))
}
// getTokenInfoFixedLen obtains known fixed-length token information. Use this
// function for information classes that output enumerations, BOOLs, integers etc.
func getTokenInfoFixedLen[T any](token windows.Token, infoClass uint32) (result T, err error) {
var actualLen uint32
p := (*byte)(unsafe.Pointer(&result))
err = windows.GetTokenInformation(token, infoClass, p, uint32(unsafe.Sizeof(result)), &actualLen)
return result, err
}
type tokenElevationType int32
const (
tokenElevationTypeDefault tokenElevationType = 1
tokenElevationTypeFull tokenElevationType = 2
tokenElevationTypeLimited tokenElevationType = 3
)
// IsTokenLimited returns whether token is a limited UAC token.
func IsTokenLimited(token windows.Token) (bool, error) {
elevationType, err := getTokenInfoFixedLen[tokenElevationType](token, windows.TokenElevationType)
if err != nil {
return false, err
}
return elevationType == tokenElevationTypeLimited, nil
}
// UserSIDs contains the SIDs for a Windows NT token object's associated user
// as well as its primary group.
type UserSIDs struct {
User *windows.SID
PrimaryGroup *windows.SID
}
// GetCurrentUserSIDs returns a UserSIDs struct containing SIDs for the
// current process' user and primary group.
func GetCurrentUserSIDs() (*UserSIDs, error) {
token, err := windows.OpenCurrentProcessToken()
if err != nil {
return nil, err
}
defer token.Close()
userInfo, err := token.GetTokenUser()
if err != nil {
return nil, err
}
primaryGroup, err := token.GetTokenPrimaryGroup()
if err != nil {
return nil, err
}
return &UserSIDs{userInfo.User.Sid, primaryGroup.PrimaryGroup}, nil
}
// IsCurrentProcessElevated returns true when the current process is
// running with an elevated token, implying Administrator access.
func IsCurrentProcessElevated() bool {
token, err := windows.OpenCurrentProcessToken()
if err != nil {
return false
}
defer token.Close()
return token.IsElevated()
}
// keyOpenTimeout is how long we wait for a registry key to appear. For some
// reason, registry keys tied to ephemeral interfaces can take a long while to
// appear after interface creation, and we can end up racing with that.
const keyOpenTimeout = 20 * time.Second
// RegistryPath represents a path inside a root registry.Key.
type RegistryPath string
// RegistryPathPrefix specifies a RegistryPath prefix that must be suffixed with
// another RegistryPath to make a valid RegistryPath.
type RegistryPathPrefix string
// WithSuffix returns a RegistryPath with the given suffix appended.
func (p RegistryPathPrefix) WithSuffix(suf string) RegistryPath {
return RegistryPath(string(p) + suf)
}
const (
IPv4TCPIPBase RegistryPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters`
IPv6TCPIPBase RegistryPath = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters`
NetBTBase RegistryPath = `SYSTEM\CurrentControlSet\Services\NetBT\Parameters`
IPv4TCPIPInterfacePrefix RegistryPathPrefix = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
IPv6TCPIPInterfacePrefix RegistryPathPrefix = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
NetBTInterfacePrefix RegistryPathPrefix = `SYSTEM\CurrentControlSet\Services\NetBT\Parameters\Interfaces\Tcpip_`
)
// ErrKeyWaitTimeout is returned by OpenKeyWait when calls timeout.
var ErrKeyWaitTimeout = errors.New("timeout waiting for registry key")
// OpenKeyWait opens a registry key, waiting for it to appear if necessary. It
// returns the opened key, or ErrKeyWaitTimeout if the key does not appear
// within 20s. The caller must call Close on the returned key.
func OpenKeyWait(k registry.Key, path RegistryPath, access uint32) (registry.Key, error) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
deadline := time.Now().Add(keyOpenTimeout)
pathSpl := strings.Split(string(path), "\\")
for i := 0; ; i++ {
keyName := pathSpl[i]
isLast := i+1 == len(pathSpl)
event, err := windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return 0, fmt.Errorf("windows.CreateEvent: %w", err)
}
defer windows.CloseHandle(event)
var key registry.Key
for {
err = windows.RegNotifyChangeKeyValue(windows.Handle(k), false, windows.REG_NOTIFY_CHANGE_NAME, event, true)
if err != nil {
return 0, fmt.Errorf("windows.RegNotifyChangeKeyValue: %w", err)
}
var accessFlags uint32
if isLast {
accessFlags = access
} else {
accessFlags = registry.NOTIFY
}
key, err = registry.OpenKey(k, keyName, accessFlags)
if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND {
timeout := time.Until(deadline) / time.Millisecond
if timeout < 0 {
timeout = 0
}
s, err := windows.WaitForSingleObject(event, uint32(timeout))
if err != nil {
return 0, fmt.Errorf("windows.WaitForSingleObject: %w", err)
}
if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows
return 0, ErrKeyWaitTimeout
}
} else if err != nil {
return 0, fmt.Errorf("registry.OpenKey(%v): %w", path, err)
} else {
if isLast {
return key, nil
}
defer key.Close()
break
}
}
k = key
}
}
func lookupPseudoUser(uid string) (*user.User, error) {
sid, err := windows.StringToSid(uid)
if err != nil {
return nil, err
}
// We're looking for SIDs "S-1-5-x" where 17 <= x <= 20.
// This is checking for the the "5"
if sid.IdentifierAuthority() != windows.SECURITY_NT_AUTHORITY {
return nil, fmt.Errorf(`SID %q does not use "NT AUTHORITY"`, uid)
}
// This is ensuring that there is only one sub-authority.
// In other words, only one value after the "5".
if sid.SubAuthorityCount() != 1 {
return nil, fmt.Errorf("SID %q should have only one subauthority", uid)
}
// Get that sub-authority value (this is "x" above) and check it.
rid := sid.SubAuthority(0)
if rid < 17 || rid > 20 {
return nil, fmt.Errorf("SID %q does not represent a known pseudo-user", uid)
}
// We've got one of the known pseudo-users. Look up the localized name of the
// account.
username, domain, _, err := sid.LookupAccount("")
if err != nil {
return nil, err
}
// This call is best-effort. If it fails, homeDir will be empty.
homeDir, _ := findHomeDirInRegistry(uid)
result := &user.User{
Uid: uid,
Gid: uid, // Gid == Uid with these accounts.
Username: fmt.Sprintf(`%s\%s`, domain, username),
Name: username,
HomeDir: homeDir,
}
return result, nil
}
// findHomeDirInRegistry finds the user home path based on the uid.
// This is borrowed from Go's std lib.
func findHomeDirInRegistry(uid string) (dir string, err error) {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion\ProfileList\`+uid, registry.QUERY_VALUE)
if err != nil {
return "", err
}
defer k.Close()
dir, _, err = k.GetStringValue("ProfileImagePath")
if err != nil {
return "", err
}
return dir, nil
}
// ProcessImageName returns the fully-qualified path to the executable image
// associated with process.
func ProcessImageName(process windows.Handle) (string, error) {
var pathBuf [windows.MAX_PATH]uint16
pathBufLen := uint32(len(pathBuf))
if err := windows.QueryFullProcessImageName(process, 0, &pathBuf[0], &pathBufLen); err != nil {
return "", err
}
return windows.UTF16ToString(pathBuf[:pathBufLen]), nil
}
// TSSessionIDToLogonSessionID retrieves the logon session ID associated with
// tsSessionId, which is a Terminal Services / RDP session ID. The calling
// process must be running as LocalSystem.
func TSSessionIDToLogonSessionID(tsSessionID uint32) (logonSessionID windows.LUID, err error) {
var token windows.Token
if err := windows.WTSQueryUserToken(tsSessionID, &token); err != nil {
return logonSessionID, fmt.Errorf("WTSQueryUserToken: %w", err)
}
defer token.Close()
return LogonSessionID(token)
}
// TSSessionID obtains the Terminal Services (RDP) session ID associated with token.
func TSSessionID(token windows.Token) (tsSessionID uint32, err error) {
return getTokenInfoFixedLen[uint32](token, windows.TokenSessionId)
}
type tokenOrigin struct {
originatingLogonSession windows.LUID
}
// LogonSessionID obtains the logon session ID associated with token.
func LogonSessionID(token windows.Token) (logonSessionID windows.LUID, err error) {
origin, err := getTokenInfoFixedLen[tokenOrigin](token, windows.TokenOrigin)
if err != nil {
return logonSessionID, err
}
return origin.originatingLogonSession, nil
}
// BufUnit is a type constraint for buffers passed into AllocateContiguousBuffer
// and SetNTString.
type BufUnit interface {
byte | uint16
}
// AllocateContiguousBuffer allocates memory to satisfy the Windows idiom where
// some structs contain pointers that are expected to refer to memory within the
// same buffer containing the struct itself. T is the type that contains
// the pointers. values must contain the actual data that is to be copied
// into the buffer after T. AllocateContiguousBuffer returns a pointer to the
// struct, the total length of the buffer in bytes, and a slice containing
// each value within the buffer. The caller may use slcs to populate any
// pointers in t as needed. Each element of slcs corresponds to the element of
// values in the same position.
//
// It is the responsibility of the caller to ensure that any values expected
// to contain null-terminated strings are in fact null-terminated!
//
// AllocateContiguousBuffer panics if no values are passed in, as there are
// better alternatives for allocating a struct in that case.
func AllocateContiguousBuffer[T any, BU BufUnit](values ...[]BU) (t *T, tLenBytes uint32, slcs [][]BU) {
if len(values) == 0 {
panic("len(values) must be > 0")
}
// Get the sizes of T and BU, then compute a preferred alignment for T.
tT := reflect.TypeFor[T]()
szT := tT.Size()
szBU := int(unsafe.Sizeof(BU(0)))
alignment := max(tT.Align(), szBU)
// Our buffers for values will start at the next szBU boundary.
tLenBytes = alignUp(uint32(szT), szBU)
firstValueOffset := tLenBytes
// Accumulate the length of each value into tLenBytes
for _, v := range values {
tLenBytes += uint32(len(v) * szBU)
}
// Now that we know the final length, align up to our preferred boundary.
tLenBytes = alignUp(tLenBytes, alignment)
// Allocate the buffer. We choose a type for the slice that is appropriate
// for the desired alignment. Note that we do not have a strict requirement
// that T contain pointer fields; we could just be appending more data
// within the same buffer.
bufLen := tLenBytes / uint32(alignment)
var pt unsafe.Pointer
switch alignment {
case 1:
pt = unsafe.Pointer(unsafe.SliceData(make([]byte, bufLen)))
case 2:
pt = unsafe.Pointer(unsafe.SliceData(make([]uint16, bufLen)))
case 4:
pt = unsafe.Pointer(unsafe.SliceData(make([]uint32, bufLen)))
case 8:
pt = unsafe.Pointer(unsafe.SliceData(make([]uint64, bufLen)))
default:
panic(fmt.Sprintf("bad alignment %d", alignment))
}
t = (*T)(pt)
slcs = make([][]BU, 0, len(values))
// Use the limits of the buffer area after t to construct a slice representing the remaining buffer.
firstValuePtr := unsafe.Pointer(uintptr(pt) + uintptr(firstValueOffset))
buf := unsafe.Slice((*BU)(firstValuePtr), (tLenBytes-firstValueOffset)/uint32(szBU))
// Copy each value into the buffer and record a slice describing each value's limits into slcs.
var index int
for _, v := range values {
if len(v) == 0 {
// We allow zero-length values; we simply append a nil slice.
slcs = append(slcs, nil)
continue
}
valueSlice := buf[index : index+len(v)]
copy(valueSlice, v)
slcs = append(slcs, valueSlice)
index += len(v)
}
return t, tLenBytes, slcs
}
// alignment must be a power of 2
func alignUp[V constraints.Integer](v V, alignment int) V {
return v + ((-v) & (V(alignment) - 1))
}
// NTStr is a type constraint requiring the type to be either a
// windows.NTString or a windows.NTUnicodeString.
type NTStr interface {
windows.NTString | windows.NTUnicodeString
}
// SetNTString sets the value of nts in-place to point to the string contained
// within buf. A nul terminator is optional in buf.
func SetNTString[NTS NTStr, BU BufUnit](nts *NTS, buf []BU) {
isEmpty := len(buf) == 0
codeUnitSize := uint16(unsafe.Sizeof(BU(0)))
lenBytes := len(buf) * int(codeUnitSize)
if lenBytes > math.MaxUint16 {
panic("buffer length must fit into uint16")
}
lenBytes16 := uint16(lenBytes)
switch p := any(nts).(type) {
case *windows.NTString:
if isEmpty {
*p = windows.NTString{}
break
}
p.Buffer = unsafe.SliceData(any(buf).([]byte))
p.MaximumLength = lenBytes16
p.Length = lenBytes16
// account for nul terminator when present
if buf[len(buf)-1] == 0 {
p.Length -= codeUnitSize
}
case *windows.NTUnicodeString:
if isEmpty {
*p = windows.NTUnicodeString{}
break
}
p.Buffer = unsafe.SliceData(any(buf).([]uint16))
p.MaximumLength = lenBytes16
p.Length = lenBytes16
// account for nul terminator when present
if buf[len(buf)-1] == 0 {
p.Length -= codeUnitSize
}
default:
panic("unknown type")
}
}
type domainControllerAddressType uint32
const (
//lint:ignore U1000 maps to a win32 API
_DS_INET_ADDRESS domainControllerAddressType = 1
_DS_NETBIOS_ADDRESS domainControllerAddressType = 2
)
type domainControllerFlag uint32
const (
//lint:ignore U1000 maps to a win32 API
_DS_PDC_FLAG domainControllerFlag = 0x00000001
_DS_GC_FLAG domainControllerFlag = 0x00000004
_DS_LDAP_FLAG domainControllerFlag = 0x00000008
_DS_DS_FLAG domainControllerFlag = 0x00000010
_DS_KDC_FLAG domainControllerFlag = 0x00000020
_DS_TIMESERV_FLAG domainControllerFlag = 0x00000040
_DS_CLOSEST_FLAG domainControllerFlag = 0x00000080
_DS_WRITABLE_FLAG domainControllerFlag = 0x00000100
_DS_GOOD_TIMESERV_FLAG domainControllerFlag = 0x00000200
_DS_NDNC_FLAG domainControllerFlag = 0x00000400
_DS_SELECT_SECRET_DOMAIN_6_FLAG domainControllerFlag = 0x00000800
_DS_FULL_SECRET_DOMAIN_6_FLAG domainControllerFlag = 0x00001000
_DS_WS_FLAG domainControllerFlag = 0x00002000
_DS_DS_8_FLAG domainControllerFlag = 0x00004000
_DS_DS_9_FLAG domainControllerFlag = 0x00008000
_DS_DS_10_FLAG domainControllerFlag = 0x00010000
_DS_KEY_LIST_FLAG domainControllerFlag = 0x00020000
_DS_PING_FLAGS domainControllerFlag = 0x000FFFFF
_DS_DNS_CONTROLLER_FLAG domainControllerFlag = 0x20000000
_DS_DNS_DOMAIN_FLAG domainControllerFlag = 0x40000000
_DS_DNS_FOREST_FLAG domainControllerFlag = 0x80000000
)
type _DOMAIN_CONTROLLER_INFO struct {
DomainControllerName *uint16
DomainControllerAddress *uint16
DomainControllerAddressType domainControllerAddressType
DomainGuid windows.GUID
DomainName *uint16
DnsForestName *uint16
Flags domainControllerFlag
DcSiteName *uint16
ClientSiteName *uint16
}
func (dci *_DOMAIN_CONTROLLER_INFO) Close() error {
if dci == nil {
return nil
}
return windows.NetApiBufferFree((*byte)(unsafe.Pointer(dci)))
}
type dsGetDcNameFlag uint32
const (
//lint:ignore U1000 maps to a win32 API
_DS_FORCE_REDISCOVERY dsGetDcNameFlag = 0x00000001
_DS_DIRECTORY_SERVICE_REQUIRED dsGetDcNameFlag = 0x00000010
_DS_DIRECTORY_SERVICE_PREFERRED dsGetDcNameFlag = 0x00000020
_DS_GC_SERVER_REQUIRED dsGetDcNameFlag = 0x00000040
_DS_PDC_REQUIRED dsGetDcNameFlag = 0x00000080
_DS_BACKGROUND_ONLY dsGetDcNameFlag = 0x00000100
_DS_IP_REQUIRED dsGetDcNameFlag = 0x00000200
_DS_KDC_REQUIRED dsGetDcNameFlag = 0x00000400
_DS_TIMESERV_REQUIRED dsGetDcNameFlag = 0x00000800
_DS_WRITABLE_REQUIRED dsGetDcNameFlag = 0x00001000
_DS_GOOD_TIMESERV_PREFERRED dsGetDcNameFlag = 0x00002000
_DS_AVOID_SELF dsGetDcNameFlag = 0x00004000
_DS_ONLY_LDAP_NEEDED dsGetDcNameFlag = 0x00008000
_DS_IS_FLAT_NAME dsGetDcNameFlag = 0x00010000
_DS_IS_DNS_NAME dsGetDcNameFlag = 0x00020000
_DS_TRY_NEXTCLOSEST_SITE dsGetDcNameFlag = 0x00040000
_DS_DIRECTORY_SERVICE_6_REQUIRED dsGetDcNameFlag = 0x00080000
_DS_WEB_SERVICE_REQUIRED dsGetDcNameFlag = 0x00100000
_DS_DIRECTORY_SERVICE_8_REQUIRED dsGetDcNameFlag = 0x00200000
_DS_DIRECTORY_SERVICE_9_REQUIRED dsGetDcNameFlag = 0x00400000
_DS_DIRECTORY_SERVICE_10_REQUIRED dsGetDcNameFlag = 0x00800000
_DS_KEY_LIST_SUPPORT_REQUIRED dsGetDcNameFlag = 0x01000000
_DS_RETURN_DNS_NAME dsGetDcNameFlag = 0x40000000
_DS_RETURN_FLAT_NAME dsGetDcNameFlag = 0x80000000
)
func resolveDomainController(domainName *uint16, domainGUID *windows.GUID) (*_DOMAIN_CONTROLLER_INFO, error) {
const flags = _DS_DIRECTORY_SERVICE_REQUIRED | _DS_IS_FLAT_NAME | _DS_RETURN_DNS_NAME
var dcInfo *_DOMAIN_CONTROLLER_INFO
if err := dsGetDcName(nil, domainName, domainGUID, nil, flags, &dcInfo); err != nil {
return nil, err
}
return dcInfo, nil
}
// ResolveDomainController resolves the DNS name of the nearest available
// domain controller for the domain specified by domainName.
func ResolveDomainController(domainName string) (string, error) {
domainName16, err := windows.UTF16PtrFromString(domainName)
if err != nil {
return "", err
}
dcInfo, err := resolveDomainController(domainName16, nil)
if err != nil {
return "", err
}
defer dcInfo.Close()
return windows.UTF16PtrToString(dcInfo.DomainControllerName), nil
}
type _NETSETUP_NAME_TYPE int32
const (
_NetSetupUnknown _NETSETUP_NAME_TYPE = 0
_NetSetupMachine _NETSETUP_NAME_TYPE = 1
_NetSetupWorkgroup _NETSETUP_NAME_TYPE = 2
_NetSetupDomain _NETSETUP_NAME_TYPE = 3
_NetSetupNonExistentDomain _NETSETUP_NAME_TYPE = 4
_NetSetupDnsMachine _NETSETUP_NAME_TYPE = 5
)
func isDomainName(name *uint16) (bool, error) {
err := netValidateName(nil, name, nil, nil, _NetSetupDomain)
switch err {
case nil:
return true, nil
case windows.ERROR_NO_SUCH_DOMAIN:
return false, nil
default:
return false, err
}
}
// IsDomainName checks whether name represents an existing domain reachable by
// the current machine.
func IsDomainName(name string) (bool, error) {
name16, err := windows.UTF16PtrFromString(name)
if err != nil {
return false, err
}
return isDomainName(name16)
}

162
vendor/tailscale.com/util/winutil/zsyscall_windows.go generated vendored Normal file
View File

@@ -0,0 +1,162 @@
// Code generated by 'go generate'; DO NOT EDIT.
package winutil
import (
"syscall"
"unsafe"
"github.com/dblohm7/wingoes"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
modnetapi32 = windows.NewLazySystemDLL("netapi32.dll")
modrstrtmgr = windows.NewLazySystemDLL("rstrtmgr.dll")
moduserenv = windows.NewLazySystemDLL("userenv.dll")
procQueryServiceConfig2W = modadvapi32.NewProc("QueryServiceConfig2W")
procGetApplicationRestartSettings = modkernel32.NewProc("GetApplicationRestartSettings")
procRegisterApplicationRestart = modkernel32.NewProc("RegisterApplicationRestart")
procDsGetDcNameW = modnetapi32.NewProc("DsGetDcNameW")
procNetValidateName = modnetapi32.NewProc("NetValidateName")
procRmEndSession = modrstrtmgr.NewProc("RmEndSession")
procRmGetList = modrstrtmgr.NewProc("RmGetList")
procRmJoinSession = modrstrtmgr.NewProc("RmJoinSession")
procRmRegisterResources = modrstrtmgr.NewProc("RmRegisterResources")
procRmStartSession = modrstrtmgr.NewProc("RmStartSession")
procExpandEnvironmentStringsForUserW = moduserenv.NewProc("ExpandEnvironmentStringsForUserW")
procLoadUserProfileW = moduserenv.NewProc("LoadUserProfileW")
procUnloadUserProfile = moduserenv.NewProc("UnloadUserProfile")
)
func queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procQueryServiceConfig2W.Addr(), 5, uintptr(hService), uintptr(infoLevel), uintptr(unsafe.Pointer(buf)), uintptr(bufLen), uintptr(unsafe.Pointer(bytesNeeded)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func getApplicationRestartSettings(process windows.Handle, commandLine *uint16, commandLineLen *uint32, flags *uint32) (ret wingoes.HRESULT) {
r0, _, _ := syscall.Syscall6(procGetApplicationRestartSettings.Addr(), 4, uintptr(process), uintptr(unsafe.Pointer(commandLine)), uintptr(unsafe.Pointer(commandLineLen)), uintptr(unsafe.Pointer(flags)), 0, 0)
ret = wingoes.HRESULT(r0)
return
}
func registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret wingoes.HRESULT) {
r0, _, _ := syscall.Syscall(procRegisterApplicationRestart.Addr(), 2, uintptr(unsafe.Pointer(cmdLineExclExeName)), uintptr(flags), 0)
ret = wingoes.HRESULT(r0)
return
}
func dsGetDcName(computerName *uint16, domainName *uint16, domainGuid *windows.GUID, siteName *uint16, flags dsGetDcNameFlag, dcInfo **_DOMAIN_CONTROLLER_INFO) (ret error) {
r0, _, _ := syscall.Syscall6(procDsGetDcNameW.Addr(), 6, uintptr(unsafe.Pointer(computerName)), uintptr(unsafe.Pointer(domainName)), uintptr(unsafe.Pointer(domainGuid)), uintptr(unsafe.Pointer(siteName)), uintptr(flags), uintptr(unsafe.Pointer(dcInfo)))
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func netValidateName(server *uint16, name *uint16, account *uint16, password *uint16, nameType _NETSETUP_NAME_TYPE) (ret error) {
r0, _, _ := syscall.Syscall6(procNetValidateName.Addr(), 5, uintptr(unsafe.Pointer(server)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(account)), uintptr(unsafe.Pointer(password)), uintptr(nameType), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func rmEndSession(session _RMHANDLE) (ret error) {
r0, _, _ := syscall.Syscall(procRmEndSession.Addr(), 1, uintptr(session), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func rmGetList(session _RMHANDLE, nProcInfoNeeded *uint32, nProcInfo *uint32, rgAffectedApps *_RM_PROCESS_INFO, pRebootReasons *uint32) (ret error) {
r0, _, _ := syscall.Syscall6(procRmGetList.Addr(), 5, uintptr(session), uintptr(unsafe.Pointer(nProcInfoNeeded)), uintptr(unsafe.Pointer(nProcInfo)), uintptr(unsafe.Pointer(rgAffectedApps)), uintptr(unsafe.Pointer(pRebootReasons)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func rmJoinSession(pSession *_RMHANDLE, sessionKey *uint16) (ret error) {
r0, _, _ := syscall.Syscall(procRmJoinSession.Addr(), 2, uintptr(unsafe.Pointer(pSession)), uintptr(unsafe.Pointer(sessionKey)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func rmRegisterResources(session _RMHANDLE, nFiles uint32, rgsFileNames **uint16, nApplications uint32, rgApplications *_RM_UNIQUE_PROCESS, nServices uint32, rgsServiceNames **uint16) (ret error) {
r0, _, _ := syscall.Syscall9(procRmRegisterResources.Addr(), 7, uintptr(session), uintptr(nFiles), uintptr(unsafe.Pointer(rgsFileNames)), uintptr(nApplications), uintptr(unsafe.Pointer(rgApplications)), uintptr(nServices), uintptr(unsafe.Pointer(rgsServiceNames)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func rmStartSession(pSession *_RMHANDLE, flags uint32, sessionKey *uint16) (ret error) {
r0, _, _ := syscall.Syscall(procRmStartSession.Addr(), 3, uintptr(unsafe.Pointer(pSession)), uintptr(flags), uintptr(unsafe.Pointer(sessionKey)))
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func expandEnvironmentStringsForUser(token windows.Token, src *uint16, dst *uint16, dstLen uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procExpandEnvironmentStringsForUserW.Addr(), 4, uintptr(token), uintptr(unsafe.Pointer(src)), uintptr(unsafe.Pointer(dst)), uintptr(dstLen), 0, 0)
if int32(r1) == 0 {
err = errnoErr(e1)
}
return
}
func loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) {
r1, _, e1 := syscall.Syscall(procLoadUserProfileW.Addr(), 2, uintptr(token), uintptr(unsafe.Pointer(profileInfo)), 0)
if int32(r1) == 0 {
err = errnoErr(e1)
}
return
}
func unloadUserProfile(token windows.Token, profile registry.Key) (err error) {
r1, _, e1 := syscall.Syscall(procUnloadUserProfile.Addr(), 2, uintptr(token), uintptr(profile), 0)
if int32(r1) == 0 {
err = errnoErr(e1)
}
return
}