ringkernel_cpu/
runtime.rs

1//! CPU runtime implementation.
2
3use std::collections::HashMap;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use parking_lot::RwLock;
9use tracing::{debug, info};
10
11use ringkernel_core::error::{Result, RingKernelError};
12use ringkernel_core::k2k::{K2KBroker, K2KBuilder, K2KConfig};
13use ringkernel_core::runtime::{
14    Backend, KernelHandle, KernelHandleInner, KernelId, LaunchOptions, RingKernelRuntime,
15    RuntimeMetrics,
16};
17
18use crate::kernel::CpuKernel;
19
20/// CPU-based implementation of RingKernelRuntime.
21///
22/// This runtime executes kernels on the CPU, simulating GPU behavior.
23/// It's primarily used for testing and as a fallback when no GPU is available.
24pub struct CpuRuntime {
25    /// Node ID for HLC.
26    node_id: u64,
27    /// Active kernels.
28    kernels: RwLock<HashMap<KernelId, Arc<CpuKernel>>>,
29    /// Total kernels launched.
30    total_launched: AtomicU64,
31    /// Total messages sent.
32    messages_sent: AtomicU64,
33    /// Total messages received.
34    messages_received: AtomicU64,
35    /// Shutdown flag.
36    shutdown: RwLock<bool>,
37    /// K2K broker for kernel-to-kernel messaging.
38    k2k_broker: Option<Arc<K2KBroker>>,
39}
40
41impl CpuRuntime {
42    /// Create a new CPU runtime.
43    pub async fn new() -> Result<Self> {
44        Self::with_node_id(1).await
45    }
46
47    /// Create a CPU runtime with specific node ID.
48    pub async fn with_node_id(node_id: u64) -> Result<Self> {
49        Self::with_config(node_id, true).await
50    }
51
52    /// Create a CPU runtime with configuration options.
53    pub async fn with_config(node_id: u64, enable_k2k: bool) -> Result<Self> {
54        info!(
55            "Initializing CPU runtime (node_id={}, k2k={})",
56            node_id, enable_k2k
57        );
58
59        let k2k_broker = if enable_k2k {
60            Some(K2KBuilder::new().build())
61        } else {
62            None
63        };
64
65        Ok(Self {
66            node_id,
67            kernels: RwLock::new(HashMap::new()),
68            total_launched: AtomicU64::new(0),
69            messages_sent: AtomicU64::new(0),
70            messages_received: AtomicU64::new(0),
71            shutdown: RwLock::new(false),
72            k2k_broker,
73        })
74    }
75
76    /// Create a CPU runtime with custom K2K configuration.
77    pub async fn with_k2k_config(node_id: u64, k2k_config: K2KConfig) -> Result<Self> {
78        info!(
79            "Initializing CPU runtime with custom K2K config (node_id={})",
80            node_id
81        );
82
83        Ok(Self {
84            node_id,
85            kernels: RwLock::new(HashMap::new()),
86            total_launched: AtomicU64::new(0),
87            messages_sent: AtomicU64::new(0),
88            messages_received: AtomicU64::new(0),
89            shutdown: RwLock::new(false),
90            k2k_broker: Some(K2KBroker::new(k2k_config)),
91        })
92    }
93
94    /// Get node ID.
95    pub fn node_id(&self) -> u64 {
96        self.node_id
97    }
98
99    /// Check if runtime is shut down.
100    pub fn is_shutdown(&self) -> bool {
101        *self.shutdown.read()
102    }
103
104    /// Check if K2K messaging is enabled.
105    pub fn is_k2k_enabled(&self) -> bool {
106        self.k2k_broker.is_some()
107    }
108
109    /// Get the K2K broker (if enabled).
110    pub fn k2k_broker(&self) -> Option<&Arc<K2KBroker>> {
111        self.k2k_broker.as_ref()
112    }
113}
114
115#[async_trait]
116impl RingKernelRuntime for CpuRuntime {
117    fn backend(&self) -> Backend {
118        Backend::Cpu
119    }
120
121    fn is_backend_available(&self, backend: Backend) -> bool {
122        matches!(backend, Backend::Cpu | Backend::Auto)
123    }
124
125    async fn launch(&self, kernel_id: &str, options: LaunchOptions) -> Result<KernelHandle> {
126        if self.is_shutdown() {
127            return Err(RingKernelError::BackendError(
128                "Runtime is shut down".to_string(),
129            ));
130        }
131
132        let id = KernelId::new(kernel_id);
133
134        // Check if kernel already exists
135        {
136            let kernels = self.kernels.read();
137            if kernels.contains_key(&id) {
138                return Err(RingKernelError::InvalidConfig(format!(
139                    "Kernel '{}' already exists",
140                    kernel_id
141                )));
142            }
143        }
144
145        debug!(
146            "Launching CPU kernel '{}' (grid={}, block={}, k2k={})",
147            kernel_id,
148            options.grid_size,
149            options.block_size,
150            self.is_k2k_enabled()
151        );
152
153        // Register with K2K broker if enabled
154        let k2k_endpoint = self
155            .k2k_broker
156            .as_ref()
157            .map(|broker| broker.register(id.clone()));
158
159        // Create kernel with K2K endpoint
160        let kernel = Arc::new(CpuKernel::new_with_k2k(
161            id.clone(),
162            options.clone(),
163            self.node_id,
164            k2k_endpoint,
165        ));
166        kernel.launch();
167
168        // Auto-activate if requested
169        if options.auto_activate {
170            kernel.activate().await?;
171        }
172
173        // Store kernel
174        {
175            let mut kernels = self.kernels.write();
176            kernels.insert(id.clone(), Arc::clone(&kernel));
177        }
178
179        self.total_launched.fetch_add(1, Ordering::Relaxed);
180
181        info!("CPU kernel '{}' launched successfully", kernel_id);
182
183        Ok(kernel.handle())
184    }
185
186    fn get_kernel(&self, kernel_id: &KernelId) -> Option<KernelHandle> {
187        let kernels = self.kernels.read();
188        kernels.get(kernel_id).map(|k| k.handle())
189    }
190
191    fn list_kernels(&self) -> Vec<KernelId> {
192        let kernels = self.kernels.read();
193        kernels.keys().cloned().collect()
194    }
195
196    fn metrics(&self) -> RuntimeMetrics {
197        let kernels = self.kernels.read();
198        let active = kernels.values().filter(|k| k.state().is_running()).count();
199
200        RuntimeMetrics {
201            active_kernels: active,
202            total_launched: self.total_launched.load(Ordering::Relaxed),
203            messages_sent: self.messages_sent.load(Ordering::Relaxed),
204            messages_received: self.messages_received.load(Ordering::Relaxed),
205            gpu_memory_used: 0,
206            host_memory_used: 0,
207        }
208    }
209
210    async fn shutdown(&self) -> Result<()> {
211        info!("Shutting down CPU runtime");
212
213        // Mark as shutdown
214        *self.shutdown.write() = true;
215
216        // Terminate all kernels
217        let kernel_ids: Vec<KernelId> = {
218            let kernels = self.kernels.read();
219            kernels.keys().cloned().collect()
220        };
221
222        for id in kernel_ids.iter() {
223            if let Some(kernel) = self.get_kernel(id) {
224                if let Err(e) = kernel.terminate().await {
225                    debug!("Error terminating kernel '{}': {}", id, e);
226                }
227            }
228            // Unregister from K2K broker
229            if let Some(broker) = &self.k2k_broker {
230                broker.unregister(id);
231            }
232        }
233
234        // Clear kernel map
235        {
236            let mut kernels = self.kernels.write();
237            kernels.clear();
238        }
239
240        info!("CPU runtime shut down complete");
241        Ok(())
242    }
243}
244
245impl Drop for CpuRuntime {
246    fn drop(&mut self) {
247        if !self.is_shutdown() {
248            // Best effort cleanup
249            let kernels = self.kernels.get_mut();
250            kernels.clear();
251        }
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    #[tokio::test]
260    async fn test_runtime_creation() {
261        let runtime = CpuRuntime::new().await.unwrap();
262        assert_eq!(runtime.backend(), Backend::Cpu);
263        assert!(runtime.is_backend_available(Backend::Cpu));
264        assert!(!runtime.is_backend_available(Backend::Cuda));
265    }
266
267    #[tokio::test]
268    async fn test_kernel_launch() {
269        let runtime = CpuRuntime::new().await.unwrap();
270
271        let handle = runtime
272            .launch("test_kernel", LaunchOptions::default())
273            .await
274            .unwrap();
275
276        assert_eq!(handle.id().as_str(), "test_kernel");
277
278        let status = handle.status();
279        assert!(status.state.is_running());
280    }
281
282    #[tokio::test]
283    async fn test_list_kernels() {
284        let runtime = CpuRuntime::new().await.unwrap();
285
286        runtime
287            .launch("kernel1", LaunchOptions::default())
288            .await
289            .unwrap();
290        runtime
291            .launch("kernel2", LaunchOptions::default())
292            .await
293            .unwrap();
294
295        let ids = runtime.list_kernels();
296        assert_eq!(ids.len(), 2);
297    }
298
299    #[tokio::test]
300    async fn test_duplicate_kernel() {
301        let runtime = CpuRuntime::new().await.unwrap();
302
303        runtime
304            .launch("test", LaunchOptions::default())
305            .await
306            .unwrap();
307
308        let result = runtime.launch("test", LaunchOptions::default()).await;
309        assert!(result.is_err());
310    }
311
312    #[tokio::test]
313    async fn test_shutdown() {
314        let runtime = CpuRuntime::new().await.unwrap();
315
316        runtime
317            .launch("kernel1", LaunchOptions::default())
318            .await
319            .unwrap();
320
321        runtime.shutdown().await.unwrap();
322
323        assert!(runtime.is_shutdown());
324        assert!(runtime.list_kernels().is_empty());
325    }
326
327    #[tokio::test]
328    async fn test_metrics() {
329        let runtime = CpuRuntime::new().await.unwrap();
330
331        runtime
332            .launch("kernel1", LaunchOptions::default())
333            .await
334            .unwrap();
335        runtime
336            .launch("kernel2", LaunchOptions::default())
337            .await
338            .unwrap();
339
340        let metrics = runtime.metrics();
341        assert_eq!(metrics.active_kernels, 2);
342        assert_eq!(metrics.total_launched, 2);
343    }
344}