Skip to main content

ringkernel_wgpu/
lib.rs

1//! WebGPU Backend for RingKernel
2//!
3//! This crate provides cross-platform GPU support via WebGPU (wgpu).
4//! Works on Vulkan, Metal, DX12, and browser environments.
5//!
6//! # Features
7//!
8//! - Cross-platform GPU access (Windows, macOS, Linux, Web)
9//! - Event-driven execution model (WebGPU limitation)
10//! - WGSL shader language support
11//!
12//! # Limitations
13//!
14//! - No true persistent kernels (WebGPU doesn't support cooperative groups)
15//! - No 64-bit atomics in WGSL
16//! - Event-driven execution only
17//!
18//! # Example
19//!
20//! ```ignore
21//! use ringkernel_wgpu::WgpuRuntime;
22//!
23//! #[tokio::main]
24//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
25//!     let runtime = WgpuRuntime::new().await?;
26//!     let kernel = runtime.launch("compute", Default::default()).await?;
27//!     kernel.activate().await?;
28//!     Ok(())
29//! }
30//! ```
31
32#![warn(missing_docs)]
33
34#[cfg(feature = "wgpu")]
35mod adapter;
36#[cfg(feature = "wgpu")]
37mod kernel;
38#[cfg(feature = "wgpu")]
39mod memory;
40#[cfg(feature = "wgpu")]
41mod runtime;
42#[cfg(feature = "wgpu")]
43mod shader;
44
45#[cfg(feature = "wgpu")]
46pub use adapter::WgpuAdapter;
47#[cfg(feature = "wgpu")]
48pub use kernel::WgpuKernel;
49#[cfg(feature = "wgpu")]
50pub use memory::WgpuBuffer;
51#[cfg(feature = "wgpu")]
52pub use runtime::WgpuRuntime;
53
54// Stub implementation when wgpu feature is disabled
55#[cfg(not(feature = "wgpu"))]
56mod stub {
57    ringkernel_core::unavailable_backend!(
58        WgpuRuntime,
59        ringkernel_core::runtime::Backend::Wgpu,
60        "wgpu"
61    );
62}
63
64#[cfg(not(feature = "wgpu"))]
65pub use stub::WgpuRuntime;
66
67/// Check if WebGPU is available at runtime.
68pub fn is_wgpu_available() -> bool {
69    #[cfg(feature = "wgpu")]
70    {
71        // Try to create an instance
72        let instance = wgpu::Instance::new(wgpu::InstanceDescriptor::default());
73        !instance
74            .enumerate_adapters(wgpu::Backends::all())
75            .is_empty()
76    }
77    #[cfg(not(feature = "wgpu"))]
78    {
79        false
80    }
81}
82
83/// WGSL shader template for ring kernel.
84pub const RING_KERNEL_WGSL_TEMPLATE: &str = r#"
85// RingKernel WGSL Template
86// Generated by ringkernel-wgpu
87
88// Control block binding
89struct ControlBlock {
90    is_active: u32,
91    should_terminate: u32,
92    has_terminated: u32,
93    _pad1: u32,
94    messages_processed_lo: u32,
95    messages_processed_hi: u32,
96    messages_in_flight_lo: u32,
97    messages_in_flight_hi: u32,
98    input_head_lo: u32,
99    input_head_hi: u32,
100    input_tail_lo: u32,
101    input_tail_hi: u32,
102    output_head_lo: u32,
103    output_head_hi: u32,
104    output_tail_lo: u32,
105    output_tail_hi: u32,
106    input_capacity: u32,
107    output_capacity: u32,
108    input_mask: u32,
109    output_mask: u32,
110    // HLC state (split for WGSL u32 limitation)
111    hlc_physical_lo: u32,
112    hlc_physical_hi: u32,
113    hlc_logical_lo: u32,
114    hlc_logical_hi: u32,
115    last_error: u32,
116    error_count: u32,
117}
118
119@group(0) @binding(0) var<storage, read_write> control: ControlBlock;
120@group(0) @binding(1) var<storage, read_write> input_queue: array<u32>;
121@group(0) @binding(2) var<storage, read_write> output_queue: array<u32>;
122
123// Thread identification
124var<private> thread_id: u32;
125var<private> workgroup_id: u32;
126
127@compute @workgroup_size(256)
128fn main(@builtin(global_invocation_id) global_id: vec3<u32>,
129        @builtin(workgroup_id) wg_id: vec3<u32>,
130        @builtin(local_invocation_id) local_id: vec3<u32>) {
131    thread_id = local_id.x;
132    workgroup_id = wg_id.x;
133
134    // Check if kernel should process
135    if (control.is_active == 0u) {
136        return;
137    }
138
139    // User kernel code will be inserted here
140    // USER_KERNEL_CODE
141
142    // Update message counter (simplified without 64-bit atomics)
143    if (thread_id == 0u) {
144        control.messages_processed_lo = control.messages_processed_lo + 1u;
145        if (control.messages_processed_lo == 0u) {
146            control.messages_processed_hi = control.messages_processed_hi + 1u;
147        }
148    }
149}
150"#;