diff options
| author | Fuwn <[email protected]> | 2022-01-03 03:20:12 -0800 |
|---|---|---|
| committer | Fuwn <[email protected]> | 2022-01-03 03:20:12 -0800 |
| commit | 85db2b507f3f69b32811c54a89d9ac7bbbc46121 (patch) | |
| tree | 2efd66da452f8a6a2cc6c91584c925f237506ddf /crates/windows-kernel-rs/src | |
| download | driver-85db2b507f3f69b32811c54a89d9ac7bbbc46121.tar.xz driver-85db2b507f3f69b32811c54a89d9ac7bbbc46121.zip | |
feat(driver): commit primer
Diffstat (limited to 'crates/windows-kernel-rs/src')
21 files changed, 2520 insertions, 0 deletions
diff --git a/crates/windows-kernel-rs/src/affinity.rs b/crates/windows-kernel-rs/src/affinity.rs new file mode 100644 index 0000000..bf2615f --- /dev/null +++ b/crates/windows-kernel-rs/src/affinity.rs @@ -0,0 +1,108 @@ +//! This module provides functions to get information about the logical CPUs in +//! the system, and to run closures on specific or all CPUs. + +use windows_kernel_sys::{ + base::{ALL_PROCESSOR_GROUPS, GROUP_AFFINITY, PROCESSOR_NUMBER, ULONG_PTR}, + ntoskrnl::{ + KeGetCurrentProcessorNumberEx, + KeGetProcessorNumberFromIndex, + KeIpiGenericCall, + KeQueryActiveProcessorCountEx, + KeRevertToUserGroupAffinityThread, + KeSetSystemGroupAffinityThread, + }, +}; + +use crate::error::{Error, IntoResult}; + +/// Uses [`KeGetCurrentProcessorNumberEx`] to get the logical number associated +/// with the CPU that is currently running our code. +pub fn get_current_cpu_num() -> u32 { + unsafe { KeGetCurrentProcessorNumberEx(core::ptr::null_mut()) } +} + +/// Uses [`KeQueryActiveProcessorCountEx`] to get the number of CPUs in the +/// system, that is all the CPUs from all the different CPU groups are counted, +/// such that each of them has a logical number. +pub fn get_cpu_count() -> u32 { + unsafe { KeQueryActiveProcessorCountEx(ALL_PROCESSOR_GROUPS as _) } +} + +/// This is the callback used by [`run_on_each_cpu_parallel`] to run the closure +/// on all CPUs. +unsafe extern "C" fn broadcast_callback<F>(context: ULONG_PTR) -> ULONG_PTR +where F: FnMut() { + let f = &mut *(context as *mut F); + f(); + + 0 +} + +/// Runs the given closure on all CPUs in the system without interrupting all +/// CPUs to force them to switch to kernel mode. Instead, this is a more +/// graceful version that simply relies on [`run_on_cpu`] to switch to all the +/// possible CPUs by configuring the affinity, and execute the closure on the +/// selected CPU. Upon executing the closure on all CPUs, the affinity is +/// restored. Also see [`run_on_each_cpu_parallel`] which is a more aggressive +/// version that relies on an IPI to run a given closure on all CPUs in +/// parallel. +pub fn run_on_each_cpu<F>(f: &mut F) -> Result<(), Error> +where F: FnMut() -> Result<(), Error> { + for cpu_num in 0..get_cpu_count() { + run_on_cpu(cpu_num, f)?; + } + + Ok(()) +} + +/// Runs the given closure on all CPUs in the system by broadcasting an +/// Inter-Processor Interrupt (IPI) to interrupt all CPUs to force them to +/// switch to kernel mode to run the given closure. Upon execution of the +/// closure, these CPUs resume their work. Also see [`run_on_each_cpu`] which is +/// a friendlier version that does not rely on an IPI but instead configures the +/// affinity to run a given a closure on all CPUs. +pub fn run_on_each_cpu_parallel<F>(f: &F) +where F: Fn() { + unsafe { + KeIpiGenericCall(Some(broadcast_callback::<F>), f as *const _ as ULONG_PTR); + } +} + +/// Runs the given closure on the CPU with the given CPU number by temporarily +/// configuring the CPU affinity to only contain the given CPU number. Upon +/// switching to the selected CPU, the CPU executes the closure. Then the +/// original affinity is restored. +pub fn run_on_cpu<F>(cpu_num: u32, f: &mut F) -> Result<(), Error> +where F: FnMut() -> Result<(), Error> { + let mut processor_num = PROCESSOR_NUMBER { + Group: 0, + Number: 0, + Reserved: 0, + }; + + unsafe { KeGetProcessorNumberFromIndex(cpu_num, &mut processor_num) }.into_result()?; + + let mut previous = GROUP_AFFINITY { + Mask: 0, + Group: 0, + Reserved: [0; 3], + }; + + let mut affinity = GROUP_AFFINITY { + Mask: 1 << processor_num.Number, + Group: processor_num.Group, + Reserved: [0; 3], + }; + + unsafe { + KeSetSystemGroupAffinityThread(&mut affinity, &mut previous); + } + + let result = f(); + + unsafe { + KeRevertToUserGroupAffinityThread(&mut previous); + } + + result +} diff --git a/crates/windows-kernel-rs/src/allocator.rs b/crates/windows-kernel-rs/src/allocator.rs new file mode 100644 index 0000000..7b1efc8 --- /dev/null +++ b/crates/windows-kernel-rs/src/allocator.rs @@ -0,0 +1,72 @@ +//! This module provides an allocator to use with the [`alloc`] crate. You can +//! define your own global allocator with the `#[global_allocator]` attribute +//! when not using the `alloc` feature, in case you want to specify your own tag +//! to use with [`ExAllocatePool2`] and [`ExAllocatePoolWithTag`]. + +use core::alloc::{GlobalAlloc, Layout}; + +use lazy_static::lazy_static; +use windows_kernel_sys::{ + base::_POOL_TYPE as POOL_TYPE, + ntoskrnl::{ExAllocatePool2, ExAllocatePoolWithTag, ExFreePool}, +}; + +use crate::version::VersionInfo; + +/// See issue #52191. +#[alloc_error_handler] +fn alloc_error(_: Layout) -> ! { loop {} } + +lazy_static! { + /// The version of Microsoft Windows that is currently running. This is used by + /// [`KernelAllocator`] to determine whether to use [`ExAllocatePool2`] or + /// [`ExAllocatePoolWithTag`]. + static ref VERSION_INFO: VersionInfo = { + VersionInfo::query().unwrap() + }; +} + +/// Represents a kernel allocator that relies on the `ExAllocatePool` family of +/// functions to allocate and free memory for the `alloc` crate. +pub struct KernelAllocator { + /// The 32-bit tag to use for the pool, this is usually derived from a + /// quadruplet of ASCII bytes, e.g. by invoking + /// `u32::from_ne_bytes(*b"rust")`. + tag: u32, +} + +impl KernelAllocator { + /// Sets up a new kernel allocator with the 32-bit tag specified. The tag is + /// usually derived from a quadruplet of ASCII bytes, e.g. by invoking + /// `u32::from_ne_bytes(*b"rust")`. + pub const fn new(tag: u32) -> Self { + Self { + tag, + } + } +} + +unsafe impl GlobalAlloc for KernelAllocator { + /// Uses [`ExAllocatePool2`] on Microsoft Windows 10.0.19041 and later, and + /// [`ExAllocatePoolWithTag`] on older versions of Microsoft Windows to + /// allocate memory. + unsafe fn alloc(&self, layout: Layout) -> *mut u8 { + let use_ex_allocate_pool2 = VERSION_INFO.major() > 10 + || (VERSION_INFO.major() == 10 && VERSION_INFO.build_number() == 19041); + + let ptr = if use_ex_allocate_pool2 { + ExAllocatePool2(POOL_TYPE::NonPagedPool as _, layout.size() as u64, self.tag) + } else { + ExAllocatePoolWithTag(POOL_TYPE::NonPagedPool, layout.size() as u64, self.tag) + }; + + if ptr.is_null() { + panic!("[kernel-alloc] failed to allocate pool."); + } + + ptr as _ + } + + /// Uses [`ExFreePool`] to free allocated memory. + unsafe fn dealloc(&self, ptr: *mut u8, _layout: Layout) { ExFreePool(ptr as _) } +} diff --git a/crates/windows-kernel-rs/src/device.rs b/crates/windows-kernel-rs/src/device.rs new file mode 100644 index 0000000..88ae85b --- /dev/null +++ b/crates/windows-kernel-rs/src/device.rs @@ -0,0 +1,443 @@ +use alloc::boxed::Box; + +use bitflags::bitflags; +use windows_kernel_sys::{ + base::{ + DEVICE_OBJECT, + IRP, + IRP_MJ_CLEANUP, + IRP_MJ_CLOSE, + IRP_MJ_CREATE, + IRP_MJ_DEVICE_CONTROL, + IRP_MJ_READ, + IRP_MJ_WRITE, + NTSTATUS, + STATUS_SUCCESS, + }, + ntoskrnl::{IoDeleteDevice, IoGetCurrentIrpStackLocation}, +}; + +use crate::{ + error::Error, + request::{IoControlRequest, IoRequest, ReadRequest, WriteRequest}, +}; + +#[derive(Copy, Clone, Debug)] +pub enum Access { + NonExclusive, + Exclusive, +} + +impl Access { + pub fn is_exclusive(&self) -> bool { + match *self { + Access::Exclusive => true, + _ => false, + } + } +} + +bitflags! { + pub struct DeviceFlags: u32 { + const SECURE_OPEN = windows_kernel_sys::base::FILE_DEVICE_SECURE_OPEN; + } +} + +bitflags! { + pub struct DeviceDoFlags: u32 { + const DO_BUFFERED_IO = windows_kernel_sys::base::DO_BUFFERED_IO; + const DO_DIRECT_IO = windows_kernel_sys::base::DO_DIRECT_IO; + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum DeviceType { + Port8042, + Acpi, + Battery, + Beep, + BusExtender, + Cdrom, + CdromFileSystem, + Changer, + Controller, + DataLink, + Dfs, + DfsFileSystem, + DfsVolume, + Disk, + DiskFileSystem, + Dvd, + FileSystem, + Fips, + FullscreenVideo, + InportPort, + Keyboard, + Ks, + Ksec, + Mailslot, + MassStorage, + MidiIn, + MidiOut, + Modem, + Mouse, + MultiUncProvider, + NamedPipe, + Network, + NetworkBrowser, + NetworkFileSystem, + NetworkRedirector, + Null, + ParallelPort, + PhysicalNetcard, + Printer, + Scanner, + Screen, + Serenum, + SerialPort, + SerialMousePort, + Smartcard, + Smb, + Sound, + Streams, + Tape, + TapeFileSystem, + Termsrv, + Transport, + Unknown, + Vdm, + Video, + VirtualDisk, + WaveIn, + WaveOut, +} + +impl Into<u32> for DeviceType { + fn into(self) -> u32 { + match self { + DeviceType::Port8042 => windows_kernel_sys::base::FILE_DEVICE_8042_PORT, + DeviceType::Acpi => windows_kernel_sys::base::FILE_DEVICE_ACPI, + DeviceType::Battery => windows_kernel_sys::base::FILE_DEVICE_BATTERY, + DeviceType::Beep => windows_kernel_sys::base::FILE_DEVICE_BEEP, + DeviceType::BusExtender => windows_kernel_sys::base::FILE_DEVICE_BUS_EXTENDER, + DeviceType::Cdrom => windows_kernel_sys::base::FILE_DEVICE_CD_ROM, + DeviceType::CdromFileSystem => windows_kernel_sys::base::FILE_DEVICE_CD_ROM_FILE_SYSTEM, + DeviceType::Changer => windows_kernel_sys::base::FILE_DEVICE_CHANGER, + DeviceType::Controller => windows_kernel_sys::base::FILE_DEVICE_CONTROLLER, + DeviceType::DataLink => windows_kernel_sys::base::FILE_DEVICE_DATALINK, + DeviceType::Dfs => windows_kernel_sys::base::FILE_DEVICE_DFS, + DeviceType::DfsFileSystem => windows_kernel_sys::base::FILE_DEVICE_DFS_FILE_SYSTEM, + DeviceType::DfsVolume => windows_kernel_sys::base::FILE_DEVICE_DFS_VOLUME, + DeviceType::Disk => windows_kernel_sys::base::FILE_DEVICE_DISK, + DeviceType::DiskFileSystem => windows_kernel_sys::base::FILE_DEVICE_DISK_FILE_SYSTEM, + DeviceType::Dvd => windows_kernel_sys::base::FILE_DEVICE_DVD, + DeviceType::FileSystem => windows_kernel_sys::base::FILE_DEVICE_FILE_SYSTEM, + DeviceType::Fips => windows_kernel_sys::base::FILE_DEVICE_FIPS, + DeviceType::FullscreenVideo => windows_kernel_sys::base::FILE_DEVICE_FULLSCREEN_VIDEO, + DeviceType::InportPort => windows_kernel_sys::base::FILE_DEVICE_INPORT_PORT, + DeviceType::Keyboard => windows_kernel_sys::base::FILE_DEVICE_KEYBOARD, + DeviceType::Ks => windows_kernel_sys::base::FILE_DEVICE_KS, + DeviceType::Ksec => windows_kernel_sys::base::FILE_DEVICE_KSEC, + DeviceType::Mailslot => windows_kernel_sys::base::FILE_DEVICE_MAILSLOT, + DeviceType::MassStorage => windows_kernel_sys::base::FILE_DEVICE_MASS_STORAGE, + DeviceType::MidiIn => windows_kernel_sys::base::FILE_DEVICE_MIDI_IN, + DeviceType::MidiOut => windows_kernel_sys::base::FILE_DEVICE_MIDI_OUT, + DeviceType::Modem => windows_kernel_sys::base::FILE_DEVICE_MODEM, + DeviceType::Mouse => windows_kernel_sys::base::FILE_DEVICE_MOUSE, + DeviceType::MultiUncProvider => windows_kernel_sys::base::FILE_DEVICE_MULTI_UNC_PROVIDER, + DeviceType::NamedPipe => windows_kernel_sys::base::FILE_DEVICE_NAMED_PIPE, + DeviceType::Network => windows_kernel_sys::base::FILE_DEVICE_NETWORK, + DeviceType::NetworkBrowser => windows_kernel_sys::base::FILE_DEVICE_NETWORK_BROWSER, + DeviceType::NetworkFileSystem => windows_kernel_sys::base::FILE_DEVICE_NETWORK_FILE_SYSTEM, + DeviceType::NetworkRedirector => windows_kernel_sys::base::FILE_DEVICE_NETWORK_REDIRECTOR, + DeviceType::Null => windows_kernel_sys::base::FILE_DEVICE_NULL, + DeviceType::ParallelPort => windows_kernel_sys::base::FILE_DEVICE_PARALLEL_PORT, + DeviceType::PhysicalNetcard => windows_kernel_sys::base::FILE_DEVICE_PHYSICAL_NETCARD, + DeviceType::Printer => windows_kernel_sys::base::FILE_DEVICE_PRINTER, + DeviceType::Scanner => windows_kernel_sys::base::FILE_DEVICE_SCANNER, + DeviceType::Screen => windows_kernel_sys::base::FILE_DEVICE_SCREEN, + DeviceType::Serenum => windows_kernel_sys::base::FILE_DEVICE_SERENUM, + DeviceType::SerialMousePort => windows_kernel_sys::base::FILE_DEVICE_SERIAL_MOUSE_PORT, + DeviceType::SerialPort => windows_kernel_sys::base::FILE_DEVICE_SERIAL_PORT, + DeviceType::Smartcard => windows_kernel_sys::base::FILE_DEVICE_SMARTCARD, + DeviceType::Smb => windows_kernel_sys::base::FILE_DEVICE_SMB, + DeviceType::Sound => windows_kernel_sys::base::FILE_DEVICE_SOUND, + DeviceType::Streams => windows_kernel_sys::base::FILE_DEVICE_STREAMS, + DeviceType::Tape => windows_kernel_sys::base::FILE_DEVICE_TAPE, + DeviceType::TapeFileSystem => windows_kernel_sys::base::FILE_DEVICE_TAPE_FILE_SYSTEM, + DeviceType::Termsrv => windows_kernel_sys::base::FILE_DEVICE_TERMSRV, + DeviceType::Transport => windows_kernel_sys::base::FILE_DEVICE_TRANSPORT, + DeviceType::Unknown => windows_kernel_sys::base::FILE_DEVICE_UNKNOWN, + DeviceType::Vdm => windows_kernel_sys::base::FILE_DEVICE_VDM, + DeviceType::Video => windows_kernel_sys::base::FILE_DEVICE_VIDEO, + DeviceType::VirtualDisk => windows_kernel_sys::base::FILE_DEVICE_VIRTUAL_DISK, + DeviceType::WaveIn => windows_kernel_sys::base::FILE_DEVICE_WAVE_IN, + DeviceType::WaveOut => windows_kernel_sys::base::FILE_DEVICE_WAVE_OUT, + } + } +} + +impl From<u32> for DeviceType { + fn from(value: u32) -> Self { + match value { + windows_kernel_sys::base::FILE_DEVICE_8042_PORT => DeviceType::Port8042, + windows_kernel_sys::base::FILE_DEVICE_ACPI => DeviceType::Acpi, + windows_kernel_sys::base::FILE_DEVICE_BATTERY => DeviceType::Battery, + windows_kernel_sys::base::FILE_DEVICE_BEEP => DeviceType::Beep, + windows_kernel_sys::base::FILE_DEVICE_BUS_EXTENDER => DeviceType::BusExtender, + windows_kernel_sys::base::FILE_DEVICE_CD_ROM => DeviceType::Cdrom, + windows_kernel_sys::base::FILE_DEVICE_CD_ROM_FILE_SYSTEM => DeviceType::CdromFileSystem, + windows_kernel_sys::base::FILE_DEVICE_CHANGER => DeviceType::Changer, + windows_kernel_sys::base::FILE_DEVICE_CONTROLLER => DeviceType::Controller, + windows_kernel_sys::base::FILE_DEVICE_DATALINK => DeviceType::DataLink, + windows_kernel_sys::base::FILE_DEVICE_DFS => DeviceType::Dfs, + windows_kernel_sys::base::FILE_DEVICE_DFS_FILE_SYSTEM => DeviceType::DfsFileSystem, + windows_kernel_sys::base::FILE_DEVICE_DFS_VOLUME => DeviceType::DfsVolume, + windows_kernel_sys::base::FILE_DEVICE_DISK => DeviceType::Disk, + windows_kernel_sys::base::FILE_DEVICE_DISK_FILE_SYSTEM => DeviceType::DiskFileSystem, + windows_kernel_sys::base::FILE_DEVICE_DVD => DeviceType::Dvd, + windows_kernel_sys::base::FILE_DEVICE_FILE_SYSTEM => DeviceType::FileSystem, + windows_kernel_sys::base::FILE_DEVICE_FIPS => DeviceType::Fips, + windows_kernel_sys::base::FILE_DEVICE_FULLSCREEN_VIDEO => DeviceType::FullscreenVideo, + windows_kernel_sys::base::FILE_DEVICE_INPORT_PORT => DeviceType::InportPort, + windows_kernel_sys::base::FILE_DEVICE_KEYBOARD => DeviceType::Keyboard, + windows_kernel_sys::base::FILE_DEVICE_KS => DeviceType::Ks, + windows_kernel_sys::base::FILE_DEVICE_KSEC => DeviceType::Ksec, + windows_kernel_sys::base::FILE_DEVICE_MAILSLOT => DeviceType::Mailslot, + windows_kernel_sys::base::FILE_DEVICE_MASS_STORAGE => DeviceType::MassStorage, + windows_kernel_sys::base::FILE_DEVICE_MIDI_IN => DeviceType::MidiIn, + windows_kernel_sys::base::FILE_DEVICE_MIDI_OUT => DeviceType::MidiOut, + windows_kernel_sys::base::FILE_DEVICE_MODEM => DeviceType::Modem, + windows_kernel_sys::base::FILE_DEVICE_MOUSE => DeviceType::Mouse, + windows_kernel_sys::base::FILE_DEVICE_MULTI_UNC_PROVIDER => DeviceType::MultiUncProvider, + windows_kernel_sys::base::FILE_DEVICE_NAMED_PIPE => DeviceType::NamedPipe, + windows_kernel_sys::base::FILE_DEVICE_NETWORK => DeviceType::Network, + windows_kernel_sys::base::FILE_DEVICE_NETWORK_BROWSER => DeviceType::NetworkBrowser, + windows_kernel_sys::base::FILE_DEVICE_NETWORK_FILE_SYSTEM => DeviceType::NetworkFileSystem, + windows_kernel_sys::base::FILE_DEVICE_NETWORK_REDIRECTOR => DeviceType::NetworkRedirector, + windows_kernel_sys::base::FILE_DEVICE_NULL => DeviceType::Null, + windows_kernel_sys::base::FILE_DEVICE_PARALLEL_PORT => DeviceType::ParallelPort, + windows_kernel_sys::base::FILE_DEVICE_PHYSICAL_NETCARD => DeviceType::PhysicalNetcard, + windows_kernel_sys::base::FILE_DEVICE_PRINTER => DeviceType::Printer, + windows_kernel_sys::base::FILE_DEVICE_SCANNER => DeviceType::Scanner, + windows_kernel_sys::base::FILE_DEVICE_SCREEN => DeviceType::Screen, + windows_kernel_sys::base::FILE_DEVICE_SERENUM => DeviceType::Serenum, + windows_kernel_sys::base::FILE_DEVICE_SERIAL_MOUSE_PORT => DeviceType::SerialMousePort, + windows_kernel_sys::base::FILE_DEVICE_SERIAL_PORT => DeviceType::SerialPort, + windows_kernel_sys::base::FILE_DEVICE_SMARTCARD => DeviceType::Smartcard, + windows_kernel_sys::base::FILE_DEVICE_SMB => DeviceType::Smb, + windows_kernel_sys::base::FILE_DEVICE_SOUND => DeviceType::Sound, + windows_kernel_sys::base::FILE_DEVICE_STREAMS => DeviceType::Streams, + windows_kernel_sys::base::FILE_DEVICE_TAPE => DeviceType::Tape, + windows_kernel_sys::base::FILE_DEVICE_TAPE_FILE_SYSTEM => DeviceType::TapeFileSystem, + windows_kernel_sys::base::FILE_DEVICE_TERMSRV => DeviceType::Termsrv, + windows_kernel_sys::base::FILE_DEVICE_TRANSPORT => DeviceType::Transport, + windows_kernel_sys::base::FILE_DEVICE_UNKNOWN => DeviceType::Unknown, + windows_kernel_sys::base::FILE_DEVICE_VDM => DeviceType::Vdm, + windows_kernel_sys::base::FILE_DEVICE_VIDEO => DeviceType::Video, + windows_kernel_sys::base::FILE_DEVICE_VIRTUAL_DISK => DeviceType::VirtualDisk, + windows_kernel_sys::base::FILE_DEVICE_WAVE_IN => DeviceType::WaveIn, + windows_kernel_sys::base::FILE_DEVICE_WAVE_OUT => DeviceType::WaveOut, + _ => DeviceType::Unknown, + } + } +} + +#[repr(C)] +pub struct device_operations { + dispatch: Option<extern "C" fn(*mut DEVICE_OBJECT, *mut IRP, u8) -> NTSTATUS>, + release: Option<extern "C" fn(*mut DEVICE_OBJECT)>, +} + +pub struct Device { + raw: *mut DEVICE_OBJECT, +} + +unsafe impl Send for Device {} +unsafe impl Sync for Device {} + +impl Device { + pub unsafe fn from_raw(raw: *mut DEVICE_OBJECT) -> Self { + Self { + raw, + } + } + + pub unsafe fn as_raw(&self) -> *const DEVICE_OBJECT { self.raw as *const _ } + + pub unsafe fn as_raw_mut(&self) -> *mut DEVICE_OBJECT { self.raw } + + pub fn into_raw(mut self) -> *mut DEVICE_OBJECT { + core::mem::replace(&mut self.raw, core::ptr::null_mut()) + } + + pub(crate) fn extension(&self) -> &DeviceExtension { + unsafe { &*((*self.raw).DeviceExtension as *const DeviceExtension) } + } + + pub(crate) fn extension_mut(&self) -> &mut DeviceExtension { + unsafe { &mut *((*self.raw).DeviceExtension as *mut DeviceExtension) } + } + + pub(crate) fn device_type(&self) -> DeviceType { self.extension().device_type } + + pub(crate) fn vtable(&self) -> &device_operations { + unsafe { &*(self.extension().vtable as *const _) } + } + + pub fn data<T: DeviceOperations>(&self) -> &T { unsafe { &*(self.extension().data as *const T) } } + + pub fn data_mut<T: DeviceOperations>(&self) -> &mut T { + unsafe { &mut *(self.extension().data as *mut T) } + } +} + +impl Drop for Device { + fn drop(&mut self) { + if self.raw.is_null() { + return; + } + + unsafe { + if let Some(release) = self.vtable().release { + release(self.raw); + } + + IoDeleteDevice(self.raw); + } + } +} + +pub struct RequestError(pub Error, pub IoRequest); + +pub enum Completion { + Complete(u32, IoRequest), +} + +pub trait DeviceOperations: Sync + Sized { + fn create(&mut self, _device: &Device, request: IoRequest) -> Result<Completion, RequestError> { + Ok(Completion::Complete(0, request)) + } + + fn close(&mut self, _device: &Device, request: IoRequest) -> Result<Completion, RequestError> { + Ok(Completion::Complete(0, request)) + } + + fn cleanup(&mut self, _device: &Device, request: IoRequest) -> Result<Completion, RequestError> { + Ok(Completion::Complete(0, request)) + } + + fn read(&mut self, _device: &Device, request: ReadRequest) -> Result<Completion, RequestError> { + Ok(Completion::Complete(0, request.into())) + } + + fn write(&mut self, _device: &Device, request: WriteRequest) -> Result<Completion, RequestError> { + Ok(Completion::Complete(0, request.into())) + } + + fn ioctl( + &mut self, + _device: &Device, + request: IoControlRequest, + ) -> Result<Completion, RequestError> { + Ok(Completion::Complete(0, request.into())) + } +} + +extern "C" fn dispatch_callback<T: DeviceOperations>( + device: *mut DEVICE_OBJECT, + irp: *mut IRP, + major: u8, +) -> NTSTATUS { + let device = unsafe { Device::from_raw(device) }; + let data: &mut T = device.data_mut(); + let request = unsafe { IoRequest::from_raw(irp) }; + + let result = match major as _ { + IRP_MJ_CREATE => data.create(&device, request), + IRP_MJ_CLOSE => data.close(&device, request), + IRP_MJ_CLEANUP => data.cleanup(&device, request), + IRP_MJ_READ => { + let read_request = ReadRequest { + inner: request + }; + + data.read(&device, read_request) + } + IRP_MJ_WRITE => { + let write_request = WriteRequest { + inner: request + }; + + data.write(&device, write_request) + } + IRP_MJ_DEVICE_CONTROL => { + let control_request = IoControlRequest { + inner: request + }; + + if device.device_type() == control_request.control_code().device_type() { + data.ioctl(&device, control_request) + } else { + Err(RequestError( + Error::INVALID_PARAMETER, + control_request.into(), + )) + } + } + _ => Err(RequestError(Error::INVALID_PARAMETER, request)), + }; + + device.into_raw(); + + match result { + Ok(Completion::Complete(size, request)) => { + request.complete(Ok(size)); + STATUS_SUCCESS + } + Err(RequestError(e, request)) => { + let status = e.to_ntstatus(); + request.complete(Err(e)); + status + } + } +} + +extern "C" fn release_callback<T: DeviceOperations>(device: *mut DEVICE_OBJECT) { + unsafe { + let extension = (*device).DeviceExtension as *mut DeviceExtension; + + let ptr = core::mem::replace(&mut (*extension).data, core::ptr::null_mut()); + Box::from_raw(ptr as *mut T); + } +} + +pub(crate) struct DeviceOperationsVtable<T>(core::marker::PhantomData<T>); + +impl<T: DeviceOperations> DeviceOperationsVtable<T> { + pub(crate) const VTABLE: device_operations = device_operations { + dispatch: Some(dispatch_callback::<T>), + release: Some(release_callback::<T>), + }; +} + +#[repr(C)] +pub struct DeviceExtension { + pub(crate) vtable: *const device_operations, + pub(crate) data: *mut cty::c_void, + pub(crate) device_type: DeviceType, +} + +pub extern "C" fn dispatch_device(device: *mut DEVICE_OBJECT, irp: *mut IRP) -> NTSTATUS { + let stack_location = unsafe { &*IoGetCurrentIrpStackLocation(irp) }; + let device = unsafe { Device::from_raw(device) }; + let vtable = device.vtable(); + + match vtable.dispatch { + Some(dispatch) => dispatch(device.into_raw(), irp, stack_location.MajorFunction), + _ => { + device.into_raw(); + STATUS_SUCCESS + } + } +} diff --git a/crates/windows-kernel-rs/src/driver.rs b/crates/windows-kernel-rs/src/driver.rs new file mode 100644 index 0000000..2bdac4e --- /dev/null +++ b/crates/windows-kernel-rs/src/driver.rs @@ -0,0 +1,85 @@ +use alloc::boxed::Box; + +use widestring::U16CString; +use windows_kernel_sys::{base::DRIVER_OBJECT, ntoskrnl::IoCreateDevice}; + +use crate::{ + device::{ + Access, + Device, + DeviceDoFlags, + DeviceExtension, + DeviceFlags, + DeviceOperations, + DeviceOperationsVtable, + DeviceType, + }, + error::{Error, IntoResult}, + string::create_unicode_string, +}; + +pub struct Driver { + pub(crate) raw: *mut DRIVER_OBJECT, +} + +impl Driver { + pub unsafe fn from_raw(raw: *mut DRIVER_OBJECT) -> Self { + Self { + raw, + } + } + + pub unsafe fn as_raw(&self) -> *const DRIVER_OBJECT { self.raw as _ } + + pub unsafe fn as_raw_mut(&mut self) -> *mut DRIVER_OBJECT { self.raw as _ } + + pub fn create_device<T>( + &mut self, + name: &str, + device_type: DeviceType, + device_flags: DeviceFlags, + device_do_flags: DeviceDoFlags, + access: Access, + data: T, + ) -> Result<Device, Error> + where + T: DeviceOperations, + { + // Box the data. + let data = Box::new(data); + + // Convert the name to UTF-16 and then create a UNICODE_STRING. + let name = U16CString::from_str(name).unwrap(); + let mut name = create_unicode_string(name.as_slice()); + + // Create the device. + let mut device = core::ptr::null_mut(); + + unsafe { + IoCreateDevice( + self.raw, + core::mem::size_of::<DeviceExtension>() as u32, + &mut name, + device_type.into(), + device_flags.bits(), + access.is_exclusive() as _, + &mut device, + ) + } + .into_result()?; + + unsafe { + (*device).Flags |= device_do_flags.bits(); + } + + let device = unsafe { Device::from_raw(device) }; + + // Store the boxed data and vtable. + let extension = device.extension_mut(); + extension.device_type = device_type; + extension.vtable = &DeviceOperationsVtable::<T>::VTABLE; + extension.data = Box::into_raw(data) as *mut cty::c_void; + + Ok(device) + } +} diff --git a/crates/windows-kernel-rs/src/error.rs b/crates/windows-kernel-rs/src/error.rs new file mode 100644 index 0000000..075287f --- /dev/null +++ b/crates/windows-kernel-rs/src/error.rs @@ -0,0 +1,87 @@ +use windows_kernel_sys::base::{ + NTSTATUS, + STATUS_ACCESS_VIOLATION, + STATUS_ARRAY_BOUNDS_EXCEEDED, + STATUS_BREAKPOINT, + STATUS_DATATYPE_MISALIGNMENT, + STATUS_END_OF_FILE, + STATUS_FLOAT_DENORMAL_OPERAND, + STATUS_FLOAT_DIVIDE_BY_ZERO, + STATUS_FLOAT_INEXACT_RESULT, + STATUS_FLOAT_INVALID_OPERATION, + STATUS_FLOAT_OVERFLOW, + STATUS_FLOAT_STACK_CHECK, + STATUS_FLOAT_UNDERFLOW, + STATUS_GUARD_PAGE_VIOLATION, + STATUS_ILLEGAL_INSTRUCTION, + STATUS_INSUFFICIENT_RESOURCES, + STATUS_INTEGER_DIVIDE_BY_ZERO, + STATUS_INTEGER_OVERFLOW, + STATUS_INVALID_DISPOSITION, + STATUS_INVALID_HANDLE, + STATUS_INVALID_PARAMETER, + STATUS_INVALID_USER_BUFFER, + STATUS_IN_PAGE_ERROR, + STATUS_NONCONTINUABLE_EXCEPTION, + STATUS_NOT_IMPLEMENTED, + STATUS_NO_MEMORY, + STATUS_PRIVILEGED_INSTRUCTION, + STATUS_SINGLE_STEP, + STATUS_STACK_OVERFLOW, + STATUS_SUCCESS, + STATUS_UNSUCCESSFUL, + STATUS_UNWIND_CONSOLIDATE, +}; + +#[derive(Clone, Copy, Debug)] +pub struct Error(NTSTATUS); + +impl Error { + pub const ACCESS_VIOLATION: Error = Error(STATUS_ACCESS_VIOLATION); + pub const ARRAY_BOUNDS_EXCEEDED: Error = Error(STATUS_ARRAY_BOUNDS_EXCEEDED); + pub const BREAKPOINT: Error = Error(STATUS_BREAKPOINT); + pub const DATATYPE_MISALIGNMENT: Error = Error(STATUS_DATATYPE_MISALIGNMENT); + pub const END_OF_FILE: Error = Error(STATUS_END_OF_FILE); + pub const FLOAT_DENORMAL_OPERAND: Error = Error(STATUS_FLOAT_DENORMAL_OPERAND); + pub const FLOAT_DIVIDE_BY_ZERO: Error = Error(STATUS_FLOAT_DIVIDE_BY_ZERO); + pub const FLOAT_INEXACT_RESULT: Error = Error(STATUS_FLOAT_INEXACT_RESULT); + pub const FLOAT_INVALID_OPERATION: Error = Error(STATUS_FLOAT_INVALID_OPERATION); + pub const FLOAT_OVERFLOW: Error = Error(STATUS_FLOAT_OVERFLOW); + pub const FLOAT_STACK_CHECK: Error = Error(STATUS_FLOAT_STACK_CHECK); + pub const FLOAT_UNDERFLOW: Error = Error(STATUS_FLOAT_UNDERFLOW); + pub const GUARD_PAGE_VIOLATION: Error = Error(STATUS_GUARD_PAGE_VIOLATION); + pub const ILLEGAL_INSTRUCTION: Error = Error(STATUS_ILLEGAL_INSTRUCTION); + pub const INSUFFICIENT_RESOURCES: Error = Error(STATUS_INSUFFICIENT_RESOURCES); + pub const INTEGER_DIVIDE_BY_ZERO: Error = Error(STATUS_INTEGER_DIVIDE_BY_ZERO); + pub const INTEGER_OVERFLOW: Error = Error(STATUS_INTEGER_OVERFLOW); + pub const INVALID_DISPOSITION: Error = Error(STATUS_INVALID_DISPOSITION); + pub const INVALID_HANDLE: Error = Error(STATUS_INVALID_HANDLE); + pub const INVALID_PARAMETER: Error = Error(STATUS_INVALID_PARAMETER); + pub const INVALID_USER_BUFFER: Error = Error(STATUS_INVALID_USER_BUFFER); + pub const IN_PAGE_ERROR: Error = Error(STATUS_IN_PAGE_ERROR); + pub const NONCONTINUABLE_EXCEPTION: Error = Error(STATUS_NONCONTINUABLE_EXCEPTION); + pub const NOT_IMPLEMENTED: Error = Error(STATUS_NOT_IMPLEMENTED); + pub const NO_MEMORY: Error = Error(STATUS_NO_MEMORY); + pub const PRIVILEGED_INSTRUCTION: Error = Error(STATUS_PRIVILEGED_INSTRUCTION); + pub const SINGLE_STEP: Error = Error(STATUS_SINGLE_STEP); + pub const STACK_OVERFLOW: Error = Error(STATUS_STACK_OVERFLOW); + pub const UNSUCCESSFUL: Error = Error(STATUS_UNSUCCESSFUL); + pub const UNWIND_CONSOLIDATE: Error = Error(STATUS_UNWIND_CONSOLIDATE); + + pub fn from_ntstatus(status: NTSTATUS) -> Error { Error(status) } + + pub fn to_ntstatus(&self) -> NTSTATUS { self.0 } +} + +pub trait IntoResult { + fn into_result(self) -> Result<(), Error>; +} + +impl IntoResult for NTSTATUS { + fn into_result(self) -> Result<(), Error> { + match self { + STATUS_SUCCESS => Ok(()), + status => Err(Error::from_ntstatus(status)), + } + } +} diff --git a/crates/windows-kernel-rs/src/intrin.rs b/crates/windows-kernel-rs/src/intrin.rs new file mode 100644 index 0000000..f415f8d --- /dev/null +++ b/crates/windows-kernel-rs/src/intrin.rs @@ -0,0 +1,25 @@ +use windows_kernel_sys::intrin::{read_msr_safe, write_msr_safe}; + +use crate::error::{Error, IntoResult}; + +/// Attempts to read the given model-specific register. Accessing an invalid +/// model-specific register would normally result in a CPU exception. This +/// function uses Structured Exception Handling (SEH) to safely catch CPU +/// exceptions and to turn them into an [`Error`]. This prevents a hang. +pub fn read_msr(register: u32) -> Result<u64, Error> { + let mut value = 0; + + unsafe { read_msr_safe(register, &mut value) }.into_result()?; + + Ok(value) +} + +/// Attempts to write the given value to the given model-specific register. +/// Accessing an invalid model-specific register would normally result in a CPU +/// exception. This function uses Structured Handling (SEH) to safely catch CPU +/// exceptions and to turn them into an [`Error`]. This prevents a hang. +pub fn write_msr(register: u32, value: u64) -> Result<(), Error> { + unsafe { write_msr_safe(register, value) }.into_result()?; + + Ok(()) +} diff --git a/crates/windows-kernel-rs/src/io.rs b/crates/windows-kernel-rs/src/io.rs new file mode 100644 index 0000000..7d783ce --- /dev/null +++ b/crates/windows-kernel-rs/src/io.rs @@ -0,0 +1,31 @@ +use windows_kernel_sys::{base::ANSI_STRING, ntoskrnl::DbgPrint}; + +#[macro_export] +macro_rules! print { + ($($arg:tt)*) => ($crate::io::_print(format_args!($($arg)*))); +} + +#[macro_export] +macro_rules! println { + () => ($crate::print!("\n")); + ($($arg:tt)*) => ($crate::print!("{}\n", format_args!($($arg)*))); +} + +#[doc(hidden)] +pub fn _print(args: core::fmt::Arguments) { + // Format the string using the `alloc::format!` as this is guaranteed to return + // a `String` instead of a `Result` that we would have to `unwrap`. This + // ensures that this code stays panic-free. + let s = alloc::format!("{}", args); + + // Print the string. We must make sure to not pass this user-supplied string as + // the format string, as `DbgPrint` may then format any format specifiers it + // contains. This could potentially be an attack vector. + let s = ANSI_STRING { + Length: s.len() as u16, + MaximumLength: s.len() as u16, + Buffer: s.as_ptr() as _, + }; + + unsafe { DbgPrint("%Z\0".as_ptr() as _, &s) }; +} diff --git a/crates/windows-kernel-rs/src/ioctl.rs b/crates/windows-kernel-rs/src/ioctl.rs new file mode 100644 index 0000000..c1d493f --- /dev/null +++ b/crates/windows-kernel-rs/src/ioctl.rs @@ -0,0 +1,111 @@ +use bitflags::bitflags; +use windows_kernel_sys::base::{ + FILE_ANY_ACCESS, + FILE_READ_DATA, + FILE_WRITE_DATA, + METHOD_BUFFERED, + METHOD_IN_DIRECT, + METHOD_NEITHER, + METHOD_OUT_DIRECT, +}; + +use crate::device::DeviceType; + +bitflags! { + pub struct RequiredAccess: u32 { + const ANY_ACCESS = FILE_ANY_ACCESS; + const READ_DATA = FILE_READ_DATA; + const WRITE_DATA = FILE_WRITE_DATA; + const READ_WRITE_DATA = FILE_READ_DATA | FILE_WRITE_DATA; + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[repr(u32)] +pub enum TransferMethod { + Neither = METHOD_NEITHER, + InputDirect = METHOD_IN_DIRECT, + OutputDirect = METHOD_OUT_DIRECT, + Buffered = METHOD_BUFFERED, +} + +impl From<u32> for TransferMethod { + fn from(value: u32) -> Self { + match value & 0x3 { + METHOD_NEITHER => Self::Neither, + METHOD_IN_DIRECT => Self::InputDirect, + METHOD_OUT_DIRECT => Self::OutputDirect, + METHOD_BUFFERED => Self::Buffered, + _ => unreachable!(), + } + } +} + +impl Into<u32> for TransferMethod { + fn into(self) -> u32 { + match self { + Self::Neither => METHOD_NEITHER, + Self::InputDirect => METHOD_IN_DIRECT, + Self::OutputDirect => METHOD_OUT_DIRECT, + Self::Buffered => METHOD_BUFFERED, + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ControlCode( + pub DeviceType, + pub RequiredAccess, + pub u32, + pub TransferMethod, +); + +impl ControlCode { + const ACCESS_BITS: usize = 2; + const ACCESS_MASK: u32 = (1 << Self::ACCESS_BITS) - 1; + const ACCESS_SHIFT: usize = Self::NUM_SHIFT + Self::NUM_BITS; + const METHOD_BITS: usize = 2; + const METHOD_MASK: u32 = (1 << Self::METHOD_BITS) - 1; + const METHOD_SHIFT: usize = 0; + const NUM_BITS: usize = 12; + const NUM_MASK: u32 = (1 << Self::NUM_BITS) - 1; + const NUM_SHIFT: usize = Self::METHOD_SHIFT + Self::METHOD_BITS; + const TYPE_BITS: usize = 16; + const TYPE_MASK: u32 = (1 << Self::TYPE_BITS) - 1; + const TYPE_SHIFT: usize = Self::ACCESS_SHIFT + Self::ACCESS_BITS; + + pub fn device_type(&self) -> DeviceType { self.0 } + + pub fn required_access(&self) -> RequiredAccess { self.1 } + + pub fn number(&self) -> u32 { self.2 } + + pub fn transfer_method(&self) -> TransferMethod { self.3 } +} + +impl From<u32> for ControlCode { + fn from(value: u32) -> Self { + let method = (value >> Self::METHOD_SHIFT) & Self::METHOD_MASK; + let num = (value >> Self::NUM_SHIFT) & Self::NUM_MASK; + let access = (value >> Self::ACCESS_SHIFT) & Self::ACCESS_MASK; + let ty = (value >> Self::TYPE_SHIFT) & Self::TYPE_MASK; + + Self( + ty.into(), + RequiredAccess::from_bits(access).unwrap_or(RequiredAccess::READ_DATA), + num, + method.into(), + ) + } +} + +impl Into<u32> for ControlCode { + fn into(self) -> u32 { + let method = Into::<u32>::into(self.3) << Self::METHOD_SHIFT; + let num = self.2 << Self::NUM_SHIFT; + let access = self.1.bits() << Self::ACCESS_SHIFT; + let ty = Into::<u32>::into(self.0) << Self::TYPE_SHIFT; + + ty | access | num | method + } +} diff --git a/crates/windows-kernel-rs/src/lib.rs b/crates/windows-kernel-rs/src/lib.rs new file mode 100644 index 0000000..a0ddf9a --- /dev/null +++ b/crates/windows-kernel-rs/src/lib.rs @@ -0,0 +1,119 @@ +#![no_std] +#![feature(alloc_error_handler)] + +extern crate alloc; + +pub mod affinity; +pub mod allocator; +pub mod device; +pub mod driver; +pub mod error; +pub mod intrin; +pub mod io; +pub mod ioctl; +pub mod mdl; +pub mod memory; +pub mod process; +pub mod request; +pub mod section; +pub mod string; +pub mod symbolic_link; +pub mod sync; +pub mod user_ptr; +pub mod version; + +pub use widestring::U16CString; +pub use windows_kernel_sys::base::{ + DRIVER_OBJECT, + IRP_MJ_MAXIMUM_FUNCTION, + NTSTATUS, + STATUS_SUCCESS, + UNICODE_STRING, +}; + +pub use crate::{ + affinity::{get_cpu_count, get_current_cpu_num, run_on_cpu, run_on_each_cpu}, + device::{ + dispatch_device, + Access, + Completion, + Device, + DeviceDoFlags, + DeviceFlags, + DeviceOperations, + DeviceType, + RequestError, + }, + driver::Driver, + error::Error, + ioctl::{ControlCode, RequiredAccess, TransferMethod}, + request::{IoControlRequest, IoRequest, ReadRequest, WriteRequest}, + symbolic_link::SymbolicLink, + user_ptr::UserPtr, +}; + +#[cfg(feature = "alloc")] +#[global_allocator] +static ALLOCATOR: allocator::KernelAllocator = + allocator::KernelAllocator::new(u32::from_ne_bytes(*b"rust")); + +#[panic_handler] +fn panic(_info: &core::panic::PanicInfo) -> ! { loop {} } + +#[used] +#[no_mangle] +pub static _fltused: i32 = 0; + +#[no_mangle] +pub extern "system" fn __CxxFrameHandler3() -> i32 { 0 } + +#[macro_export] +macro_rules! kernel_module { + ($module:ty) => { + static mut __MOD: Option<$module> = None; + + #[no_mangle] + pub extern "system" fn driver_entry( + driver: &mut $crate::DRIVER_OBJECT, + registry_path: &$crate::UNICODE_STRING, + ) -> $crate::NTSTATUS { + unsafe { + driver.DriverUnload = Some(driver_exit); + + for i in 0..$crate::IRP_MJ_MAXIMUM_FUNCTION { + driver.MajorFunction[i as usize] = Some($crate::dispatch_device); + } + } + + let driver = unsafe { Driver::from_raw(driver) }; + + let registry_path = unsafe { $crate::U16CString::from_ptr_str(registry_path.Buffer) }; + let registry_path = registry_path.to_string_lossy(); + + match <$module as $crate::KernelModule>::init(driver, registry_path.as_str()) { + Ok(m) => { + unsafe { + __MOD = Some(m); + } + + $crate::STATUS_SUCCESS + } + Err(e) => e.to_ntstatus(), + } + } + + pub unsafe extern "C" fn driver_exit(driver: *mut $crate::DRIVER_OBJECT) { + let driver = unsafe { Driver::from_raw(driver) }; + + match __MOD.take() { + Some(mut m) => m.cleanup(driver), + _ => (), + } + } + }; +} + +pub trait KernelModule: Sized + Sync { + fn init(driver: Driver, registry_path: &str) -> Result<Self, Error>; + fn cleanup(&mut self, _driver: Driver) {} +} diff --git a/crates/windows-kernel-rs/src/mdl.rs b/crates/windows-kernel-rs/src/mdl.rs new file mode 100644 index 0000000..7d002f0 --- /dev/null +++ b/crates/windows-kernel-rs/src/mdl.rs @@ -0,0 +1,117 @@ +use crate::{error::Error, memory::MemoryCaching}; + +#[repr(i32)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum AccessMode { + KernelMode = windows_kernel_sys::base::_MODE::KernelMode, + UserMode = windows_kernel_sys::base::_MODE::UserMode, +} + +pub struct MemoryDescriptorList { + raw: *mut windows_kernel_sys::base::MDL, +} + +unsafe impl Send for MemoryDescriptorList {} +unsafe impl Sync for MemoryDescriptorList {} + +impl MemoryDescriptorList { + pub fn new(addr: *mut core::ffi::c_void, size: usize) -> Result<Self, Error> { + use windows_kernel_sys::ntoskrnl::IoAllocateMdl; + + let raw = unsafe { + IoAllocateMdl( + addr, + size as _, + false as _, + false as _, + core::ptr::null_mut(), + ) + }; + + if raw.is_null() { + return Err(Error::INSUFFICIENT_RESOURCES); + } + + Ok(Self { + raw, + }) + } + + pub fn build_for_non_paged_pool(&mut self) { + use windows_kernel_sys::ntoskrnl::MmBuildMdlForNonPagedPool; + + unsafe { + MmBuildMdlForNonPagedPool(self.raw); + } + } + + pub fn map_locked_pages( + self, + access: AccessMode, + caching: MemoryCaching, + desired_addr: Option<*mut core::ffi::c_void>, + ) -> Result<LockedMapping, Error> { + use windows_kernel_sys::ntoskrnl::MmMapLockedPagesSpecifyCache; + + let ptr = unsafe { + MmMapLockedPagesSpecifyCache( + self.raw, + access as _, + caching as _, + desired_addr.unwrap_or(core::ptr::null_mut()), + false as _, + 0, + ) + }; + + Ok(LockedMapping { + raw: self.raw, + ptr, + }) + } +} + +impl Drop for MemoryDescriptorList { + fn drop(&mut self) { + use windows_kernel_sys::ntoskrnl::IoFreeMdl; + + unsafe { + IoFreeMdl(self.raw); + } + } +} + +pub struct LockedMapping { + raw: *mut windows_kernel_sys::base::MDL, + ptr: *mut core::ffi::c_void, +} + +unsafe impl Send for LockedMapping {} +unsafe impl Sync for LockedMapping {} + +impl LockedMapping { + pub fn ptr(&self) -> *mut core::ffi::c_void { self.ptr } + + pub fn unlock(self) -> MemoryDescriptorList { + use windows_kernel_sys::ntoskrnl::MmUnmapLockedPages; + + unsafe { + MmUnmapLockedPages(self.ptr, self.raw); + } + + MemoryDescriptorList { + raw: self.raw + } + } +} + +impl Drop for LockedMapping { + fn drop(&mut self) { + use windows_kernel_sys::ntoskrnl::{IoFreeMdl, MmUnmapLockedPages}; + + unsafe { + MmUnmapLockedPages(self.ptr, self.raw); + IoFreeMdl(self.raw); + } + } +} diff --git a/crates/windows-kernel-rs/src/memory.rs b/crates/windows-kernel-rs/src/memory.rs new file mode 100644 index 0000000..cb5d2e9 --- /dev/null +++ b/crates/windows-kernel-rs/src/memory.rs @@ -0,0 +1,173 @@ +use windows_kernel_sys::base::{ + MM_COPY_ADDRESS, + MM_COPY_MEMORY_PHYSICAL, + MM_COPY_MEMORY_VIRTUAL, + PHYSICAL_ADDRESS, + _MEMORY_CACHING_TYPE as MEMORY_CACHING_TYPE, +}; + +use crate::error::{Error, IntoResult}; + +#[repr(i32)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum MemoryCaching { + NonCached = MEMORY_CACHING_TYPE::MmNonCached, + Cached = MEMORY_CACHING_TYPE::MmCached, + WriteCombined = MEMORY_CACHING_TYPE::MmWriteCombined, + #[cfg(feature = "system")] + HardwareCoherentCached = MEMORY_CACHING_TYPE::MmHardwareCoherentCached, + #[cfg(feature = "system")] + NonCachedUnordered = MEMORY_CACHING_TYPE::MmNonCachedUnordered, + #[cfg(feature = "system")] + USWCCached = MEMORY_CACHING_TYPE::MmUSWCCached, + NotMapped = MEMORY_CACHING_TYPE::MmNotMapped, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct PhysicalAddress(u64); + +impl From<u64> for PhysicalAddress { + fn from(value: u64) -> Self { Self(value) } +} + +impl Into<u64> for PhysicalAddress { + fn into(self) -> u64 { self.0 } +} + +impl From<PHYSICAL_ADDRESS> for PhysicalAddress { + fn from(value: PHYSICAL_ADDRESS) -> Self { Self(unsafe { value.QuadPart } as _) } +} + +impl Into<PHYSICAL_ADDRESS> for PhysicalAddress { + fn into(self) -> PHYSICAL_ADDRESS { + let mut addr: PHYSICAL_ADDRESS = unsafe { core::mem::zeroed() }; + + addr.QuadPart = self.0 as _; + + addr + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum CopyAddress { + Virtual(*mut core::ffi::c_void), + Physical(PhysicalAddress), +} + +unsafe impl Send for CopyAddress {} +unsafe impl Sync for CopyAddress {} + +impl Into<(u32, MM_COPY_ADDRESS)> for CopyAddress { + fn into(self) -> (u32, MM_COPY_ADDRESS) { + let mut copy_addr: MM_COPY_ADDRESS = unsafe { core::mem::zeroed() }; + + let flags = match self { + CopyAddress::Virtual(addr) => { + copy_addr.__bindgen_anon_1.VirtualAddress = addr as _; + MM_COPY_MEMORY_VIRTUAL + } + CopyAddress::Physical(addr) => { + copy_addr.__bindgen_anon_1.PhysicalAddress = addr.into(); + MM_COPY_MEMORY_PHYSICAL + } + }; + + (flags, copy_addr) + } +} + +pub struct IoMapping { + ptr: *mut core::ffi::c_void, + size: usize, +} + +unsafe impl Send for IoMapping {} +unsafe impl Sync for IoMapping {} + +impl IoMapping { + pub fn new(addr: PhysicalAddress, size: usize, caching: MemoryCaching) -> Result<Self, Error> { + use windows_kernel_sys::ntoskrnl::MmMapIoSpace; + + let ptr = unsafe { MmMapIoSpace(addr.into(), size as _, caching as _) }; + + if ptr.is_null() { + return Err(Error::INVALID_PARAMETER); + } + + Ok(Self { + ptr, + size, + }) + } + + pub fn ptr(&self) -> *mut core::ffi::c_void { self.ptr } + + pub fn size(&self) -> usize { self.size } +} + +impl Drop for IoMapping { + fn drop(&mut self) { + use windows_kernel_sys::ntoskrnl::MmUnmapIoSpace; + + unsafe { + MmUnmapIoSpace(self.ptr, self.size as _); + } + } +} + +#[cfg(feature = "system")] +pub fn get_virtual_for_physical(addr: PhysicalAddress) -> *mut core::ffi::c_void { + use windows_kernel_sys::ntoskrnl::MmGetVirtualForPhysical; + + let virt_addr = unsafe { MmGetVirtualForPhysical(addr.into()) }; + + virt_addr as _ +} + +pub fn read_memory(buffer: &mut [u8], source: CopyAddress) -> Result<usize, Error> { + use windows_kernel_sys::ntoskrnl::MmCopyMemory; + + let (flags, copy_addr) = source.into(); + let mut bytes = 0; + + unsafe { + MmCopyMemory( + buffer.as_mut_ptr() as _, + copy_addr, + buffer.len() as _, + flags, + &mut bytes, + ) + } + .into_result()?; + + Ok(bytes as _) +} + +#[cfg(feature = "system")] +pub fn write_memory(target: CopyAddress, buffer: &[u8]) -> Result<usize, Error> { + use windows_kernel_sys::ntoskrnl::MmCopyMemory; + + let mut copy_addr: MM_COPY_ADDRESS = unsafe { core::mem::zeroed() }; + let mut bytes = 0; + + let target = match target { + CopyAddress::Virtual(addr) => addr, + CopyAddress::Physical(addr) => get_virtual_for_physical(addr), + }; + + copy_addr.__bindgen_anon_1.VirtualAddress = buffer.as_ptr() as _; + + unsafe { + MmCopyMemory( + target as _, + copy_addr, + buffer.len() as _, + MM_COPY_MEMORY_VIRTUAL, + &mut bytes, + ) + } + .into_result()?; + + Ok(bytes as _) +} diff --git a/crates/windows-kernel-rs/src/process.rs b/crates/windows-kernel-rs/src/process.rs new file mode 100644 index 0000000..e0133c0 --- /dev/null +++ b/crates/windows-kernel-rs/src/process.rs @@ -0,0 +1,132 @@ +use bitflags::bitflags; +use windows_kernel_sys::{ + base::{CLIENT_ID, HANDLE, KAPC_STATE, OBJECT_ATTRIBUTES, PEPROCESS}, + ntoskrnl::{ + KeStackAttachProcess, + KeUnstackDetachProcess, + ObDereferenceObject, + ObReferenceObject, + PsGetCurrentProcess, + PsLookupProcessByProcessId, + ZwClose, + ZwOpenProcess, + }, +}; + +use crate::error::{Error, IntoResult}; + +pub type ProcessId = usize; + +#[derive(Clone, Debug)] +pub struct Process { + pub process: PEPROCESS, +} + +impl Process { + pub fn as_ptr(&self) -> PEPROCESS { self.process } + + pub fn current() -> Self { + let process = unsafe { PsGetCurrentProcess() }; + + unsafe { + ObReferenceObject(process as _); + } + + Self { + process, + } + } + + pub fn by_id(process_id: ProcessId) -> Result<Self, Error> { + let mut process = core::ptr::null_mut(); + + unsafe { PsLookupProcessByProcessId(process_id as _, &mut process) }.into_result()?; + + Ok(Self { + process, + }) + } + + pub fn id(&self) -> ProcessId { + let handle = unsafe { windows_kernel_sys::ntoskrnl::PsGetProcessId(self.process) }; + + handle as _ + } + + pub fn attach(&self) -> ProcessAttachment { unsafe { ProcessAttachment::attach(self.process) } } +} + +impl Drop for Process { + fn drop(&mut self) { + unsafe { + ObDereferenceObject(self.process as _); + } + } +} + +pub struct ProcessAttachment { + process: PEPROCESS, + state: KAPC_STATE, +} + +impl ProcessAttachment { + pub unsafe fn attach(process: PEPROCESS) -> Self { + let mut state: KAPC_STATE = core::mem::zeroed(); + + ObReferenceObject(process as _); + KeStackAttachProcess(process, &mut state); + + Self { + process, + state, + } + } +} + +impl Drop for ProcessAttachment { + fn drop(&mut self) { + unsafe { + KeUnstackDetachProcess(&mut self.state); + ObDereferenceObject(self.process as _); + } + } +} + +bitflags! { + pub struct ProcessAccess: u32 { + const ALL_ACCESS = windows_kernel_sys::base::PROCESS_ALL_ACCESS; + } +} + +pub struct ZwProcess { + pub(crate) handle: HANDLE, +} + +impl ZwProcess { + pub fn open(id: ProcessId, access: ProcessAccess) -> Result<Self, Error> { + let mut attrs: OBJECT_ATTRIBUTES = unsafe { core::mem::zeroed() }; + attrs.Length = core::mem::size_of::<OBJECT_ATTRIBUTES>() as u32; + + let mut client_id = CLIENT_ID { + UniqueProcess: id as _, + UniqueThread: core::ptr::null_mut(), + }; + + let mut handle = core::ptr::null_mut(); + + unsafe { ZwOpenProcess(&mut handle, access.bits(), &mut attrs, &mut client_id) } + .into_result()?; + + Ok(Self { + handle, + }) + } +} + +impl Drop for ZwProcess { + fn drop(&mut self) { + unsafe { + ZwClose(self.handle); + } + } +} diff --git a/crates/windows-kernel-rs/src/request.rs b/crates/windows-kernel-rs/src/request.rs new file mode 100644 index 0000000..43a6a83 --- /dev/null +++ b/crates/windows-kernel-rs/src/request.rs @@ -0,0 +1,265 @@ +use core::ops::Deref; + +use bitflags::bitflags; +use windows_kernel_sys::{ + base::{ + IO_NO_INCREMENT, + IO_STACK_LOCATION, + IRP, + STATUS_SUCCESS, + _MM_PAGE_PRIORITY as MM_PAGE_PRIORITY, + }, + ntoskrnl::{ + IoCompleteRequest, + IoGetCurrentIrpStackLocation, + MmGetMdlByteCount, + MmGetMdlByteOffset, + MmGetSystemAddressForMdlSafe, + }, +}; + +use crate::{ + error::Error, + ioctl::{ControlCode, RequiredAccess, TransferMethod}, + user_ptr::UserPtr, +}; + +bitflags! { + pub struct IrpFlags: u32 { + const NOCACHE = windows_kernel_sys::base::IRP_NOCACHE; + const PAGING_IO = windows_kernel_sys::base::IRP_PAGING_IO; + const MOUNT_COMPLETION = windows_kernel_sys::base::IRP_MOUNT_COMPLETION; + const SYNCHRONOUS_API = windows_kernel_sys::base::IRP_SYNCHRONOUS_API; + const ASSOCIATED_IRP = windows_kernel_sys::base::IRP_ASSOCIATED_IRP; + const BUFFERED_IO = windows_kernel_sys::base::IRP_BUFFERED_IO; + const DEALLOCATE_BUFFER = windows_kernel_sys::base::IRP_DEALLOCATE_BUFFER; + const INPUT_OPERATION = windows_kernel_sys::base::IRP_INPUT_OPERATION; + const SYNCHRONOUS_PAGING_IO = windows_kernel_sys::base::IRP_SYNCHRONOUS_PAGING_IO; + const CREATE_OPERATION = windows_kernel_sys::base::IRP_CREATE_OPERATION; + const READ_OPERATION = windows_kernel_sys::base::IRP_READ_OPERATION; + const WRITE_OPERATION = windows_kernel_sys::base::IRP_WRITE_OPERATION; + const CLOSE_OPERATION = windows_kernel_sys::base::IRP_CLOSE_OPERATION; + const DEFER_IO_COMPLETION = windows_kernel_sys::base::IRP_DEFER_IO_COMPLETION; + const OB_QUERY_NAME = windows_kernel_sys::base::IRP_OB_QUERY_NAME; + const HOLD_DEVICE_QUEUE = windows_kernel_sys::base::IRP_HOLD_DEVICE_QUEUE; + const UM_DRIVER_INITIATED_IO = windows_kernel_sys::base::IRP_UM_DRIVER_INITIATED_IO; + } +} + +pub struct IoRequest { + irp: *mut IRP, +} + +impl IoRequest { + pub unsafe fn from_raw(irp: *mut IRP) -> Self { + Self { + irp, + } + } + + pub fn irp(&self) -> &IRP { unsafe { &*self.irp } } + + pub fn irp_mut(&self) -> &mut IRP { unsafe { &mut *self.irp } } + + pub fn flags(&self) -> IrpFlags { + IrpFlags::from_bits(self.irp().Flags).unwrap_or(IrpFlags::empty()) + } + + pub fn stack_location(&self) -> &IO_STACK_LOCATION { + unsafe { &*IoGetCurrentIrpStackLocation(self.irp_mut()) } + } + + pub fn major(&self) -> u8 { self.stack_location().MajorFunction } + + pub(crate) fn complete(&self, value: Result<u32, Error>) { + let irp = self.irp_mut(); + + match value { + Ok(value) => { + irp.IoStatus.Information = value as _; + irp.IoStatus.__bindgen_anon_1.Status = STATUS_SUCCESS; + } + Err(error) => { + irp.IoStatus.Information = 0; + irp.IoStatus.__bindgen_anon_1.Status = error.to_ntstatus(); + } + } + + unsafe { + IoCompleteRequest(irp, IO_NO_INCREMENT as _); + } + } +} + +pub struct ReadRequest { + pub(crate) inner: IoRequest, +} + +impl Deref for ReadRequest { + type Target = IoRequest; + + fn deref(&self) -> &Self::Target { &self.inner } +} + +impl ReadRequest { + pub fn user_ptr(&self) -> UserPtr { + let stack_location = self.stack_location(); + let irp = self.irp(); + + let (ptr, size) = if !irp.MdlAddress.is_null() { + let ptr = unsafe { + MmGetSystemAddressForMdlSafe(irp.MdlAddress, MM_PAGE_PRIORITY::HighPagePriority as _) + }; + + let size = unsafe { MmGetMdlByteCount(irp.MdlAddress) } as usize; + + (ptr, size) + } else if !unsafe { irp.AssociatedIrp.SystemBuffer }.is_null() { + let ptr = unsafe { irp.AssociatedIrp.SystemBuffer }; + let size = unsafe { stack_location.Parameters.Read }.Length as usize; + + (ptr, size) + } else { + (core::ptr::null_mut(), 0) + }; + + unsafe { UserPtr::new_buffered(ptr, 0, size) } + } + + pub fn offset(&self) -> i64 { + let stack_location = self.stack_location(); + let irp = self.irp(); + + if !irp.MdlAddress.is_null() { + (unsafe { MmGetMdlByteOffset(irp.MdlAddress) }) as i64 + } else if !unsafe { irp.AssociatedIrp.SystemBuffer }.is_null() { + unsafe { stack_location.Parameters.Read.ByteOffset.QuadPart } + } else { + 0 + } + } +} + +impl Into<IoRequest> for ReadRequest { + fn into(self) -> IoRequest { self.inner } +} + +pub struct WriteRequest { + pub(crate) inner: IoRequest, +} + +impl Deref for WriteRequest { + type Target = IoRequest; + + fn deref(&self) -> &Self::Target { &self.inner } +} + +impl WriteRequest { + pub fn user_ptr(&self) -> UserPtr { + let stack_location = self.stack_location(); + let irp = self.irp(); + + let (ptr, size) = if !irp.MdlAddress.is_null() { + let ptr = unsafe { + MmGetSystemAddressForMdlSafe(irp.MdlAddress, MM_PAGE_PRIORITY::HighPagePriority as _) + }; + + let size = unsafe { MmGetMdlByteCount(irp.MdlAddress) } as usize; + + (ptr, size) + } else if !unsafe { irp.AssociatedIrp.SystemBuffer }.is_null() { + let ptr = unsafe { irp.AssociatedIrp.SystemBuffer }; + let size = unsafe { stack_location.Parameters.Write }.Length as usize; + + (ptr, size) + } else { + (core::ptr::null_mut(), 0) + }; + + unsafe { UserPtr::new_buffered(ptr, size, 0) } + } + + pub fn offset(&self) -> i64 { + let stack_location = self.stack_location(); + let irp = self.irp(); + + if !irp.MdlAddress.is_null() { + (unsafe { MmGetMdlByteOffset(irp.MdlAddress) }) as i64 + } else if !unsafe { irp.AssociatedIrp.SystemBuffer }.is_null() { + unsafe { stack_location.Parameters.Write.ByteOffset.QuadPart } + } else { + 0 + } + } +} + +impl Into<IoRequest> for WriteRequest { + fn into(self) -> IoRequest { self.inner } +} + +pub struct IoControlRequest { + pub(crate) inner: IoRequest, +} + +impl Deref for IoControlRequest { + type Target = IoRequest; + + fn deref(&self) -> &Self::Target { &self.inner } +} + +impl IoControlRequest { + pub fn control_code(&self) -> ControlCode { + let stack_location = self.stack_location(); + + unsafe { + stack_location + .Parameters + .DeviceIoControl + .IoControlCode + .into() + } + } + + pub fn function(&self) -> (RequiredAccess, u32) { + let code = self.control_code(); + + (code.required_access(), code.number()) + } + + pub fn user_ptr(&self) -> UserPtr { + let stack_location = self.stack_location(); + let irp = self.irp(); + + let system_buffer = unsafe { irp.AssociatedIrp.SystemBuffer }; + + let mdl_address = if !irp.MdlAddress.is_null() { + unsafe { + MmGetSystemAddressForMdlSafe(irp.MdlAddress, MM_PAGE_PRIORITY::HighPagePriority as _) + } + } else { + core::ptr::null_mut() + }; + + let input_size = + unsafe { stack_location.Parameters.DeviceIoControl.InputBufferLength } as usize; + let output_size = + unsafe { stack_location.Parameters.DeviceIoControl.OutputBufferLength } as usize; + + match self.control_code().transfer_method() { + TransferMethod::Buffered => unsafe { + UserPtr::new_buffered(system_buffer, input_size, output_size) + }, + TransferMethod::InputDirect => unsafe { + UserPtr::new_direct(mdl_address, system_buffer, output_size, input_size) + }, + TransferMethod::OutputDirect => unsafe { + UserPtr::new_direct(system_buffer, mdl_address, input_size, output_size) + }, + TransferMethod::Neither => unsafe { UserPtr::new_neither() }, + } + } +} + +impl Into<IoRequest> for IoControlRequest { + fn into(self) -> IoRequest { self.inner } +} diff --git a/crates/windows-kernel-rs/src/section.rs b/crates/windows-kernel-rs/src/section.rs new file mode 100644 index 0000000..f80b8b7 --- /dev/null +++ b/crates/windows-kernel-rs/src/section.rs @@ -0,0 +1,163 @@ +use bitflags::bitflags; +use widestring::U16CString; +use windows_kernel_sys::{ + base::{HANDLE, LARGE_INTEGER, OBJECT_ATTRIBUTES}, + ntoskrnl::{ZwClose, ZwMapViewOfSection, ZwOpenSection, ZwUnmapViewOfSection}, +}; + +use crate::{ + error::{Error, IntoResult}, + process::ZwProcess, + string::create_unicode_string, +}; + +bitflags! { + pub struct AllocationFlags: u32 { + const RESERVE = windows_kernel_sys::base::MEM_RESERVE; + const LARGE_PAGES = windows_kernel_sys::base::MEM_LARGE_PAGES; + const TOP_DOWN = windows_kernel_sys::base::MEM_TOP_DOWN; + } +} + +bitflags! { + pub struct ProtectFlags: u32 { + const READ_WRITE = windows_kernel_sys::base::PAGE_READWRITE; + } +} + +bitflags! { + pub struct SectionAccess: u32 { + const EXTEND_SIZE = windows_kernel_sys::base::SECTION_EXTEND_SIZE; + const MAP_EXECUTE = windows_kernel_sys::base::SECTION_MAP_EXECUTE; + const MAP_READ = windows_kernel_sys::base::SECTION_MAP_READ; + const MAP_WRITE = windows_kernel_sys::base::SECTION_MAP_WRITE; + const QUERY = windows_kernel_sys::base::SECTION_QUERY; + const ALL_ACCESS = windows_kernel_sys::base::SECTION_ALL_ACCESS; + } +} + +bitflags! { + pub struct ObjectFlags: u32 { + const CASE_INSENSITIVE = windows_kernel_sys::base::OBJ_CASE_INSENSITIVE; + const KERNEL_HANDLE = windows_kernel_sys::base::OBJ_KERNEL_HANDLE; + } +} + +#[repr(i32)] +pub enum SectionInherit { + ViewShare = windows_kernel_sys::base::_SECTION_INHERIT::ViewShare, + ViewUnmap = windows_kernel_sys::base::_SECTION_INHERIT::ViewUnmap, +} + +pub enum BaseAddress { + Desired(*mut core::ffi::c_void), + ZeroBits(usize), +} + +pub struct Section { + handle: HANDLE, +} + +unsafe impl Send for Section {} +unsafe impl Sync for Section {} + +impl Section { + pub fn open(path: &str, obj_flags: ObjectFlags, access: SectionAccess) -> Result<Self, Error> { + let name = U16CString::from_str(path).unwrap(); + let mut name = create_unicode_string(name.as_slice()); + + let mut attrs = OBJECT_ATTRIBUTES { + Length: core::mem::size_of::<OBJECT_ATTRIBUTES>() as u32, + RootDirectory: core::ptr::null_mut(), + ObjectName: &mut name, + Attributes: obj_flags.bits(), + SecurityDescriptor: core::ptr::null_mut(), + SecurityQualityOfService: core::ptr::null_mut(), + }; + + let mut handle: HANDLE = core::ptr::null_mut(); + + unsafe { ZwOpenSection(&mut handle, access.bits(), &mut attrs) }.into_result()?; + + Ok(Self { + handle, + }) + } + + pub fn map_view( + &mut self, + process: ZwProcess, + base_address: BaseAddress, + commit_size: usize, + offset: Option<u64>, + view_size: usize, + inherit: SectionInherit, + allocation: AllocationFlags, + protection: ProtectFlags, + ) -> Result<SectionView, Error> { + let (mut base_address, zero_bits) = match base_address { + BaseAddress::Desired(ptr) => (ptr, 0), + BaseAddress::ZeroBits(bits) => (core::ptr::null_mut(), bits), + }; + + let mut offset = offset.map(|value| { + let mut offset: LARGE_INTEGER = unsafe { core::mem::zeroed() }; + offset.QuadPart = value as _; + offset + }); + + let mut size: u64 = view_size as _; + + unsafe { + ZwMapViewOfSection( + self.handle, + process.handle, + &mut base_address, + zero_bits as _, + commit_size as _, + match offset { + Some(ref mut offset) => offset as _, + _ => core::ptr::null_mut(), + }, + &mut size, + inherit as _, + allocation.bits(), + protection.bits(), + ) + } + .into_result()?; + + Ok(SectionView { + process, + address: base_address, + }) + } +} + +impl Drop for Section { + fn drop(&mut self) { + unsafe { + ZwClose(self.handle); + } + } +} + +pub struct SectionView { + process: ZwProcess, + address: *mut core::ffi::c_void, +} + +unsafe impl Send for SectionView {} +unsafe impl Sync for SectionView {} + +impl SectionView { + pub fn address(&self) -> *mut core::ffi::c_void { self.address } +} + +impl Drop for SectionView { + fn drop(&mut self) { + unsafe { + ZwUnmapViewOfSection(self.process.handle, self.address); + } + } +} diff --git a/crates/windows-kernel-rs/src/string.rs b/crates/windows-kernel-rs/src/string.rs new file mode 100644 index 0000000..71ad169 --- /dev/null +++ b/crates/windows-kernel-rs/src/string.rs @@ -0,0 +1,17 @@ +use windows_kernel_sys::base::UNICODE_STRING; + +pub fn create_unicode_string(s: &[u16]) -> UNICODE_STRING { + let len = s.len(); + + let n = if len > 0 && s[len - 1] == 0 { + len - 1 + } else { + len + }; + + UNICODE_STRING { + Length: (n * 2) as u16, + MaximumLength: (len * 2) as u16, + Buffer: s.as_ptr() as _, + } +} diff --git a/crates/windows-kernel-rs/src/symbolic_link.rs b/crates/windows-kernel-rs/src/symbolic_link.rs new file mode 100644 index 0000000..8c30f6b --- /dev/null +++ b/crates/windows-kernel-rs/src/symbolic_link.rs @@ -0,0 +1,39 @@ +use widestring::U16CString; +use windows_kernel_sys::ntoskrnl::{IoCreateSymbolicLink, IoDeleteSymbolicLink}; + +use crate::{ + error::{Error, IntoResult}, + string::create_unicode_string, +}; + +pub struct SymbolicLink { + name: U16CString, +} + +impl SymbolicLink { + pub fn new(name: &str, target: &str) -> Result<Self, Error> { + // Convert the name to UTF-16 and then create a UNICODE_STRING. + let name = U16CString::from_str(name).unwrap(); + let mut name_ptr = create_unicode_string(name.as_slice()); + + // Convert the target to UTF-16 and then create a UNICODE_STRING. + let target = U16CString::from_str(target).unwrap(); + let mut target_ptr = create_unicode_string(target.as_slice()); + + unsafe { IoCreateSymbolicLink(&mut name_ptr, &mut target_ptr) }.into_result()?; + + Ok(Self { + name, + }) + } +} + +impl Drop for SymbolicLink { + fn drop(&mut self) { + let mut name_ptr = create_unicode_string(self.name.as_slice()); + + unsafe { + IoDeleteSymbolicLink(&mut name_ptr); + } + } +} diff --git a/crates/windows-kernel-rs/src/sync/fast_mutex.rs b/crates/windows-kernel-rs/src/sync/fast_mutex.rs new file mode 100644 index 0000000..9a82524 --- /dev/null +++ b/crates/windows-kernel-rs/src/sync/fast_mutex.rs @@ -0,0 +1,137 @@ +use alloc::boxed::Box; +use core::{ + cell::UnsafeCell, + ops::{Deref, DerefMut}, +}; + +use windows_kernel_sys::{ + base::FAST_MUTEX, + ntoskrnl::{ + ExAcquireFastMutex, + ExInitializeFastMutex, + ExReleaseFastMutex, + ExTryToAcquireFastMutex, + }, +}; + +/// A mutual exclusion primitive useful for protecting shared data. +/// +/// This mutex will block threads waiting for the lock to become available. The +/// mutex can also be statically initialized or created via a [`new`] +/// constructor. Each mutex has a type parameter which represents the data that +/// it is protecting. The data can only be accessed through the RAII +/// guards returned from [`lock`] and [`try_lock`], which guarantees that the +/// data is only ever accessed when the mutex is locked. +/// +/// [`new`]: FastMutex::new +/// [`lock`]: FastMutex::lock +/// [`try_lock`]: FastMutex::try_lock +pub struct FastMutex<T: ?Sized> { + pub(crate) lock: Box<FAST_MUTEX>, + pub(crate) data: UnsafeCell<T>, +} + +unsafe impl<T> Send for FastMutex<T> {} +unsafe impl<T> Sync for FastMutex<T> {} + +impl<T> FastMutex<T> { + /// Creates a new mutex in an unlocked state ready for use. + pub fn new(data: T) -> Self { + let mut lock: Box<FAST_MUTEX> = Box::new(unsafe { core::mem::zeroed() }); + + unsafe { ExInitializeFastMutex(&mut *lock) }; + + Self { + lock, + data: UnsafeCell::new(data), + } + } + + /// Consumes this `FastMutex`, returning the underlying data. + #[inline] + pub fn into_inner(self) -> T { + let Self { + data, .. + } = self; + data.into_inner() + } + + /// Attempts to acquire this lock. + /// + /// If the lock could not be acquired at this time, then `None` is returned. + /// Otherwise, an RAII guard is returned. The lock will be unlocked when the + /// guard is dropped. + /// + /// This function does not block. + #[inline] + pub fn try_lock(&mut self) -> Option<FastMutexGuard<T>> { + let status = unsafe { ExTryToAcquireFastMutex(&mut *self.lock) } != 0; + + match status { + true => + Some(FastMutexGuard { + lock: &mut self.lock, + data: unsafe { &mut *self.data.get() }, + }), + _ => None, + } + } + + /// Acquires a mutex, blocking the current thread until it is able to do so. + /// + /// This function will block the local thread until it is available to acquire + /// the mutex. Upon returning, the thread is the only thread with the lock + /// held. An RAII guard is returned to allow scoped unlock of the lock. When + /// the guard goes out of scope, the mutex will be unlocked. + /// + /// The underlying function does not allow for recursion. If the thread + /// already holds the lock and tries to lock the mutex again, this function + /// will return `None` instead. + #[inline] + pub fn lock(&mut self) -> Option<FastMutexGuard<T>> { + unsafe { ExAcquireFastMutex(&mut *self.lock) }; + + Some(FastMutexGuard { + lock: &mut self.lock, + data: unsafe { &mut *self.data.get() }, + }) + } +} + +impl<T: ?Sized + Default> Default for FastMutex<T> { + fn default() -> Self { Self::new(T::default()) } +} + +impl<T> From<T> for FastMutex<T> { + fn from(data: T) -> Self { Self::new(data) } +} + +/// An RAII implementation of a "scoped lock" of a mutex. When this structure is +/// dropped (falls out of scope), the lock will be unlocked. +/// +/// The data protected by the mutex can be accessed through this guard via its +/// [`Deref`] and [`DerefMut`] implementations. +/// +/// This structure is created by the [`lock`] and [`try_lock`] methods on +/// [`FastMutex`]. +/// +/// [`lock`]: FastMutex::lock +/// [`try_lock`]: FastMutex::try_lock +pub struct FastMutexGuard<'a, T: 'a + ?Sized> { + pub(crate) lock: &'a mut FAST_MUTEX, + pub(crate) data: &'a mut T, +} + +impl<'a, T: ?Sized> Drop for FastMutexGuard<'a, T> { + fn drop(&mut self) { unsafe { ExReleaseFastMutex(&mut *self.lock) }; } +} + +impl<'a, T: ?Sized> Deref for FastMutexGuard<'a, T> { + type Target = T; + + fn deref(&self) -> &T { self.data } +} + +impl<'a, T: ?Sized> DerefMut for FastMutexGuard<'a, T> { + fn deref_mut(&mut self) -> &mut T { self.data } +} diff --git a/crates/windows-kernel-rs/src/sync/mod.rs b/crates/windows-kernel-rs/src/sync/mod.rs new file mode 100644 index 0000000..b024c55 --- /dev/null +++ b/crates/windows-kernel-rs/src/sync/mod.rs @@ -0,0 +1,4 @@ +pub mod fast_mutex; +pub mod push_lock; + +pub use self::{fast_mutex::FastMutex as Mutex, push_lock::PushLock as RwLock}; diff --git a/crates/windows-kernel-rs/src/sync/push_lock.rs b/crates/windows-kernel-rs/src/sync/push_lock.rs new file mode 100644 index 0000000..71b016a --- /dev/null +++ b/crates/windows-kernel-rs/src/sync/push_lock.rs @@ -0,0 +1,185 @@ +use alloc::boxed::Box; +use core::{ + cell::UnsafeCell, + ops::{Deref, DerefMut}, +}; + +use windows_kernel_sys::{ + base::EX_PUSH_LOCK, + ntoskrnl::{ + ExAcquirePushLockExclusive, + ExAcquirePushLockShared, + ExInitializePushLock, + ExReleasePushLockExclusive, + ExReleasePushLockShared, + KeEnterCriticalRegion, + KeLeaveCriticalRegion, + }, +}; + +/// A [`PushLock`] is an efficient implementation of a reader-writer lock that +/// can be stored both in paged and non-paged memory. +/// +/// This type of lock allows a number of readers or at most one writer at any +/// point in time. The write portion of this lock typically allows modifications +/// of the underlying data (exclusive access) and the read portion of this lock +/// typically allows for read-only access (shared access). +/// +/// In comparison, a [`FastMutex`] does not distinguish between readers or +/// writers that acquire the lock, therefore blocking any threads waiting for +/// the lock to become available. A [`PushLock`] will allow any number of +/// readers to acquire the lock as long as a writer is not holding the lock. +/// +/// The priority policy is such that a thread trying to acquire the [`PushLock`] +/// for exclusive access will be prioritized over threads trying to acquire the +/// [`PushLock`] for shared access. More specifically, if a thread cannot lock +/// the [`PushLock`] for exclusive access immediately, it will wait for the +/// thread(s) that currently holds the lock to release the lock. If another +/// thread tries to acquire the [`PushLock`] for shared access while a thread is +/// waiting to acquire the lock for exclusive access, it will yield to the +/// thread(s) trying to acquire the [`PushLock`] for exclusive access, even in +/// the event that the [`PushLock`] is acquired for shared access. +/// +/// [`FastMutex`]: crate::fast_mutex::FastMutex +pub struct PushLock<T: ?Sized> { + pub(crate) lock: Box<EX_PUSH_LOCK>, + pub(crate) data: UnsafeCell<T>, +} + +unsafe impl<T> Send for PushLock<T> {} +unsafe impl<T> Sync for PushLock<T> {} + +impl<T> PushLock<T> { + /// Creates new instance of [`PushLock<T>`] that is unlocked. + pub fn new(data: T) -> Self { + let mut lock: Box<EX_PUSH_LOCK> = Box::new(0); + + unsafe { ExInitializePushLock(&mut *lock) }; + + Self { + lock, + data: UnsafeCell::new(data), + } + } + + /// Consumes this [`PushLock`], returning the underlying data. + #[inline] + pub fn into_inner(self) -> T { + let Self { + data, .. + } = self; + data.into_inner() + } + + /// Locks this [`PushLock`] with shared read access, blocking the current + /// thread until it can be acquired. + /// + /// The calling thread will be blocked until there are no more writers which + /// hold the lock. There may be other readers currently inside the lock when + /// this method returns. + /// + /// This function will yield to threads waiting to acquire the [`PushLock`] + /// for exclusive access, even in the event that the [`PushLock`] is + /// currently held by one or more threads for shared access. + /// + /// While the underlying function does allow for recursion, this atomically + /// increments a shared reader counter. Since dropping the RAII guard + /// releases the lock by atomically decrementing this shared counter, it + /// will eventually reach zero once all RAII guards have been dropped. + #[inline] + pub fn read(&mut self) -> Option<PushLockReadGuard<T>> { + unsafe { KeEnterCriticalRegion() }; + + unsafe { ExAcquirePushLockShared(&mut *self.lock) }; + + Some(PushLockReadGuard { + lock: &mut self.lock, + data: unsafe { &mut *self.data.get() }, + }) + } + + /// Locks this [`PushLock`] with exclusive write access, blocking the current + /// thread until it can be acquired. + /// + /// This function will not return while other writers or other readers + /// currently have access to the lock. + /// + /// Returns an RAII guard which will drop the write access of this + /// [`PushLock`] when dropped. + /// + /// This thread will take priority over any threads that are trying to acquire + /// the lock for shared access but that do not currently hold the lock for + /// shared access. + /// + /// The underlying function does not allow for recursion, which ensures + /// correct behavior. + #[inline] + pub fn write(&mut self) -> Option<PushLockWriteGuard<T>> { + unsafe { KeEnterCriticalRegion() }; + + unsafe { ExAcquirePushLockExclusive(&mut *self.lock) }; + + Some(PushLockWriteGuard { + lock: &mut self.lock, + data: unsafe { &mut *self.data.get() }, + }) + } +} + +/// RAII structure used to release the shared read access of a lock when +/// dropped. +/// +/// This structure is created by the [`read`] and [`try_read`] methods on +/// [`PushLock`] +/// +/// [`read`]: PushLock::read +/// [`try_read`]: PushLock::try_read +pub struct PushLockReadGuard<'a, T: 'a + ?Sized> { + pub(crate) lock: &'a mut EX_PUSH_LOCK, + pub(crate) data: &'a T, +} + +impl<'a, T: ?Sized> Drop for PushLockReadGuard<'a, T> { + fn drop(&mut self) { + unsafe { ExReleasePushLockShared(&mut *self.lock) }; + + unsafe { KeLeaveCriticalRegion() }; + } +} + +impl<'a, T: ?Sized> Deref for PushLockReadGuard<'a, T> { + type Target = T; + + fn deref(&self) -> &T { self.data } +} + +/// RAII structure used to release the exclusive write access of a lock when +/// dropped. +/// +/// This structure is created by the [`write`] and [`try_write`] methods on +/// [`PushLock`] +/// +/// [`write`]: PushLock::write +/// [`try_write`]: PushLock::try_write +pub struct PushLockWriteGuard<'a, T: 'a + ?Sized> { + pub(crate) lock: &'a mut EX_PUSH_LOCK, + pub(crate) data: &'a mut T, +} + +impl<'a, T: ?Sized> Drop for PushLockWriteGuard<'a, T> { + fn drop(&mut self) { + unsafe { ExReleasePushLockExclusive(&mut *self.lock) }; + + unsafe { KeLeaveCriticalRegion() }; + } +} + +impl<'a, T: ?Sized> Deref for PushLockWriteGuard<'a, T> { + type Target = T; + + fn deref(&self) -> &T { self.data } +} + +impl<'a, T: ?Sized> DerefMut for PushLockWriteGuard<'a, T> { + fn deref_mut(&mut self) -> &mut T { self.data } +} diff --git a/crates/windows-kernel-rs/src/user_ptr.rs b/crates/windows-kernel-rs/src/user_ptr.rs new file mode 100644 index 0000000..54dcb58 --- /dev/null +++ b/crates/windows-kernel-rs/src/user_ptr.rs @@ -0,0 +1,172 @@ +use crate::error::Error; + +pub enum UserPtr { + Buffered { + ptr: *mut cty::c_void, + read_size: usize, + write_size: usize, + }, + Direct { + read_ptr: *const cty::c_void, + write_ptr: *mut cty::c_void, + read_size: usize, + write_size: usize, + }, + Neither, +} + +impl UserPtr { + pub unsafe fn new_buffered(ptr: *mut cty::c_void, read_size: usize, write_size: usize) -> Self { + Self::Buffered { + ptr, + read_size, + write_size, + } + } + + pub unsafe fn new_direct( + read_ptr: *const cty::c_void, + write_ptr: *mut cty::c_void, + read_size: usize, + write_size: usize, + ) -> Self { + Self::Direct { + read_ptr, + write_ptr, + read_size, + write_size, + } + } + + pub unsafe fn new_neither() -> Self { Self::Neither } + + pub fn read_size(&self) -> usize { + match self { + Self::Buffered { + read_size, .. + } => *read_size, + Self::Direct { + read_size, .. + } => *read_size, + Self::Neither => 0, + } + } + + pub fn write_size(&self) -> usize { + match self { + Self::Buffered { + write_size, .. + } => *write_size, + Self::Direct { + write_size, .. + } => *write_size, + Self::Neither => 0, + } + } + + pub fn as_slice(&self) -> &[u8] { + let (ptr, size) = match self { + Self::Buffered { + ptr, + read_size, + .. + } => (*ptr as _, *read_size), + Self::Direct { + read_ptr, + read_size, + .. + } => (*read_ptr, *read_size), + Self::Neither => (core::ptr::null(), 0), + }; + + if ptr.is_null() || size == 0 { + &[] + } else { + unsafe { core::slice::from_raw_parts(ptr as *const u8, size) } + } + } + + pub fn as_mut_slice(&mut self) -> &mut [u8] { + let (ptr, size) = match self { + Self::Buffered { + ptr, + write_size, + .. + } => (*ptr, *write_size), + Self::Direct { + write_ptr, + write_size, + .. + } => (*write_ptr, *write_size), + Self::Neither => (core::ptr::null_mut(), 0), + }; + + if ptr.is_null() || size == 0 { + &mut [] + } else { + unsafe { core::slice::from_raw_parts_mut(ptr as *mut u8, size) } + } + } + + pub fn read<T: Copy + Default>(&self) -> Result<T, Error> { + let (ptr, size) = match self { + Self::Buffered { + ptr, + read_size, + .. + } => (*ptr as _, *read_size), + Self::Direct { + read_ptr, + read_size, + .. + } => (*read_ptr, *read_size), + Self::Neither => (core::ptr::null(), 0), + }; + + if ptr.is_null() || size == 0 { + return Err(Error::INVALID_PARAMETER); + } + + if core::mem::size_of::<T>() > size { + return Err(Error::INVALID_USER_BUFFER); + } + + let mut obj = T::default(); + + unsafe { + core::ptr::copy_nonoverlapping(ptr as _, &mut obj, 1); + } + + Ok(obj) + } + + pub fn write<T: Copy>(&mut self, obj: &T) -> Result<(), Error> { + let (ptr, size) = match self { + Self::Buffered { + ptr, + write_size, + .. + } => (*ptr, *write_size), + Self::Direct { + write_ptr, + write_size, + .. + } => (*write_ptr, *write_size), + Self::Neither => (core::ptr::null_mut(), 0), + }; + + if ptr.is_null() || size == 0 { + return Err(Error::INVALID_PARAMETER); + } + + if core::mem::size_of::<T>() > size { + return Err(Error::INVALID_USER_BUFFER); + } + + unsafe { + core::ptr::copy_nonoverlapping(obj, ptr as _, 1); + } + + Ok(()) + } +} diff --git a/crates/windows-kernel-rs/src/version.rs b/crates/windows-kernel-rs/src/version.rs new file mode 100644 index 0000000..3f46a87 --- /dev/null +++ b/crates/windows-kernel-rs/src/version.rs @@ -0,0 +1,35 @@ +//! This module provides utilities to query information about the version of +//! Microsoft Windows. + +use windows_kernel_sys::{base::RTL_OSVERSIONINFOW, ntoskrnl::RtlGetVersion}; + +use crate::error::{Error, IntoResult}; + +/// Represents version information for Microsoft Windows. +pub struct VersionInfo { + version_info: RTL_OSVERSIONINFOW, +} + +impl VersionInfo { + /// Uses [`RtlGetVersion`] to query the version info for Microsoft Windows. + pub fn query() -> Result<Self, Error> { + let mut version_info: RTL_OSVERSIONINFOW = unsafe { core::mem::zeroed() }; + + version_info.dwOSVersionInfoSize = core::mem::size_of::<RTL_OSVERSIONINFOW>() as u32; + + unsafe { RtlGetVersion(&mut version_info) }.into_result()?; + + Ok(Self { + version_info, + }) + } + + /// Retrieves the major version of Microsoft Windows. + pub fn major(&self) -> u32 { self.version_info.dwMajorVersion } + + /// Retrieves the minor version of Microsoft Windows. + pub fn minor(&self) -> u32 { self.version_info.dwMinorVersion } + + /// Retrieves the build number of Microsoft Windows. + pub fn build_number(&self) -> u32 { self.version_info.dwBuildNumber } +} |