Files
2024-11-01 17:43:06 +00:00

269 lines
8.9 KiB
Go

// Copyright (c) 2022 Tailscale Inc & AUTHORS. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build windows
package com
import (
"os"
"runtime"
"github.com/dblohm7/wingoes"
"golang.org/x/sys/windows"
)
// ProcessType is an enumeration that specifies the type of the current process
// when calling StartRuntime.
type ProcessType uint
const (
// ConsoleApp is a text-mode Windows program.
ConsoleApp = ProcessType(iota)
// Service is a Windows service.
Service
// GUIApp is a GUI-mode Windows program.
GUIApp
// Note: Even though this implementation is not yet internally distinguishing
// between console apps and services, this distinction may be useful in the
// future. For example, a service could receive more restrictive default
// security settings than a console app.
// Having this as part of the API now avoids future breakage.
)
// StartRuntime permanently initializes COM for the remaining lifetime of the
// current process. To avoid errors, it should be called as early as possible
// during program initialization. When processType == GUIApp, the current
// OS thread becomes permanently locked to the current goroutine; any subsequent
// GUI *must* be created on the same OS thread.
// An excellent location to call StartRuntime is in the init function of the
// main package.
func StartRuntime(processType ProcessType) error {
return StartRuntimeWithDACL(processType, nil)
}
// StartRuntimeWithDACL permanently initializes COM for the remaining lifetime
// of the current process. To avoid errors, it should be called as early as
// possible during program initialization. When processType == GUIApp, the
// current OS thread becomes permanently locked to the current goroutine; any
// subsequent GUI *must* be created on the same OS thread. dacl is an ACL that
// controls access of other processes connecting to the current process over COM.
// For further information about COM access control, look up the COM_RIGHTS_*
// access flags in the Windows developer documentation.
// An excellent location to call StartRuntimeWithDACL is in the init function of
// the main package.
func StartRuntimeWithDACL(processType ProcessType, dacl *windows.ACL) error {
runtime.LockOSThread()
defer func() {
// When initializing for non-GUI processes, the OS thread may be unlocked
// upon return from this function.
if processType != GUIApp {
runtime.UnlockOSThread()
}
}()
switch processType {
case ConsoleApp, Service:
// Just start the MTA implicitly.
if err := startMTAImplicitly(); err != nil {
return err
}
case GUIApp:
// For GUIApp, we want the current OS thread to enter a single-threaded
// apartment (STA). However, we want all other OS threads to reside inside
// a multi-threaded apartment (MTA). The way to so this is to first start
// the MTA implicitly, affecting all OS threads who have not yet explicitly
// entered a COM apartment...
if err := startMTAImplicitly(); err != nil {
runtime.UnlockOSThread()
return err
}
// ...and then subsequently explicitly enter a STA on this OS thread, which
// automatically removes this OS thread from the MTA.
if err := enterSTA(); err != nil {
runtime.UnlockOSThread()
return err
}
// From this point forward, we must never unlock the OS thread.
default:
return os.ErrInvalid
}
// Order is extremely important here: initSecurity must be called immediately
// after apartments are set up, but before doing anything else.
if err := initSecurity(dacl); err != nil {
return err
}
// By default, for compatibility reasons, COM internally sets a catch-all
// exception handler at its API boundary. This is dangerous, so we override it.
// This work must happen after security settings are initialized, but before
// anything "significant" is done with COM.
globalOpts, err := CreateInstance[GlobalOptions](CLSID_GlobalOptions)
if err != nil {
return err
}
err = globalOpts.Set(COMGLB_EXCEPTION_HANDLING, COMGLB_EXCEPTION_DONOT_HANDLE_ANY)
// The BSTR cache never invalidates itself, so we disable it unconditionally.
// We do this here to ensure that the BSTR cache is off before anything
// can possibly start using oleaut32.dll.
setOaNoCache()
return err
}
// startMTAImplicitly creates an implicit multi-threaded apartment (MTA) for
// all threads in a process that do not otherwise explicitly enter a COM apartment.
func startMTAImplicitly() error {
// CoIncrementMTAUsage is the modern API to use for creating the MTA implicitly,
// however we may fall back to a legacy mechanism when the former API is unavailable.
if err := procCoIncrementMTAUsage.Find(); err != nil {
return startMTAImplicitlyLegacy()
}
// We do not retain cookie beyond this function, as we have no intention of
// tearing any of this back down.
var cookie coMTAUsageCookie
hr := coIncrementMTAUsage(&cookie)
if e := wingoes.ErrorFromHRESULT(hr); e.Failed() {
return e
}
return nil
}
// startMTAImplicitlyLegacy works by having a background OS thread explicitly enter
// the multi-threaded apartment. All other OS threads that have not explicitly
// entered an apartment will become implicit members of that MTA. This function is
// written assuming that the current OS thread has already been locked.
func startMTAImplicitlyLegacy() error {
// We need to start the MTA on a background OS thread, HOWEVER we also want this
// to happen synchronously, so we wait on c for MTA initialization to complete.
c := make(chan error)
go bgMTASustainer(c)
return <-c
}
// bgMTASustainer locks the current goroutine to the current OS thread, enters
// the COM multi-threaded apartment, and then blocks for the remainder of the
// process's lifetime. It sends its result to c so that startMTAImplicitlyLegacy
// can wait for the MTA to be ready before proceeding.
func bgMTASustainer(c chan error) {
runtime.LockOSThread()
err := enterMTA()
c <- err
if err != nil {
// We didn't enter the MTA, so just unlock and bail.
runtime.UnlockOSThread()
return
}
select {}
}
// enterMTA causes the current OS thread to explicitly declare itself to be a
// member of COM's multi-threaded apartment. Note that this function affects
// thread-local state, so use carefully!
func enterMTA() error {
return coInit(windows.COINIT_MULTITHREADED)
}
// enterSTA causes the current OS thread to create and enter a single-threaded
// apartment. The current OS thread must be locked and remain locked for the
// duration of the thread's time in the apartment. For our purposes, the calling
// OS thread never leaves the STA, so it must effectively remain locked for
// the remaining lifetime of the process. A single-threaded apartment should be
// used if and only if an OS thread is going to be creating windows and pumping
// messages; STAs are NOT generic containers for single-threaded COM code,
// contrary to popular belief. Note that this function affects thread-local
// state, so use carefully!
func enterSTA() error {
return coInit(windows.COINIT_APARTMENTTHREADED)
}
// coInit is a wrapper for CoInitializeEx that properly handles the S_FALSE
// error code (x/sys/windows.CoInitializeEx does not).
func coInit(apartment uint32) error {
hr := coInitializeEx(0, apartment)
if e := wingoes.ErrorFromHRESULT(hr); e.Failed() {
return e
}
return nil
}
const (
authSvcCOMChooses = -1
)
// initSecurity initializes COM security using the ACL specified by dacl.
// A nil dacl implies that a default ACL should be used instead.
func initSecurity(dacl *windows.ACL) error {
sd, err := buildSecurityDescriptor(dacl)
if err != nil {
return err
}
caps := authCapNone
if sd == nil {
// For COM to fall back to system-wide defaults, we need to set this bit.
caps |= authCapAppID
}
hr := coInitializeSecurity(
sd,
authSvcCOMChooses,
nil, // authSvc (not used because previous arg is authSvcCOMChooses)
0, // Reserved, must be 0
rpcAuthnLevelDefault,
rpcImpLevelIdentify,
nil, // authlist: use defaults
caps,
0, // Reserved, must be 0
)
if e := wingoes.ErrorFromHRESULT(hr); e.Failed() {
return e
}
return nil
}
// buildSecurityDescriptor inserts dacl into a valid security descriptor for use
// with CoInitializeSecurity. A nil dacl results in a nil security descriptor,
// which we consider to be a valid "use defaults" sentinel.
func buildSecurityDescriptor(dacl *windows.ACL) (*windows.SECURITY_DESCRIPTOR, error) {
if dacl == nil {
// Not an error, just use defaults.
return nil, nil
}
sd, err := windows.NewSecurityDescriptor()
if err != nil {
return nil, err
}
if err := sd.SetDACL(dacl, true, false); err != nil {
return nil, err
}
// CoInitializeSecurity will fail unless the SD's owner and group are both set.
userSIDs, err := wingoes.CurrentProcessUserSIDs()
if err != nil {
return nil, err
}
if err := sd.SetOwner(userSIDs.User, false); err != nil {
return nil, err
}
if err := sd.SetGroup(userSIDs.PrimaryGroup, false); err != nil {
return nil, err
}
return sd, nil
}