Skip to content

Commit 7b948a2

Browse files
Danilo Krummrichgregkh
authored andcommitted
rust: pci: fix unrestricted &mut pci::Device
As by now, pci::Device is implemented as: #[derive(Clone)] pub struct Device(ARef<device::Device>); This may be convenient, but has the implication that drivers can call device methods that require a mutable reference concurrently at any point of time. Instead define pci::Device as pub struct Device<Ctx: DeviceContext = Normal>( Opaque<bindings::pci_dev>, PhantomData<Ctx>, ); and manually implement the AlwaysRefCounted trait. With this we can implement methods that should only be called from bus callbacks (such as probe()) for pci::Device<Core>. Consequently, we make this type accessible in bus callbacks only. Arbitrary references taken by the driver are still of type ARef<pci::Device> and hence don't provide access to methods that are reserved for bus callbacks. Fixes: 1bd8b6b ("rust: pci: add basic PCI device / driver abstractions") Reviewed-by: Benno Lossin <[email protected]> Signed-off-by: Danilo Krummrich <[email protected]> Acked-by: Boqun Feng <[email protected]> Link: https://lore.kernel.org/r/[email protected] Signed-off-by: Greg Kroah-Hartman <[email protected]>
1 parent 4d03277 commit 7b948a2

File tree

2 files changed

+89
-51
lines changed

2 files changed

+89
-51
lines changed

rust/kernel/pci.rs

