Compare commits

...

7 commits

5 changed files with 76 additions and 0 deletions

View file

@ -1,5 +1,10 @@
from sys import platform
from pci_passthrough_assist.permissions import is_ran_by_root
if platform != "linux":
print("This tool will only work on Linux based OS.")
exit(1)
if not is_ran_by_root():
print("This script needs to run as root.")
exit(1)

View file

@ -0,0 +1,16 @@
from os import listdir
from os.path import realpath, basename
def all_pci_device_ids() -> list[str]:
return listdir("/sys/bus/pci/devices")
def all_pci_driver_ids() -> list[str]:
return listdir("/sys/bus/pci/drivers")
def driver_of_pci_device(pci_device_id: str) -> str:
driver_directory: str = realpath(
f"/sys/bus/pci/devices/{pci_device_id}/driver")
return basename(driver_directory)

View file

@ -0,0 +1,36 @@
from pci_passthrough_assist.pci import driver_of_pci_device
from os.path import exists
from os import listdir
class PciDevice:
def __init__(self, device_id: str):
self.device_id = device_id
def driver_name(self) -> str:
return driver_of_pci_device(self.device_id)
def is_vga(self) -> bool:
return exists(f"/sys/bus/pci/devices/{self.device_id}/boot_vga")
def unbind_driver(self):
with open(f"/sys/bus/pci/devices/{self.device_id}/driver/unbind",
"w") as device_driver:
device_driver.write(self.device_id)
def bind_to_driver(self, driver_to_bind: str, unbind_first: bool = True):
if unbind_first:
self.unbind_driver()
with open(f"/sys/bus/pci/drivers/{driver_to_bind}/bind",
"w") as driver:
driver.write(self.device_id)
def devices_in_iommu_group(self) -> list['PciDevice']:
device_ids: list[str] = listdir(
f"/sys/bus/pci/devices/{self.device_id}/iommu_group/devices")
return [PciDevice(device_id) for device_id in device_ids]
def __str__(self) -> str:
return f"{self.device_id} driver: {self.driver_name()} VGA: {self.is_vga()}"

View file

@ -0,0 +1,5 @@
from pci_passthrough_assist.process_runner import sh
def is_ran_by_root() -> bool:
return sh(["whoami"]) == "root"

View file

@ -0,0 +1,14 @@
from subprocess import run
def sh_binary(args: list[str], ) -> bytes:
return run(args, capture_output=True).stdout
def sh(args: list[str], rstrip_newline=True):
string_output: str = sh_binary(args).decode()
return string_output if not rstrip_newline else string_output.rstrip("\n")
def sh_lines(args: list[str]) -> list[str]:
return sh(args, rstrip_newline=True).splitlines()