ringkernel_core/
message.rs

1//! Message types and traits for kernel-to-kernel communication.
2//!
3//! This module defines the core message abstraction used for communication
4//! between GPU kernels and between host and device.
5
6use bytemuck::{Pod, Zeroable};
7use rkyv::{Archive, Deserialize, Serialize};
8use zerocopy::{AsBytes, FromBytes, FromZeroes};
9
10use crate::hlc::HlcTimestamp;
11
12/// Unique message identifier.
13#[derive(
14    Debug,
15    Clone,
16    Copy,
17    PartialEq,
18    Eq,
19    Hash,
20    Default,
21    AsBytes,
22    FromBytes,
23    FromZeroes,
24    Pod,
25    Zeroable,
26    Archive,
27    Serialize,
28    Deserialize,
29)]
30#[repr(C)]
31pub struct MessageId(pub u64);
32
33impl MessageId {
34    /// Create a new message ID.
35    pub const fn new(id: u64) -> Self {
36        Self(id)
37    }
38
39    /// Generate a new unique message ID.
40    pub fn generate() -> Self {
41        use std::sync::atomic::{AtomicU64, Ordering};
42        static COUNTER: AtomicU64 = AtomicU64::new(1);
43        Self(COUNTER.fetch_add(1, Ordering::Relaxed))
44    }
45
46    /// Get the inner value.
47    pub const fn inner(&self) -> u64 {
48        self.0
49    }
50}
51
52impl std::fmt::Display for MessageId {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        write!(f, "msg:{:016x}", self.0)
55    }
56}
57
58/// Correlation ID for request-response patterns.
59#[derive(
60    Debug,
61    Clone,
62    Copy,
63    PartialEq,
64    Eq,
65    Hash,
66    Default,
67    AsBytes,
68    FromBytes,
69    FromZeroes,
70    Pod,
71    Zeroable,
72    Archive,
73    Serialize,
74    Deserialize,
75)]
76#[repr(C)]
77pub struct CorrelationId(pub u64);
78
79impl CorrelationId {
80    /// Create a new correlation ID.
81    pub const fn new(id: u64) -> Self {
82        Self(id)
83    }
84
85    /// Generate a new unique correlation ID.
86    pub fn generate() -> Self {
87        Self(MessageId::generate().0)
88    }
89
90    /// No correlation (for fire-and-forget messages).
91    pub const fn none() -> Self {
92        Self(0)
93    }
94
95    /// Check if this is a valid correlation ID.
96    pub const fn is_some(&self) -> bool {
97        self.0 != 0
98    }
99}
100
101/// Message priority levels.
102#[derive(
103    Debug,
104    Clone,
105    Copy,
106    PartialEq,
107    Eq,
108    Hash,
109    Default,
110    rkyv::Archive,
111    rkyv::Serialize,
112    rkyv::Deserialize,
113)]
114#[archive(compare(PartialEq))]
115#[repr(u8)]
116pub enum Priority {
117    /// Low priority (background tasks).
118    Low = 0,
119    /// Normal priority (default).
120    #[default]
121    Normal = 1,
122    /// High priority (important tasks).
123    High = 2,
124    /// Critical priority (system messages).
125    Critical = 3,
126}
127
128impl Priority {
129    /// Convert from u8.
130    pub const fn from_u8(value: u8) -> Self {
131        match value {
132            0 => Self::Low,
133            1 => Self::Normal,
134            2 => Self::High,
135            _ => Self::Critical,
136        }
137    }
138
139    /// Convert to u8.
140    pub const fn as_u8(self) -> u8 {
141        self as u8
142    }
143}
144
145/// Priority constants for convenient use.
146///
147/// # Example
148/// ```ignore
149/// use ringkernel::prelude::*;
150///
151/// let opts = LaunchOptions::default()
152///     .with_priority(priority::HIGH);
153/// ```
154pub mod priority {
155    /// Low priority (0) - background tasks.
156    pub const LOW: u8 = 0;
157    /// Normal priority (64) - default.
158    pub const NORMAL: u8 = 64;
159    /// High priority (128) - important tasks.
160    pub const HIGH: u8 = 128;
161    /// Critical priority (192) - system messages.
162    pub const CRITICAL: u8 = 192;
163}
164
165/// Fixed-size message header (256 bytes, cache-line aligned).
166///
167/// This header precedes the variable-length payload and provides
168/// all metadata needed for routing and processing.
169#[derive(Debug, Clone, Copy)]
170#[repr(C, align(64))]
171pub struct MessageHeader {
172    /// Magic number for validation (0xRINGKERN).
173    pub magic: u64,
174    /// Header version for compatibility.
175    pub version: u32,
176    /// Message flags.
177    pub flags: u32,
178    /// Unique message identifier.
179    pub message_id: MessageId,
180    /// Correlation ID for request-response.
181    pub correlation_id: CorrelationId,
182    /// Source kernel ID (0 for host).
183    pub source_kernel: u64,
184    /// Destination kernel ID (0 for host).
185    pub dest_kernel: u64,
186    /// Message type discriminator.
187    pub message_type: u64,
188    /// Priority level.
189    pub priority: u8,
190    /// Reserved for alignment.
191    pub _reserved1: [u8; 7],
192    /// Payload size in bytes.
193    pub payload_size: u64,
194    /// Checksum of payload (CRC32).
195    pub checksum: u32,
196    /// Reserved for alignment.
197    pub _reserved2: u32,
198    /// HLC timestamp when message was created.
199    pub timestamp: HlcTimestamp,
200    /// Deadline timestamp (0 = no deadline).
201    pub deadline: HlcTimestamp,
202    /// Reserved for future use (split for derive compatibility).
203    pub _reserved3a: [u8; 32],
204    /// Reserved for future use.
205    pub _reserved3b: [u8; 32],
206    /// Reserved for future use.
207    pub _reserved3c: [u8; 32],
208    /// Reserved for future use.
209    pub _reserved3d: [u8; 8],
210}
211
212impl MessageHeader {
213    /// Magic number for validation.
214    pub const MAGIC: u64 = 0x52494E474B45524E; // "RINGKERN"
215
216    /// Current header version.
217    pub const VERSION: u32 = 1;
218
219    /// Maximum payload size (64KB).
220    pub const MAX_PAYLOAD_SIZE: usize = 64 * 1024;
221
222    /// Convert header to bytes.
223    pub fn as_bytes(&self) -> &[u8] {
224        unsafe {
225            std::slice::from_raw_parts(
226                self as *const Self as *const u8,
227                std::mem::size_of::<Self>(),
228            )
229        }
230    }
231
232    /// Read header from bytes.
233    pub fn read_from(bytes: &[u8]) -> Option<Self> {
234        if bytes.len() < std::mem::size_of::<Self>() {
235            return None;
236        }
237        unsafe { Some(std::ptr::read_unaligned(bytes.as_ptr() as *const Self)) }
238    }
239
240    /// Create a new message header.
241    pub fn new(
242        message_type: u64,
243        source_kernel: u64,
244        dest_kernel: u64,
245        payload_size: usize,
246        timestamp: HlcTimestamp,
247    ) -> Self {
248        Self {
249            magic: Self::MAGIC,
250            version: Self::VERSION,
251            flags: 0,
252            message_id: MessageId::generate(),
253            correlation_id: CorrelationId::none(),
254            source_kernel,
255            dest_kernel,
256            message_type,
257            priority: Priority::Normal as u8,
258            _reserved1: [0; 7],
259            payload_size: payload_size as u64,
260            checksum: 0,
261            _reserved2: 0,
262            timestamp,
263            deadline: HlcTimestamp::zero(),
264            _reserved3a: [0; 32],
265            _reserved3b: [0; 32],
266            _reserved3c: [0; 32],
267            _reserved3d: [0; 8],
268        }
269    }
270
271    /// Validate the header.
272    pub fn validate(&self) -> bool {
273        self.magic == Self::MAGIC
274            && self.version <= Self::VERSION
275            && self.payload_size <= Self::MAX_PAYLOAD_SIZE as u64
276    }
277
278    /// Set correlation ID.
279    pub fn with_correlation(mut self, correlation_id: CorrelationId) -> Self {
280        self.correlation_id = correlation_id;
281        self
282    }
283
284    /// Set priority.
285    pub fn with_priority(mut self, priority: Priority) -> Self {
286        self.priority = priority as u8;
287        self
288    }
289
290    /// Set deadline.
291    pub fn with_deadline(mut self, deadline: HlcTimestamp) -> Self {
292        self.deadline = deadline;
293        self
294    }
295}
296
297impl Default for MessageHeader {
298    fn default() -> Self {
299        Self {
300            magic: Self::MAGIC,
301            version: Self::VERSION,
302            flags: 0,
303            message_id: MessageId::default(),
304            correlation_id: CorrelationId::none(),
305            source_kernel: 0,
306            dest_kernel: 0,
307            message_type: 0,
308            priority: Priority::Normal as u8,
309            _reserved1: [0; 7],
310            payload_size: 0,
311            checksum: 0,
312            _reserved2: 0,
313            timestamp: HlcTimestamp::zero(),
314            deadline: HlcTimestamp::zero(),
315            _reserved3a: [0; 32],
316            _reserved3b: [0; 32],
317            _reserved3c: [0; 32],
318            _reserved3d: [0; 8],
319        }
320    }
321}
322
323// Verify size at compile time
324const _: () = assert!(std::mem::size_of::<MessageHeader>() == 256);
325
326/// Trait for types that can be sent as kernel messages.
327///
328/// This trait is typically implemented via the `#[derive(RingMessage)]` macro.
329///
330/// # Example
331///
332/// ```ignore
333/// #[derive(RingMessage)]
334/// struct MyRequest {
335///     #[message(id)]
336///     id: MessageId,
337///     data: Vec<f32>,
338/// }
339/// ```
340pub trait RingMessage: Send + Sync + 'static {
341    /// Get the message type discriminator.
342    fn message_type() -> u64;
343
344    /// Get the message ID.
345    fn message_id(&self) -> MessageId;
346
347    /// Get the correlation ID (if any).
348    fn correlation_id(&self) -> CorrelationId {
349        CorrelationId::none()
350    }
351
352    /// Get the priority.
353    fn priority(&self) -> Priority {
354        Priority::Normal
355    }
356
357    /// Serialize the message to bytes.
358    fn serialize(&self) -> Vec<u8>;
359
360    /// Deserialize a message from bytes.
361    fn deserialize(bytes: &[u8]) -> crate::error::Result<Self>
362    where
363        Self: Sized;
364
365    /// Get the serialized size hint.
366    fn size_hint(&self) -> usize
367    where
368        Self: Sized,
369    {
370        std::mem::size_of::<Self>()
371    }
372}
373
374/// Envelope containing header and serialized payload.
375#[derive(Debug, Clone)]
376pub struct MessageEnvelope {
377    /// Message header.
378    pub header: MessageHeader,
379    /// Serialized payload.
380    pub payload: Vec<u8>,
381}
382
383impl MessageEnvelope {
384    /// Create a new envelope from a message.
385    pub fn new<M: RingMessage>(
386        message: &M,
387        source_kernel: u64,
388        dest_kernel: u64,
389        timestamp: HlcTimestamp,
390    ) -> Self {
391        let payload = message.serialize();
392        let header = MessageHeader::new(
393            M::message_type(),
394            source_kernel,
395            dest_kernel,
396            payload.len(),
397            timestamp,
398        )
399        .with_correlation(message.correlation_id())
400        .with_priority(message.priority());
401
402        Self { header, payload }
403    }
404
405    /// Get total size (header + payload).
406    pub fn total_size(&self) -> usize {
407        std::mem::size_of::<MessageHeader>() + self.payload.len()
408    }
409
410    /// Serialize to contiguous bytes.
411    pub fn to_bytes(&self) -> Vec<u8> {
412        let mut bytes = Vec::with_capacity(self.total_size());
413        bytes.extend_from_slice(self.header.as_bytes());
414        bytes.extend_from_slice(&self.payload);
415        bytes
416    }
417
418    /// Deserialize from bytes.
419    pub fn from_bytes(bytes: &[u8]) -> crate::error::Result<Self> {
420        if bytes.len() < std::mem::size_of::<MessageHeader>() {
421            return Err(crate::error::RingKernelError::DeserializationError(
422                "buffer too small for header".to_string(),
423            ));
424        }
425
426        let header_bytes = &bytes[..std::mem::size_of::<MessageHeader>()];
427        let header = MessageHeader::read_from(header_bytes).ok_or_else(|| {
428            crate::error::RingKernelError::DeserializationError("invalid header".to_string())
429        })?;
430
431        if !header.validate() {
432            return Err(crate::error::RingKernelError::ValidationError(
433                "header validation failed".to_string(),
434            ));
435        }
436
437        let payload_start = std::mem::size_of::<MessageHeader>();
438        let payload_end = payload_start + header.payload_size as usize;
439
440        if bytes.len() < payload_end {
441            return Err(crate::error::RingKernelError::DeserializationError(
442                "buffer too small for payload".to_string(),
443            ));
444        }
445
446        let payload = bytes[payload_start..payload_end].to_vec();
447
448        Ok(Self { header, payload })
449    }
450
451    /// Create an empty envelope (for testing).
452    pub fn empty(source_kernel: u64, dest_kernel: u64, timestamp: HlcTimestamp) -> Self {
453        let header = MessageHeader::new(0, source_kernel, dest_kernel, 0, timestamp);
454        Self {
455            header,
456            payload: Vec::new(),
457        }
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464
465    #[test]
466    fn test_message_id_generation() {
467        let id1 = MessageId::generate();
468        let id2 = MessageId::generate();
469        assert_ne!(id1, id2);
470    }
471
472    #[test]
473    fn test_header_validation() {
474        let header = MessageHeader::new(1, 0, 1, 100, HlcTimestamp::zero());
475        assert!(header.validate());
476
477        let mut invalid = header;
478        invalid.magic = 0;
479        assert!(!invalid.validate());
480    }
481
482    #[test]
483    fn test_header_size() {
484        assert_eq!(std::mem::size_of::<MessageHeader>(), 256);
485    }
486
487    #[test]
488    fn test_priority_conversion() {
489        assert_eq!(Priority::from_u8(0), Priority::Low);
490        assert_eq!(Priority::from_u8(1), Priority::Normal);
491        assert_eq!(Priority::from_u8(2), Priority::High);
492        assert_eq!(Priority::from_u8(3), Priority::Critical);
493        assert_eq!(Priority::from_u8(255), Priority::Critical);
494    }
495
496    #[test]
497    fn test_envelope_roundtrip() {
498        let header = MessageHeader::new(42, 0, 1, 8, HlcTimestamp::now(1));
499        let envelope = MessageEnvelope {
500            header,
501            payload: vec![1, 2, 3, 4, 5, 6, 7, 8],
502        };
503
504        let bytes = envelope.to_bytes();
505        let restored = MessageEnvelope::from_bytes(&bytes).unwrap();
506
507        assert_eq!(envelope.header.message_type, restored.header.message_type);
508        assert_eq!(envelope.payload, restored.payload);
509    }
510}