Lines changed: 85 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
77
use crate::{
88
alloc::flags::*,
9-
bindings, container_of, device,
9+
bindings, device,
1010
device_id::RawDeviceId,
1111
devres::Devres,
1212
driver,
@@ -17,7 +17,11 @@ use crate::{
1717
types::{ARef, ForeignOwnable, Opaque},
1818
ThisModule,
1919
};
20-
use core::{ops::Deref, ptr::addr_of_mut};
20+
use core::{
21+
marker::PhantomData,
22+
ops::Deref,
23+
ptr::{addr_of_mut, NonNull},
24+
};
2125
use kernel::prelude::*;
2226

2327
/// An adapter for the registration of PCI drivers.
@@ -60,17 +64,16 @@ impl<T: Driver + 'static> Adapter<T> {
6064
) -> kernel::ffi::c_int {
6165
// SAFETY: The PCI bus only ever calls the probe callback with a valid pointer to a
6266
// `struct pci_dev`.
63-
let dev = unsafe { device::Device::get_device(addr_of_mut!((*pdev).dev)) };
64-
// SAFETY: `dev` is guaranteed to be embedded in a valid `struct pci_dev` by the call
65-
// above.
66-
let mut pdev = unsafe { Device::from_dev(dev) };
67+
//
68+
// INVARIANT: `pdev` is valid for the duration of `probe_callback()`.
69+
let pdev = unsafe { &*pdev.cast::<Device<device::Core>>() };
6770

6871
// SAFETY: `DeviceId` is a `#[repr(transparent)` wrapper of `struct pci_device_id` and
6972
// does not add additional invariants, so it's safe to transmute.
7073
let id = unsafe { &*id.cast::<DeviceId>() };
7174
let info = T::ID_TABLE.info(id.index());
7275

73-
match T::probe(&mut pdev, info) {
76+
match T::probe(pdev, info) {
7477
Ok(data) => {
7578
// Let the `struct pci_dev` own a reference of the driver's private data.
7679
// SAFETY: By the type invariant `pdev.as_raw` returns a valid pointer to a
@@ -192,7 +195,7 @@ macro_rules! pci_device_table {
192195
/// # Example
193196
///
194197
///```
195-
/// # use kernel::{bindings, pci};
198+
/// # use kernel::{bindings, device::Core, pci};
196199
///
197200
/// struct MyDriver;
198201
///
@@ -210,7 +213,7 @@ macro_rules! pci_device_table {
210213
/// const ID_TABLE: pci::IdTable<Self::IdInfo> = &PCI_TABLE;
211214
///
212215
/// fn probe(
213-
/// _pdev: &mut pci::Device,
216+
/// _pdev: &pci::Device<Core>,
214217
/// _id_info: &Self::IdInfo,
215218
/// ) -> Result<Pin<KBox<Self>>> {
216219
/// Err(ENODEV)
@@ -234,20 +237,23 @@ pub trait Driver {
234237
///
235238
/// Called when a new platform device is added or discovered.
236239
/// Implementers should attempt to initialize the device here.
237-
fn probe(dev: &mut Device, id_info: &Self::IdInfo) -> Result<Pin<KBox<Self>>>;
240+
fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> Result<Pin<KBox<Self>>>;
238241
}
239242

240243
/// The PCI device representation.
241244
///
242-
/// A PCI device is based on an always reference counted `device:Device` instance. Cloning a PCI
243-
/// device, hence, also increments the base device' reference count.
245+
/// This structure represents the Rust abstraction for a C `struct pci_dev`. The implementation
246+
/// abstracts the usage of an already existing C `struct pci_dev` within Rust code that we get
247+
/// passed from the C side.
244248
///
245249
/// # Invariants
246250
///
247-
/// `Device` hold a valid reference of `ARef<device::Device>` whose underlying `struct device` is a
248-
/// member of a `struct pci_dev`.
249-
#[derive(Clone)]
250-
pub struct Device(ARef<device::Device>);
251+
/// A [`Device`] instance represents a valid `struct device` created by the C portion of the kernel.
252+
#[repr(transparent)]
253+
pub struct Device<Ctx: device::DeviceContext = device::Normal>(
254+
Opaque<bindings::pci_dev>,
255+
PhantomData<Ctx>,
256+
);
251257

252258
/// A PCI BAR to perform I/O-Operations on.
253259
///
@@ -256,13 +262,13 @@ pub struct Device(ARef<device::Device>);
256262
/// `Bar` always holds an `IoRaw` inststance that holds a valid pointer to the start of the I/O
257263
/// memory mapped PCI bar and its size.
258264
pub struct Bar<const SIZE: usize = 0> {
259-
pdev: Device,
265+
pdev: ARef<Device>,
260266
io: IoRaw<SIZE>,
261267
num: i32,
262268
}
263269

264270
impl<const SIZE: usize> Bar<SIZE> {
265-
fn new(pdev: Device, num: u32, name: &CStr) -> Result<Self> {
271+
fn new(pdev: &Device, num: u32, name: &CStr) -> Result<Self> {
266272
let len = pdev.resource_len(num)?;
267273
if len == 0 {
268274
return Err(ENOMEM);
@@ -300,12 +306,16 @@ impl<const SIZE: usize> Bar<SIZE> {
300306
// `pdev` is valid by the invariants of `Device`.
301307
// `ioptr` is guaranteed to be the start of a valid I/O mapped memory region.
302308
// `num` is checked for validity by a previous call to `Device::resource_len`.
303-
unsafe { Self::do_release(&pdev, ioptr, num) };
309+
unsafe { Self::do_release(pdev, ioptr, num) };
304310
return Err(err);
305311
}
306312
};
307313

308-
Ok(Bar { pdev, io, num })
314+
Ok(Bar {
315+
pdev: pdev.into(),
316+
io,
317+
num,
318+
})
309319
}
310320

311321
/// # Safety
@@ -351,20 +361,8 @@ impl<const SIZE: usize> Deref for Bar<SIZE> {
351361
}
352362

353363
impl Device {
354-
/// Create a PCI Device instance from an existing `device::Device`.
355-
///
356-
/// # Safety
357-
///
358-
/// `dev` must be an `ARef<device::Device>` whose underlying `bindings::device` is a member of
359-
/// a `bindings::pci_dev`.
360-
pub unsafe fn from_dev(dev: ARef<device::Device>) -> Self {
361-
Self(dev)
362-
}
363-
364364
fn as_raw(&self) -> *mut bindings::pci_dev {
365-
// SAFETY: By the type invariant `self.0.as_raw` is a pointer to the `struct device`
366-
// embedded in `struct pci_dev`.
367-
unsafe { container_of!(self.0.as_raw(), bindings::pci_dev, dev) as _ }
365+
self.0.get()
368366
}
369367

370368
/// Returns the PCI vendor ID.
@@ -379,18 +377,6 @@ impl Device {
379377
unsafe { (*self.as_raw()).device }
380378
}
381379

382-
/// Enable memory resources for this device.
383-
pub fn enable_device_mem(&self) -> Result {
384-
// SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`.
385-
to_result(unsafe { bindings::pci_enable_device_mem(self.as_raw()) })
386-
}
387-
388-
/// Enable bus-mastering for this device.
389-
pub fn set_master(&self) {
390-
// SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`.
391-
unsafe { bindings::pci_set_master(self.as_raw()) };
392-
}
393-
394380
/// Returns the size of the given PCI bar resource.
395381
pub fn resource_len(&self, bar: u32) -> Result<bindings::resource_size_t> {
396382
if !Bar::index_is_valid(bar) {
@@ -410,7 +396,7 @@ impl Device {
410396
bar: u32,
411397
name: &CStr,
412398
) -> Result<Devres<Bar<SIZE>>> {
413-
let bar = Bar::<SIZE>::new(self.clone(), bar, name)?;
399+
let bar = Bar::<SIZE>::new(self, bar, name)?;
414400
let devres = Devres::new(self.as_ref(), bar, GFP_KERNEL)?;
415401

416402
Ok(devres)
@@ -422,8 +408,60 @@ impl Device {
422408
}
423409
}
424410

411+
impl Device<device::Core> {
412+
/// Enable memory resources for this device.
413+
pub fn enable_device_mem(&self) -> Result {
414+
// SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`.
415+
to_result(unsafe { bindings::pci_enable_device_mem(self.as_raw()) })
416+
}
417+
418+
/// Enable bus-mastering for this device.
419+
pub fn set_master(&self) {
420+
// SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`.
421+
unsafe { bindings::pci_set_master(self.as_raw()) };
422+
}
423+
}
424+
425+
impl Deref for Device<device::Core> {
426+
type Target = Device;
427+
428+
fn deref(&self) -> &Self::Target {
429+
let ptr: *const Self = self;
430+
431+
// CAST: `Device<Ctx>` is a transparent wrapper of `Opaque<bindings::pci_dev>`.
432+
let ptr = ptr.cast::<Device>();
433+
434+
// SAFETY: `ptr` was derived from `&self`.
435+
unsafe { &*ptr }
436+
}
437+
}
438+
439+
impl From<&Device<device::Core>> for ARef<Device> {
440+
fn from(dev: &Device<device::Core>) -> Self {
441+
(&**dev).into()
442+
}
443+
}
444+
445+
// SAFETY: Instances of `Device` are always reference-counted.
446+
unsafe impl crate::types::AlwaysRefCounted for Device {
447+
fn inc_ref(&self) {
448+
// SAFETY: The existence of a shared reference guarantees that the refcount is non-zero.
449+
unsafe { bindings::pci_dev_get(self.as_raw()) };
450+
}
451+
452+
unsafe fn dec_ref(obj: NonNull<Self>) {
453+
// SAFETY: The safety requirements guarantee that the refcount is non-zero.
454+
unsafe { bindings::pci_dev_put(obj.cast().as_ptr()) }
455+
}
456+
}
457+
425458
impl AsRef<device::Device> for Device {
426459
fn as_ref(&self) -> &device::Device {
427-
&self.0
460+
// SAFETY: By the type invariant of `Self`, `self.as_raw()` is a pointer to a valid
461+
// `struct pci_dev`.
462+
let dev = unsafe { addr_of_mut!((*self.as_raw()).dev) };
463+
464+
// SAFETY: `dev` points to a valid `struct device`.
465+
unsafe { device::Device::as_ref(dev) }
428466
}
429467
}

samples/rust/rust_driver_pci.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
//!
55
//! To make this driver probe, QEMU must be run with `-device pci-testdev`.
66
7-
use kernel::{bindings, c_str, devres::Devres, pci, prelude::*};
7+
use kernel::{bindings, c_str, device::Core, devres::Devres, pci, prelude::*, types::ARef};
88

99
struct Regs;
1010

@@ -26,7 +26,7 @@ impl TestIndex {
2626
}
2727

2828
struct SampleDriver {
29-
pdev: pci::Device,
29+
pdev: ARef<pci::Device>,
3030
bar: Devres<Bar0>,
3131
}
3232

@@ -62,7 +62,7 @@ impl pci::Driver for SampleDriver {
6262

6363
const ID_TABLE: pci::IdTable<Self::IdInfo> = &PCI_TABLE;
6464

65-
fn probe(pdev: &mut pci::Device, info: &Self::IdInfo) -> Result<Pin<KBox<Self>>> {
65+
fn probe(pdev: &pci::Device<Core>, info: &Self::IdInfo) -> Result<Pin<KBox<Self>>> {
6666
dev_dbg!(
6767
pdev.as_ref(),
6868
"Probe Rust PCI driver sample (PCI ID: 0x{:x}, 0x{:x}).\n",
@@ -77,7 +77,7 @@ impl pci::Driver for SampleDriver {
7777

7878
let drvdata = KBox::new(
7979
Self {
80-
pdev: pdev.clone(),
80+
pdev: pdev.into(),
8181
bar,
8282
},
8383
GFP_KERNEL,

0 commit comments

Comments
 (0)