1use bytemuck::{Pod, Zeroable};
7use rkyv::{Archive, Deserialize, Serialize};
8use zerocopy::{AsBytes, FromBytes, FromZeroes};
9
10use crate::hlc::HlcTimestamp;
11
12#[derive(
14 Debug,
15 Clone,
16 Copy,
17 PartialEq,
18 Eq,
19 Hash,
20 Default,
21 AsBytes,
22 FromBytes,
23 FromZeroes,
24 Pod,
25 Zeroable,
26 Archive,
27 Serialize,
28 Deserialize,
29)]
30#[repr(C)]
31pub struct MessageId(pub u64);
32
33impl MessageId {
34 pub const fn new(id: u64) -> Self {
36 Self(id)
37 }
38
39 pub fn generate() -> Self {
41 use std::sync::atomic::{AtomicU64, Ordering};
42 static COUNTER: AtomicU64 = AtomicU64::new(1);
43 Self(COUNTER.fetch_add(1, Ordering::Relaxed))
44 }
45
46 pub const fn inner(&self) -> u64 {
48 self.0
49 }
50}
51
52impl std::fmt::Display for MessageId {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 write!(f, "msg:{:016x}", self.0)
55 }
56}
57
58#[derive(
60 Debug,
61 Clone,
62 Copy,
63 PartialEq,
64 Eq,
65 Hash,
66 Default,
67 AsBytes,
68 FromBytes,
69 FromZeroes,
70 Pod,
71 Zeroable,
72 Archive,
73 Serialize,
74 Deserialize,
75)]
76#[repr(C)]
77pub struct CorrelationId(pub u64);
78
79impl CorrelationId {
80 pub const fn new(id: u64) -> Self {
82 Self(id)
83 }
84
85 pub fn generate() -> Self {
87 Self(MessageId::generate().0)
88 }
89
90 pub const fn none() -> Self {
92 Self(0)
93 }
94
95 pub const fn is_some(&self) -> bool {
97 self.0 != 0
98 }
99}
100
101#[derive(
103 Debug,
104 Clone,
105 Copy,
106 PartialEq,
107 Eq,
108 Hash,
109 Default,
110 rkyv::Archive,
111 rkyv::Serialize,
112 rkyv::Deserialize,
113)]
114#[archive(compare(PartialEq))]
115#[repr(u8)]
116pub enum Priority {
117 Low = 0,
119 #[default]
121 Normal = 1,
122 High = 2,
124 Critical = 3,
126}
127
128impl Priority {
129 pub const fn from_u8(value: u8) -> Self {
131 match value {
132 0 => Self::Low,
133 1 => Self::Normal,
134 2 => Self::High,
135 _ => Self::Critical,
136 }
137 }
138
139 pub const fn as_u8(self) -> u8 {
141 self as u8
142 }
143}
144
145pub mod priority {
155 pub const LOW: u8 = 0;
157 pub const NORMAL: u8 = 64;
159 pub const HIGH: u8 = 128;
161 pub const CRITICAL: u8 = 192;
163}
164
165#[derive(Debug, Clone, Copy)]
170#[repr(C, align(64))]
171pub struct MessageHeader {
172 pub magic: u64,
174 pub version: u32,
176 pub flags: u32,
178 pub message_id: MessageId,
180 pub correlation_id: CorrelationId,
182 pub source_kernel: u64,
184 pub dest_kernel: u64,
186 pub message_type: u64,
188 pub priority: u8,
190 pub _reserved1: [u8; 7],
192 pub payload_size: u64,
194 pub checksum: u32,
196 pub _reserved2: u32,
198 pub timestamp: HlcTimestamp,
200 pub deadline: HlcTimestamp,
202 pub _reserved3a: [u8; 32],
204 pub _reserved3b: [u8; 32],
206 pub _reserved3c: [u8; 32],
208 pub _reserved3d: [u8; 8],
210}
211
212impl MessageHeader {
213 pub const MAGIC: u64 = 0x52494E474B45524E; pub const VERSION: u32 = 1;
218
219 pub const MAX_PAYLOAD_SIZE: usize = 64 * 1024;
221
222 pub fn as_bytes(&self) -> &[u8] {
224 unsafe {
225 std::slice::from_raw_parts(
226 self as *const Self as *const u8,
227 std::mem::size_of::<Self>(),
228 )
229 }
230 }
231
232 pub fn read_from(bytes: &[u8]) -> Option<Self> {
234 if bytes.len() < std::mem::size_of::<Self>() {
235 return None;
236 }
237 unsafe { Some(std::ptr::read_unaligned(bytes.as_ptr() as *const Self)) }
238 }
239
240 pub fn new(
242 message_type: u64,
243 source_kernel: u64,
244 dest_kernel: u64,
245 payload_size: usize,
246 timestamp: HlcTimestamp,
247 ) -> Self {
248 Self {
249 magic: Self::MAGIC,
250 version: Self::VERSION,
251 flags: 0,
252 message_id: MessageId::generate(),
253 correlation_id: CorrelationId::none(),
254 source_kernel,
255 dest_kernel,
256 message_type,
257 priority: Priority::Normal as u8,
258 _reserved1: [0; 7],
259 payload_size: payload_size as u64,
260 checksum: 0,
261 _reserved2: 0,
262 timestamp,
263 deadline: HlcTimestamp::zero(),
264 _reserved3a: [0; 32],
265 _reserved3b: [0; 32],
266 _reserved3c: [0; 32],
267 _reserved3d: [0; 8],
268 }
269 }
270
271 pub fn validate(&self) -> bool {
273 self.magic == Self::MAGIC
274 && self.version <= Self::VERSION
275 && self.payload_size <= Self::MAX_PAYLOAD_SIZE as u64
276 }
277
278 pub fn with_correlation(mut self, correlation_id: CorrelationId) -> Self {
280 self.correlation_id = correlation_id;
281 self
282 }
283
284 pub fn with_priority(mut self, priority: Priority) -> Self {
286 self.priority = priority as u8;
287 self
288 }
289
290 pub fn with_deadline(mut self, deadline: HlcTimestamp) -> Self {
292 self.deadline = deadline;
293 self
294 }
295}
296
297impl Default for MessageHeader {
298 fn default() -> Self {
299 Self {
300 magic: Self::MAGIC,
301 version: Self::VERSION,
302 flags: 0,
303 message_id: MessageId::default(),
304 correlation_id: CorrelationId::none(),
305 source_kernel: 0,
306 dest_kernel: 0,
307 message_type: 0,
308 priority: Priority::Normal as u8,
309 _reserved1: [0; 7],
310 payload_size: 0,
311 checksum: 0,
312 _reserved2: 0,
313 timestamp: HlcTimestamp::zero(),
314 deadline: HlcTimestamp::zero(),
315 _reserved3a: [0; 32],
316 _reserved3b: [0; 32],
317 _reserved3c: [0; 32],
318 _reserved3d: [0; 8],
319 }
320 }
321}
322
323const _: () = assert!(std::mem::size_of::<MessageHeader>() == 256);
325
326pub trait RingMessage: Send + Sync + 'static {
341 fn message_type() -> u64;
343
344 fn message_id(&self) -> MessageId;
346
347 fn correlation_id(&self) -> CorrelationId {
349 CorrelationId::none()
350 }
351
352 fn priority(&self) -> Priority {
354 Priority::Normal
355 }
356
357 fn serialize(&self) -> Vec<u8>;
359
360 fn deserialize(bytes: &[u8]) -> crate::error::Result<Self>
362 where
363 Self: Sized;
364
365 fn size_hint(&self) -> usize
367 where
368 Self: Sized,
369 {
370 std::mem::size_of::<Self>()
371 }
372}
373
374#[derive(Debug, Clone)]
376pub struct MessageEnvelope {
377 pub header: MessageHeader,
379 pub payload: Vec<u8>,
381}
382
383impl MessageEnvelope {
384 pub fn new<M: RingMessage>(
386 message: &M,
387 source_kernel: u64,
388 dest_kernel: u64,
389 timestamp: HlcTimestamp,
390 ) -> Self {
391 let payload = message.serialize();
392 let header = MessageHeader::new(
393 M::message_type(),
394 source_kernel,
395 dest_kernel,
396 payload.len(),
397 timestamp,
398 )
399 .with_correlation(message.correlation_id())
400 .with_priority(message.priority());
401
402 Self { header, payload }
403 }
404
405 pub fn total_size(&self) -> usize {
407 std::mem::size_of::<MessageHeader>() + self.payload.len()
408 }
409
410 pub fn to_bytes(&self) -> Vec<u8> {
412 let mut bytes = Vec::with_capacity(self.total_size());
413 bytes.extend_from_slice(self.header.as_bytes());
414 bytes.extend_from_slice(&self.payload);
415 bytes
416 }
417
418 pub fn from_bytes(bytes: &[u8]) -> crate::error::Result<Self> {
420 if bytes.len() < std::mem::size_of::<MessageHeader>() {
421 return Err(crate::error::RingKernelError::DeserializationError(
422 "buffer too small for header".to_string(),
423 ));
424 }
425
426 let header_bytes = &bytes[..std::mem::size_of::<MessageHeader>()];
427 let header = MessageHeader::read_from(header_bytes).ok_or_else(|| {
428 crate::error::RingKernelError::DeserializationError("invalid header".to_string())
429 })?;
430
431 if !header.validate() {
432 return Err(crate::error::RingKernelError::ValidationError(
433 "header validation failed".to_string(),
434 ));
435 }
436
437 let payload_start = std::mem::size_of::<MessageHeader>();
438 let payload_end = payload_start + header.payload_size as usize;
439
440 if bytes.len() < payload_end {
441 return Err(crate::error::RingKernelError::DeserializationError(
442 "buffer too small for payload".to_string(),
443 ));
444 }
445
446 let payload = bytes[payload_start..payload_end].to_vec();
447
448 Ok(Self { header, payload })
449 }
450
451 pub fn empty(source_kernel: u64, dest_kernel: u64, timestamp: HlcTimestamp) -> Self {
453 let header = MessageHeader::new(0, source_kernel, dest_kernel, 0, timestamp);
454 Self {
455 header,
456 payload: Vec::new(),
457 }
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464
465 #[test]
466 fn test_message_id_generation() {
467 let id1 = MessageId::generate();
468 let id2 = MessageId::generate();
469 assert_ne!(id1, id2);
470 }
471
472 #[test]
473 fn test_header_validation() {
474 let header = MessageHeader::new(1, 0, 1, 100, HlcTimestamp::zero());
475 assert!(header.validate());
476
477 let mut invalid = header;
478 invalid.magic = 0;
479 assert!(!invalid.validate());
480 }
481
482 #[test]
483 fn test_header_size() {
484 assert_eq!(std::mem::size_of::<MessageHeader>(), 256);
485 }
486
487 #[test]
488 fn test_priority_conversion() {
489 assert_eq!(Priority::from_u8(0), Priority::Low);
490 assert_eq!(Priority::from_u8(1), Priority::Normal);
491 assert_eq!(Priority::from_u8(2), Priority::High);
492 assert_eq!(Priority::from_u8(3), Priority::Critical);
493 assert_eq!(Priority::from_u8(255), Priority::Critical);
494 }
495
496 #[test]
497 fn test_envelope_roundtrip() {
498 let header = MessageHeader::new(42, 0, 1, 8, HlcTimestamp::now(1));
499 let envelope = MessageEnvelope {
500 header,
501 payload: vec![1, 2, 3, 4, 5, 6, 7, 8],
502 };
503
504 let bytes = envelope.to_bytes();
505 let restored = MessageEnvelope::from_bytes(&bytes).unwrap();
506
507 assert_eq!(envelope.header.message_type, restored.header.message_type);
508 assert_eq!(envelope.payload, restored.payload);
509 }
510}