Ich versuche, die LAPACK-Funktion dgtsv
(ein Löser für tridiagonal Systeme von Gleichungen) mit Cython zu wickeln.Wrapping einer LAPACKE-Funktion mit Cython
Ich stieß auf this previous answer, aber seit dgtsv
ist nicht eine der LAPACK-Funktionen, die in scipy.linalg
gewickelt sind Ich glaube nicht, dass ich diesen bestimmten Ansatz verwenden kann. Stattdessen habe ich versucht, this example zu folgen.
Hier ist der Inhalt meiner lapacke.pxd
Datei:
ctypedef int lapack_int
cdef extern from "lapacke.h" nogil:
int LAPACK_ROW_MAJOR
int LAPACK_COL_MAJOR
lapack_int LAPACKE_dgtsv(int matrix_order,
lapack_int n,
lapack_int nrhs,
double * dl,
double * d,
double * du,
double * b,
lapack_int ldb)
... hier ist mein dünner Cython Wrapper in _solvers.pyx
:
#!python
cimport cython
from lapacke cimport *
cpdef TDMA_lapacke(double[::1] DL, double[::1] D, double[::1] DU,
double[:, ::1] B):
cdef:
lapack_int n = D.shape[0]
lapack_int nrhs = B.shape[1]
lapack_int ldb = B.shape[0]
double * dl = &DL[0]
double * d = &D[0]
double * du = &DU[0]
double * b = &B[0, 0]
lapack_int info
info = LAPACKE_dgtsv(LAPACK_ROW_MAJOR, n, nrhs, dl, d, du, b, ldb)
return info
... und hier ist ein Python-Wrapper und Testskript:
import numpy as np
from scipy import sparse
from cymodules import _solvers
def trisolve_lapacke(dl, d, du, b, inplace=False):
if (dl.shape[0] != du.shape[0] or dl.shape[0] != d.shape[0] - 1
or b.shape != d.shape):
raise ValueError('Invalid diagonal shapes')
if b.ndim == 1:
# b is (LDB, NRHS)
b = b[:, None]
# be sure to force a copy of d and b if we're not solving in place
if not inplace:
d = d.copy()
b = b.copy()
# this may also force copies if arrays are improperly typed/noncontiguous
dl, d, du, b = (np.ascontiguousarray(v, dtype=np.float64)
for v in (dl, d, du, b))
# b will now be modified in place to contain the solution
info = _solvers.TDMA_lapacke(dl, d, du, b)
print info
return b.ravel()
def test_trisolve(n=20000):
dl = np.random.randn(n - 1)
d = np.random.randn(n)
du = np.random.randn(n - 1)
M = sparse.diags((dl, d, du), (-1, 0, 1), format='csc')
x = np.random.randn(n)
b = M.dot(x)
x_hat = trisolve_lapacke(dl, d, du, b)
print "||x - x_hat|| = ", np.linalg.norm(x - x_hat)
Leider test_trisolve
nur se Standardwerte für den Anruf an _solvers.TDMA_lapacke
. Ich bin mir ziemlich sicher, dass meine setup.py
korrekt ist - ldd _solvers.so
zeigt, dass _solvers.so
zur Laufzeit mit den richtigen gemeinsam genutzten Bibliotheken verknüpft ist.
Ich bin nicht wirklich sicher, wie man von hier aus vorgeht - irgendwelche Ideen?
Ein kurzes Update:
für kleinere Werte von n
Ich neige dazu, nicht segfaults sofort zu bekommen, aber ich Unsinn Ergebnisse erhalten (|| x - x_hat || sollte sehr sein nahe bei 0):
In [28]: test_trisolve2.test_trisolve(10)
0
||x - x_hat|| = 6.23202576396
In [29]: test_trisolve2.test_trisolve(10)
-7
||x - x_hat|| = 3.88623414288
In [30]: test_trisolve2.test_trisolve(10)
0
||x - x_hat|| = 2.60190676562
In [31]: test_trisolve2.test_trisolve(10)
0
||x - x_hat|| = 3.86631743386
In [32]: test_trisolve2.test_trisolve(10)
Segmentation fault
Normalerweise LAPACKE_dgtsv
kehrt mit Code 0
(der Erfolg zeigen sollte), aber ich gelegentlich -7
, was bedeutet, dass das Argument 7 (b
) einen ungültigen Wert hatte. Was passiert, ist, dass nur der erste Wert von b
tatsächlich geändert wird. Wenn ich weiter test_trisolve
aufrufen werde ich schließlich einen segfault treffen, selbst wenn n
klein ist.