ringkernel_core/
k2k.rs

1//! Kernel-to-Kernel (K2K) direct messaging.
2//!
3//! This module provides infrastructure for direct communication between
4//! GPU kernels without host-side mediation.
5
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::Arc;
10use tokio::sync::mpsc;
11
12use crate::error::{Result, RingKernelError};
13use crate::hlc::HlcTimestamp;
14use crate::message::{MessageEnvelope, MessageId};
15use crate::runtime::KernelId;
16
17/// Configuration for K2K messaging.
18#[derive(Debug, Clone)]
19pub struct K2KConfig {
20    /// Maximum pending messages per kernel pair.
21    pub max_pending_messages: usize,
22    /// Timeout for delivery in milliseconds.
23    pub delivery_timeout_ms: u64,
24    /// Enable message tracing.
25    pub enable_tracing: bool,
26    /// Maximum hop count for routed messages.
27    pub max_hops: u8,
28}
29
30impl Default for K2KConfig {
31    fn default() -> Self {
32        Self {
33            max_pending_messages: 1024,
34            delivery_timeout_ms: 5000,
35            enable_tracing: false,
36            max_hops: 8,
37        }
38    }
39}
40
41/// A K2K message with routing information.
42#[derive(Debug, Clone)]
43pub struct K2KMessage {
44    /// Unique message ID.
45    pub id: MessageId,
46    /// Source kernel.
47    pub source: KernelId,
48    /// Destination kernel.
49    pub destination: KernelId,
50    /// The message envelope.
51    pub envelope: MessageEnvelope,
52    /// Hop count (for detecting routing loops).
53    pub hops: u8,
54    /// Timestamp when message was sent.
55    pub sent_at: HlcTimestamp,
56    /// Priority (higher = more urgent).
57    pub priority: u8,
58}
59
60impl K2KMessage {
61    /// Create a new K2K message.
62    pub fn new(
63        source: KernelId,
64        destination: KernelId,
65        envelope: MessageEnvelope,
66        timestamp: HlcTimestamp,
67    ) -> Self {
68        Self {
69            id: MessageId::generate(),
70            source,
71            destination,
72            envelope,
73            hops: 0,
74            sent_at: timestamp,
75            priority: 0,
76        }
77    }
78
79    /// Create with priority.
80    pub fn with_priority(mut self, priority: u8) -> Self {
81        self.priority = priority;
82        self
83    }
84
85    /// Increment hop count.
86    pub fn increment_hops(&mut self) -> Result<()> {
87        self.hops += 1;
88        if self.hops > 16 {
89            return Err(RingKernelError::K2KError(
90                "Maximum hop count exceeded".to_string(),
91            ));
92        }
93        Ok(())
94    }
95}
96
97/// Receipt for a K2K message delivery.
98#[derive(Debug, Clone)]
99pub struct DeliveryReceipt {
100    /// Message ID.
101    pub message_id: MessageId,
102    /// Source kernel.
103    pub source: KernelId,
104    /// Destination kernel.
105    pub destination: KernelId,
106    /// Delivery status.
107    pub status: DeliveryStatus,
108    /// Timestamp of delivery/failure.
109    pub timestamp: HlcTimestamp,
110}
111
112/// Status of message delivery.
113#[derive(Debug, Clone, Copy, PartialEq, Eq)]
114pub enum DeliveryStatus {
115    /// Message delivered successfully.
116    Delivered,
117    /// Message pending delivery.
118    Pending,
119    /// Destination kernel not found.
120    NotFound,
121    /// Destination queue full.
122    QueueFull,
123    /// Delivery timed out.
124    Timeout,
125    /// Maximum hops exceeded.
126    MaxHopsExceeded,
127}
128
129/// K2K endpoint for a single kernel.
130pub struct K2KEndpoint {
131    /// Kernel ID.
132    kernel_id: KernelId,
133    /// Incoming message channel.
134    receiver: mpsc::Receiver<K2KMessage>,
135    /// Reference to the broker.
136    broker: Arc<K2KBroker>,
137}
138
139impl K2KEndpoint {
140    /// Receive a K2K message (blocking).
141    pub async fn receive(&mut self) -> Option<K2KMessage> {
142        self.receiver.recv().await
143    }
144
145    /// Try to receive a K2K message (non-blocking).
146    pub fn try_receive(&mut self) -> Option<K2KMessage> {
147        self.receiver.try_recv().ok()
148    }
149
150    /// Send a message to another kernel.
151    pub async fn send(
152        &self,
153        destination: KernelId,
154        envelope: MessageEnvelope,
155    ) -> Result<DeliveryReceipt> {
156        self.broker
157            .send(self.kernel_id.clone(), destination, envelope)
158            .await
159    }
160
161    /// Send a high-priority message.
162    pub async fn send_priority(
163        &self,
164        destination: KernelId,
165        envelope: MessageEnvelope,
166        priority: u8,
167    ) -> Result<DeliveryReceipt> {
168        self.broker
169            .send_priority(self.kernel_id.clone(), destination, envelope, priority)
170            .await
171    }
172
173    /// Get pending message count.
174    pub fn pending_count(&self) -> usize {
175        // Note: This is an estimate since the channel may be modified concurrently
176        0 // mpsc doesn't provide len() directly
177    }
178}
179
180/// K2K message broker for routing messages between kernels.
181pub struct K2KBroker {
182    /// Configuration.
183    config: K2KConfig,
184    /// Registered endpoints (kernel_id -> sender).
185    endpoints: RwLock<HashMap<KernelId, mpsc::Sender<K2KMessage>>>,
186    /// Message counter.
187    message_counter: AtomicU64,
188    /// Delivery receipts (for acknowledgment).
189    receipts: RwLock<HashMap<MessageId, DeliveryReceipt>>,
190    /// Routing table for indirect delivery.
191    routing_table: RwLock<HashMap<KernelId, KernelId>>,
192}
193
194impl K2KBroker {
195    /// Create a new K2K broker.
196    pub fn new(config: K2KConfig) -> Arc<Self> {
197        Arc::new(Self {
198            config,
199            endpoints: RwLock::new(HashMap::new()),
200            message_counter: AtomicU64::new(0),
201            receipts: RwLock::new(HashMap::new()),
202            routing_table: RwLock::new(HashMap::new()),
203        })
204    }
205
206    /// Register a kernel endpoint.
207    pub fn register(self: &Arc<Self>, kernel_id: KernelId) -> K2KEndpoint {
208        let (sender, receiver) = mpsc::channel(self.config.max_pending_messages);
209
210        self.endpoints.write().insert(kernel_id.clone(), sender);
211
212        K2KEndpoint {
213            kernel_id,
214            receiver,
215            broker: Arc::clone(self),
216        }
217    }
218
219    /// Unregister a kernel endpoint.
220    pub fn unregister(&self, kernel_id: &KernelId) {
221        self.endpoints.write().remove(kernel_id);
222        self.routing_table.write().remove(kernel_id);
223    }
224
225    /// Check if a kernel is registered.
226    pub fn is_registered(&self, kernel_id: &KernelId) -> bool {
227        self.endpoints.read().contains_key(kernel_id)
228    }
229
230    /// Get all registered kernels.
231    pub fn registered_kernels(&self) -> Vec<KernelId> {
232        self.endpoints.read().keys().cloned().collect()
233    }
234
235    /// Send a message from one kernel to another.
236    pub async fn send(
237        &self,
238        source: KernelId,
239        destination: KernelId,
240        envelope: MessageEnvelope,
241    ) -> Result<DeliveryReceipt> {
242        self.send_priority(source, destination, envelope, 0).await
243    }
244
245    /// Send a priority message.
246    pub async fn send_priority(
247        &self,
248        source: KernelId,
249        destination: KernelId,
250        envelope: MessageEnvelope,
251        priority: u8,
252    ) -> Result<DeliveryReceipt> {
253        let timestamp = envelope.header.timestamp;
254        let mut message = K2KMessage::new(source.clone(), destination.clone(), envelope, timestamp);
255        message.priority = priority;
256
257        self.deliver(message).await
258    }
259
260    /// Deliver a message to its destination.
261    async fn deliver(&self, message: K2KMessage) -> Result<DeliveryReceipt> {
262        let message_id = message.id;
263        let source = message.source.clone();
264        let destination = message.destination.clone();
265        let timestamp = message.sent_at;
266
267        // Try direct delivery first
268        let endpoints = self.endpoints.read();
269        if let Some(sender) = endpoints.get(&destination) {
270            match sender.try_send(message) {
271                Ok(()) => {
272                    self.message_counter.fetch_add(1, Ordering::Relaxed);
273                    let receipt = DeliveryReceipt {
274                        message_id,
275                        source,
276                        destination,
277                        status: DeliveryStatus::Delivered,
278                        timestamp,
279                    };
280                    self.receipts.write().insert(message_id, receipt.clone());
281                    return Ok(receipt);
282                }
283                Err(mpsc::error::TrySendError::Full(_)) => {
284                    return Ok(DeliveryReceipt {
285                        message_id,
286                        source,
287                        destination,
288                        status: DeliveryStatus::QueueFull,
289                        timestamp,
290                    });
291                }
292                Err(mpsc::error::TrySendError::Closed(_)) => {
293                    return Ok(DeliveryReceipt {
294                        message_id,
295                        source,
296                        destination,
297                        status: DeliveryStatus::NotFound,
298                        timestamp,
299                    });
300                }
301            }
302        }
303        drop(endpoints);
304
305        // Try routing table
306        let next_hop = {
307            let routing = self.routing_table.read();
308            routing.get(&destination).cloned()
309        };
310
311        if let Some(next_hop) = next_hop {
312            let routed_message = K2KMessage {
313                id: message_id,
314                source,
315                destination: destination.clone(),
316                envelope: message.envelope,
317                hops: message.hops + 1,
318                sent_at: message.sent_at,
319                priority: message.priority,
320            };
321
322            if routed_message.hops > self.config.max_hops {
323                return Ok(DeliveryReceipt {
324                    message_id,
325                    source: routed_message.source,
326                    destination,
327                    status: DeliveryStatus::MaxHopsExceeded,
328                    timestamp,
329                });
330            }
331
332            // Try to deliver to next hop
333            let endpoints = self.endpoints.read();
334            if let Some(sender) = endpoints.get(&next_hop) {
335                if sender.try_send(routed_message).is_ok() {
336                    self.message_counter.fetch_add(1, Ordering::Relaxed);
337                    return Ok(DeliveryReceipt {
338                        message_id,
339                        source: message.source,
340                        destination,
341                        status: DeliveryStatus::Pending,
342                        timestamp,
343                    });
344                }
345            }
346        }
347
348        // Destination not found
349        Ok(DeliveryReceipt {
350            message_id,
351            source: message.source,
352            destination,
353            status: DeliveryStatus::NotFound,
354            timestamp,
355        })
356    }
357
358    /// Add a route to the routing table.
359    pub fn add_route(&self, destination: KernelId, next_hop: KernelId) {
360        self.routing_table.write().insert(destination, next_hop);
361    }
362
363    /// Remove a route from the routing table.
364    pub fn remove_route(&self, destination: &KernelId) {
365        self.routing_table.write().remove(destination);
366    }
367
368    /// Get statistics.
369    pub fn stats(&self) -> K2KStats {
370        K2KStats {
371            registered_endpoints: self.endpoints.read().len(),
372            messages_delivered: self.message_counter.load(Ordering::Relaxed),
373            routes_configured: self.routing_table.read().len(),
374        }
375    }
376
377    /// Get delivery receipt for a message.
378    pub fn get_receipt(&self, message_id: &MessageId) -> Option<DeliveryReceipt> {
379        self.receipts.read().get(message_id).cloned()
380    }
381}
382
383/// K2K messaging statistics.
384#[derive(Debug, Clone, Default)]
385pub struct K2KStats {
386    /// Number of registered endpoints.
387    pub registered_endpoints: usize,
388    /// Total messages delivered.
389    pub messages_delivered: u64,
390    /// Number of routes configured.
391    pub routes_configured: usize,
392}
393
394/// Builder for creating K2K infrastructure.
395pub struct K2KBuilder {
396    config: K2KConfig,
397}
398
399impl K2KBuilder {
400    /// Create a new builder.
401    pub fn new() -> Self {
402        Self {
403            config: K2KConfig::default(),
404        }
405    }
406
407    /// Set maximum pending messages.
408    pub fn max_pending_messages(mut self, count: usize) -> Self {
409        self.config.max_pending_messages = count;
410        self
411    }
412
413    /// Set delivery timeout.
414    pub fn delivery_timeout_ms(mut self, timeout: u64) -> Self {
415        self.config.delivery_timeout_ms = timeout;
416        self
417    }
418
419    /// Enable message tracing.
420    pub fn enable_tracing(mut self, enable: bool) -> Self {
421        self.config.enable_tracing = enable;
422        self
423    }
424
425    /// Set maximum hop count.
426    pub fn max_hops(mut self, hops: u8) -> Self {
427        self.config.max_hops = hops;
428        self
429    }
430
431    /// Build the K2K broker.
432    pub fn build(self) -> Arc<K2KBroker> {
433        K2KBroker::new(self.config)
434    }
435}
436
437impl Default for K2KBuilder {
438    fn default() -> Self {
439        Self::new()
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446
447    #[tokio::test]
448    async fn test_k2k_broker_registration() {
449        let broker = K2KBuilder::new().build();
450
451        let kernel1 = KernelId::new("kernel1");
452        let kernel2 = KernelId::new("kernel2");
453
454        let _endpoint1 = broker.register(kernel1.clone());
455        let _endpoint2 = broker.register(kernel2.clone());
456
457        assert!(broker.is_registered(&kernel1));
458        assert!(broker.is_registered(&kernel2));
459        assert_eq!(broker.registered_kernels().len(), 2);
460    }
461
462    #[tokio::test]
463    async fn test_k2k_message_delivery() {
464        let broker = K2KBuilder::new().build();
465
466        let kernel1 = KernelId::new("kernel1");
467        let kernel2 = KernelId::new("kernel2");
468
469        let endpoint1 = broker.register(kernel1.clone());
470        let mut endpoint2 = broker.register(kernel2.clone());
471
472        // Create a test envelope
473        let envelope = MessageEnvelope::empty(1, 2, HlcTimestamp::now(1));
474
475        // Send from kernel1 to kernel2
476        let receipt = endpoint1.send(kernel2.clone(), envelope).await.unwrap();
477        assert_eq!(receipt.status, DeliveryStatus::Delivered);
478
479        // Receive on kernel2
480        let message = endpoint2.try_receive();
481        assert!(message.is_some());
482        assert_eq!(message.unwrap().source, kernel1);
483    }
484
485    #[test]
486    fn test_k2k_config_default() {
487        let config = K2KConfig::default();
488        assert_eq!(config.max_pending_messages, 1024);
489        assert_eq!(config.delivery_timeout_ms, 5000);
490    }
491}