1 /**
2    cuBLAS wrappers
3  */
4 module grain.cublas;
5 version (grain_cuda)  : import derelict.cuda;
6 
7 /++
8  TODO: make derelict-cublas (already exist?)
9  +/
10 
11 extern (C):
12 
13 alias cublasStatus_t = int;
14 ///
15 enum : cublasStatus_t {
16     CUBLAS_STATUS_SUCCESS = 0,
17     CUBLAS_STATUS_NOT_INITIALIZED = 1,
18     CUBLAS_STATUS_ALLOC_FAILED = 3,
19     CUBLAS_STATUS_INVALID_VALUE = 7,
20     CUBLAS_STATUS_ARCH_MISMATCH = 8,
21     CUBLAS_STATUS_MAPPING_ERROR = 11,
22     CUBLAS_STATUS_EXECUTION_FAILED = 13,
23     CUBLAS_STATUS_INTERNAL_ERROR = 14,
24     CUBLAS_STATUS_NOT_SUPPORTED = 15,
25     CUBLAS_STATUS_LICENSE_ERROR = 16
26 }
27 
28 ///
29 struct cublasContext;
30 ///
31 alias cublasHandle_t = cublasContext*;
32 ///
33 alias cublasOperation_t = int;
34 ///
35 enum : cublasOperation_t {
36     CUBLAS_OP_N, // the non-transpose operation is selected
37     CUBLAS_OP_T, // the transpose operation is selected
38     CUBLAS_OP_C // the conjugate transpose operation is selected
39 }
40 
41 // TODO: parse and retrieve cublas_api.h
42 
43 ///
44 cublasStatus_t cublasCreate_v2(cublasHandle_t*);
45 ///
46 cublasStatus_t cublasDestroy_v2(cublasHandle_t handle);
47 ///
48 cublasStatus_t cublasSgemm_v2(cublasHandle_t handle, cublasOperation_t transa,
49         cublasOperation_t transb,
50         int m, int n, int k, const float* alpha, const float* A, int lda,
51         const float* B, int ldb, const float* beta, float* C, int ldc);
52 
53 /*
54 cublasStatus_t cublasSgemm_v2(cublasHandle_t handle,
55                            cublasOperation_t transa, cublasOperation_t transb,
56                            int m, int n, int k,
57                            const float           *alpha,
58                            const float           *A, int lda,
59                            const float           *B, int ldb,
60                            const float           *beta,
61                            float           *C, int ldc);
62 */
63 
64 ///
65 cublasStatus_t cublasDgemm_v2(cublasHandle_t handle, cublasOperation_t transa,
66         cublasOperation_t transb,
67         int m, int n, int k, const double* alpha, const double* A, int lda,
68         const double* B, int ldb, const double* beta, double* C, int ldc);
69 ///
70 cublasStatus_t cublasSaxpy_v2(cublasHandle_t handle, int n, const float* alpha,
71         const float* x, int incx, float* y, int incy);
72 ///
73 cublasStatus_t cublasDaxpy_v2(cublasHandle_t handle, int n, const double* alpha,
74         const double* x, int incx, double* y, int incy);
75 
76 /// emit error message string from enum
77 auto cublasGetErrorEnum(cublasStatus_t error) {
78     final switch (error) {
79     case CUBLAS_STATUS_SUCCESS:
80         return "CUBLAS_STATUS_SUCCESS";
81 
82     case CUBLAS_STATUS_NOT_INITIALIZED:
83         return "CUBLAS_STATUS_NOT_INITIALIZED";
84 
85     case CUBLAS_STATUS_ALLOC_FAILED:
86         return "CUBLAS_STATUS_ALLOC_FAILED";
87 
88     case CUBLAS_STATUS_INVALID_VALUE:
89         return "CUBLAS_STATUS_INVALID_VALUE";
90 
91     case CUBLAS_STATUS_ARCH_MISMATCH:
92         return "CUBLAS_STATUS_ARCH_MISMATCH";
93 
94     case CUBLAS_STATUS_MAPPING_ERROR:
95         return "CUBLAS_STATUS_MAPPING_ERROR";
96 
97     case CUBLAS_STATUS_EXECUTION_FAILED:
98         return "CUBLAS_STATUS_EXECUTION_FAILED";
99 
100     case CUBLAS_STATUS_INTERNAL_ERROR:
101         return "CUBLAS_STATUS_INTERNAL_ERROR";
102     }
103 }