-
Notifications
You must be signed in to change notification settings - Fork 38
Expand file tree
/
Copy pathcusolverdx_samples_potrf.py
More file actions
83 lines (62 loc) · 2.27 KB
/
Copy pathcusolverdx_samples_potrf.py
File metadata and controls
83 lines (62 loc) · 2.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: Apache-2.0
#
# Mirrors https://github.com/NVIDIA/CUDALibrarySamples/tree/main/MathDx/cuSolverDx/01_Linear_Solve/potrf.cu
#
import warnings
import numpy as np
import scipy
from common import prepare_random_matrix, verify_triangular_relative_error
from common_numba import load_to_shared_strided, store_from_shared_strided
from numba import cuda
from nvmath.device import CholeskySolver
warnings.filterwarnings("ignore", module="numba")
_PREC_TABLE = {
np.float32: 1e-4,
np.float64: 1e-6,
}
_ABS_ERR = 1e-8
def main():
solver = CholeskySolver(
size=(32, 32),
precision=np.float64,
data_type="complex",
execution="Block",
leading_dimensions=(33, 33),
block_dim=(256, 1, 1),
fill_mode="upper",
)
n_a = solver.a_size()
@cuda.jit
def kernel(a, info):
smem_a = cuda.shared.array(n_a, dtype=solver.value_type)
load_to_shared_strided(a, smem_a, solver.a_shape, solver.a_strides())
cuda.syncthreads()
solver.factorize(smem_a, info)
cuda.syncthreads()
store_from_shared_strided(smem_a, a, solver.a_shape, solver.a_strides())
print(f"Use compile-time leading dimension LDA for shared memory = {solver.lda}")
a = prepare_random_matrix(
solver.a_shape,
solver.precision,
is_complex=True,
is_hermitian=False, # input A is not symmetric
is_diag_dom=True,
)
a_d = cuda.to_device(a)
info_d = cuda.device_array(solver.info_shape, dtype=solver.info_type)
kernel[1, solver.block_dim](a_d, info_d)
cuda.synchronize()
a_result = a_d.copy_to_host()
info_result = info_d.copy_to_host()
print(f"after cuSolverDx potrf kernel: info = {info_result[0]}")
if info_result[0] != 0:
raise RuntimeError(f"{info_result[0]}-th parameter is wrong")
scipy_result = scipy.linalg.cholesky(a[0], lower=(solver.fill_mode == "lower"))
error = verify_triangular_relative_error(
a_result[0], scipy_result, _PREC_TABLE[solver.precision], _ABS_ERR, solver, solver.fill_mode
)
print(f"Successfully validated against scipy reference, with error: {error}")
if __name__ == "__main__":
main()