diff options
| author | a1xd <[email protected]> | 2021-09-24 02:04:43 -0400 |
|---|---|---|
| committer | GitHub <[email protected]> | 2021-09-24 02:04:43 -0400 |
| commit | 2896b8a09ce42e965705c58593b8738adc454f7f (patch) | |
| tree | 71e4d0cff60b5a1ad11427d78e1f8c7b775e5690 /driver/driver.cpp | |
| parent | Merge pull request #107 from a1xd/1.5.0-fix (diff) | |
| parent | make note clearer (diff) | |
| download | rawaccel-1.6.0.tar.xz rawaccel-1.6.0.zip | |
v1.6
Diffstat (limited to 'driver/driver.cpp')
| -rw-r--r-- | driver/driver.cpp | 340 |
1 files changed, 282 insertions, 58 deletions
diff --git a/driver/driver.cpp b/driver/driver.cpp index feace77..febf64f 100644 --- a/driver/driver.cpp +++ b/driver/driver.cpp @@ -7,18 +7,26 @@ #ifdef ALLOC_PRAGMA #pragma alloc_text (INIT, DriverEntry) +#pragma alloc_text (INIT, RawaccelInit) +#pragma alloc_text (INIT, CreateControlDevice) #pragma alloc_text (PAGE, EvtDeviceAdd) #pragma alloc_text (PAGE, EvtIoInternalDeviceControl) #pragma alloc_text (PAGE, RawaccelControl) +#pragma alloc_text (PAGE, DeviceCleanup) +#pragma alloc_text (PAGE, DeviceSetup) +#pragma alloc_text (PAGE, WriteDelay) #endif using milliseconds = double; struct { - ra::settings args; + bool initialized; + WDFCOLLECTION device_collection; + WDFWAITLOCK collection_lock; + ra::io_base base_data; + ra::modifier_settings* modifier_data; + ra::device_settings* device_data; milliseconds tick_interval; - vec2<ra::accel_invoker> invokers; - ra::mouse_modifier modifier; } global = {}; extern "C" PULONG InitSafeBootMode; @@ -58,26 +66,39 @@ Arguments: auto num_packets = InputDataEnd - InputDataStart; if (num_packets > 0 && - !(InputDataStart->Flags & MOUSE_MOVE_ABSOLUTE) && - (global.args.device_id[0] == 0 || - bool(wcsncmp(devExt->dev_id, global.args.device_id, ra::MAX_DEV_ID_LEN)) == - global.args.ignore)) { - counter_t now = KeQueryPerformanceCounter(NULL).QuadPart; - counter_t ticks = now - devExt->counter; - devExt->counter = now; - milliseconds raw_elapsed = ticks * global.tick_interval; - milliseconds time = ra::clampsd(raw_elapsed / num_packets, - global.args.time_min, - global.args.time_max); + !(InputDataStart->Flags & MOUSE_MOVE_ABSOLUTE) && + devExt->enable) { + + milliseconds time; + if (devExt->keep_time) { + counter_t now = KeQueryPerformanceCounter(NULL).QuadPart; + counter_t ticks = now - devExt->counter; + devExt->counter = now; + milliseconds raw = ticks * global.tick_interval / num_packets; + time = ra::clampsd(raw, devExt->clamp.min, devExt->clamp.max); + } + else { + time = devExt->clamp.min; + } + auto it = InputDataStart; do { + if (devExt->set_extra_info) { + union { + short input[2]; + ULONG data; + } u = { short(it->LastX), short(it->LastY) }; + + it->ExtraInformation = u.data; + } + if (it->LastX || it->LastY) { vec2d input = { static_cast<double>(it->LastX), static_cast<double>(it->LastY) }; - global.modifier.modify(input, global.invokers, time); + devExt->mod.modify(input, devExt->mod_settings, devExt->dpi_factor, time); double carried_result_x = input.x + devExt->carry.x; double carried_result_y = input.y + devExt->carry.y; @@ -89,8 +110,8 @@ Arguments: double carry_y = carried_result_y - out_y; if (!ra::infnan(carry_x + carry_y)) { - devExt->carry.x = carried_result_x - out_x; - devExt->carry.y = carried_result_y - out_y; + devExt->carry.x = carry_x; + devExt->carry.y = carry_y; it->LastX = out_x; it->LastY = out_y; } @@ -139,7 +160,7 @@ Return Value: { NTSTATUS status; void* buffer; - + size_t buffer_length; size_t bytes_out = 0; UNREFERENCED_PARAMETER(Queue); @@ -149,49 +170,126 @@ Return Value: DebugPrint(("Ioctl received into filter control object.\n")); + if (!global.initialized) { + WdfRequestCompleteWithInformation(Request, STATUS_CANCELLED, 0); + return; + } + + const auto SIZEOF_BASE = sizeof(ra::io_base); + switch (IoControlCode) { - case RA_READ: + case ra::READ: status = WdfRequestRetrieveOutputBuffer( Request, - sizeof(ra::io_t), + SIZEOF_BASE, &buffer, - NULL + &buffer_length ); if (!NT_SUCCESS(status)) { DebugPrint(("RetrieveOutputBuffer failed: 0x%x\n", status)); } else { - ra::io_t& output = *reinterpret_cast<ra::io_t*>(buffer); + *static_cast<ra::io_base*>(buffer) = global.base_data; + + size_t modifier_bytes = global.base_data.modifier_data_size * sizeof(ra::modifier_settings); + size_t device_bytes = global.base_data.device_data_size * sizeof(ra::device_settings); + size_t total_bytes = SIZEOF_BASE + modifier_bytes + device_bytes; - output.args = global.args; - output.mod = global.modifier; + if (buffer_length < total_bytes) { + bytes_out = SIZEOF_BASE; + } + else { + BYTE* output_ptr = static_cast<BYTE*>(buffer) + SIZEOF_BASE; - bytes_out = sizeof(ra::io_t); + if (global.modifier_data) RtlCopyMemory(output_ptr, global.modifier_data, modifier_bytes); + output_ptr += modifier_bytes; + if (global.device_data) RtlCopyMemory(output_ptr, global.device_data, device_bytes); + bytes_out = total_bytes; + } } break; - case RA_WRITE: + case ra::WRITE: status = WdfRequestRetrieveInputBuffer( Request, - sizeof(ra::io_t), + SIZEOF_BASE, &buffer, - NULL + &buffer_length ); if (!NT_SUCCESS(status)) { DebugPrint(("RetrieveInputBuffer failed: 0x%x\n", status)); } else { - LARGE_INTEGER interval; - interval.QuadPart = static_cast<LONGLONG>(ra::WRITE_DELAY) * -10000; - KeDelayExecutionThread(KernelMode, FALSE, &interval); + WriteDelay(); + + ra::io_base& input = *static_cast<ra::io_base*>(buffer); + + auto modifier_bytes = size_t(input.modifier_data_size) * sizeof(ra::modifier_settings); + auto device_bytes = size_t(input.device_data_size) * sizeof(ra::device_settings); + auto alloc_size = modifier_bytes + device_bytes; + auto total_size = alloc_size + SIZEOF_BASE; + + auto max_u32 = unsigned(-1); + if (modifier_bytes > max_u32 || device_bytes > max_u32 || total_size > max_u32) { + status = STATUS_CANCELLED; + break; + } + + if (input.modifier_data_size == 0) { + // clear data and disable all devices + WdfWaitLockAcquire(global.collection_lock, NULL); + + global.base_data = {}; - ra::io_t& input = *reinterpret_cast<ra::io_t*>(buffer); + if (global.modifier_data) { + ExFreePoolWithTag(global.modifier_data, 'g'); + global.modifier_data = NULL; + global.device_data = NULL; + } + + auto count = WdfCollectionGetCount(global.device_collection); + + for (auto i = 0u; i < count; i++) { + DeviceSetup(WdfCollectionGetItem(global.device_collection, i)); + } + + WdfWaitLockRelease(global.collection_lock); + } + else if (buffer_length == total_size) { + void* pool = ExAllocatePoolWithTag(PagedPool, alloc_size, 'g'); + if (!pool) { + DebugPrint(("ExAllocatePoolWithTag (PagedPool) failed")); + status = STATUS_UNSUCCESSFUL; + break; + } + RtlCopyMemory(pool, static_cast<BYTE*>(buffer) + SIZEOF_BASE, alloc_size); - global.args = input.args; - global.invokers = ra::invokers(input.args); - global.modifier = input.mod; + WdfWaitLockAcquire(global.collection_lock, NULL); + + if (global.modifier_data) { + ExFreePoolWithTag(global.modifier_data, 'g'); + } + + void* dev_data = static_cast<BYTE*>(pool) + modifier_bytes; + global.device_data = input.device_data_size > 0 ? + static_cast<ra::device_settings*>(dev_data) : + NULL; + global.modifier_data = static_cast<ra::modifier_settings*>(pool); + global.base_data = input; + + auto count = WdfCollectionGetCount(global.device_collection); + + for (auto i = 0u; i < count; i++) { + DeviceSetup(WdfCollectionGetItem(global.device_collection, i)); + } + + WdfWaitLockRelease(global.collection_lock); + } + else { + status = STATUS_CANCELLED; + } } break; - case RA_GET_VERSION: + case ra::GET_VERSION: status = WdfRequestRetrieveOutputBuffer( Request, sizeof(ra::version_t), @@ -202,7 +300,7 @@ Return Value: DebugPrint(("RetrieveOutputBuffer failed: 0x%x\n", status)); } else { - *reinterpret_cast<ra::version_t*>(buffer) = ra::version; + *static_cast<ra::version_t*>(buffer) = ra::version; bytes_out = sizeof(ra::version_t); } break; @@ -216,6 +314,120 @@ Return Value: } #pragma warning(pop) // enable 28118 again +VOID +RawaccelInit(WDFDRIVER driver) +{ + NTSTATUS status; + + status = CreateControlDevice(driver); + + if (!NT_SUCCESS(status)) { + DebugPrint(("CreateControlDevice failed with status 0x%x\n", status)); + return; + } + + status = WdfCollectionCreate( + WDF_NO_OBJECT_ATTRIBUTES, + &global.device_collection + ); + + if (!NT_SUCCESS(status)) { + DebugPrint(("WdfCollectionCreate failed with status 0x%x\n", status)); + return; + } + + status = WdfWaitLockCreate( + WDF_NO_OBJECT_ATTRIBUTES, + &global.collection_lock + ); + + if (!NT_SUCCESS(status)) { + DebugPrint(("WdfWaitLockCreate failed with status 0x%x\n", status)); + return; + } + + LARGE_INTEGER freq; + KeQueryPerformanceCounter(&freq); + global.tick_interval = 1e3 / freq.QuadPart; + + global.initialized = true; +} + +VOID +DeviceSetup(WDFOBJECT hDevice) +{ + auto* devExt = FilterGetData(hDevice); + + auto set_ext_from_cfg = [devExt](const ra::device_config& cfg) { + devExt->enable = !cfg.disable; + devExt->set_extra_info = cfg.set_extra_info; + devExt->keep_time = cfg.polling_rate <= 0; + devExt->dpi_factor = (cfg.dpi > 0) ? (1000.0 / cfg.dpi) : 1; + + if (devExt->keep_time) { + devExt->clamp = cfg.clamp; + } + else { + milliseconds interval = 1000.0 / cfg.polling_rate; + devExt->clamp = { interval, interval }; + } + }; + + auto set_mod_if_found = [devExt](auto* prof_name) { + for (auto i = 0u; i < global.base_data.modifier_data_size; i++) { + auto& profile = global.modifier_data[i].prof; + + if (wcsncmp(prof_name, profile.name, ra::MAX_NAME_LEN) == 0) { + devExt->mod_settings = global.modifier_data[i]; + devExt->mod = { devExt->mod_settings }; + return; + } + } + }; + + if (!global.modifier_data) { + devExt->enable = false; + devExt->mod = {}; + return; + } + + set_ext_from_cfg(global.base_data.default_dev_cfg); + devExt->mod_settings = *global.modifier_data; + devExt->mod = { devExt->mod_settings }; + + for (auto i = 0u; i < global.base_data.device_data_size; i++) { + auto& dev_settings = global.device_data[i]; + + if (wcsncmp(devExt->dev_id, dev_settings.id, ra::MAX_DEV_ID_LEN) == 0) { + set_ext_from_cfg(dev_settings.config); + + if (dev_settings.profile[0] != L'\0') { + set_mod_if_found(dev_settings.profile); + } + + break; + } + } +} + +VOID +DeviceCleanup(WDFOBJECT hDevice) +{ + PAGED_CODE(); + DebugPrint(("Removing device from collection\n")); + + WdfWaitLockAcquire(global.collection_lock, NULL); + WdfCollectionRemove(global.device_collection, hDevice); + WdfWaitLockRelease(global.collection_lock); +} + +VOID +WriteDelay() +{ + LARGE_INTEGER interval; + interval.QuadPart = static_cast<LONGLONG>(ra::WRITE_DELAY) * -10000; + KeDelayExecutionThread(KernelMode, FALSE, &interval); +} NTSTATUS DriverEntry( @@ -259,14 +471,10 @@ Routine Description: WDF_NO_OBJECT_ATTRIBUTES, &config, &driver); - - if (NT_SUCCESS(status)) { - LARGE_INTEGER freq; - KeQueryPerformanceCounter(&freq); - global.tick_interval = 1e3 / freq.QuadPart; - CreateControlDevice(driver); + if (NT_SUCCESS(status)) { + if (*InitSafeBootMode == 0) RawaccelInit(driver); } else { DebugPrint(("WdfDriverCreate failed with status 0x%x\n", status)); @@ -276,8 +484,7 @@ Routine Description: } -inline -VOID +NTSTATUS CreateControlDevice(WDFDRIVER Driver) /*++ Routine Description: @@ -378,7 +585,7 @@ Return Value: // WdfControlFinishInitializing(controlDevice); - return; + return STATUS_SUCCESS; Error: @@ -393,8 +600,9 @@ Error: } DebugPrint(("CreateControlDevice failed with status code 0x%x\n", status)); -} + return status; +} NTSTATUS EvtDeviceAdd( @@ -437,7 +645,7 @@ Return Value: DebugPrint(("Enter FilterEvtDeviceAdd \n")); - if (*InitSafeBootMode > 0) { + if (!global.initialized) { return STATUS_SUCCESS; } @@ -453,6 +661,7 @@ Return Value: WDF_OBJECT_ATTRIBUTES_INIT_CONTEXT_TYPE(&deviceAttributes, DEVICE_EXTENSION); + deviceAttributes.EvtCleanupCallback = DeviceCleanup; // // Create a framework device object. This call will in turn create @@ -469,7 +678,7 @@ Return Value: // get device id from bus driver // DEVICE_OBJECT* pdo = WdfDeviceWdmGetPhysicalDevice(hDevice); - + KEVENT ke; KeInitializeEvent(&ke, NotificationEvent, FALSE); IO_STATUS_BLOCK iosb = {}; @@ -480,18 +689,36 @@ Return Value: stack->MinorFunction = IRP_MN_QUERY_ID; stack->Parameters.QueryId.IdType = BusQueryDeviceID; - NTSTATUS nts = IoCallDriver(pdo, Irp); + NTSTATUS tmp = IoCallDriver(pdo, Irp); - if (nts == STATUS_PENDING) { + if (tmp == STATUS_PENDING) { KeWaitForSingleObject(&ke, Executive, KernelMode, FALSE, NULL); + tmp = iosb.Status; } - if (NT_SUCCESS(nts)) { - auto* id_ptr = reinterpret_cast<WCHAR*>(iosb.Information); - wcsncpy(FilterGetData(hDevice)->dev_id, id_ptr, ra::MAX_DEV_ID_LEN); - DebugPrint(("Device ID = %ws\n", id_ptr)); + auto* devExt = FilterGetData(hDevice); + + if (NT_SUCCESS(tmp)) { + auto* id_ptr = reinterpret_cast<WCHAR*>(iosb.Information); + wcsncpy(devExt->dev_id, id_ptr, ra::MAX_DEV_ID_LEN); ExFreePool(id_ptr); } + else { + DebugPrint(("IoCallDriver failed with status 0x%x\n", tmp)); + *devExt->dev_id = L'\0'; + } + + WdfWaitLockAcquire(global.collection_lock, NULL); + + DeviceSetup(hDevice); + + tmp = WdfCollectionAdd(global.device_collection, hDevice); + + if (!NT_SUCCESS(tmp)) { + DebugPrint(("WdfCollectionAdd failed with status 0x%x\n", tmp)); + } + + WdfWaitLockRelease(global.collection_lock); // // Configure the default queue to be Parallel. Do not use sequential queue @@ -597,8 +824,6 @@ Routine Description: break; } - devExt->counter = 0; - devExt->carry = {}; devExt->UpperConnectData = *connectData; // @@ -636,7 +861,6 @@ Routine Description: } -inline VOID DispatchPassThrough( _In_ WDFREQUEST Request, |