1use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use async_trait::async_trait;
8use parking_lot::{Mutex, RwLock};
9use tokio::sync::{mpsc, Notify};
10
11use ringkernel_core::control::ControlBlock;
12use ringkernel_core::error::{Result, RingKernelError};
13use ringkernel_core::hlc::HlcClock;
14use ringkernel_core::k2k::{DeliveryReceipt, K2KEndpoint, K2KMessage};
15use ringkernel_core::message::{CorrelationId, MessageEnvelope};
16use ringkernel_core::queue::{BoundedQueue, MessageQueue};
17use ringkernel_core::runtime::{
18 KernelHandle, KernelHandleInner, KernelId, KernelState, KernelStatus, LaunchOptions,
19};
20use ringkernel_core::telemetry::{KernelMetrics, TelemetryBuffer};
21
22pub struct CpuKernel {
24 id: KernelId,
26 id_num: u64,
28 state: RwLock<KernelState>,
30 options: LaunchOptions,
32 control: RwLock<ControlBlock>,
34 telemetry: RwLock<TelemetryBuffer>,
36 input_queue: Arc<BoundedQueue>,
38 output_queue: Arc<BoundedQueue>,
40 clock: Arc<HlcClock>,
42 correlation_waiters: Mutex<std::collections::HashMap<u64, mpsc::Sender<MessageEnvelope>>>,
44 terminate_notify: Notify,
46 launched_at: Instant,
48 message_counter: AtomicU64,
50 k2k_endpoint: Mutex<Option<K2KEndpoint>>,
52}
53
54impl CpuKernel {
55 pub fn new(id: KernelId, options: LaunchOptions, node_id: u64) -> Self {
57 Self::new_with_k2k(id, options, node_id, None)
58 }
59
60 pub fn new_with_k2k(
62 id: KernelId,
63 options: LaunchOptions,
64 node_id: u64,
65 k2k_endpoint: Option<K2KEndpoint>,
66 ) -> Self {
67 static KERNEL_COUNTER: AtomicU64 = AtomicU64::new(1);
68 let id_num = KERNEL_COUNTER.fetch_add(1, Ordering::Relaxed);
69
70 let input_capacity = options.input_queue_capacity;
72 let output_capacity = options.output_queue_capacity;
73
74 let control = ControlBlock::with_capacities(input_capacity as u32, output_capacity as u32);
75
76 Self {
77 id,
78 id_num,
79 state: RwLock::new(KernelState::Created),
80 options,
81 control: RwLock::new(control),
82 telemetry: RwLock::new(TelemetryBuffer::new()),
83 input_queue: Arc::new(BoundedQueue::new(input_capacity)),
84 output_queue: Arc::new(BoundedQueue::new(output_capacity)),
85 clock: Arc::new(HlcClock::new(node_id)),
86 correlation_waiters: Mutex::new(std::collections::HashMap::new()),
87 terminate_notify: Notify::new(),
88 launched_at: Instant::now(),
89 message_counter: AtomicU64::new(0),
90 k2k_endpoint: Mutex::new(k2k_endpoint),
91 }
92 }
93
94 pub fn launch(&self) {
96 let mut state = self.state.write();
97 if *state == KernelState::Created {
98 *state = KernelState::Launched;
99 }
100 }
101
102 pub fn id(&self) -> &KernelId {
104 &self.id
105 }
106
107 pub fn state(&self) -> KernelState {
109 *self.state.read()
110 }
111
112 pub fn process_message(&self, envelope: MessageEnvelope) -> Result<()> {
114 let mut telemetry = self.telemetry.write();
116 telemetry.messages_processed += 1;
117
118 self.output_queue.try_enqueue(envelope)?;
121
122 Ok(())
123 }
124
125 pub fn handle(self: &Arc<Self>) -> KernelHandle {
127 KernelHandle::new(
128 self.id.clone(),
129 Arc::clone(self) as Arc<dyn KernelHandleInner>,
130 )
131 }
132
133 pub fn is_k2k_enabled(&self) -> bool {
135 self.k2k_endpoint.lock().is_some()
136 }
137
138 pub async fn k2k_send(
140 &self,
141 destination: KernelId,
142 envelope: MessageEnvelope,
143 ) -> Result<DeliveryReceipt> {
144 let endpoint = {
145 let mut endpoint_guard = self.k2k_endpoint.lock();
146 endpoint_guard.take().ok_or_else(|| {
147 RingKernelError::K2KError("K2K not enabled for this kernel".to_string())
148 })?
149 };
150 let result = endpoint.send(destination, envelope).await;
151 *self.k2k_endpoint.lock() = Some(endpoint);
153 result
154 }
155
156 pub fn k2k_try_recv(&self) -> Option<K2KMessage> {
158 let mut endpoint_guard = self.k2k_endpoint.lock();
159 endpoint_guard.as_mut()?.try_receive()
160 }
161
162 pub async fn k2k_recv(&self) -> Option<K2KMessage> {
164 let mut endpoint = {
166 let mut endpoint_guard = self.k2k_endpoint.lock();
167 endpoint_guard.take()?
168 };
169 let result = endpoint.receive().await;
170 *self.k2k_endpoint.lock() = Some(endpoint);
172 result
173 }
174}
175
176#[async_trait]
177impl KernelHandleInner for CpuKernel {
178 fn kernel_id_num(&self) -> u64 {
179 self.id_num
180 }
181
182 fn current_timestamp(&self) -> ringkernel_core::hlc::HlcTimestamp {
183 self.clock.now()
184 }
185
186 async fn activate(&self) -> Result<()> {
187 let mut state = self.state.write();
188 if !state.can_activate() {
189 return Err(RingKernelError::InvalidStateTransition {
190 from: format!("{:?}", *state),
191 to: "Active".to_string(),
192 });
193 }
194 *state = KernelState::Active;
195
196 let mut control = self.control.write();
198 control.is_active = 1;
199
200 Ok(())
201 }
202
203 async fn deactivate(&self) -> Result<()> {
204 let mut state = self.state.write();
205 if !state.can_deactivate() {
206 return Err(RingKernelError::InvalidStateTransition {
207 from: format!("{:?}", *state),
208 to: "Deactivated".to_string(),
209 });
210 }
211 *state = KernelState::Deactivated;
212
213 let mut control = self.control.write();
215 control.is_active = 0;
216
217 Ok(())
218 }
219
220 async fn terminate(&self) -> Result<()> {
221 let mut state = self.state.write();
222 if !state.can_terminate() {
223 return Err(RingKernelError::InvalidStateTransition {
224 from: format!("{:?}", *state),
225 to: "Terminated".to_string(),
226 });
227 }
228 *state = KernelState::Terminating;
229
230 {
232 let mut control = self.control.write();
233 control.should_terminate = 1;
234 control.is_active = 0;
235 }
236
237 self.terminate_notify.notify_waiters();
239
240 *state = KernelState::Terminated;
242 {
243 let mut control = self.control.write();
244 control.has_terminated = 1;
245 }
246
247 Ok(())
248 }
249
250 async fn send_envelope(&self, envelope: MessageEnvelope) -> Result<()> {
251 let state = self.state();
252 if !state.is_running() {
253 return Err(RingKernelError::KernelNotActive(self.id.to_string()));
254 }
255
256 self.input_queue
257 .enqueue_timeout(envelope, Duration::from_secs(5))?;
258 self.message_counter.fetch_add(1, Ordering::Relaxed);
259
260 Ok(())
261 }
262
263 async fn receive(&self) -> Result<MessageEnvelope> {
264 self.receive_timeout(Duration::from_secs(30)).await
265 }
266
267 async fn receive_timeout(&self, timeout: Duration) -> Result<MessageEnvelope> {
268 let envelope = self.output_queue.dequeue_timeout(timeout)?;
269
270 if envelope.header.correlation_id.is_some() {
272 let waiters = self.correlation_waiters.lock();
273 if let Some(sender) = waiters.get(&envelope.header.correlation_id.0) {
274 let _ = sender.try_send(envelope.clone());
275 }
276 }
277
278 Ok(envelope)
279 }
280
281 fn try_receive(&self) -> Result<MessageEnvelope> {
282 self.output_queue.try_dequeue()
283 }
284
285 async fn receive_correlated(
286 &self,
287 correlation: CorrelationId,
288 timeout: Duration,
289 ) -> Result<MessageEnvelope> {
290 let (tx, mut rx) = mpsc::channel(1);
291
292 {
294 let mut waiters = self.correlation_waiters.lock();
295 waiters.insert(correlation.0, tx);
296 }
297
298 let result = tokio::time::timeout(timeout, rx.recv()).await;
300
301 {
303 let mut waiters = self.correlation_waiters.lock();
304 waiters.remove(&correlation.0);
305 }
306
307 match result {
308 Ok(Some(envelope)) => Ok(envelope),
309 Ok(None) => Err(RingKernelError::ChannelClosed),
310 Err(_) => Err(RingKernelError::Timeout(timeout)),
311 }
312 }
313
314 fn status(&self) -> KernelStatus {
315 let state = *self.state.read();
316 let control = self.control.read();
317
318 KernelStatus {
319 id: self.id.clone(),
320 state,
321 mode: self.options.mode,
322 input_queue_depth: self.input_queue.len(),
323 output_queue_depth: self.output_queue.len(),
324 messages_processed: control.messages_processed,
325 uptime: self.launched_at.elapsed(),
326 }
327 }
328
329 fn metrics(&self) -> KernelMetrics {
330 let telemetry = *self.telemetry.read();
331 KernelMetrics {
332 telemetry,
333 kernel_id: self.id.to_string(),
334 collected_at: Instant::now(),
335 uptime: self.launched_at.elapsed(),
336 invocations: 0,
337 bytes_to_device: 0,
338 bytes_from_device: 0,
339 gpu_memory_used: 0,
340 host_memory_used: 0,
341 }
342 }
343
344 async fn wait(&self) -> Result<()> {
345 loop {
346 if self.state().is_finished() {
347 return Ok(());
348 }
349 self.terminate_notify.notified().await;
350 }
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use ringkernel_core::hlc::HlcTimestamp;
358 use ringkernel_core::message::MessageHeader;
359
360 fn make_envelope() -> MessageEnvelope {
361 MessageEnvelope {
362 header: MessageHeader::new(1, 0, 1, 8, HlcTimestamp::now(1)),
363 payload: vec![1, 2, 3, 4, 5, 6, 7, 8],
364 }
365 }
366
367 #[tokio::test]
368 async fn test_kernel_lifecycle() {
369 let kernel = Arc::new(CpuKernel::new(
370 KernelId::new("test"),
371 LaunchOptions::default(),
372 1,
373 ));
374
375 assert_eq!(kernel.state(), KernelState::Created);
376
377 kernel.launch();
378 assert_eq!(kernel.state(), KernelState::Launched);
379
380 kernel.activate().await.unwrap();
381 assert_eq!(kernel.state(), KernelState::Active);
382
383 kernel.deactivate().await.unwrap();
384 assert_eq!(kernel.state(), KernelState::Deactivated);
385
386 kernel.activate().await.unwrap();
387 assert_eq!(kernel.state(), KernelState::Active);
388
389 kernel.terminate().await.unwrap();
390 assert_eq!(kernel.state(), KernelState::Terminated);
391 }
392
393 #[tokio::test]
394 async fn test_send_receive() {
395 let kernel = Arc::new(CpuKernel::new(
396 KernelId::new("test"),
397 LaunchOptions::default(),
398 1,
399 ));
400
401 kernel.launch();
402 kernel.activate().await.unwrap();
403
404 let env = make_envelope();
406 kernel.send_envelope(env.clone()).await.unwrap();
407
408 let recv = kernel.input_queue.try_dequeue().unwrap();
410 kernel.output_queue.try_enqueue(recv).unwrap();
411
412 let result = kernel.try_receive().unwrap();
414 assert_eq!(result.header.message_type, env.header.message_type);
415 }
416
417 #[tokio::test]
418 async fn test_status() {
419 let kernel = Arc::new(CpuKernel::new(
420 KernelId::new("test"),
421 LaunchOptions::default(),
422 1,
423 ));
424
425 kernel.launch();
426 kernel.activate().await.unwrap();
427
428 let status = kernel.status();
429 assert_eq!(status.id.as_str(), "test");
430 assert_eq!(status.state, KernelState::Active);
431 }
432}