1#![warn(missing_docs)]
34
35#[cfg(feature = "ptx-cache")]
36pub mod compile;
37#[cfg(feature = "cooperative")]
38pub mod cooperative;
39#[cfg(feature = "cuda")]
40mod device;
41#[cfg(feature = "cuda")]
42pub mod driver_api;
43#[cfg(feature = "cuda")]
44pub mod k2k_gpu;
45#[cfg(feature = "cuda")]
46mod kernel;
47#[cfg(feature = "cuda")]
48pub mod launch_config;
49#[cfg(feature = "cuda")]
50mod memory;
51#[cfg(feature = "cuda")]
52pub mod memory_pool;
53#[cfg(feature = "cuda")]
54pub mod persistent;
55#[cfg(feature = "cuda")]
56pub mod phases;
57#[cfg(feature = "profiling")]
58pub mod profiling;
59#[cfg(feature = "cuda")]
60pub mod reduction;
61#[cfg(feature = "cuda")]
62mod runtime;
63#[cfg(feature = "cuda")]
64mod stencil;
65#[cfg(feature = "cuda")]
66pub mod stream;
67
68#[cfg(feature = "cuda")]
69pub use device::CudaDevice;
70#[cfg(feature = "cuda")]
71pub use kernel::CudaKernel;
72#[cfg(feature = "cuda")]
73pub use memory::{CudaBuffer, CudaControlBlock, CudaMemoryPool, CudaMessageQueue};
74#[cfg(feature = "cuda")]
75pub use persistent::CudaMappedBuffer;
76#[cfg(feature = "cuda")]
77pub use phases::{
78 InterPhaseReduction, KernelPhase, MultiPhaseConfig, MultiPhaseExecutor, PhaseExecutionStats,
79 SyncMode,
80};
81#[cfg(feature = "cuda")]
82pub use reduction::{
83 generate_block_reduce_code, generate_grid_reduce_code, generate_reduce_and_broadcast_code,
84 CacheKey, CacheStats, CachedReductionBuffer, ReductionBuffer, ReductionBufferBuilder,
85 ReductionBufferCache,
86};
87#[cfg(feature = "cuda")]
88pub use runtime::CudaRuntime;
89#[cfg(feature = "cuda")]
90pub use stencil::{CompiledStencilKernel, LaunchConfig, StencilKernelLoader};
91
92#[cfg(feature = "profiling")]
94pub use profiling::{
95 CudaEvent, CudaEventFlags, CudaMemoryKind, CudaMemoryTracker, CudaNvtxProfiler,
96 GpuChromeTraceBuilder, GpuEventArgs, GpuTimer, GpuTimerPool, GpuTraceEvent, KernelMetrics,
97 ProfilingSession, TrackedAllocation, TransferDirection, TransferMetrics,
98};
99
100#[cfg(feature = "ptx-cache")]
102pub use compile::{PtxCache, PtxCacheError, PtxCacheResult, PtxCacheStats, CACHE_VERSION};
103
104#[cfg(feature = "cuda")]
106pub use memory_pool::{
107 GpuBucketStats, GpuPoolConfig, GpuPoolDiagnostics, GpuSizeClass, GpuStratifiedPool,
108};
109
110#[cfg(feature = "cuda")]
112pub use stream::{
113 OverlapMetrics, StreamConfig, StreamConfigBuilder, StreamError, StreamId, StreamManager,
114 StreamPool, StreamPoolStats, StreamResult,
115};
116
117#[cfg(feature = "cuda")]
119pub mod memory_exports {
120 pub use super::memory::{CudaBuffer, CudaControlBlock, CudaMemoryPool, CudaMessageQueue};
121}
122
123#[cfg(not(feature = "cuda"))]
125mod stub {
126 ringkernel_core::unavailable_backend!(
127 CudaRuntime,
128 ringkernel_core::runtime::Backend::Cuda,
129 "CUDA"
130 );
131}
132
133#[cfg(not(feature = "cuda"))]
134pub use stub::CudaRuntime;
135
136pub fn is_cuda_available() -> bool {
145 #[cfg(feature = "cuda")]
146 {
147 std::panic::catch_unwind(|| {
149 cudarc::driver::CudaContext::device_count()
150 .map(|c| c > 0)
151 .unwrap_or(false)
152 })
153 .unwrap_or(false)
154 }
155 #[cfg(not(feature = "cuda"))]
156 {
157 false
158 }
159}
160
161pub fn cuda_device_count() -> usize {
165 #[cfg(feature = "cuda")]
166 {
167 std::panic::catch_unwind(|| {
169 cudarc::driver::CudaContext::device_count().unwrap_or(0) as usize
170 })
171 .unwrap_or(0)
172 }
173 #[cfg(not(feature = "cuda"))]
174 {
175 0
176 }
177}
178
179#[cfg(feature = "cuda")]
207pub fn compile_ptx(cuda_source: &str) -> ringkernel_core::error::Result<String> {
208 use ringkernel_core::error::RingKernelError;
209
210 let ptx = cudarc::nvrtc::compile_ptx(cuda_source).map_err(|e| {
211 RingKernelError::CompilationError(format!("NVRTC compilation failed: {}", e))
212 })?;
213
214 Ok(ptx.to_src().to_string())
215}
216
217#[cfg(not(feature = "cuda"))]
219pub fn compile_ptx(_cuda_source: &str) -> ringkernel_core::error::Result<String> {
220 Err(ringkernel_core::error::RingKernelError::BackendUnavailable(
221 "CUDA feature not enabled".to_string(),
222 ))
223}
224
225pub const RING_KERNEL_PTX_TEMPLATE: &str = r#"
230.version 8.0
231.target sm_89
232.address_size 64
233
234.visible .entry ring_kernel_main(
235 .param .u64 control_block_ptr,
236 .param .u64 input_queue_ptr,
237 .param .u64 output_queue_ptr,
238 .param .u64 shared_state_ptr
239) {
240 .reg .u64 %cb_ptr;
241 .reg .u32 %one;
242
243 // Load control block pointer
244 ld.param.u64 %cb_ptr, [control_block_ptr];
245
246 // Mark as terminated immediately (offset 8)
247 mov.u32 %one, 1;
248 st.global.u32 [%cb_ptr + 8], %one;
249
250 ret;
251}
252"#;