Files
2024-11-01 17:43:06 +00:00

557 lines
12 KiB
Go

// Copyright (c) 2023 Tailscale Inc & AUTHORS. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build windows
package com
import (
"io"
"runtime"
"syscall"
"unsafe"
"github.com/dblohm7/wingoes"
"github.com/dblohm7/wingoes/internal"
"golang.org/x/sys/windows"
)
var (
IID_ISequentialStream = &IID{0x0C733A30, 0x2A1C, 0x11CE, [8]byte{0xAD, 0xE5, 0x00, 0xAA, 0x00, 0x44, 0x77, 0x3D}}
IID_IStream = &IID{0x0000000C, 0x0000, 0x0000, [8]byte{0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46}}
)
type STGC uint32
const (
STGC_DEFAULT = STGC(0)
STGC_OVERWRITE = STGC(1)
STGC_ONLYIFCURRENT = STGC(2)
STGC_DANGEROUSLYCOMMITMERELYTODISKCACHE = STGC(4)
STGC_CONSOLIDATE = STGC(8)
)
type LOCKTYPE uint32
const (
LOCK_WRITE = LOCKTYPE(1)
LOCK_EXCLUSIVE = LOCKTYPE(2)
LOCK_ONLYONCE = LOCKTYPE(4)
)
type STGTY uint32
const (
STGTY_STORAGE = STGTY(1)
STGTY_STREAM = STGTY(2)
STGTY_LOCKBYTES = STGTY(3)
STGTY_PROPERTY = STGTY(4)
)
type STATFLAG uint32
const (
STATFLAG_DEFAULT = STATFLAG(0)
STATFLAG_NONAME = STATFLAG(1)
STATFLAG_NOOPEN = STATFLAG(2)
)
type STATSTG struct {
Name COMAllocatedString
Type STGTY
Size uint64
MTime windows.Filetime
CTime windows.Filetime
ATime windows.Filetime
Mode uint32
LocksSupported LOCKTYPE
ClsID CLSID
_ uint32 // StateBits
_ uint32 // reserved
}
func (st *STATSTG) Close() error {
return st.Name.Close()
}
type ISequentialStreamABI struct {
IUnknownABI
}
type IStreamABI struct {
ISequentialStreamABI
}
type SequentialStream struct {
GenericObject[ISequentialStreamABI]
}
type Stream struct {
GenericObject[IStreamABI]
}
func (abi *ISequentialStreamABI) Read(p []byte) (int, error) {
if len(p) > maxStreamRWLen {
p = p[:maxStreamRWLen]
}
var cbRead uint32
method := unsafe.Slice(abi.Vtbl, 5)[3]
rc, _, _ := syscall.SyscallN(
method,
uintptr(unsafe.Pointer(abi)),
uintptr(unsafe.Pointer(unsafe.SliceData(p))),
uintptr(uint32(len(p))),
uintptr(unsafe.Pointer(&cbRead)),
)
n := int(cbRead)
e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(rc))
if e.Failed() {
return n, e
}
// Various implementations of IStream handle EOF differently. We need to
// deal with both.
if e.AsHRESULT() == wingoes.S_FALSE || (n == 0 && len(p) > 0) {
return n, io.EOF
}
return n, nil
}
func (abi *ISequentialStreamABI) Write(p []byte) (int, error) {
w := p
if len(w) > maxStreamRWLen {
w = w[:maxStreamRWLen]
}
var cbWritten uint32
method := unsafe.Slice(abi.Vtbl, 5)[4]
rc, _, _ := syscall.SyscallN(
method,
uintptr(unsafe.Pointer(abi)),
uintptr(unsafe.Pointer(unsafe.SliceData(w))),
uintptr(uint32(len(w))),
uintptr(unsafe.Pointer(&cbWritten)),
)
n := int(cbWritten)
if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(rc)); e.Failed() {
return n, e
}
// Need this to satisfy Writer.
if n < len(p) {
return n, io.ErrShortWrite
}
return n, nil
}
func (o SequentialStream) IID() *IID {
return IID_ISequentialStream
}
func (o SequentialStream) Make(r ABIReceiver) any {
if r == nil {
return SequentialStream{}
}
runtime.SetFinalizer(r, ReleaseABI)
pp := (**ISequentialStreamABI)(unsafe.Pointer(r))
return SequentialStream{GenericObject[ISequentialStreamABI]{Pp: pp}}
}
func (o SequentialStream) UnsafeUnwrap() *ISequentialStreamABI {
return *(o.Pp)
}
func (o SequentialStream) Read(b []byte) (n int, err error) {
p := *(o.Pp)
return p.Read(b)
}
func (o SequentialStream) Write(b []byte) (int, error) {
p := *(o.Pp)
return p.Write(b)
}
func (abi *IStreamABI) Seek(offset int64, whence int) (n int64, _ error) {
var hr wingoes.HRESULT
method := unsafe.Slice(abi.Vtbl, 14)[5]
if runtime.GOARCH == "386" {
words := (*[2]uintptr)(unsafe.Pointer(&offset))
rc, _, _ := syscall.SyscallN(
method,
uintptr(unsafe.Pointer(abi)),
words[0],
words[1],
uintptr(uint32(whence)),
uintptr(unsafe.Pointer(&n)),
)
hr = wingoes.HRESULT(rc)
} else {
rc, _, _ := syscall.SyscallN(
method,
uintptr(unsafe.Pointer(abi)),
uintptr(offset),
uintptr(uint32(whence)),
uintptr(unsafe.Pointer(&n)),
)
hr = wingoes.HRESULT(rc)
}
if e := wingoes.ErrorFromHRESULT(hr); e.Failed() {
return 0, e
}
return n, nil
}
func (abi *IStreamABI) SetSize(newSize uint64) error {
var hr wingoes.HRESULT
method := unsafe.Slice(abi.Vtbl, 14)[6]
if runtime.GOARCH == "386" {
words := (*[2]uintptr)(unsafe.Pointer(&newSize))
rc, _, _ := syscall.SyscallN(
method,
uintptr(unsafe.Pointer(abi)),
words[0],
words[1],
)
hr = wingoes.HRESULT(rc)
} else {
rc, _, _ := syscall.SyscallN(
method,
uintptr(unsafe.Pointer(abi)),
uintptr(newSize),
)
hr = wingoes.HRESULT(rc)
}
if e := wingoes.ErrorFromHRESULT(hr); e.Failed() {
return e
}
return nil
}
func (abi *IStreamABI) CopyTo(dest *IStreamABI, numBytesToCopy uint64) (bytesRead, bytesWritten uint64, _ error) {
var hr wingoes.HRESULT
method := unsafe.Slice(abi.Vtbl, 14)[7]
if runtime.GOARCH == "386" {
words := (*[2]uintptr)(unsafe.Pointer(&numBytesToCopy))
rc, _, _ := syscall.SyscallN(
method,
uintptr(unsafe.Pointer(abi)),
uintptr(unsafe.Pointer(dest)),
words[0],
words[1],
uintptr(unsafe.Pointer(&bytesRead)),
uintptr(unsafe.Pointer(&bytesWritten)),
)
hr = wingoes.HRESULT(rc)
} else {
rc, _, _ := syscall.SyscallN(
method,
uintptr(unsafe.Pointer(abi)),
uintptr(unsafe.Pointer(dest)),
uintptr(numBytesToCopy),
uintptr(unsafe.Pointer(&bytesRead)),
uintptr(unsafe.Pointer(&bytesWritten)),
)
hr = wingoes.HRESULT(rc)
}
if e := wingoes.ErrorFromHRESULT(hr); e.Failed() {
return bytesRead, bytesWritten, e
}
return bytesRead, bytesWritten, nil
}
func (abi *IStreamABI) Commit(flags STGC) error {
method := unsafe.Slice(abi.Vtbl, 14)[8]
rc, _, _ := syscall.SyscallN(
method,
uintptr(unsafe.Pointer(abi)),
uintptr(flags),
)
if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(rc)); e.Failed() {
return e
}
return nil
}
func (abi *IStreamABI) Revert() error {
method := unsafe.Slice(abi.Vtbl, 14)[9]
rc, _, _ := syscall.SyscallN(
method,
uintptr(unsafe.Pointer(abi)),
)
if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(rc)); e.Failed() {
return e
}
return nil
}
func (abi *IStreamABI) LockRegion(offset, numBytes uint64, lockType LOCKTYPE) error {
var hr wingoes.HRESULT
method := unsafe.Slice(abi.Vtbl, 14)[10]
if runtime.GOARCH == "386" {
oWords := (*[2]uintptr)(unsafe.Pointer(&offset))
nWords := (*[2]uintptr)(unsafe.Pointer(&numBytes))
rc, _, _ := syscall.SyscallN(
method,
uintptr(unsafe.Pointer(abi)),
oWords[0],
oWords[1],
nWords[0],
nWords[1],
uintptr(lockType),
)
hr = wingoes.HRESULT(rc)
} else {
rc, _, _ := syscall.SyscallN(
method,
uintptr(unsafe.Pointer(abi)),
uintptr(offset),
uintptr(numBytes),
uintptr(lockType),
)
hr = wingoes.HRESULT(rc)
}
if e := wingoes.ErrorFromHRESULT(hr); e.Failed() {
return e
}
return nil
}
func (abi *IStreamABI) UnlockRegion(offset, numBytes uint64, lockType LOCKTYPE) error {
var hr wingoes.HRESULT
method := unsafe.Slice(abi.Vtbl, 14)[11]
if runtime.GOARCH == "386" {
oWords := (*[2]uintptr)(unsafe.Pointer(&offset))
nWords := (*[2]uintptr)(unsafe.Pointer(&numBytes))
rc, _, _ := syscall.SyscallN(
method,
uintptr(unsafe.Pointer(abi)),
oWords[0],
oWords[1],
nWords[0],
nWords[1],
uintptr(lockType),
)
hr = wingoes.HRESULT(rc)
} else {
rc, _, _ := syscall.SyscallN(
method,
uintptr(unsafe.Pointer(abi)),
uintptr(offset),
uintptr(numBytes),
uintptr(lockType),
)
hr = wingoes.HRESULT(rc)
}
if e := wingoes.ErrorFromHRESULT(hr); e.Failed() {
return e
}
return nil
}
func (abi *IStreamABI) Stat(flags STATFLAG) (*STATSTG, error) {
result := new(STATSTG)
method := unsafe.Slice(abi.Vtbl, 14)[12]
rc, _, _ := syscall.SyscallN(
method,
uintptr(unsafe.Pointer(abi)),
uintptr(unsafe.Pointer(result)),
uintptr(flags),
)
if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(rc)); e.Failed() {
return nil, e
}
return result, nil
}
func (abi *IStreamABI) Clone() (result *IUnknownABI, _ error) {
method := unsafe.Slice(abi.Vtbl, 14)[13]
rc, _, _ := syscall.SyscallN(
method,
uintptr(unsafe.Pointer(abi)),
uintptr(unsafe.Pointer(&result)),
)
if e := wingoes.ErrorFromHRESULT(wingoes.HRESULT(rc)); e.Failed() {
return nil, e
}
return result, nil
}
func (o Stream) IID() *IID {
return IID_IStream
}
func (o Stream) Make(r ABIReceiver) any {
if r == nil {
return Stream{}
}
runtime.SetFinalizer(r, ReleaseABI)
pp := (**IStreamABI)(unsafe.Pointer(r))
return Stream{GenericObject[IStreamABI]{Pp: pp}}
}
func (o Stream) UnsafeUnwrap() *IStreamABI {
return *(o.Pp)
}
func (o Stream) Read(buf []byte) (int, error) {
p := *(o.Pp)
return p.Read(buf)
}
func (o Stream) Write(buf []byte) (int, error) {
p := *(o.Pp)
return p.Write(buf)
}
func (o Stream) Seek(offset int64, whence int) (n int64, _ error) {
p := *(o.Pp)
return p.Seek(offset, whence)
}
func (o Stream) SetSize(newSize uint64) error {
p := *(o.Pp)
return p.SetSize(newSize)
}
func (o Stream) CopyTo(dest Stream, numBytesToCopy uint64) (bytesRead, bytesWritten uint64, _ error) {
p := *(o.Pp)
return p.CopyTo(dest.UnsafeUnwrap(), numBytesToCopy)
}
func (o Stream) Commit(flags STGC) error {
p := *(o.Pp)
return p.Commit(flags)
}
func (o Stream) Revert() error {
p := *(o.Pp)
return p.Revert()
}
func (o Stream) LockRegion(offset, numBytes uint64, lockType LOCKTYPE) error {
p := *(o.Pp)
return p.LockRegion(offset, numBytes, lockType)
}
func (o Stream) UnlockRegion(offset, numBytes uint64, lockType LOCKTYPE) error {
p := *(o.Pp)
return p.UnlockRegion(offset, numBytes, lockType)
}
func (o Stream) Stat(flags STATFLAG) (*STATSTG, error) {
p := *(o.Pp)
return p.Stat(flags)
}
func (o Stream) Clone() (result Stream, _ error) {
p := *(o.Pp)
punk, err := p.Clone()
if err != nil {
return result, err
}
return result.Make(&punk).(Stream), nil
}
const hrE_OUTOFMEMORY = wingoes.HRESULT(-((0x8007000E ^ 0xFFFFFFFF) + 1))
// NewMemoryStream creates a new in-memory Stream object initially containing a
// copy of initialBytes. Its seek pointer is guaranteed to reference the
// beginning of the stream.
func NewMemoryStream(initialBytes []byte) (result Stream, _ error) {
return newMemoryStreamInternal(initialBytes, false)
}
func newMemoryStreamInternal(initialBytes []byte, forceLegacy bool) (result Stream, _ error) {
if len(initialBytes) > maxStreamRWLen {
return result, wingoes.ErrorFromHRESULT(hrE_OUTOFMEMORY)
}
// SHCreateMemStream exists on Win7 but is not safe for us to use until Win8.
if forceLegacy || !wingoes.IsWin8OrGreater() {
return newMemoryStreamLegacy(initialBytes)
}
var base *byte
var length uint32
if l := uint32(len(initialBytes)); l > 0 {
base = unsafe.SliceData(initialBytes)
length = l
}
punk := shCreateMemStream(base, length)
if punk == nil {
return result, wingoes.ErrorFromHRESULT(hrE_OUTOFMEMORY)
}
obj := result.Make(&punk).(Stream)
if _, err := obj.Seek(0, io.SeekStart); err != nil {
return result, err
}
return obj, nil
}
func newMemoryStreamLegacy(initialBytes []byte) (result Stream, _ error) {
ppstream := NewABIReceiver()
hr := createStreamOnHGlobal(internal.HGLOBAL(0), true, ppstream)
if e := wingoes.ErrorFromHRESULT(hr); e.Failed() {
return result, e
}
obj := result.Make(ppstream).(Stream)
if err := obj.SetSize(uint64(len(initialBytes))); err != nil {
return result, err
}
if len(initialBytes) == 0 {
return obj, nil
}
_, err := obj.Write(initialBytes)
if err != nil {
return result, err
}
if _, err := obj.Seek(0, io.SeekStart); err != nil {
return result, err
}
return obj, nil
}