ringkernel_derive/
lib.rs

1//! Procedural macros for RingKernel.
2//!
3//! This crate provides the following macros:
4//!
5//! - `#[derive(RingMessage)]` - Implement the RingMessage trait for message types
6//! - `#[ring_kernel]` - Define a ring kernel handler
7//! - `#[stencil_kernel]` - Define a GPU stencil kernel (with `cuda-codegen` feature)
8//!
9//! # Example
10//!
11//! ```ignore
12//! use ringkernel_derive::{RingMessage, ring_kernel};
13//!
14//! #[derive(RingMessage)]
15//! struct AddRequest {
16//!     #[message(id)]
17//!     id: MessageId,
18//!     a: f32,
19//!     b: f32,
20//! }
21//!
22//! #[derive(RingMessage)]
23//! struct AddResponse {
24//!     #[message(id)]
25//!     id: MessageId,
26//!     result: f32,
27//! }
28//!
29//! #[ring_kernel(id = "adder")]
30//! async fn process(ctx: &mut RingContext, req: AddRequest) -> AddResponse {
31//!     AddResponse {
32//!         id: MessageId::generate(),
33//!         result: req.a + req.b,
34//!     }
35//! }
36//! ```
37//!
38//! # Stencil Kernels (with `cuda-codegen` feature)
39//!
40//! ```ignore
41//! use ringkernel_derive::stencil_kernel;
42//! use ringkernel_cuda_codegen::GridPos;
43//!
44//! #[stencil_kernel(id = "fdtd", grid = "2d", tile_size = 16, halo = 1)]
45//! fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
46//!     let curr = p[pos.idx()];
47//!     let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
48//!     p_prev[pos.idx()] = 2.0 * curr - p_prev[pos.idx()] + c2 * lap;
49//! }
50//! ```
51
52use darling::{ast, FromDeriveInput, FromField, FromMeta};
53use proc_macro::TokenStream;
54use quote::{format_ident, quote};
55use syn::{parse_macro_input, DeriveInput, ItemFn};
56
57/// Attributes for the RingMessage derive macro.
58#[derive(Debug, FromDeriveInput)]
59#[darling(attributes(message), supports(struct_named))]
60struct RingMessageArgs {
61    ident: syn::Ident,
62    generics: syn::Generics,
63    data: ast::Data<(), RingMessageField>,
64    /// Optional explicit message type ID.
65    #[darling(default)]
66    type_id: Option<u64>,
67}
68
69/// Field attributes for RingMessage.
70#[derive(Debug, FromField)]
71#[darling(attributes(message))]
72struct RingMessageField {
73    ident: Option<syn::Ident>,
74    #[allow(dead_code)]
75    ty: syn::Type,
76    /// Mark this field as the message ID.
77    #[darling(default)]
78    id: bool,
79    /// Mark this field as the correlation ID.
80    #[darling(default)]
81    correlation: bool,
82    /// Mark this field as the priority.
83    #[darling(default)]
84    priority: bool,
85}
86
87/// Derive macro for implementing the RingMessage trait.
88///
89/// # Attributes
90///
91/// On the struct:
92/// - `#[message(type_id = 123)]` - Set explicit message type ID
93///
94/// On fields:
95/// - `#[message(id)]` - Mark as message ID field
96/// - `#[message(correlation)]` - Mark as correlation ID field
97/// - `#[message(priority)]` - Mark as priority field
98///
99/// # Example
100///
101/// ```ignore
102/// #[derive(RingMessage)]
103/// #[message(type_id = 1)]
104/// struct MyMessage {
105///     #[message(id)]
106///     id: MessageId,
107///     #[message(correlation)]
108///     correlation: CorrelationId,
109///     #[message(priority)]
110///     priority: Priority,
111///     payload: Vec<u8>,
112/// }
113/// ```
114#[proc_macro_derive(RingMessage, attributes(message))]
115pub fn derive_ring_message(input: TokenStream) -> TokenStream {
116    let input = parse_macro_input!(input as DeriveInput);
117
118    let args = match RingMessageArgs::from_derive_input(&input) {
119        Ok(args) => args,
120        Err(e) => return e.write_errors().into(),
121    };
122
123    let name = &args.ident;
124    let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
125
126    // Calculate type ID from name hash if not specified
127    let type_id = args.type_id.unwrap_or_else(|| {
128        use std::collections::hash_map::DefaultHasher;
129        use std::hash::{Hash, Hasher};
130        let mut hasher = DefaultHasher::new();
131        name.to_string().hash(&mut hasher);
132        hasher.finish()
133    });
134
135    // Find annotated fields
136    let fields = match &args.data {
137        ast::Data::Struct(fields) => fields,
138        _ => panic!("RingMessage can only be derived for structs"),
139    };
140
141    let mut id_field: Option<&syn::Ident> = None;
142    let mut correlation_field: Option<&syn::Ident> = None;
143    let mut priority_field: Option<&syn::Ident> = None;
144
145    for field in fields.iter() {
146        if field.id {
147            id_field = field.ident.as_ref();
148        }
149        if field.correlation {
150            correlation_field = field.ident.as_ref();
151        }
152        if field.priority {
153            priority_field = field.ident.as_ref();
154        }
155    }
156
157    // Generate message_id method
158    let message_id_impl = if let Some(field) = id_field {
159        quote! { self.#field }
160    } else {
161        quote! { ::ringkernel_core::message::MessageId::new(0) }
162    };
163
164    // Generate correlation_id method
165    let correlation_id_impl = if let Some(field) = correlation_field {
166        quote! { self.#field }
167    } else {
168        quote! { ::ringkernel_core::message::CorrelationId::none() }
169    };
170
171    // Generate priority method
172    let priority_impl = if let Some(field) = priority_field {
173        quote! { self.#field }
174    } else {
175        quote! { ::ringkernel_core::message::Priority::Normal }
176    };
177
178    let expanded = quote! {
179        impl #impl_generics ::ringkernel_core::message::RingMessage for #name #ty_generics #where_clause {
180            fn message_type() -> u64 {
181                #type_id
182            }
183
184            fn message_id(&self) -> ::ringkernel_core::message::MessageId {
185                #message_id_impl
186            }
187
188            fn correlation_id(&self) -> ::ringkernel_core::message::CorrelationId {
189                #correlation_id_impl
190            }
191
192            fn priority(&self) -> ::ringkernel_core::message::Priority {
193                #priority_impl
194            }
195
196            fn serialize(&self) -> Vec<u8> {
197                // Use rkyv for serialization with a 4KB scratch buffer
198                // For larger payloads, rkyv will allocate as needed
199                ::rkyv::to_bytes::<_, 4096>(self)
200                    .map(|v| v.to_vec())
201                    .unwrap_or_default()
202            }
203
204            fn deserialize(bytes: &[u8]) -> ::ringkernel_core::error::Result<Self>
205            where
206                Self: Sized,
207            {
208                use ::rkyv::Deserialize as _;
209                let archived = unsafe { ::rkyv::archived_root::<Self>(bytes) };
210                let deserialized: Self = archived.deserialize(&mut ::rkyv::Infallible)
211                    .map_err(|_| ::ringkernel_core::error::RingKernelError::DeserializationError(
212                        "rkyv deserialization failed".to_string()
213                    ))?;
214                Ok(deserialized)
215            }
216
217            fn size_hint(&self) -> usize {
218                ::std::mem::size_of::<Self>()
219            }
220        }
221    };
222
223    TokenStream::from(expanded)
224}
225
226/// Attributes for the ring_kernel macro.
227#[derive(Debug, FromMeta)]
228struct RingKernelArgs {
229    /// Kernel identifier.
230    id: String,
231    /// Execution mode (persistent or event_driven).
232    #[darling(default)]
233    mode: Option<String>,
234    /// Grid size.
235    #[darling(default)]
236    grid_size: Option<u32>,
237    /// Block size.
238    #[darling(default)]
239    block_size: Option<u32>,
240    /// Target kernels this kernel publishes to.
241    #[darling(default)]
242    publishes_to: Option<String>,
243}
244
245/// Attribute macro for defining ring kernel handlers.
246///
247/// # Attributes
248///
249/// - `id` (required) - Unique kernel identifier
250/// - `mode` - Execution mode: "persistent" (default) or "event_driven"
251/// - `grid_size` - Number of blocks (default: 1)
252/// - `block_size` - Threads per block (default: 256)
253/// - `publishes_to` - Comma-separated list of target kernel IDs
254///
255/// # Example
256///
257/// ```ignore
258/// #[ring_kernel(id = "processor", mode = "persistent", block_size = 128)]
259/// async fn handle(ctx: &mut RingContext, msg: MyMessage) -> MyResponse {
260///     // Process message
261///     MyResponse { ... }
262/// }
263/// ```
264#[proc_macro_attribute]
265pub fn ring_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
266    let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
267        Ok(v) => v,
268        Err(e) => return TokenStream::from(darling::Error::from(e).write_errors()),
269    };
270
271    let args = match RingKernelArgs::from_list(&args) {
272        Ok(v) => v,
273        Err(e) => return TokenStream::from(e.write_errors()),
274    };
275
276    let input = parse_macro_input!(item as ItemFn);
277
278    let kernel_id = &args.id;
279    let fn_name = &input.sig.ident;
280    let fn_vis = &input.vis;
281    let fn_block = &input.block;
282    let fn_attrs = &input.attrs;
283
284    // Parse function signature
285    let inputs = &input.sig.inputs;
286    let output = &input.sig.output;
287
288    // Extract context and message types from signature
289    let (_ctx_arg, msg_arg) = if inputs.len() >= 2 {
290        let ctx = inputs.first();
291        let msg = inputs.iter().nth(1);
292        (ctx, msg)
293    } else {
294        (None, None)
295    };
296
297    // Get message type
298    let msg_type = msg_arg
299        .map(|arg| {
300            if let syn::FnArg::Typed(pat_type) = arg {
301                pat_type.ty.clone()
302            } else {
303                syn::parse_quote!(())
304            }
305        })
306        .unwrap_or_else(|| syn::parse_quote!(()));
307
308    // Generate kernel mode
309    let mode = args.mode.as_deref().unwrap_or("persistent");
310    let mode_expr = if mode == "event_driven" {
311        quote! { ::ringkernel_core::types::KernelMode::EventDriven }
312    } else {
313        quote! { ::ringkernel_core::types::KernelMode::Persistent }
314    };
315
316    // Generate grid/block size
317    let grid_size = args.grid_size.unwrap_or(1);
318    let block_size = args.block_size.unwrap_or(256);
319
320    // Parse publishes_to into a list of target kernel IDs
321    let publishes_to_targets: Vec<String> = args
322        .publishes_to
323        .as_ref()
324        .map(|s| s.split(',').map(|t| t.trim().to_string()).collect())
325        .unwrap_or_default();
326
327    // Generate registration struct name
328    let registration_name = format_ident!(
329        "__RINGKERNEL_REGISTRATION_{}",
330        fn_name.to_string().to_uppercase()
331    );
332    let handler_name = format_ident!("{}_handler", fn_name);
333
334    // Generate the expanded code
335    let expanded = quote! {
336        // Original function (preserved for documentation/testing)
337        #(#fn_attrs)*
338        #fn_vis async fn #fn_name #inputs #output #fn_block
339
340        // Kernel handler wrapper
341        #fn_vis fn #handler_name(
342            ctx: &mut ::ringkernel_core::RingContext<'_>,
343            envelope: ::ringkernel_core::message::MessageEnvelope,
344        ) -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = ::ringkernel_core::error::Result<::ringkernel_core::message::MessageEnvelope>> + Send + '_>> {
345            Box::pin(async move {
346                // Deserialize input message
347                let msg: #msg_type = ::ringkernel_core::message::RingMessage::deserialize(&envelope.payload)?;
348
349                // Call the actual handler
350                let response = #fn_name(ctx, msg).await;
351
352                // Serialize response
353                let response_payload = ::ringkernel_core::message::RingMessage::serialize(&response);
354                let response_header = ::ringkernel_core::message::MessageHeader::new(
355                    <_ as ::ringkernel_core::message::RingMessage>::message_type(),
356                    envelope.header.dest_kernel,
357                    envelope.header.source_kernel,
358                    response_payload.len(),
359                    ctx.now(),
360                ).with_correlation(envelope.header.correlation_id);
361
362                Ok(::ringkernel_core::message::MessageEnvelope {
363                    header: response_header,
364                    payload: response_payload,
365                })
366            })
367        }
368
369        // Kernel registration
370        #[allow(non_upper_case_globals)]
371        #[::inventory::submit]
372        static #registration_name: ::ringkernel_core::__private::KernelRegistration = ::ringkernel_core::__private::KernelRegistration {
373            id: #kernel_id,
374            mode: #mode_expr,
375            grid_size: #grid_size,
376            block_size: #block_size,
377            publishes_to: &[#(#publishes_to_targets),*],
378        };
379    };
380
381    TokenStream::from(expanded)
382}
383
384/// Derive macro for GPU-compatible types.
385///
386/// Ensures the type has a stable memory layout suitable for GPU transfer.
387#[proc_macro_derive(GpuType)]
388pub fn derive_gpu_type(input: TokenStream) -> TokenStream {
389    let input = parse_macro_input!(input as DeriveInput);
390    let name = &input.ident;
391    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
392
393    // Generate assertions for GPU compatibility
394    let expanded = quote! {
395        // Verify type is Copy (required for GPU transfer)
396        const _: fn() = || {
397            fn assert_copy<T: Copy>() {}
398            assert_copy::<#name #ty_generics>();
399        };
400
401        // Verify type is Pod (plain old data)
402        unsafe impl #impl_generics ::bytemuck::Pod for #name #ty_generics #where_clause {}
403        unsafe impl #impl_generics ::bytemuck::Zeroable for #name #ty_generics #where_clause {}
404    };
405
406    TokenStream::from(expanded)
407}
408
409// ============================================================================
410// Stencil Kernel Macro (requires cuda-codegen feature)
411// ============================================================================
412
413/// Attributes for the stencil_kernel macro.
414#[derive(Debug, FromMeta)]
415struct StencilKernelArgs {
416    /// Kernel identifier.
417    id: String,
418    /// Grid dimensionality: "1d", "2d", or "3d".
419    #[darling(default)]
420    grid: Option<String>,
421    /// Tile/block size (single value for square tiles).
422    #[darling(default)]
423    tile_size: Option<u32>,
424    /// Tile width (for non-square tiles).
425    #[darling(default)]
426    tile_width: Option<u32>,
427    /// Tile height (for non-square tiles).
428    #[darling(default)]
429    tile_height: Option<u32>,
430    /// Halo/ghost cell width (stencil radius).
431    #[darling(default)]
432    halo: Option<u32>,
433}
434
435/// Attribute macro for defining stencil kernels that transpile to CUDA.
436///
437/// This macro generates CUDA C code from Rust stencil kernel functions at compile time.
438/// The generated CUDA source is embedded in the binary and can be compiled at runtime
439/// using NVRTC.
440///
441/// # Attributes
442///
443/// - `id` (required) - Unique kernel identifier
444/// - `grid` - Grid dimensionality: "1d", "2d" (default), or "3d"
445/// - `tile_size` - Tile/block size (default: 16)
446/// - `tile_width` / `tile_height` - Non-square tile dimensions
447/// - `halo` - Stencil radius / ghost cell width (default: 1)
448///
449/// # Supported Rust Subset
450///
451/// - Primitives: `f32`, `f64`, `i32`, `u32`, `i64`, `u64`, `bool`
452/// - Slices: `&[T]`, `&mut [T]`
453/// - Arithmetic: `+`, `-`, `*`, `/`, `%`
454/// - Comparisons: `<`, `>`, `<=`, `>=`, `==`, `!=`
455/// - Let bindings: `let x = expr;`
456/// - If/else: `if cond { a } else { b }`
457/// - Stencil intrinsics via `GridPos`
458///
459/// # Example
460///
461/// ```ignore
462/// use ringkernel_derive::stencil_kernel;
463/// use ringkernel_cuda_codegen::GridPos;
464///
465/// #[stencil_kernel(id = "fdtd", grid = "2d", tile_size = 16, halo = 1)]
466/// fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
467///     let curr = p[pos.idx()];
468///     let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
469///     p_prev[pos.idx()] = 2.0 * curr - p_prev[pos.idx()] + c2 * lap;
470/// }
471///
472/// // Access generated CUDA source:
473/// assert!(FDTD_CUDA_SOURCE.contains("__global__"));
474/// ```
475#[proc_macro_attribute]
476pub fn stencil_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
477    let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
478        Ok(v) => v,
479        Err(e) => return TokenStream::from(darling::Error::from(e).write_errors()),
480    };
481
482    let args = match StencilKernelArgs::from_list(&args) {
483        Ok(v) => v,
484        Err(e) => return TokenStream::from(e.write_errors()),
485    };
486
487    let input = parse_macro_input!(item as ItemFn);
488
489    // Generate the stencil kernel code
490    stencil_kernel_impl(args, input)
491}
492
493fn stencil_kernel_impl(args: StencilKernelArgs, input: ItemFn) -> TokenStream {
494    let kernel_id = &args.id;
495    let fn_name = &input.sig.ident;
496    let fn_vis = &input.vis;
497    let fn_block = &input.block;
498    let fn_inputs = &input.sig.inputs;
499    let fn_output = &input.sig.output;
500    let fn_attrs = &input.attrs;
501
502    // Parse configuration
503    let grid = args.grid.as_deref().unwrap_or("2d");
504    let tile_width = args
505        .tile_width
506        .unwrap_or_else(|| args.tile_size.unwrap_or(16));
507    let tile_height = args
508        .tile_height
509        .unwrap_or_else(|| args.tile_size.unwrap_or(16));
510    let halo = args.halo.unwrap_or(1);
511
512    // Generate CUDA source constant name
513    let cuda_const_name = format_ident!("{}_CUDA_SOURCE", fn_name.to_string().to_uppercase());
514
515    // Generate registration name
516    let registration_name = format_ident!(
517        "__STENCIL_KERNEL_REGISTRATION_{}",
518        fn_name.to_string().to_uppercase()
519    );
520
521    // Transpile to CUDA (if feature enabled)
522    #[cfg(feature = "cuda-codegen")]
523    let cuda_source_code = {
524        use ringkernel_cuda_codegen::{transpile_stencil_kernel, Grid, StencilConfig};
525
526        let grid_type = match grid {
527            "1d" => Grid::Grid1D,
528            "2d" => Grid::Grid2D,
529            "3d" => Grid::Grid3D,
530            _ => Grid::Grid2D,
531        };
532
533        let config = StencilConfig::new(kernel_id.clone())
534            .with_grid(grid_type)
535            .with_tile_size(tile_width as usize, tile_height as usize)
536            .with_halo(halo as usize);
537
538        match transpile_stencil_kernel(&input, &config) {
539            Ok(cuda) => cuda,
540            Err(e) => {
541                return TokenStream::from(
542                    syn::Error::new_spanned(
543                        &input.sig.ident,
544                        format!("CUDA transpilation failed: {}", e),
545                    )
546                    .to_compile_error(),
547                );
548            }
549        }
550    };
551
552    #[cfg(not(feature = "cuda-codegen"))]
553    let cuda_source_code = format!(
554        "// CUDA codegen not enabled. Enable 'cuda-codegen' feature.\n// Kernel: {}\n",
555        kernel_id
556    );
557
558    // Generate the expanded code
559    let expanded = quote! {
560        // Original function (for documentation/testing/CPU fallback)
561        #(#fn_attrs)*
562        #fn_vis fn #fn_name #fn_inputs #fn_output #fn_block
563
564        /// Generated CUDA source code for this stencil kernel.
565        #fn_vis const #cuda_const_name: &str = #cuda_source_code;
566
567        /// Stencil kernel registration for runtime discovery.
568        #[allow(non_upper_case_globals)]
569        #[::inventory::submit]
570        static #registration_name: ::ringkernel_core::__private::StencilKernelRegistration =
571            ::ringkernel_core::__private::StencilKernelRegistration {
572                id: #kernel_id,
573                grid: #grid,
574                tile_width: #tile_width,
575                tile_height: #tile_height,
576                halo: #halo,
577                cuda_source: #cuda_source_code,
578            };
579    };
580
581    TokenStream::from(expanded)
582}