1use std::alloc::{alloc, dealloc, Layout};
7use std::marker::PhantomData;
8use std::ptr::NonNull;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11
12use parking_lot::Mutex;
13
14use crate::error::{Result, RingKernelError};
15
16pub trait GpuBuffer: Send + Sync {
18 fn size(&self) -> usize;
20
21 fn device_ptr(&self) -> usize;
23
24 fn copy_from_host(&self, data: &[u8]) -> Result<()>;
26
27 fn copy_to_host(&self, data: &mut [u8]) -> Result<()>;
29}
30
31pub trait DeviceMemory: Send + Sync {
33 fn allocate(&self, size: usize) -> Result<Box<dyn GpuBuffer>>;
35
36 fn allocate_aligned(&self, size: usize, alignment: usize) -> Result<Box<dyn GpuBuffer>>;
38
39 fn total_memory(&self) -> usize;
41
42 fn free_memory(&self) -> usize;
44}
45
46pub struct PinnedMemory<T: Copy> {
51 ptr: NonNull<T>,
52 len: usize,
53 layout: Layout,
54 _marker: PhantomData<T>,
55}
56
57impl<T: Copy> PinnedMemory<T> {
58 pub fn new(count: usize) -> Result<Self> {
65 if count == 0 {
66 return Err(RingKernelError::InvalidConfig(
67 "Cannot allocate zero-sized buffer".to_string(),
68 ));
69 }
70
71 let layout =
72 Layout::array::<T>(count).map_err(|_| RingKernelError::HostAllocationFailed {
73 size: count * std::mem::size_of::<T>(),
74 })?;
75
76 let ptr = unsafe { alloc(layout) };
79
80 if ptr.is_null() {
81 return Err(RingKernelError::HostAllocationFailed {
82 size: layout.size(),
83 });
84 }
85
86 Ok(Self {
87 ptr: NonNull::new(ptr as *mut T).unwrap(),
88 len: count,
89 layout,
90 _marker: PhantomData,
91 })
92 }
93
94 pub fn from_slice(data: &[T]) -> Result<Self> {
96 let mut mem = Self::new(data.len())?;
97 mem.as_mut_slice().copy_from_slice(data);
98 Ok(mem)
99 }
100
101 pub fn as_slice(&self) -> &[T] {
103 unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
104 }
105
106 pub fn as_mut_slice(&mut self) -> &mut [T] {
108 unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
109 }
110
111 pub fn as_ptr(&self) -> *const T {
113 self.ptr.as_ptr()
114 }
115
116 pub fn as_mut_ptr(&mut self) -> *mut T {
118 self.ptr.as_ptr()
119 }
120
121 pub fn len(&self) -> usize {
123 self.len
124 }
125
126 pub fn is_empty(&self) -> bool {
128 self.len == 0
129 }
130
131 pub fn size_bytes(&self) -> usize {
133 self.len * std::mem::size_of::<T>()
134 }
135}
136
137impl<T: Copy> Drop for PinnedMemory<T> {
138 fn drop(&mut self) {
139 unsafe {
141 dealloc(self.ptr.as_ptr() as *mut u8, self.layout);
142 }
143 }
144}
145
146unsafe impl<T: Copy + Send> Send for PinnedMemory<T> {}
148unsafe impl<T: Copy + Sync> Sync for PinnedMemory<T> {}
149
150pub struct MemoryPool {
155 name: String,
157 buffer_size: usize,
159 max_buffers: usize,
161 free_list: Mutex<Vec<Vec<u8>>>,
163 total_allocations: AtomicUsize,
165 cache_hits: AtomicUsize,
167 pool_size: AtomicUsize,
169}
170
171impl MemoryPool {
172 pub fn new(name: impl Into<String>, buffer_size: usize, max_buffers: usize) -> Self {
174 Self {
175 name: name.into(),
176 buffer_size,
177 max_buffers,
178 free_list: Mutex::new(Vec::with_capacity(max_buffers)),
179 total_allocations: AtomicUsize::new(0),
180 cache_hits: AtomicUsize::new(0),
181 pool_size: AtomicUsize::new(0),
182 }
183 }
184
185 pub fn allocate(&self) -> PooledBuffer<'_> {
187 self.total_allocations.fetch_add(1, Ordering::Relaxed);
188
189 let buffer = {
190 let mut free = self.free_list.lock();
191 if let Some(buf) = free.pop() {
192 self.cache_hits.fetch_add(1, Ordering::Relaxed);
193 self.pool_size.fetch_sub(1, Ordering::Relaxed);
194 buf
195 } else {
196 vec![0u8; self.buffer_size]
197 }
198 };
199
200 PooledBuffer {
201 buffer: Some(buffer),
202 pool: self,
203 }
204 }
205
206 fn return_buffer(&self, mut buffer: Vec<u8>) {
208 let mut free = self.free_list.lock();
209 if free.len() < self.max_buffers {
210 buffer.clear();
211 buffer.resize(self.buffer_size, 0);
212 free.push(buffer);
213 self.pool_size.fetch_add(1, Ordering::Relaxed);
214 }
215 }
217
218 pub fn name(&self) -> &str {
220 &self.name
221 }
222
223 pub fn buffer_size(&self) -> usize {
225 self.buffer_size
226 }
227
228 pub fn current_size(&self) -> usize {
230 self.pool_size.load(Ordering::Relaxed)
231 }
232
233 pub fn hit_rate(&self) -> f64 {
235 let total = self.total_allocations.load(Ordering::Relaxed);
236 let hits = self.cache_hits.load(Ordering::Relaxed);
237 if total == 0 {
238 0.0
239 } else {
240 hits as f64 / total as f64
241 }
242 }
243
244 pub fn preallocate(&self, count: usize) {
246 let count = count.min(self.max_buffers);
247 let mut free = self.free_list.lock();
248 for _ in free.len()..count {
249 free.push(vec![0u8; self.buffer_size]);
250 self.pool_size.fetch_add(1, Ordering::Relaxed);
251 }
252 }
253}
254
255pub struct PooledBuffer<'a> {
259 buffer: Option<Vec<u8>>,
260 pool: &'a MemoryPool,
261}
262
263impl<'a> PooledBuffer<'a> {
264 pub fn as_slice(&self) -> &[u8] {
266 self.buffer.as_deref().unwrap_or(&[])
267 }
268
269 pub fn as_mut_slice(&mut self) -> &mut [u8] {
271 self.buffer.as_deref_mut().unwrap_or(&mut [])
272 }
273
274 pub fn len(&self) -> usize {
276 self.buffer.as_ref().map(|b| b.len()).unwrap_or(0)
277 }
278
279 pub fn is_empty(&self) -> bool {
281 self.len() == 0
282 }
283}
284
285impl<'a> Drop for PooledBuffer<'a> {
286 fn drop(&mut self) {
287 if let Some(buffer) = self.buffer.take() {
288 self.pool.return_buffer(buffer);
289 }
290 }
291}
292
293impl<'a> std::ops::Deref for PooledBuffer<'a> {
294 type Target = [u8];
295
296 fn deref(&self) -> &Self::Target {
297 self.as_slice()
298 }
299}
300
301impl<'a> std::ops::DerefMut for PooledBuffer<'a> {
302 fn deref_mut(&mut self) -> &mut Self::Target {
303 self.as_mut_slice()
304 }
305}
306
307pub type SharedMemoryPool = Arc<MemoryPool>;
309
310pub fn create_pool(
312 name: impl Into<String>,
313 buffer_size: usize,
314 max_buffers: usize,
315) -> SharedMemoryPool {
316 Arc::new(MemoryPool::new(name, buffer_size, max_buffers))
317}
318
319pub mod align {
321 pub const CACHE_LINE_SIZE: usize = 64;
323
324 pub const GPU_CACHE_LINE_SIZE: usize = 128;
326
327 #[inline]
329 pub const fn align_up(value: usize, alignment: usize) -> usize {
330 let mask = alignment - 1;
331 (value + mask) & !mask
332 }
333
334 #[inline]
336 pub const fn align_down(value: usize, alignment: usize) -> usize {
337 let mask = alignment - 1;
338 value & !mask
339 }
340
341 #[inline]
343 pub const fn is_aligned(value: usize, alignment: usize) -> bool {
344 value & (alignment - 1) == 0
345 }
346
347 #[inline]
349 pub const fn padding_for(offset: usize, alignment: usize) -> usize {
350 let misalignment = offset & (alignment - 1);
351 if misalignment == 0 {
352 0
353 } else {
354 alignment - misalignment
355 }
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
364 fn test_pinned_memory() {
365 let mut mem = PinnedMemory::<f32>::new(1024).unwrap();
366 assert_eq!(mem.len(), 1024);
367 assert_eq!(mem.size_bytes(), 1024 * 4);
368
369 let slice = mem.as_mut_slice();
371 for (i, v) in slice.iter_mut().enumerate() {
372 *v = i as f32;
373 }
374
375 assert_eq!(mem.as_slice()[42], 42.0);
377 }
378
379 #[test]
380 fn test_pinned_memory_from_slice() {
381 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
382 let mem = PinnedMemory::from_slice(&data).unwrap();
383 assert_eq!(mem.as_slice(), &data[..]);
384 }
385
386 #[test]
387 fn test_memory_pool() {
388 let pool = MemoryPool::new("test", 1024, 10);
389
390 let buf1 = pool.allocate();
392 assert_eq!(buf1.len(), 1024);
393 drop(buf1);
394
395 let _buf2 = pool.allocate();
397 assert_eq!(pool.hit_rate(), 0.5); }
399
400 #[test]
401 fn test_pool_preallocate() {
402 let pool = MemoryPool::new("test", 1024, 10);
403 pool.preallocate(5);
404 assert_eq!(pool.current_size(), 5);
405
406 for _ in 0..5 {
408 let _ = pool.allocate();
409 }
410 assert_eq!(pool.hit_rate(), 1.0);
411 }
412
413 #[test]
414 fn test_align_up() {
415 use align::*;
416
417 assert_eq!(align_up(0, 64), 0);
418 assert_eq!(align_up(1, 64), 64);
419 assert_eq!(align_up(64, 64), 64);
420 assert_eq!(align_up(65, 64), 128);
421 }
422
423 #[test]
424 fn test_is_aligned() {
425 use align::*;
426
427 assert!(is_aligned(0, 64));
428 assert!(is_aligned(64, 64));
429 assert!(is_aligned(128, 64));
430 assert!(!is_aligned(1, 64));
431 assert!(!is_aligned(63, 64));
432 }
433
434 #[test]
435 fn test_padding_for() {
436 use align::*;
437
438 assert_eq!(padding_for(0, 64), 0);
439 assert_eq!(padding_for(1, 64), 63);
440 assert_eq!(padding_for(63, 64), 1);
441 assert_eq!(padding_for(64, 64), 0);
442 }
443}