1use parking_lot::RwLock;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
10use std::sync::Arc;
11
12use crate::error::{Result, RingKernelError};
13use crate::runtime::{Backend, KernelId, LaunchOptions};
14
15#[derive(Debug, Clone)]
17pub struct MultiGpuConfig {
18 pub load_balancing: LoadBalancingStrategy,
20 pub auto_select_device: bool,
22 pub max_kernels_per_device: usize,
24 pub enable_p2p: bool,
26 pub preferred_devices: Vec<usize>,
28}
29
30impl Default for MultiGpuConfig {
31 fn default() -> Self {
32 Self {
33 load_balancing: LoadBalancingStrategy::LeastLoaded,
34 auto_select_device: true,
35 max_kernels_per_device: 64,
36 enable_p2p: true,
37 preferred_devices: vec![],
38 }
39 }
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum LoadBalancingStrategy {
45 FirstAvailable,
47 LeastLoaded,
49 RoundRobin,
51 MemoryBased,
53 ComputeCapability,
55 Custom,
57}
58
59#[derive(Debug, Clone)]
61pub struct DeviceInfo {
62 pub index: usize,
64 pub name: String,
66 pub backend: Backend,
68 pub total_memory: u64,
70 pub available_memory: u64,
72 pub compute_capability: Option<(u32, u32)>,
74 pub max_threads_per_block: u32,
76 pub multiprocessor_count: u32,
78 pub p2p_capable: bool,
80}
81
82impl DeviceInfo {
83 pub fn new(index: usize, name: String, backend: Backend) -> Self {
85 Self {
86 index,
87 name,
88 backend,
89 total_memory: 0,
90 available_memory: 0,
91 compute_capability: None,
92 max_threads_per_block: 1024,
93 multiprocessor_count: 1,
94 p2p_capable: false,
95 }
96 }
97
98 pub fn memory_utilization(&self) -> f64 {
100 if self.total_memory == 0 {
101 0.0
102 } else {
103 1.0 - (self.available_memory as f64 / self.total_memory as f64)
104 }
105 }
106}
107
108#[derive(Debug, Clone)]
110pub struct DeviceStatus {
111 pub info: DeviceInfo,
113 pub kernel_count: usize,
115 pub kernels: Vec<KernelId>,
117 pub available: bool,
119 pub load: f64,
121}
122
123pub struct MultiGpuCoordinator {
125 config: MultiGpuConfig,
127 devices: RwLock<Vec<DeviceInfo>>,
129 kernel_device_map: RwLock<HashMap<KernelId, usize>>,
131 device_kernel_counts: RwLock<Vec<AtomicUsize>>,
133 round_robin_counter: AtomicUsize,
135 total_kernels: AtomicU64,
137 #[allow(clippy::type_complexity)]
139 custom_selector:
140 RwLock<Option<Arc<dyn Fn(&[DeviceStatus], &LaunchOptions) -> usize + Send + Sync>>>,
141}
142
143impl MultiGpuCoordinator {
144 pub fn new(config: MultiGpuConfig) -> Arc<Self> {
146 Arc::new(Self {
147 config,
148 devices: RwLock::new(Vec::new()),
149 kernel_device_map: RwLock::new(HashMap::new()),
150 device_kernel_counts: RwLock::new(Vec::new()),
151 round_robin_counter: AtomicUsize::new(0),
152 total_kernels: AtomicU64::new(0),
153 custom_selector: RwLock::new(None),
154 })
155 }
156
157 pub fn register_device(&self, device: DeviceInfo) {
159 let index = device.index;
160 let mut devices = self.devices.write();
161 let mut counts = self.device_kernel_counts.write();
162
163 let mut current_len = devices.len();
165 while current_len <= index {
166 devices.push(DeviceInfo::new(
167 current_len,
168 "Unknown".to_string(),
169 Backend::Cpu,
170 ));
171 counts.push(AtomicUsize::new(0));
172 current_len += 1;
173 }
174
175 devices[index] = device;
176 }
177
178 pub fn unregister_device(&self, index: usize) {
180 let devices = self.devices.read();
181 if index < devices.len() {
182 }
184 }
185
186 pub fn devices(&self) -> Vec<DeviceInfo> {
188 self.devices.read().clone()
189 }
190
191 pub fn device(&self, index: usize) -> Option<DeviceInfo> {
193 self.devices.read().get(index).cloned()
194 }
195
196 pub fn device_count(&self) -> usize {
198 self.devices.read().len()
199 }
200
201 pub fn select_device(&self, options: &LaunchOptions) -> Result<usize> {
203 let devices = self.devices.read();
204 if devices.is_empty() {
205 return Err(RingKernelError::BackendUnavailable(
206 "No GPU devices available".to_string(),
207 ));
208 }
209
210 let status = self.get_all_status();
212
213 if self.config.load_balancing == LoadBalancingStrategy::Custom {
215 if let Some(selector) = &*self.custom_selector.read() {
216 return Ok(selector(&status, options));
217 }
218 }
219
220 let candidates: Vec<_> = if !self.config.preferred_devices.is_empty() {
222 status
223 .into_iter()
224 .filter(|s| self.config.preferred_devices.contains(&s.info.index))
225 .collect()
226 } else {
227 status
228 };
229
230 if candidates.is_empty() {
231 return Err(RingKernelError::BackendUnavailable(
232 "No suitable GPU device available".to_string(),
233 ));
234 }
235
236 let selected = match self.config.load_balancing {
237 LoadBalancingStrategy::FirstAvailable => {
238 candidates.first().map(|s| s.info.index).unwrap_or(0)
239 }
240 LoadBalancingStrategy::LeastLoaded => candidates
241 .iter()
242 .filter(|s| s.available && s.kernel_count < self.config.max_kernels_per_device)
243 .min_by(|a, b| a.kernel_count.cmp(&b.kernel_count))
244 .map(|s| s.info.index)
245 .unwrap_or(0),
246 LoadBalancingStrategy::RoundRobin => {
247 let available: Vec<_> = candidates.iter().filter(|s| s.available).collect();
248
249 if available.is_empty() {
250 candidates.first().map(|s| s.info.index).unwrap_or(0)
251 } else {
252 let idx =
253 self.round_robin_counter.fetch_add(1, Ordering::Relaxed) % available.len();
254 available[idx].info.index
255 }
256 }
257 LoadBalancingStrategy::MemoryBased => candidates
258 .iter()
259 .filter(|s| s.available)
260 .max_by(|a, b| a.info.available_memory.cmp(&b.info.available_memory))
261 .map(|s| s.info.index)
262 .unwrap_or(0),
263 LoadBalancingStrategy::ComputeCapability => candidates
264 .iter()
265 .filter(|s| s.available)
266 .max_by(|a, b| {
267 let a_cap = a.info.compute_capability.unwrap_or((0, 0));
268 let b_cap = b.info.compute_capability.unwrap_or((0, 0));
269 a_cap.cmp(&b_cap)
270 })
271 .map(|s| s.info.index)
272 .unwrap_or(0),
273 LoadBalancingStrategy::Custom => {
274 0
276 }
277 };
278
279 Ok(selected)
280 }
281
282 pub fn assign_kernel(&self, kernel_id: KernelId, device_index: usize) {
284 self.kernel_device_map
285 .write()
286 .insert(kernel_id, device_index);
287
288 let counts = self.device_kernel_counts.read();
289 if device_index < counts.len() {
290 counts[device_index].fetch_add(1, Ordering::Relaxed);
291 }
292
293 self.total_kernels.fetch_add(1, Ordering::Relaxed);
294 }
295
296 pub fn remove_kernel(&self, kernel_id: &KernelId) {
298 if let Some(device_index) = self.kernel_device_map.write().remove(kernel_id) {
299 let counts = self.device_kernel_counts.read();
300 if device_index < counts.len() {
301 counts[device_index].fetch_sub(1, Ordering::Relaxed);
302 }
303 }
304 }
305
306 pub fn get_kernel_device(&self, kernel_id: &KernelId) -> Option<usize> {
308 self.kernel_device_map.read().get(kernel_id).copied()
309 }
310
311 pub fn kernels_on_device(&self, device_index: usize) -> Vec<KernelId> {
313 self.kernel_device_map
314 .read()
315 .iter()
316 .filter(|(_, &idx)| idx == device_index)
317 .map(|(k, _)| k.clone())
318 .collect()
319 }
320
321 pub fn get_all_status(&self) -> Vec<DeviceStatus> {
323 let devices = self.devices.read();
324 let kernel_map = self.kernel_device_map.read();
325 let counts = self.device_kernel_counts.read();
326
327 devices
328 .iter()
329 .enumerate()
330 .map(|(idx, info)| {
331 let kernel_count = if idx < counts.len() {
332 counts[idx].load(Ordering::Relaxed)
333 } else {
334 0
335 };
336
337 let kernels: Vec<_> = kernel_map
338 .iter()
339 .filter(|(_, &dev_idx)| dev_idx == idx)
340 .map(|(k, _)| k.clone())
341 .collect();
342
343 let load = kernel_count as f64 / self.config.max_kernels_per_device as f64;
344 let available = kernel_count < self.config.max_kernels_per_device;
345
346 DeviceStatus {
347 info: info.clone(),
348 kernel_count,
349 kernels,
350 available,
351 load,
352 }
353 })
354 .collect()
355 }
356
357 pub fn get_device_status(&self, device_index: usize) -> Option<DeviceStatus> {
359 self.get_all_status().into_iter().nth(device_index)
360 }
361
362 pub fn set_custom_selector<F>(&self, selector: F)
364 where
365 F: Fn(&[DeviceStatus], &LaunchOptions) -> usize + Send + Sync + 'static,
366 {
367 *self.custom_selector.write() = Some(Arc::new(selector));
368 }
369
370 pub fn stats(&self) -> MultiGpuStats {
372 let status = self.get_all_status();
373 let total_kernels: usize = status.iter().map(|s| s.kernel_count).sum();
374 let total_memory: u64 = status.iter().map(|s| s.info.total_memory).sum();
375 let available_memory: u64 = status.iter().map(|s| s.info.available_memory).sum();
376
377 MultiGpuStats {
378 device_count: status.len(),
379 total_kernels,
380 total_memory,
381 available_memory,
382 kernels_launched: self.total_kernels.load(Ordering::Relaxed),
383 }
384 }
385
386 pub fn can_p2p(&self, device_a: usize, device_b: usize) -> bool {
388 if !self.config.enable_p2p {
389 return false;
390 }
391
392 let devices = self.devices.read();
393 if let (Some(a), Some(b)) = (devices.get(device_a), devices.get(device_b)) {
394 a.p2p_capable && b.p2p_capable
395 } else {
396 false
397 }
398 }
399
400 pub fn update_device_memory(&self, device_index: usize, available_memory: u64) {
402 let mut devices = self.devices.write();
403 if let Some(device) = devices.get_mut(device_index) {
404 device.available_memory = available_memory;
405 }
406 }
407}
408
409#[derive(Debug, Clone, Default)]
411pub struct MultiGpuStats {
412 pub device_count: usize,
414 pub total_kernels: usize,
416 pub total_memory: u64,
418 pub available_memory: u64,
420 pub kernels_launched: u64,
422}
423
424pub struct MultiGpuBuilder {
426 config: MultiGpuConfig,
427}
428
429impl MultiGpuBuilder {
430 pub fn new() -> Self {
432 Self {
433 config: MultiGpuConfig::default(),
434 }
435 }
436
437 pub fn load_balancing(mut self, strategy: LoadBalancingStrategy) -> Self {
439 self.config.load_balancing = strategy;
440 self
441 }
442
443 pub fn auto_select_device(mut self, enable: bool) -> Self {
445 self.config.auto_select_device = enable;
446 self
447 }
448
449 pub fn max_kernels_per_device(mut self, max: usize) -> Self {
451 self.config.max_kernels_per_device = max;
452 self
453 }
454
455 pub fn enable_p2p(mut self, enable: bool) -> Self {
457 self.config.enable_p2p = enable;
458 self
459 }
460
461 pub fn preferred_devices(mut self, devices: Vec<usize>) -> Self {
463 self.config.preferred_devices = devices;
464 self
465 }
466
467 pub fn build(self) -> Arc<MultiGpuCoordinator> {
469 MultiGpuCoordinator::new(self.config)
470 }
471}
472
473impl Default for MultiGpuBuilder {
474 fn default() -> Self {
475 Self::new()
476 }
477}
478
479pub struct CrossDeviceTransfer {
481 pub source_device: usize,
483 pub dest_device: usize,
485 pub size: usize,
487 pub use_p2p: bool,
489}
490
491impl CrossDeviceTransfer {
492 pub fn new(source: usize, dest: usize, size: usize) -> Self {
494 Self {
495 source_device: source,
496 dest_device: dest,
497 size,
498 use_p2p: true,
499 }
500 }
501
502 pub fn without_p2p(mut self) -> Self {
504 self.use_p2p = false;
505 self
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512
513 #[test]
514 fn test_device_info() {
515 let info = DeviceInfo::new(0, "Test GPU".to_string(), Backend::Cuda);
516 assert_eq!(info.index, 0);
517 assert_eq!(info.name, "Test GPU");
518 assert_eq!(info.memory_utilization(), 0.0);
519 }
520
521 #[test]
522 fn test_coordinator_registration() {
523 let coord = MultiGpuBuilder::new().build();
524
525 let device = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
526 coord.register_device(device);
527
528 assert_eq!(coord.device_count(), 1);
529 assert!(coord.device(0).is_some());
530 }
531
532 #[test]
533 fn test_kernel_assignment() {
534 let coord = MultiGpuBuilder::new().build();
535
536 let device = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
537 coord.register_device(device);
538
539 let kernel_id = KernelId::new("test_kernel");
540 coord.assign_kernel(kernel_id.clone(), 0);
541
542 assert_eq!(coord.get_kernel_device(&kernel_id), Some(0));
543 assert_eq!(coord.kernels_on_device(0).len(), 1);
544 }
545
546 #[test]
547 fn test_load_balancing_least_loaded() {
548 let coord = MultiGpuBuilder::new()
549 .load_balancing(LoadBalancingStrategy::LeastLoaded)
550 .build();
551
552 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
554 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
555
556 coord.assign_kernel(KernelId::new("k1"), 0);
558
559 let selected = coord.select_device(&LaunchOptions::default()).unwrap();
561 assert_eq!(selected, 1);
562 }
563
564 #[test]
565 fn test_round_robin() {
566 let coord = MultiGpuBuilder::new()
567 .load_balancing(LoadBalancingStrategy::RoundRobin)
568 .build();
569
570 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
571 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
572
573 let d1 = coord.select_device(&LaunchOptions::default()).unwrap();
574 let d2 = coord.select_device(&LaunchOptions::default()).unwrap();
575 let d3 = coord.select_device(&LaunchOptions::default()).unwrap();
576
577 assert_ne!(d1, d2);
579 assert_eq!(d1, d3);
580 }
581}