ringkernel_core/
context.rs

1//! Ring context providing GPU intrinsics facade for kernel handlers.
2//!
3//! The RingContext provides a unified interface for GPU operations that
4//! abstracts over different backends (CUDA, Metal, WebGPU, CPU).
5
6use crate::hlc::{HlcClock, HlcTimestamp};
7use crate::message::MessageEnvelope;
8use crate::types::{BlockId, Dim3, FenceScope, GlobalThreadId, MemoryOrder, ThreadId, WarpId};
9
10/// GPU intrinsics facade for kernel handlers.
11///
12/// This struct provides access to GPU-specific operations like thread
13/// identification, synchronization, and atomic operations. The actual
14/// implementation varies by backend.
15///
16/// # Lifetime
17///
18/// The context is borrowed for the duration of the kernel handler execution.
19pub struct RingContext<'a> {
20    /// Thread identity within block.
21    pub thread_id: ThreadId,
22    /// Block identity within grid.
23    pub block_id: BlockId,
24    /// Block dimensions.
25    pub block_dim: Dim3,
26    /// Grid dimensions.
27    pub grid_dim: Dim3,
28    /// HLC clock instance.
29    clock: &'a HlcClock,
30    /// Kernel ID.
31    kernel_id: u64,
32    /// Backend implementation.
33    backend: ContextBackend,
34}
35
36/// Backend-specific context implementation.
37#[derive(Debug, Clone)]
38pub enum ContextBackend {
39    /// CPU backend (for testing).
40    Cpu,
41    /// CUDA backend.
42    Cuda,
43    /// Metal backend.
44    Metal,
45    /// WebGPU backend.
46    Wgpu,
47}
48
49impl<'a> RingContext<'a> {
50    /// Create a new context.
51    pub fn new(
52        thread_id: ThreadId,
53        block_id: BlockId,
54        block_dim: Dim3,
55        grid_dim: Dim3,
56        clock: &'a HlcClock,
57        kernel_id: u64,
58        backend: ContextBackend,
59    ) -> Self {
60        Self {
61            thread_id,
62            block_id,
63            block_dim,
64            grid_dim,
65            clock,
66            kernel_id,
67            backend,
68        }
69    }
70
71    // === Thread Identity ===
72
73    /// Get thread ID within block.
74    #[inline]
75    pub fn thread_id(&self) -> ThreadId {
76        self.thread_id
77    }
78
79    /// Get block ID within grid.
80    #[inline]
81    pub fn block_id(&self) -> BlockId {
82        self.block_id
83    }
84
85    /// Get global thread ID across all blocks.
86    #[inline]
87    pub fn global_thread_id(&self) -> GlobalThreadId {
88        GlobalThreadId::from_block_thread(self.block_id, self.thread_id, self.block_dim)
89    }
90
91    /// Get warp ID within block.
92    #[inline]
93    pub fn warp_id(&self) -> WarpId {
94        let linear = self
95            .thread_id
96            .linear_for_dim(self.block_dim.x, self.block_dim.y);
97        WarpId::from_thread_linear(linear)
98    }
99
100    /// Get lane ID within warp (0-31).
101    #[inline]
102    pub fn lane_id(&self) -> u32 {
103        let linear = self
104            .thread_id
105            .linear_for_dim(self.block_dim.x, self.block_dim.y);
106        WarpId::lane_id(linear)
107    }
108
109    /// Get block dimensions.
110    #[inline]
111    pub fn block_dim(&self) -> Dim3 {
112        self.block_dim
113    }
114
115    /// Get grid dimensions.
116    #[inline]
117    pub fn grid_dim(&self) -> Dim3 {
118        self.grid_dim
119    }
120
121    /// Get kernel ID.
122    #[inline]
123    pub fn kernel_id(&self) -> u64 {
124        self.kernel_id
125    }
126
127    // === Synchronization ===
128
129    /// Synchronize all threads in the block.
130    ///
131    /// All threads in the block must reach this barrier before any
132    /// thread can proceed past it.
133    #[inline]
134    pub fn sync_threads(&self) {
135        match self.backend {
136            ContextBackend::Cpu => {
137                // CPU: no-op (single-threaded simulation)
138            }
139            _ => {
140                // GPU backends would call __syncthreads() or equivalent
141                // Placeholder for actual implementation
142            }
143        }
144    }
145
146    /// Synchronize all threads in the grid (cooperative groups).
147    ///
148    /// Requires cooperative kernel launch support.
149    #[inline]
150    pub fn sync_grid(&self) {
151        match self.backend {
152            ContextBackend::Cpu => {
153                // CPU: no-op
154            }
155            _ => {
156                // GPU backends would call cooperative grid sync
157            }
158        }
159    }
160
161    /// Synchronize threads within a warp.
162    #[inline]
163    pub fn sync_warp(&self) {
164        match self.backend {
165            ContextBackend::Cpu => {
166                // CPU: no-op
167            }
168            _ => {
169                // GPU backends would call __syncwarp()
170            }
171        }
172    }
173
174    // === Memory Fencing ===
175
176    /// Memory fence at the specified scope.
177    #[inline]
178    pub fn thread_fence(&self, scope: FenceScope) {
179        match (self.backend.clone(), scope) {
180            (ContextBackend::Cpu, _) => {
181                std::sync::atomic::fence(std::sync::atomic::Ordering::SeqCst);
182            }
183            _ => {
184                // GPU backends would call appropriate fence intrinsic
185            }
186        }
187    }
188
189    /// Thread-scope fence (compiler barrier).
190    #[inline]
191    pub fn fence_thread(&self) {
192        self.thread_fence(FenceScope::Thread);
193    }
194
195    /// Block-scope fence.
196    #[inline]
197    pub fn fence_block(&self) {
198        self.thread_fence(FenceScope::Block);
199    }
200
201    /// Device-scope fence.
202    #[inline]
203    pub fn fence_device(&self) {
204        self.thread_fence(FenceScope::Device);
205    }
206
207    /// System-scope fence (CPU+GPU visible).
208    #[inline]
209    pub fn fence_system(&self) {
210        self.thread_fence(FenceScope::System);
211    }
212
213    // === HLC Operations ===
214
215    /// Get current HLC timestamp.
216    #[inline]
217    pub fn now(&self) -> HlcTimestamp {
218        self.clock.now()
219    }
220
221    /// Generate a new HLC timestamp (advances clock).
222    #[inline]
223    pub fn tick(&self) -> HlcTimestamp {
224        self.clock.tick()
225    }
226
227    /// Update clock with received timestamp.
228    #[inline]
229    pub fn update_clock(&self, received: &HlcTimestamp) -> crate::error::Result<HlcTimestamp> {
230        self.clock.update(received)
231    }
232
233    // === Atomic Operations ===
234
235    /// Atomic add and return old value.
236    #[inline]
237    pub fn atomic_add(
238        &self,
239        ptr: &std::sync::atomic::AtomicU64,
240        val: u64,
241        order: MemoryOrder,
242    ) -> u64 {
243        let ordering = match order {
244            MemoryOrder::Relaxed => std::sync::atomic::Ordering::Relaxed,
245            MemoryOrder::Acquire => std::sync::atomic::Ordering::Acquire,
246            MemoryOrder::Release => std::sync::atomic::Ordering::Release,
247            MemoryOrder::AcquireRelease => std::sync::atomic::Ordering::AcqRel,
248            MemoryOrder::SeqCst => std::sync::atomic::Ordering::SeqCst,
249        };
250        ptr.fetch_add(val, ordering)
251    }
252
253    /// Atomic compare-and-swap.
254    #[inline]
255    pub fn atomic_cas(
256        &self,
257        ptr: &std::sync::atomic::AtomicU64,
258        expected: u64,
259        desired: u64,
260        success: MemoryOrder,
261        failure: MemoryOrder,
262    ) -> Result<u64, u64> {
263        let success_ord = match success {
264            MemoryOrder::Relaxed => std::sync::atomic::Ordering::Relaxed,
265            MemoryOrder::Acquire => std::sync::atomic::Ordering::Acquire,
266            MemoryOrder::Release => std::sync::atomic::Ordering::Release,
267            MemoryOrder::AcquireRelease => std::sync::atomic::Ordering::AcqRel,
268            MemoryOrder::SeqCst => std::sync::atomic::Ordering::SeqCst,
269        };
270        let failure_ord = match failure {
271            MemoryOrder::Relaxed => std::sync::atomic::Ordering::Relaxed,
272            MemoryOrder::Acquire => std::sync::atomic::Ordering::Acquire,
273            MemoryOrder::Release => std::sync::atomic::Ordering::Release,
274            MemoryOrder::AcquireRelease => std::sync::atomic::Ordering::AcqRel,
275            MemoryOrder::SeqCst => std::sync::atomic::Ordering::SeqCst,
276        };
277        ptr.compare_exchange(expected, desired, success_ord, failure_ord)
278    }
279
280    /// Atomic exchange.
281    #[inline]
282    pub fn atomic_exchange(
283        &self,
284        ptr: &std::sync::atomic::AtomicU64,
285        val: u64,
286        order: MemoryOrder,
287    ) -> u64 {
288        let ordering = match order {
289            MemoryOrder::Relaxed => std::sync::atomic::Ordering::Relaxed,
290            MemoryOrder::Acquire => std::sync::atomic::Ordering::Acquire,
291            MemoryOrder::Release => std::sync::atomic::Ordering::Release,
292            MemoryOrder::AcquireRelease => std::sync::atomic::Ordering::AcqRel,
293            MemoryOrder::SeqCst => std::sync::atomic::Ordering::SeqCst,
294        };
295        ptr.swap(val, ordering)
296    }
297
298    // === Warp Primitives ===
299
300    /// Warp shuffle - get value from another lane.
301    ///
302    /// Returns the value from the specified source lane.
303    #[inline]
304    pub fn warp_shuffle<T: Copy>(&self, value: T, src_lane: u32) -> T {
305        match self.backend {
306            ContextBackend::Cpu => {
307                // CPU: just return own value (no other lanes)
308                let _ = src_lane;
309                value
310            }
311            _ => {
312                // GPU would use __shfl_sync()
313                let _ = src_lane;
314                value
315            }
316        }
317    }
318
319    /// Warp shuffle down - get value from lane + delta.
320    #[inline]
321    pub fn warp_shuffle_down<T: Copy>(&self, value: T, delta: u32) -> T {
322        self.warp_shuffle(value, self.lane_id().saturating_add(delta))
323    }
324
325    /// Warp shuffle up - get value from lane - delta.
326    #[inline]
327    pub fn warp_shuffle_up<T: Copy>(&self, value: T, delta: u32) -> T {
328        self.warp_shuffle(value, self.lane_id().saturating_sub(delta))
329    }
330
331    /// Warp shuffle XOR - get value from lane XOR mask.
332    #[inline]
333    pub fn warp_shuffle_xor<T: Copy>(&self, value: T, mask: u32) -> T {
334        self.warp_shuffle(value, self.lane_id() ^ mask)
335    }
336
337    /// Warp ballot - get bitmask of lanes where predicate is true.
338    #[inline]
339    pub fn warp_ballot(&self, predicate: bool) -> u32 {
340        match self.backend {
341            ContextBackend::Cpu => {
342                // CPU: single thread, return 1 or 0
343                if predicate {
344                    1
345                } else {
346                    0
347                }
348            }
349            _ => {
350                // GPU would use __ballot_sync()
351                if predicate {
352                    1 << self.lane_id()
353                } else {
354                    0
355                }
356            }
357        }
358    }
359
360    /// Warp all - check if predicate is true for all lanes.
361    #[inline]
362    pub fn warp_all(&self, predicate: bool) -> bool {
363        match self.backend {
364            ContextBackend::Cpu => predicate,
365            _ => {
366                // GPU would use __all_sync()
367                predicate
368            }
369        }
370    }
371
372    /// Warp any - check if predicate is true for any lane.
373    #[inline]
374    pub fn warp_any(&self, predicate: bool) -> bool {
375        match self.backend {
376            ContextBackend::Cpu => predicate,
377            _ => {
378                // GPU would use __any_sync()
379                predicate
380            }
381        }
382    }
383
384    // === K2K Messaging ===
385
386    /// Send message to another kernel (K2K).
387    ///
388    /// This is a placeholder; actual implementation requires runtime support.
389    #[inline]
390    pub fn k2k_send(
391        &self,
392        _target_kernel: u64,
393        _envelope: &MessageEnvelope,
394    ) -> crate::error::Result<()> {
395        // K2K messaging requires runtime bridge support
396        Err(crate::error::RingKernelError::NotSupported(
397            "K2K messaging requires runtime".to_string(),
398        ))
399    }
400
401    /// Try to receive message from K2K queue.
402    #[inline]
403    pub fn k2k_try_recv(&self) -> crate::error::Result<MessageEnvelope> {
404        // K2K messaging requires runtime bridge support
405        Err(crate::error::RingKernelError::NotSupported(
406            "K2K messaging requires runtime".to_string(),
407        ))
408    }
409}
410
411impl<'a> std::fmt::Debug for RingContext<'a> {
412    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413        f.debug_struct("RingContext")
414            .field("thread_id", &self.thread_id)
415            .field("block_id", &self.block_id)
416            .field("block_dim", &self.block_dim)
417            .field("grid_dim", &self.grid_dim)
418            .field("kernel_id", &self.kernel_id)
419            .field("backend", &self.backend)
420            .finish()
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    fn make_test_context(clock: &HlcClock) -> RingContext<'_> {
429        RingContext::new(
430            ThreadId::new_1d(0),
431            BlockId::new_1d(0),
432            Dim3::new_1d(256),
433            Dim3::new_1d(1),
434            clock,
435            1,
436            ContextBackend::Cpu,
437        )
438    }
439
440    #[test]
441    fn test_thread_identity() {
442        let clock = HlcClock::new(1);
443        let ctx = make_test_context(&clock);
444
445        assert_eq!(ctx.thread_id().x, 0);
446        assert_eq!(ctx.block_id().x, 0);
447        assert_eq!(ctx.global_thread_id().x, 0);
448    }
449
450    #[test]
451    fn test_warp_id() {
452        let clock = HlcClock::new(1);
453        let ctx = RingContext::new(
454            ThreadId::new_1d(35), // Thread 35 is in warp 1, lane 3
455            BlockId::new_1d(0),
456            Dim3::new_1d(256),
457            Dim3::new_1d(1),
458            &clock,
459            1,
460            ContextBackend::Cpu,
461        );
462
463        assert_eq!(ctx.warp_id().0, 1);
464        assert_eq!(ctx.lane_id(), 3);
465    }
466
467    #[test]
468    fn test_hlc_operations() {
469        let clock = HlcClock::new(1);
470        let ctx = make_test_context(&clock);
471
472        let ts1 = ctx.now();
473        let ts2 = ctx.tick();
474        assert!(ts2 >= ts1);
475    }
476
477    #[test]
478    fn test_warp_ballot_cpu() {
479        let clock = HlcClock::new(1);
480        let ctx = make_test_context(&clock);
481
482        assert_eq!(ctx.warp_ballot(true), 1);
483        assert_eq!(ctx.warp_ballot(false), 0);
484    }
485}