Class MetalTensorOperations
- Namespace
- DotCompute.Backends.Metal.Execution
- Assembly
- DotCompute.Backends.Metal.dll
Metal tensor operations using Metal Performance Shaders (MPS) for ML workloads. Provides high-performance matrix operations, convolution, and batch normalization optimized for Apple Silicon and Metal GPUs.
public sealed class MetalTensorOperations : IDisposable
- Inheritance
-
MetalTensorOperations
- Implements
- Inherited Members
- Extension Methods
Constructors
MetalTensorOperations(nint, nint, ILogger<MetalTensorOperations>, MetalPerformanceShadersBackend?)
public MetalTensorOperations(nint device, nint commandQueue, ILogger<MetalTensorOperations> logger, MetalPerformanceShadersBackend? mpsBackend = null)
Parameters
devicenintcommandQueuenintloggerILogger<MetalTensorOperations>mpsBackendMetalPerformanceShadersBackend
Methods
BatchNormAsync(float[], float[], float[], int, int, int, int, CancellationToken)
Performs batch normalization using MPS.
public Task<float[]> BatchNormAsync(float[] input, float[] scale, float[] bias, int batchSize, int channels, int height, int width, CancellationToken cancellationToken = default)
Parameters
inputfloat[]scalefloat[]biasfloat[]batchSizeintchannelsintheightintwidthintcancellationTokenCancellationToken
Returns
ConvolutionAsync(float[], float[], ConvolutionConfig, CancellationToken)
Performs 2D convolution using MPS.
public Task<float[]> ConvolutionAsync(float[] input, float[] kernel, ConvolutionConfig config, CancellationToken cancellationToken = default)
Parameters
inputfloat[]kernelfloat[]configConvolutionConfigcancellationTokenCancellationToken
Returns
Dispose()
Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
public void Dispose()
GetMetrics()
Gets performance metrics for tensor operations.
public TensorPerformanceMetrics GetMetrics()
Returns
MatrixMultiplyAsync(float[], float[], int, int, int, CancellationToken)
Performs matrix multiplication using MPS: C = A × B.
public Task<float[]> MatrixMultiplyAsync(float[] a, float[] b, int m, int n, int k, CancellationToken cancellationToken = default)