1use darling::{ast, FromDeriveInput, FromField, FromMeta};
53use proc_macro::TokenStream;
54use quote::{format_ident, quote};
55use syn::{parse_macro_input, DeriveInput, ItemFn};
56
57#[derive(Debug, FromDeriveInput)]
59#[darling(attributes(message), supports(struct_named))]
60struct RingMessageArgs {
61 ident: syn::Ident,
62 generics: syn::Generics,
63 data: ast::Data<(), RingMessageField>,
64 #[darling(default)]
66 type_id: Option<u64>,
67}
68
69#[derive(Debug, FromField)]
71#[darling(attributes(message))]
72struct RingMessageField {
73 ident: Option<syn::Ident>,
74 #[allow(dead_code)]
75 ty: syn::Type,
76 #[darling(default)]
78 id: bool,
79 #[darling(default)]
81 correlation: bool,
82 #[darling(default)]
84 priority: bool,
85}
86
87#[proc_macro_derive(RingMessage, attributes(message))]
115pub fn derive_ring_message(input: TokenStream) -> TokenStream {
116 let input = parse_macro_input!(input as DeriveInput);
117
118 let args = match RingMessageArgs::from_derive_input(&input) {
119 Ok(args) => args,
120 Err(e) => return e.write_errors().into(),
121 };
122
123 let name = &args.ident;
124 let (impl_generics, ty_generics, where_clause) = args.generics.split_for_impl();
125
126 let type_id = args.type_id.unwrap_or_else(|| {
128 use std::collections::hash_map::DefaultHasher;
129 use std::hash::{Hash, Hasher};
130 let mut hasher = DefaultHasher::new();
131 name.to_string().hash(&mut hasher);
132 hasher.finish()
133 });
134
135 let fields = match &args.data {
137 ast::Data::Struct(fields) => fields,
138 _ => panic!("RingMessage can only be derived for structs"),
139 };
140
141 let mut id_field: Option<&syn::Ident> = None;
142 let mut correlation_field: Option<&syn::Ident> = None;
143 let mut priority_field: Option<&syn::Ident> = None;
144
145 for field in fields.iter() {
146 if field.id {
147 id_field = field.ident.as_ref();
148 }
149 if field.correlation {
150 correlation_field = field.ident.as_ref();
151 }
152 if field.priority {
153 priority_field = field.ident.as_ref();
154 }
155 }
156
157 let message_id_impl = if let Some(field) = id_field {
159 quote! { self.#field }
160 } else {
161 quote! { ::ringkernel_core::message::MessageId::new(0) }
162 };
163
164 let correlation_id_impl = if let Some(field) = correlation_field {
166 quote! { self.#field }
167 } else {
168 quote! { ::ringkernel_core::message::CorrelationId::none() }
169 };
170
171 let priority_impl = if let Some(field) = priority_field {
173 quote! { self.#field }
174 } else {
175 quote! { ::ringkernel_core::message::Priority::Normal }
176 };
177
178 let expanded = quote! {
179 impl #impl_generics ::ringkernel_core::message::RingMessage for #name #ty_generics #where_clause {
180 fn message_type() -> u64 {
181 #type_id
182 }
183
184 fn message_id(&self) -> ::ringkernel_core::message::MessageId {
185 #message_id_impl
186 }
187
188 fn correlation_id(&self) -> ::ringkernel_core::message::CorrelationId {
189 #correlation_id_impl
190 }
191
192 fn priority(&self) -> ::ringkernel_core::message::Priority {
193 #priority_impl
194 }
195
196 fn serialize(&self) -> Vec<u8> {
197 ::rkyv::to_bytes::<_, 4096>(self)
200 .map(|v| v.to_vec())
201 .unwrap_or_default()
202 }
203
204 fn deserialize(bytes: &[u8]) -> ::ringkernel_core::error::Result<Self>
205 where
206 Self: Sized,
207 {
208 use ::rkyv::Deserialize as _;
209 let archived = unsafe { ::rkyv::archived_root::<Self>(bytes) };
210 let deserialized: Self = archived.deserialize(&mut ::rkyv::Infallible)
211 .map_err(|_| ::ringkernel_core::error::RingKernelError::DeserializationError(
212 "rkyv deserialization failed".to_string()
213 ))?;
214 Ok(deserialized)
215 }
216
217 fn size_hint(&self) -> usize {
218 ::std::mem::size_of::<Self>()
219 }
220 }
221 };
222
223 TokenStream::from(expanded)
224}
225
226#[derive(Debug, FromMeta)]
228struct RingKernelArgs {
229 id: String,
231 #[darling(default)]
233 mode: Option<String>,
234 #[darling(default)]
236 grid_size: Option<u32>,
237 #[darling(default)]
239 block_size: Option<u32>,
240 #[darling(default)]
242 publishes_to: Option<String>,
243}
244
245#[proc_macro_attribute]
265pub fn ring_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
266 let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
267 Ok(v) => v,
268 Err(e) => return TokenStream::from(darling::Error::from(e).write_errors()),
269 };
270
271 let args = match RingKernelArgs::from_list(&args) {
272 Ok(v) => v,
273 Err(e) => return TokenStream::from(e.write_errors()),
274 };
275
276 let input = parse_macro_input!(item as ItemFn);
277
278 let kernel_id = &args.id;
279 let fn_name = &input.sig.ident;
280 let fn_vis = &input.vis;
281 let fn_block = &input.block;
282 let fn_attrs = &input.attrs;
283
284 let inputs = &input.sig.inputs;
286 let output = &input.sig.output;
287
288 let (_ctx_arg, msg_arg) = if inputs.len() >= 2 {
290 let ctx = inputs.first();
291 let msg = inputs.iter().nth(1);
292 (ctx, msg)
293 } else {
294 (None, None)
295 };
296
297 let msg_type = msg_arg
299 .map(|arg| {
300 if let syn::FnArg::Typed(pat_type) = arg {
301 pat_type.ty.clone()
302 } else {
303 syn::parse_quote!(())
304 }
305 })
306 .unwrap_or_else(|| syn::parse_quote!(()));
307
308 let mode = args.mode.as_deref().unwrap_or("persistent");
310 let mode_expr = if mode == "event_driven" {
311 quote! { ::ringkernel_core::types::KernelMode::EventDriven }
312 } else {
313 quote! { ::ringkernel_core::types::KernelMode::Persistent }
314 };
315
316 let grid_size = args.grid_size.unwrap_or(1);
318 let block_size = args.block_size.unwrap_or(256);
319
320 let publishes_to_targets: Vec<String> = args
322 .publishes_to
323 .as_ref()
324 .map(|s| s.split(',').map(|t| t.trim().to_string()).collect())
325 .unwrap_or_default();
326
327 let registration_name = format_ident!(
329 "__RINGKERNEL_REGISTRATION_{}",
330 fn_name.to_string().to_uppercase()
331 );
332 let handler_name = format_ident!("{}_handler", fn_name);
333
334 let expanded = quote! {
336 #(#fn_attrs)*
338 #fn_vis async fn #fn_name #inputs #output #fn_block
339
340 #fn_vis fn #handler_name(
342 ctx: &mut ::ringkernel_core::RingContext<'_>,
343 envelope: ::ringkernel_core::message::MessageEnvelope,
344 ) -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = ::ringkernel_core::error::Result<::ringkernel_core::message::MessageEnvelope>> + Send + '_>> {
345 Box::pin(async move {
346 let msg: #msg_type = ::ringkernel_core::message::RingMessage::deserialize(&envelope.payload)?;
348
349 let response = #fn_name(ctx, msg).await;
351
352 let response_payload = ::ringkernel_core::message::RingMessage::serialize(&response);
354 let response_header = ::ringkernel_core::message::MessageHeader::new(
355 <_ as ::ringkernel_core::message::RingMessage>::message_type(),
356 envelope.header.dest_kernel,
357 envelope.header.source_kernel,
358 response_payload.len(),
359 ctx.now(),
360 ).with_correlation(envelope.header.correlation_id);
361
362 Ok(::ringkernel_core::message::MessageEnvelope {
363 header: response_header,
364 payload: response_payload,
365 })
366 })
367 }
368
369 #[allow(non_upper_case_globals)]
371 #[::inventory::submit]
372 static #registration_name: ::ringkernel_core::__private::KernelRegistration = ::ringkernel_core::__private::KernelRegistration {
373 id: #kernel_id,
374 mode: #mode_expr,
375 grid_size: #grid_size,
376 block_size: #block_size,
377 publishes_to: &[#(#publishes_to_targets),*],
378 };
379 };
380
381 TokenStream::from(expanded)
382}
383
384#[proc_macro_derive(GpuType)]
388pub fn derive_gpu_type(input: TokenStream) -> TokenStream {
389 let input = parse_macro_input!(input as DeriveInput);
390 let name = &input.ident;
391 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
392
393 let expanded = quote! {
395 const _: fn() = || {
397 fn assert_copy<T: Copy>() {}
398 assert_copy::<#name #ty_generics>();
399 };
400
401 unsafe impl #impl_generics ::bytemuck::Pod for #name #ty_generics #where_clause {}
403 unsafe impl #impl_generics ::bytemuck::Zeroable for #name #ty_generics #where_clause {}
404 };
405
406 TokenStream::from(expanded)
407}
408
409#[derive(Debug, FromMeta)]
415struct StencilKernelArgs {
416 id: String,
418 #[darling(default)]
420 grid: Option<String>,
421 #[darling(default)]
423 tile_size: Option<u32>,
424 #[darling(default)]
426 tile_width: Option<u32>,
427 #[darling(default)]
429 tile_height: Option<u32>,
430 #[darling(default)]
432 halo: Option<u32>,
433}
434
435#[proc_macro_attribute]
476pub fn stencil_kernel(attr: TokenStream, item: TokenStream) -> TokenStream {
477 let args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
478 Ok(v) => v,
479 Err(e) => return TokenStream::from(darling::Error::from(e).write_errors()),
480 };
481
482 let args = match StencilKernelArgs::from_list(&args) {
483 Ok(v) => v,
484 Err(e) => return TokenStream::from(e.write_errors()),
485 };
486
487 let input = parse_macro_input!(item as ItemFn);
488
489 stencil_kernel_impl(args, input)
491}
492
493fn stencil_kernel_impl(args: StencilKernelArgs, input: ItemFn) -> TokenStream {
494 let kernel_id = &args.id;
495 let fn_name = &input.sig.ident;
496 let fn_vis = &input.vis;
497 let fn_block = &input.block;
498 let fn_inputs = &input.sig.inputs;
499 let fn_output = &input.sig.output;
500 let fn_attrs = &input.attrs;
501
502 let grid = args.grid.as_deref().unwrap_or("2d");
504 let tile_width = args
505 .tile_width
506 .unwrap_or_else(|| args.tile_size.unwrap_or(16));
507 let tile_height = args
508 .tile_height
509 .unwrap_or_else(|| args.tile_size.unwrap_or(16));
510 let halo = args.halo.unwrap_or(1);
511
512 let cuda_const_name = format_ident!("{}_CUDA_SOURCE", fn_name.to_string().to_uppercase());
514
515 let registration_name = format_ident!(
517 "__STENCIL_KERNEL_REGISTRATION_{}",
518 fn_name.to_string().to_uppercase()
519 );
520
521 #[cfg(feature = "cuda-codegen")]
523 let cuda_source_code = {
524 use ringkernel_cuda_codegen::{transpile_stencil_kernel, Grid, StencilConfig};
525
526 let grid_type = match grid {
527 "1d" => Grid::Grid1D,
528 "2d" => Grid::Grid2D,
529 "3d" => Grid::Grid3D,
530 _ => Grid::Grid2D,
531 };
532
533 let config = StencilConfig::new(kernel_id.clone())
534 .with_grid(grid_type)
535 .with_tile_size(tile_width as usize, tile_height as usize)
536 .with_halo(halo as usize);
537
538 match transpile_stencil_kernel(&input, &config) {
539 Ok(cuda) => cuda,
540 Err(e) => {
541 return TokenStream::from(
542 syn::Error::new_spanned(
543 &input.sig.ident,
544 format!("CUDA transpilation failed: {}", e),
545 )
546 .to_compile_error(),
547 );
548 }
549 }
550 };
551
552 #[cfg(not(feature = "cuda-codegen"))]
553 let cuda_source_code = format!(
554 "// CUDA codegen not enabled. Enable 'cuda-codegen' feature.\n// Kernel: {}\n",
555 kernel_id
556 );
557
558 let expanded = quote! {
560 #(#fn_attrs)*
562 #fn_vis fn #fn_name #fn_inputs #fn_output #fn_block
563
564 #fn_vis const #cuda_const_name: &str = #cuda_source_code;
566
567 #[allow(non_upper_case_globals)]
569 #[::inventory::submit]
570 static #registration_name: ::ringkernel_core::__private::StencilKernelRegistration =
571 ::ringkernel_core::__private::StencilKernelRegistration {
572 id: #kernel_id,
573 grid: #grid,
574 tile_width: #tile_width,
575 tile_height: #tile_height,
576 halo: #halo,
577 cuda_source: #cuda_source_code,
578 };
579 };
580
581 TokenStream::from(expanded)
582}