1use crate::hlc::{HlcClock, HlcTimestamp};
7use crate::message::MessageEnvelope;
8use crate::types::{BlockId, Dim3, FenceScope, GlobalThreadId, MemoryOrder, ThreadId, WarpId};
9
10pub struct RingContext<'a> {
20 pub thread_id: ThreadId,
22 pub block_id: BlockId,
24 pub block_dim: Dim3,
26 pub grid_dim: Dim3,
28 clock: &'a HlcClock,
30 kernel_id: u64,
32 backend: ContextBackend,
34}
35
36#[derive(Debug, Clone)]
38pub enum ContextBackend {
39 Cpu,
41 Cuda,
43 Metal,
45 Wgpu,
47}
48
49impl<'a> RingContext<'a> {
50 pub fn new(
52 thread_id: ThreadId,
53 block_id: BlockId,
54 block_dim: Dim3,
55 grid_dim: Dim3,
56 clock: &'a HlcClock,
57 kernel_id: u64,
58 backend: ContextBackend,
59 ) -> Self {
60 Self {
61 thread_id,
62 block_id,
63 block_dim,
64 grid_dim,
65 clock,
66 kernel_id,
67 backend,
68 }
69 }
70
71 #[inline]
75 pub fn thread_id(&self) -> ThreadId {
76 self.thread_id
77 }
78
79 #[inline]
81 pub fn block_id(&self) -> BlockId {
82 self.block_id
83 }
84
85 #[inline]
87 pub fn global_thread_id(&self) -> GlobalThreadId {
88 GlobalThreadId::from_block_thread(self.block_id, self.thread_id, self.block_dim)
89 }
90
91 #[inline]
93 pub fn warp_id(&self) -> WarpId {
94 let linear = self
95 .thread_id
96 .linear_for_dim(self.block_dim.x, self.block_dim.y);
97 WarpId::from_thread_linear(linear)
98 }
99
100 #[inline]
102 pub fn lane_id(&self) -> u32 {
103 let linear = self
104 .thread_id
105 .linear_for_dim(self.block_dim.x, self.block_dim.y);
106 WarpId::lane_id(linear)
107 }
108
109 #[inline]
111 pub fn block_dim(&self) -> Dim3 {
112 self.block_dim
113 }
114
115 #[inline]
117 pub fn grid_dim(&self) -> Dim3 {
118 self.grid_dim
119 }
120
121 #[inline]
123 pub fn kernel_id(&self) -> u64 {
124 self.kernel_id
125 }
126
127 #[inline]
134 pub fn sync_threads(&self) {
135 match self.backend {
136 ContextBackend::Cpu => {
137 }
139 _ => {
140 }
143 }
144 }
145
146 #[inline]
150 pub fn sync_grid(&self) {
151 match self.backend {
152 ContextBackend::Cpu => {
153 }
155 _ => {
156 }
158 }
159 }
160
161 #[inline]
163 pub fn sync_warp(&self) {
164 match self.backend {
165 ContextBackend::Cpu => {
166 }
168 _ => {
169 }
171 }
172 }
173
174 #[inline]
178 pub fn thread_fence(&self, scope: FenceScope) {
179 match (self.backend.clone(), scope) {
180 (ContextBackend::Cpu, _) => {
181 std::sync::atomic::fence(std::sync::atomic::Ordering::SeqCst);
182 }
183 _ => {
184 }
186 }
187 }
188
189 #[inline]
191 pub fn fence_thread(&self) {
192 self.thread_fence(FenceScope::Thread);
193 }
194
195 #[inline]
197 pub fn fence_block(&self) {
198 self.thread_fence(FenceScope::Block);
199 }
200
201 #[inline]
203 pub fn fence_device(&self) {
204 self.thread_fence(FenceScope::Device);
205 }
206
207 #[inline]
209 pub fn fence_system(&self) {
210 self.thread_fence(FenceScope::System);
211 }
212
213 #[inline]
217 pub fn now(&self) -> HlcTimestamp {
218 self.clock.now()
219 }
220
221 #[inline]
223 pub fn tick(&self) -> HlcTimestamp {
224 self.clock.tick()
225 }
226
227 #[inline]
229 pub fn update_clock(&self, received: &HlcTimestamp) -> crate::error::Result<HlcTimestamp> {
230 self.clock.update(received)
231 }
232
233 #[inline]
237 pub fn atomic_add(
238 &self,
239 ptr: &std::sync::atomic::AtomicU64,
240 val: u64,
241 order: MemoryOrder,
242 ) -> u64 {
243 let ordering = match order {
244 MemoryOrder::Relaxed => std::sync::atomic::Ordering::Relaxed,
245 MemoryOrder::Acquire => std::sync::atomic::Ordering::Acquire,
246 MemoryOrder::Release => std::sync::atomic::Ordering::Release,
247 MemoryOrder::AcquireRelease => std::sync::atomic::Ordering::AcqRel,
248 MemoryOrder::SeqCst => std::sync::atomic::Ordering::SeqCst,
249 };
250 ptr.fetch_add(val, ordering)
251 }
252
253 #[inline]
255 pub fn atomic_cas(
256 &self,
257 ptr: &std::sync::atomic::AtomicU64,
258 expected: u64,
259 desired: u64,
260 success: MemoryOrder,
261 failure: MemoryOrder,
262 ) -> Result<u64, u64> {
263 let success_ord = match success {
264 MemoryOrder::Relaxed => std::sync::atomic::Ordering::Relaxed,
265 MemoryOrder::Acquire => std::sync::atomic::Ordering::Acquire,
266 MemoryOrder::Release => std::sync::atomic::Ordering::Release,
267 MemoryOrder::AcquireRelease => std::sync::atomic::Ordering::AcqRel,
268 MemoryOrder::SeqCst => std::sync::atomic::Ordering::SeqCst,
269 };
270 let failure_ord = match failure {
271 MemoryOrder::Relaxed => std::sync::atomic::Ordering::Relaxed,
272 MemoryOrder::Acquire => std::sync::atomic::Ordering::Acquire,
273 MemoryOrder::Release => std::sync::atomic::Ordering::Release,
274 MemoryOrder::AcquireRelease => std::sync::atomic::Ordering::AcqRel,
275 MemoryOrder::SeqCst => std::sync::atomic::Ordering::SeqCst,
276 };
277 ptr.compare_exchange(expected, desired, success_ord, failure_ord)
278 }
279
280 #[inline]
282 pub fn atomic_exchange(
283 &self,
284 ptr: &std::sync::atomic::AtomicU64,
285 val: u64,
286 order: MemoryOrder,
287 ) -> u64 {
288 let ordering = match order {
289 MemoryOrder::Relaxed => std::sync::atomic::Ordering::Relaxed,
290 MemoryOrder::Acquire => std::sync::atomic::Ordering::Acquire,
291 MemoryOrder::Release => std::sync::atomic::Ordering::Release,
292 MemoryOrder::AcquireRelease => std::sync::atomic::Ordering::AcqRel,
293 MemoryOrder::SeqCst => std::sync::atomic::Ordering::SeqCst,
294 };
295 ptr.swap(val, ordering)
296 }
297
298 #[inline]
304 pub fn warp_shuffle<T: Copy>(&self, value: T, src_lane: u32) -> T {
305 match self.backend {
306 ContextBackend::Cpu => {
307 let _ = src_lane;
309 value
310 }
311 _ => {
312 let _ = src_lane;
314 value
315 }
316 }
317 }
318
319 #[inline]
321 pub fn warp_shuffle_down<T: Copy>(&self, value: T, delta: u32) -> T {
322 self.warp_shuffle(value, self.lane_id().saturating_add(delta))
323 }
324
325 #[inline]
327 pub fn warp_shuffle_up<T: Copy>(&self, value: T, delta: u32) -> T {
328 self.warp_shuffle(value, self.lane_id().saturating_sub(delta))
329 }
330
331 #[inline]
333 pub fn warp_shuffle_xor<T: Copy>(&self, value: T, mask: u32) -> T {
334 self.warp_shuffle(value, self.lane_id() ^ mask)
335 }
336
337 #[inline]
339 pub fn warp_ballot(&self, predicate: bool) -> u32 {
340 match self.backend {
341 ContextBackend::Cpu => {
342 if predicate {
344 1
345 } else {
346 0
347 }
348 }
349 _ => {
350 if predicate {
352 1 << self.lane_id()
353 } else {
354 0
355 }
356 }
357 }
358 }
359
360 #[inline]
362 pub fn warp_all(&self, predicate: bool) -> bool {
363 match self.backend {
364 ContextBackend::Cpu => predicate,
365 _ => {
366 predicate
368 }
369 }
370 }
371
372 #[inline]
374 pub fn warp_any(&self, predicate: bool) -> bool {
375 match self.backend {
376 ContextBackend::Cpu => predicate,
377 _ => {
378 predicate
380 }
381 }
382 }
383
384 #[inline]
390 pub fn k2k_send(
391 &self,
392 _target_kernel: u64,
393 _envelope: &MessageEnvelope,
394 ) -> crate::error::Result<()> {
395 Err(crate::error::RingKernelError::NotSupported(
397 "K2K messaging requires runtime".to_string(),
398 ))
399 }
400
401 #[inline]
403 pub fn k2k_try_recv(&self) -> crate::error::Result<MessageEnvelope> {
404 Err(crate::error::RingKernelError::NotSupported(
406 "K2K messaging requires runtime".to_string(),
407 ))
408 }
409}
410
411impl<'a> std::fmt::Debug for RingContext<'a> {
412 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413 f.debug_struct("RingContext")
414 .field("thread_id", &self.thread_id)
415 .field("block_id", &self.block_id)
416 .field("block_dim", &self.block_dim)
417 .field("grid_dim", &self.grid_dim)
418 .field("kernel_id", &self.kernel_id)
419 .field("backend", &self.backend)
420 .finish()
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 fn make_test_context(clock: &HlcClock) -> RingContext<'_> {
429 RingContext::new(
430 ThreadId::new_1d(0),
431 BlockId::new_1d(0),
432 Dim3::new_1d(256),
433 Dim3::new_1d(1),
434 clock,
435 1,
436 ContextBackend::Cpu,
437 )
438 }
439
440 #[test]
441 fn test_thread_identity() {
442 let clock = HlcClock::new(1);
443 let ctx = make_test_context(&clock);
444
445 assert_eq!(ctx.thread_id().x, 0);
446 assert_eq!(ctx.block_id().x, 0);
447 assert_eq!(ctx.global_thread_id().x, 0);
448 }
449
450 #[test]
451 fn test_warp_id() {
452 let clock = HlcClock::new(1);
453 let ctx = RingContext::new(
454 ThreadId::new_1d(35), BlockId::new_1d(0),
456 Dim3::new_1d(256),
457 Dim3::new_1d(1),
458 &clock,
459 1,
460 ContextBackend::Cpu,
461 );
462
463 assert_eq!(ctx.warp_id().0, 1);
464 assert_eq!(ctx.lane_id(), 3);
465 }
466
467 #[test]
468 fn test_hlc_operations() {
469 let clock = HlcClock::new(1);
470 let ctx = make_test_context(&clock);
471
472 let ts1 = ctx.now();
473 let ts2 = ctx.tick();
474 assert!(ts2 >= ts1);
475 }
476
477 #[test]
478 fn test_warp_ballot_cpu() {
479 let clock = HlcClock::new(1);
480 let ctx = make_test_context(&clock);
481
482 assert_eq!(ctx.warp_ballot(true), 1);
483 assert_eq!(ctx.warp_ballot(false), 0);
484 }
485}