ringkernel_core/
multi_gpu.rs

1//! Multi-GPU coordination and load balancing.
2//!
3//! This module provides infrastructure for coordinating work across
4//! multiple GPUs, including device selection, load balancing, and
5//! cross-device communication.
6
7use parking_lot::RwLock;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
10use std::sync::Arc;
11
12use crate::error::{Result, RingKernelError};
13use crate::runtime::{Backend, KernelId, LaunchOptions};
14
15/// Configuration for multi-GPU coordination.
16#[derive(Debug, Clone)]
17pub struct MultiGpuConfig {
18    /// Load balancing strategy.
19    pub load_balancing: LoadBalancingStrategy,
20    /// Enable automatic device selection.
21    pub auto_select_device: bool,
22    /// Maximum kernels per device.
23    pub max_kernels_per_device: usize,
24    /// Enable peer-to-peer transfers when available.
25    pub enable_p2p: bool,
26    /// Preferred devices (by index).
27    pub preferred_devices: Vec<usize>,
28}
29
30impl Default for MultiGpuConfig {
31    fn default() -> Self {
32        Self {
33            load_balancing: LoadBalancingStrategy::LeastLoaded,
34            auto_select_device: true,
35            max_kernels_per_device: 64,
36            enable_p2p: true,
37            preferred_devices: vec![],
38        }
39    }
40}
41
42/// Strategy for balancing load across devices.
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum LoadBalancingStrategy {
45    /// Always use the first available device.
46    FirstAvailable,
47    /// Use the device with fewest kernels.
48    LeastLoaded,
49    /// Round-robin across devices.
50    RoundRobin,
51    /// Select based on memory availability.
52    MemoryBased,
53    /// Select based on compute capability.
54    ComputeCapability,
55    /// Custom selection function.
56    Custom,
57}
58
59/// Information about a GPU device.
60#[derive(Debug, Clone)]
61pub struct DeviceInfo {
62    /// Device index.
63    pub index: usize,
64    /// Device name.
65    pub name: String,
66    /// Backend type.
67    pub backend: Backend,
68    /// Total memory in bytes.
69    pub total_memory: u64,
70    /// Available memory in bytes.
71    pub available_memory: u64,
72    /// Compute capability (for CUDA).
73    pub compute_capability: Option<(u32, u32)>,
74    /// Maximum threads per block.
75    pub max_threads_per_block: u32,
76    /// Number of multiprocessors.
77    pub multiprocessor_count: u32,
78    /// Whether device supports P2P with other devices.
79    pub p2p_capable: bool,
80}
81
82impl DeviceInfo {
83    /// Create a new device info.
84    pub fn new(index: usize, name: String, backend: Backend) -> Self {
85        Self {
86            index,
87            name,
88            backend,
89            total_memory: 0,
90            available_memory: 0,
91            compute_capability: None,
92            max_threads_per_block: 1024,
93            multiprocessor_count: 1,
94            p2p_capable: false,
95        }
96    }
97
98    /// Get memory utilization (0.0-1.0).
99    pub fn memory_utilization(&self) -> f64 {
100        if self.total_memory == 0 {
101            0.0
102        } else {
103            1.0 - (self.available_memory as f64 / self.total_memory as f64)
104        }
105    }
106}
107
108/// Status of a device in the multi-GPU coordinator.
109#[derive(Debug, Clone)]
110pub struct DeviceStatus {
111    /// Device info.
112    pub info: DeviceInfo,
113    /// Number of kernels running on this device.
114    pub kernel_count: usize,
115    /// Kernels running on this device.
116    pub kernels: Vec<KernelId>,
117    /// Whether device is available for new kernels.
118    pub available: bool,
119    /// Current load estimate (0.0-1.0).
120    pub load: f64,
121}
122
123/// Multi-GPU coordinator for managing kernels across devices.
124pub struct MultiGpuCoordinator {
125    /// Configuration.
126    config: MultiGpuConfig,
127    /// Available devices.
128    devices: RwLock<Vec<DeviceInfo>>,
129    /// Kernel-to-device mapping.
130    kernel_device_map: RwLock<HashMap<KernelId, usize>>,
131    /// Device kernel counts.
132    device_kernel_counts: RwLock<Vec<AtomicUsize>>,
133    /// Round-robin counter.
134    round_robin_counter: AtomicUsize,
135    /// Total kernels launched.
136    total_kernels: AtomicU64,
137    /// Device selection callbacks (for custom strategy).
138    #[allow(clippy::type_complexity)]
139    custom_selector:
140        RwLock<Option<Arc<dyn Fn(&[DeviceStatus], &LaunchOptions) -> usize + Send + Sync>>>,
141}
142
143impl MultiGpuCoordinator {
144    /// Create a new multi-GPU coordinator.
145    pub fn new(config: MultiGpuConfig) -> Arc<Self> {
146        Arc::new(Self {
147            config,
148            devices: RwLock::new(Vec::new()),
149            kernel_device_map: RwLock::new(HashMap::new()),
150            device_kernel_counts: RwLock::new(Vec::new()),
151            round_robin_counter: AtomicUsize::new(0),
152            total_kernels: AtomicU64::new(0),
153            custom_selector: RwLock::new(None),
154        })
155    }
156
157    /// Register a device with the coordinator.
158    pub fn register_device(&self, device: DeviceInfo) {
159        let index = device.index;
160        let mut devices = self.devices.write();
161        let mut counts = self.device_kernel_counts.write();
162
163        // Ensure we have enough slots
164        let mut current_len = devices.len();
165        while current_len <= index {
166            devices.push(DeviceInfo::new(
167                current_len,
168                "Unknown".to_string(),
169                Backend::Cpu,
170            ));
171            counts.push(AtomicUsize::new(0));
172            current_len += 1;
173        }
174
175        devices[index] = device;
176    }
177
178    /// Unregister a device.
179    pub fn unregister_device(&self, index: usize) {
180        let devices = self.devices.read();
181        if index < devices.len() {
182            // Move kernels to another device (TODO: implement migration)
183        }
184    }
185
186    /// Get all registered devices.
187    pub fn devices(&self) -> Vec<DeviceInfo> {
188        self.devices.read().clone()
189    }
190
191    /// Get device info by index.
192    pub fn device(&self, index: usize) -> Option<DeviceInfo> {
193        self.devices.read().get(index).cloned()
194    }
195
196    /// Get number of devices.
197    pub fn device_count(&self) -> usize {
198        self.devices.read().len()
199    }
200
201    /// Select a device for launching a kernel.
202    pub fn select_device(&self, options: &LaunchOptions) -> Result<usize> {
203        let devices = self.devices.read();
204        if devices.is_empty() {
205            return Err(RingKernelError::BackendUnavailable(
206                "No GPU devices available".to_string(),
207            ));
208        }
209
210        // Get current status
211        let status = self.get_all_status();
212
213        // Check for custom selector
214        if self.config.load_balancing == LoadBalancingStrategy::Custom {
215            if let Some(selector) = &*self.custom_selector.read() {
216                return Ok(selector(&status, options));
217            }
218        }
219
220        // Apply preferred devices filter if specified
221        let candidates: Vec<_> = if !self.config.preferred_devices.is_empty() {
222            status
223                .into_iter()
224                .filter(|s| self.config.preferred_devices.contains(&s.info.index))
225                .collect()
226        } else {
227            status
228        };
229
230        if candidates.is_empty() {
231            return Err(RingKernelError::BackendUnavailable(
232                "No suitable GPU device available".to_string(),
233            ));
234        }
235
236        let selected = match self.config.load_balancing {
237            LoadBalancingStrategy::FirstAvailable => {
238                candidates.first().map(|s| s.info.index).unwrap_or(0)
239            }
240            LoadBalancingStrategy::LeastLoaded => candidates
241                .iter()
242                .filter(|s| s.available && s.kernel_count < self.config.max_kernels_per_device)
243                .min_by(|a, b| a.kernel_count.cmp(&b.kernel_count))
244                .map(|s| s.info.index)
245                .unwrap_or(0),
246            LoadBalancingStrategy::RoundRobin => {
247                let available: Vec<_> = candidates.iter().filter(|s| s.available).collect();
248
249                if available.is_empty() {
250                    candidates.first().map(|s| s.info.index).unwrap_or(0)
251                } else {
252                    let idx =
253                        self.round_robin_counter.fetch_add(1, Ordering::Relaxed) % available.len();
254                    available[idx].info.index
255                }
256            }
257            LoadBalancingStrategy::MemoryBased => candidates
258                .iter()
259                .filter(|s| s.available)
260                .max_by(|a, b| a.info.available_memory.cmp(&b.info.available_memory))
261                .map(|s| s.info.index)
262                .unwrap_or(0),
263            LoadBalancingStrategy::ComputeCapability => candidates
264                .iter()
265                .filter(|s| s.available)
266                .max_by(|a, b| {
267                    let a_cap = a.info.compute_capability.unwrap_or((0, 0));
268                    let b_cap = b.info.compute_capability.unwrap_or((0, 0));
269                    a_cap.cmp(&b_cap)
270                })
271                .map(|s| s.info.index)
272                .unwrap_or(0),
273            LoadBalancingStrategy::Custom => {
274                // Should have been handled above
275                0
276            }
277        };
278
279        Ok(selected)
280    }
281
282    /// Assign a kernel to a device.
283    pub fn assign_kernel(&self, kernel_id: KernelId, device_index: usize) {
284        self.kernel_device_map
285            .write()
286            .insert(kernel_id, device_index);
287
288        let counts = self.device_kernel_counts.read();
289        if device_index < counts.len() {
290            counts[device_index].fetch_add(1, Ordering::Relaxed);
291        }
292
293        self.total_kernels.fetch_add(1, Ordering::Relaxed);
294    }
295
296    /// Remove a kernel assignment.
297    pub fn remove_kernel(&self, kernel_id: &KernelId) {
298        if let Some(device_index) = self.kernel_device_map.write().remove(kernel_id) {
299            let counts = self.device_kernel_counts.read();
300            if device_index < counts.len() {
301                counts[device_index].fetch_sub(1, Ordering::Relaxed);
302            }
303        }
304    }
305
306    /// Get device for a kernel.
307    pub fn get_kernel_device(&self, kernel_id: &KernelId) -> Option<usize> {
308        self.kernel_device_map.read().get(kernel_id).copied()
309    }
310
311    /// Get all kernels on a device.
312    pub fn kernels_on_device(&self, device_index: usize) -> Vec<KernelId> {
313        self.kernel_device_map
314            .read()
315            .iter()
316            .filter(|(_, &idx)| idx == device_index)
317            .map(|(k, _)| k.clone())
318            .collect()
319    }
320
321    /// Get status of all devices.
322    pub fn get_all_status(&self) -> Vec<DeviceStatus> {
323        let devices = self.devices.read();
324        let kernel_map = self.kernel_device_map.read();
325        let counts = self.device_kernel_counts.read();
326
327        devices
328            .iter()
329            .enumerate()
330            .map(|(idx, info)| {
331                let kernel_count = if idx < counts.len() {
332                    counts[idx].load(Ordering::Relaxed)
333                } else {
334                    0
335                };
336
337                let kernels: Vec<_> = kernel_map
338                    .iter()
339                    .filter(|(_, &dev_idx)| dev_idx == idx)
340                    .map(|(k, _)| k.clone())
341                    .collect();
342
343                let load = kernel_count as f64 / self.config.max_kernels_per_device as f64;
344                let available = kernel_count < self.config.max_kernels_per_device;
345
346                DeviceStatus {
347                    info: info.clone(),
348                    kernel_count,
349                    kernels,
350                    available,
351                    load,
352                }
353            })
354            .collect()
355    }
356
357    /// Get status of a specific device.
358    pub fn get_device_status(&self, device_index: usize) -> Option<DeviceStatus> {
359        self.get_all_status().into_iter().nth(device_index)
360    }
361
362    /// Set custom device selector.
363    pub fn set_custom_selector<F>(&self, selector: F)
364    where
365        F: Fn(&[DeviceStatus], &LaunchOptions) -> usize + Send + Sync + 'static,
366    {
367        *self.custom_selector.write() = Some(Arc::new(selector));
368    }
369
370    /// Get coordinator statistics.
371    pub fn stats(&self) -> MultiGpuStats {
372        let status = self.get_all_status();
373        let total_kernels: usize = status.iter().map(|s| s.kernel_count).sum();
374        let total_memory: u64 = status.iter().map(|s| s.info.total_memory).sum();
375        let available_memory: u64 = status.iter().map(|s| s.info.available_memory).sum();
376
377        MultiGpuStats {
378            device_count: status.len(),
379            total_kernels,
380            total_memory,
381            available_memory,
382            kernels_launched: self.total_kernels.load(Ordering::Relaxed),
383        }
384    }
385
386    /// Check if P2P is available between two devices.
387    pub fn can_p2p(&self, device_a: usize, device_b: usize) -> bool {
388        if !self.config.enable_p2p {
389            return false;
390        }
391
392        let devices = self.devices.read();
393        if let (Some(a), Some(b)) = (devices.get(device_a), devices.get(device_b)) {
394            a.p2p_capable && b.p2p_capable
395        } else {
396            false
397        }
398    }
399
400    /// Update device memory info.
401    pub fn update_device_memory(&self, device_index: usize, available_memory: u64) {
402        let mut devices = self.devices.write();
403        if let Some(device) = devices.get_mut(device_index) {
404            device.available_memory = available_memory;
405        }
406    }
407}
408
409/// Multi-GPU coordinator statistics.
410#[derive(Debug, Clone, Default)]
411pub struct MultiGpuStats {
412    /// Number of registered devices.
413    pub device_count: usize,
414    /// Total kernels across all devices.
415    pub total_kernels: usize,
416    /// Total memory across all devices.
417    pub total_memory: u64,
418    /// Available memory across all devices.
419    pub available_memory: u64,
420    /// Total kernels launched since start.
421    pub kernels_launched: u64,
422}
423
424/// Builder for multi-GPU coordinator.
425pub struct MultiGpuBuilder {
426    config: MultiGpuConfig,
427}
428
429impl MultiGpuBuilder {
430    /// Create a new builder.
431    pub fn new() -> Self {
432        Self {
433            config: MultiGpuConfig::default(),
434        }
435    }
436
437    /// Set load balancing strategy.
438    pub fn load_balancing(mut self, strategy: LoadBalancingStrategy) -> Self {
439        self.config.load_balancing = strategy;
440        self
441    }
442
443    /// Set auto device selection.
444    pub fn auto_select_device(mut self, enable: bool) -> Self {
445        self.config.auto_select_device = enable;
446        self
447    }
448
449    /// Set max kernels per device.
450    pub fn max_kernels_per_device(mut self, max: usize) -> Self {
451        self.config.max_kernels_per_device = max;
452        self
453    }
454
455    /// Enable P2P transfers.
456    pub fn enable_p2p(mut self, enable: bool) -> Self {
457        self.config.enable_p2p = enable;
458        self
459    }
460
461    /// Set preferred devices.
462    pub fn preferred_devices(mut self, devices: Vec<usize>) -> Self {
463        self.config.preferred_devices = devices;
464        self
465    }
466
467    /// Build the coordinator.
468    pub fn build(self) -> Arc<MultiGpuCoordinator> {
469        MultiGpuCoordinator::new(self.config)
470    }
471}
472
473impl Default for MultiGpuBuilder {
474    fn default() -> Self {
475        Self::new()
476    }
477}
478
479/// Helper for cross-device data transfer.
480pub struct CrossDeviceTransfer {
481    /// Source device index.
482    pub source_device: usize,
483    /// Destination device index.
484    pub dest_device: usize,
485    /// Data size in bytes.
486    pub size: usize,
487    /// Use P2P if available.
488    pub use_p2p: bool,
489}
490
491impl CrossDeviceTransfer {
492    /// Create a new transfer specification.
493    pub fn new(source: usize, dest: usize, size: usize) -> Self {
494        Self {
495            source_device: source,
496            dest_device: dest,
497            size,
498            use_p2p: true,
499        }
500    }
501
502    /// Disable P2P for this transfer.
503    pub fn without_p2p(mut self) -> Self {
504        self.use_p2p = false;
505        self
506    }
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512
513    #[test]
514    fn test_device_info() {
515        let info = DeviceInfo::new(0, "Test GPU".to_string(), Backend::Cuda);
516        assert_eq!(info.index, 0);
517        assert_eq!(info.name, "Test GPU");
518        assert_eq!(info.memory_utilization(), 0.0);
519    }
520
521    #[test]
522    fn test_coordinator_registration() {
523        let coord = MultiGpuBuilder::new().build();
524
525        let device = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
526        coord.register_device(device);
527
528        assert_eq!(coord.device_count(), 1);
529        assert!(coord.device(0).is_some());
530    }
531
532    #[test]
533    fn test_kernel_assignment() {
534        let coord = MultiGpuBuilder::new().build();
535
536        let device = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
537        coord.register_device(device);
538
539        let kernel_id = KernelId::new("test_kernel");
540        coord.assign_kernel(kernel_id.clone(), 0);
541
542        assert_eq!(coord.get_kernel_device(&kernel_id), Some(0));
543        assert_eq!(coord.kernels_on_device(0).len(), 1);
544    }
545
546    #[test]
547    fn test_load_balancing_least_loaded() {
548        let coord = MultiGpuBuilder::new()
549            .load_balancing(LoadBalancingStrategy::LeastLoaded)
550            .build();
551
552        // Register two devices
553        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
554        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
555
556        // Assign a kernel to device 0
557        coord.assign_kernel(KernelId::new("k1"), 0);
558
559        // Next kernel should go to device 1 (least loaded)
560        let selected = coord.select_device(&LaunchOptions::default()).unwrap();
561        assert_eq!(selected, 1);
562    }
563
564    #[test]
565    fn test_round_robin() {
566        let coord = MultiGpuBuilder::new()
567            .load_balancing(LoadBalancingStrategy::RoundRobin)
568            .build();
569
570        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
571        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
572
573        let d1 = coord.select_device(&LaunchOptions::default()).unwrap();
574        let d2 = coord.select_device(&LaunchOptions::default()).unwrap();
575        let d3 = coord.select_device(&LaunchOptions::default()).unwrap();
576
577        // Should cycle through devices
578        assert_ne!(d1, d2);
579        assert_eq!(d1, d3);
580    }
581}