ringkernel_core/
memory.rs

1//! GPU and host memory management abstractions.
2//!
3//! This module provides RAII wrappers for GPU memory, pinned host memory,
4//! and memory pools for efficient allocation.
5
6use std::alloc::{alloc, dealloc, Layout};
7use std::marker::PhantomData;
8use std::ptr::NonNull;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11
12use parking_lot::Mutex;
13
14use crate::error::{Result, RingKernelError};
15
16/// Trait for GPU buffer operations.
17pub trait GpuBuffer: Send + Sync {
18    /// Get buffer size in bytes.
19    fn size(&self) -> usize;
20
21    /// Get device pointer (as usize for FFI compatibility).
22    fn device_ptr(&self) -> usize;
23
24    /// Copy data from host to device.
25    fn copy_from_host(&self, data: &[u8]) -> Result<()>;
26
27    /// Copy data from device to host.
28    fn copy_to_host(&self, data: &mut [u8]) -> Result<()>;
29}
30
31/// Trait for device memory allocation.
32pub trait DeviceMemory: Send + Sync {
33    /// Allocate device memory.
34    fn allocate(&self, size: usize) -> Result<Box<dyn GpuBuffer>>;
35
36    /// Allocate device memory with alignment.
37    fn allocate_aligned(&self, size: usize, alignment: usize) -> Result<Box<dyn GpuBuffer>>;
38
39    /// Get total device memory.
40    fn total_memory(&self) -> usize;
41
42    /// Get free device memory.
43    fn free_memory(&self) -> usize;
44}
45
46/// Pinned (page-locked) host memory for efficient DMA transfers.
47///
48/// Pinned memory allows direct DMA transfers between host and device
49/// without intermediate copying, significantly improving transfer performance.
50pub struct PinnedMemory<T: Copy> {
51    ptr: NonNull<T>,
52    len: usize,
53    layout: Layout,
54    _marker: PhantomData<T>,
55}
56
57impl<T: Copy> PinnedMemory<T> {
58    /// Allocate pinned memory for `count` elements.
59    ///
60    /// # Safety
61    ///
62    /// The underlying memory is uninitialized. Caller must ensure
63    /// data is initialized before reading.
64    pub fn new(count: usize) -> Result<Self> {
65        if count == 0 {
66            return Err(RingKernelError::InvalidConfig(
67                "Cannot allocate zero-sized buffer".to_string(),
68            ));
69        }
70
71        let layout =
72            Layout::array::<T>(count).map_err(|_| RingKernelError::HostAllocationFailed {
73                size: count * std::mem::size_of::<T>(),
74            })?;
75
76        // In production, this would use platform-specific pinned allocation
77        // (e.g., cuMemAllocHost for CUDA, or mlock for general case)
78        let ptr = unsafe { alloc(layout) };
79
80        if ptr.is_null() {
81            return Err(RingKernelError::HostAllocationFailed {
82                size: layout.size(),
83            });
84        }
85
86        Ok(Self {
87            ptr: NonNull::new(ptr as *mut T).unwrap(),
88            len: count,
89            layout,
90            _marker: PhantomData,
91        })
92    }
93
94    /// Create pinned memory from a slice, copying the data.
95    pub fn from_slice(data: &[T]) -> Result<Self> {
96        let mut mem = Self::new(data.len())?;
97        mem.as_mut_slice().copy_from_slice(data);
98        Ok(mem)
99    }
100
101    /// Get slice reference.
102    pub fn as_slice(&self) -> &[T] {
103        unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
104    }
105
106    /// Get mutable slice reference.
107    pub fn as_mut_slice(&mut self) -> &mut [T] {
108        unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
109    }
110
111    /// Get raw pointer.
112    pub fn as_ptr(&self) -> *const T {
113        self.ptr.as_ptr()
114    }
115
116    /// Get mutable raw pointer.
117    pub fn as_mut_ptr(&mut self) -> *mut T {
118        self.ptr.as_ptr()
119    }
120
121    /// Get number of elements.
122    pub fn len(&self) -> usize {
123        self.len
124    }
125
126    /// Check if empty.
127    pub fn is_empty(&self) -> bool {
128        self.len == 0
129    }
130
131    /// Get size in bytes.
132    pub fn size_bytes(&self) -> usize {
133        self.len * std::mem::size_of::<T>()
134    }
135}
136
137impl<T: Copy> Drop for PinnedMemory<T> {
138    fn drop(&mut self) {
139        // In production, this would use platform-specific deallocation
140        unsafe {
141            dealloc(self.ptr.as_ptr() as *mut u8, self.layout);
142        }
143    }
144}
145
146// SAFETY: PinnedMemory can be sent between threads
147unsafe impl<T: Copy + Send> Send for PinnedMemory<T> {}
148unsafe impl<T: Copy + Sync> Sync for PinnedMemory<T> {}
149
150/// Memory pool for efficient allocation/deallocation.
151///
152/// Memory pools amortize allocation costs by maintaining a free list
153/// of pre-allocated buffers.
154pub struct MemoryPool {
155    /// Pool name for debugging.
156    name: String,
157    /// Buffer size for this pool.
158    buffer_size: usize,
159    /// Maximum number of buffers to pool.
160    max_buffers: usize,
161    /// Free list of buffers.
162    free_list: Mutex<Vec<Vec<u8>>>,
163    /// Statistics: total allocations.
164    total_allocations: AtomicUsize,
165    /// Statistics: cache hits.
166    cache_hits: AtomicUsize,
167    /// Statistics: current pool size.
168    pool_size: AtomicUsize,
169}
170
171impl MemoryPool {
172    /// Create a new memory pool.
173    pub fn new(name: impl Into<String>, buffer_size: usize, max_buffers: usize) -> Self {
174        Self {
175            name: name.into(),
176            buffer_size,
177            max_buffers,
178            free_list: Mutex::new(Vec::with_capacity(max_buffers)),
179            total_allocations: AtomicUsize::new(0),
180            cache_hits: AtomicUsize::new(0),
181            pool_size: AtomicUsize::new(0),
182        }
183    }
184
185    /// Allocate a buffer from the pool.
186    pub fn allocate(&self) -> PooledBuffer<'_> {
187        self.total_allocations.fetch_add(1, Ordering::Relaxed);
188
189        let buffer = {
190            let mut free = self.free_list.lock();
191            if let Some(buf) = free.pop() {
192                self.cache_hits.fetch_add(1, Ordering::Relaxed);
193                self.pool_size.fetch_sub(1, Ordering::Relaxed);
194                buf
195            } else {
196                vec![0u8; self.buffer_size]
197            }
198        };
199
200        PooledBuffer {
201            buffer: Some(buffer),
202            pool: self,
203        }
204    }
205
206    /// Return a buffer to the pool.
207    fn return_buffer(&self, mut buffer: Vec<u8>) {
208        let mut free = self.free_list.lock();
209        if free.len() < self.max_buffers {
210            buffer.clear();
211            buffer.resize(self.buffer_size, 0);
212            free.push(buffer);
213            self.pool_size.fetch_add(1, Ordering::Relaxed);
214        }
215        // If pool is full, buffer is dropped
216    }
217
218    /// Get pool name.
219    pub fn name(&self) -> &str {
220        &self.name
221    }
222
223    /// Get buffer size.
224    pub fn buffer_size(&self) -> usize {
225        self.buffer_size
226    }
227
228    /// Get current pool size.
229    pub fn current_size(&self) -> usize {
230        self.pool_size.load(Ordering::Relaxed)
231    }
232
233    /// Get cache hit rate.
234    pub fn hit_rate(&self) -> f64 {
235        let total = self.total_allocations.load(Ordering::Relaxed);
236        let hits = self.cache_hits.load(Ordering::Relaxed);
237        if total == 0 {
238            0.0
239        } else {
240            hits as f64 / total as f64
241        }
242    }
243
244    /// Pre-allocate buffers.
245    pub fn preallocate(&self, count: usize) {
246        let count = count.min(self.max_buffers);
247        let mut free = self.free_list.lock();
248        for _ in free.len()..count {
249            free.push(vec![0u8; self.buffer_size]);
250            self.pool_size.fetch_add(1, Ordering::Relaxed);
251        }
252    }
253}
254
255/// A buffer from a memory pool.
256///
257/// When dropped, the buffer is returned to the pool for reuse.
258pub struct PooledBuffer<'a> {
259    buffer: Option<Vec<u8>>,
260    pool: &'a MemoryPool,
261}
262
263impl<'a> PooledBuffer<'a> {
264    /// Get slice reference.
265    pub fn as_slice(&self) -> &[u8] {
266        self.buffer.as_deref().unwrap_or(&[])
267    }
268
269    /// Get mutable slice reference.
270    pub fn as_mut_slice(&mut self) -> &mut [u8] {
271        self.buffer.as_deref_mut().unwrap_or(&mut [])
272    }
273
274    /// Get buffer length.
275    pub fn len(&self) -> usize {
276        self.buffer.as_ref().map(|b| b.len()).unwrap_or(0)
277    }
278
279    /// Check if empty.
280    pub fn is_empty(&self) -> bool {
281        self.len() == 0
282    }
283}
284
285impl<'a> Drop for PooledBuffer<'a> {
286    fn drop(&mut self) {
287        if let Some(buffer) = self.buffer.take() {
288            self.pool.return_buffer(buffer);
289        }
290    }
291}
292
293impl<'a> std::ops::Deref for PooledBuffer<'a> {
294    type Target = [u8];
295
296    fn deref(&self) -> &Self::Target {
297        self.as_slice()
298    }
299}
300
301impl<'a> std::ops::DerefMut for PooledBuffer<'a> {
302    fn deref_mut(&mut self) -> &mut Self::Target {
303        self.as_mut_slice()
304    }
305}
306
307/// Shared memory pool that can be cloned.
308pub type SharedMemoryPool = Arc<MemoryPool>;
309
310/// Create a shared memory pool.
311pub fn create_pool(
312    name: impl Into<String>,
313    buffer_size: usize,
314    max_buffers: usize,
315) -> SharedMemoryPool {
316    Arc::new(MemoryPool::new(name, buffer_size, max_buffers))
317}
318
319/// Alignment utilities.
320pub mod align {
321    /// Cache line size (64 bytes on most modern CPUs).
322    pub const CACHE_LINE_SIZE: usize = 64;
323
324    /// GPU cache line size (128 bytes on many GPUs).
325    pub const GPU_CACHE_LINE_SIZE: usize = 128;
326
327    /// Align a value up to the next multiple of alignment.
328    #[inline]
329    pub const fn align_up(value: usize, alignment: usize) -> usize {
330        let mask = alignment - 1;
331        (value + mask) & !mask
332    }
333
334    /// Align a value down to the previous multiple of alignment.
335    #[inline]
336    pub const fn align_down(value: usize, alignment: usize) -> usize {
337        let mask = alignment - 1;
338        value & !mask
339    }
340
341    /// Check if a value is aligned.
342    #[inline]
343    pub const fn is_aligned(value: usize, alignment: usize) -> bool {
344        value & (alignment - 1) == 0
345    }
346
347    /// Get required padding for alignment.
348    #[inline]
349    pub const fn padding_for(offset: usize, alignment: usize) -> usize {
350        let misalignment = offset & (alignment - 1);
351        if misalignment == 0 {
352            0
353        } else {
354            alignment - misalignment
355        }
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[test]
364    fn test_pinned_memory() {
365        let mut mem = PinnedMemory::<f32>::new(1024).unwrap();
366        assert_eq!(mem.len(), 1024);
367        assert_eq!(mem.size_bytes(), 1024 * 4);
368
369        // Write some data
370        let slice = mem.as_mut_slice();
371        for (i, v) in slice.iter_mut().enumerate() {
372            *v = i as f32;
373        }
374
375        // Read back
376        assert_eq!(mem.as_slice()[42], 42.0);
377    }
378
379    #[test]
380    fn test_pinned_memory_from_slice() {
381        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
382        let mem = PinnedMemory::from_slice(&data).unwrap();
383        assert_eq!(mem.as_slice(), &data[..]);
384    }
385
386    #[test]
387    fn test_memory_pool() {
388        let pool = MemoryPool::new("test", 1024, 10);
389
390        // First allocation should be fresh
391        let buf1 = pool.allocate();
392        assert_eq!(buf1.len(), 1024);
393        drop(buf1);
394
395        // Second allocation should be cached
396        let _buf2 = pool.allocate();
397        assert_eq!(pool.hit_rate(), 0.5); // 1 hit out of 2 allocations
398    }
399
400    #[test]
401    fn test_pool_preallocate() {
402        let pool = MemoryPool::new("test", 1024, 10);
403        pool.preallocate(5);
404        assert_eq!(pool.current_size(), 5);
405
406        // All allocations should hit cache
407        for _ in 0..5 {
408            let _ = pool.allocate();
409        }
410        assert_eq!(pool.hit_rate(), 1.0);
411    }
412
413    #[test]
414    fn test_align_up() {
415        use align::*;
416
417        assert_eq!(align_up(0, 64), 0);
418        assert_eq!(align_up(1, 64), 64);
419        assert_eq!(align_up(64, 64), 64);
420        assert_eq!(align_up(65, 64), 128);
421    }
422
423    #[test]
424    fn test_is_aligned() {
425        use align::*;
426
427        assert!(is_aligned(0, 64));
428        assert!(is_aligned(64, 64));
429        assert!(is_aligned(128, 64));
430        assert!(!is_aligned(1, 64));
431        assert!(!is_aligned(63, 64));
432    }
433
434    #[test]
435    fn test_padding_for() {
436        use align::*;
437
438        assert_eq!(padding_for(0, 64), 0);
439        assert_eq!(padding_for(1, 64), 63);
440        assert_eq!(padding_for(63, 64), 1);
441        assert_eq!(padding_for(64, 64), 0);
442    }
443}