summaryrefslogtreecommitdiff
path: root/crates/windows-kernel-rs/src/driver.rs
blob: 484c012ccb676789cbbfafbb897f3900332cdf22 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
use alloc::boxed::Box;

use widestring::U16CString;
use windows_kernel_sys::{base::DRIVER_OBJECT, ntoskrnl::IoCreateDevice};

use crate::{
  device::{
    Access,
    Device,
    DeviceDoFlags,
    DeviceExtension,
    DeviceFlags,
    DeviceOperations,
    DeviceOperationsVtable,
    DeviceType,
  },
  error::{Error, IntoResult},
  string::create_unicode_string,
};

pub struct Driver {
  pub(crate) raw: *mut DRIVER_OBJECT,
}

impl Driver {
  /// # Safety
  /// `unsafe`
  pub unsafe fn from_raw(raw: *mut DRIVER_OBJECT) -> Self {
    Self {
      raw,
    }
  }

  /// # Safety
  /// `unsafe`
  pub unsafe fn as_raw(&self) -> *const DRIVER_OBJECT { self.raw as _ }

  /// # Safety
  /// `unsafe`
  pub unsafe fn as_raw_mut(&mut self) -> *mut DRIVER_OBJECT { self.raw as _ }

  pub fn create_device<T>(
    &mut self,
    name: &str,
    device_type: DeviceType,
    device_flags: DeviceFlags,
    device_do_flags: DeviceDoFlags,
    access: Access,
    data: T,
  ) -> Result<Device, Error>
  where
    T: DeviceOperations,
  {
    // Box the data.
    let data = Box::new(data);

    // Convert the name to UTF-16 and then create a UNICODE_STRING.
    let name = U16CString::from_str(name).unwrap();
    let mut name = create_unicode_string(name.as_slice());

    // Create the device.
    let mut device = core::ptr::null_mut();

    unsafe {
      IoCreateDevice(
        self.raw,
        core::mem::size_of::<DeviceExtension>() as u32,
        &mut name,
        device_type.into(),
        device_flags.bits(),
        access.is_exclusive() as _,
        &mut device,
      )
    }
    .into_result()?;

    unsafe {
      (*device).Flags |= device_do_flags.bits();
    }

    let device = unsafe { Device::from_raw(device) };

    // Store the boxed data and vtable.
    let extension = device.extension_mut();
    extension.device_type = device_type;
    extension.vtable = &DeviceOperationsVtable::<T>::VTABLE;
    extension.data = Box::into_raw(data) as *mut cty::c_void;

    Ok(device)
  }
}