Ich versuche, eine benutzerdefinierte Metrik mit sklearn.neighbors.BallTree zu verwenden, aber wenn sie meine Metrik aufruft, sehen die Eingaben nicht korrekt aus. Wenn ich scipy.spatial.distance.pdist mit derselben benutzerdefinierten Metrik verwende, funktioniert es wie erwartet. Wenn ich versuche, einen BallTree instanziieren, wird eine Ausnahme ausgelöst, wenn ich versuche, die Eingabe umzuformen. Wenn ich mir die tatsächlichen Eingaben ansehe, sehen die Form und die Werte nicht korrekt aus.Verwenden einer benutzerdefinierten Metrik mit sklearn.neighbors.BallTree gibt falsche Eingabe?
import numpy as np
import scipy.spatial.distance as spdist
import sklearn.neighbors.ball_tree as ball_tree
# custom metric
def minimum_average_direct_flip(x, y):
x = np.reshape(x, (-1, 3))
y = np.reshape(y, (-1, 3))
direct = np.mean(np.sqrt(np.sum(np.square(x - y), axis=1)))
flipped = np.mean(np.sqrt(np.sum(np.square(np.flipud(x) - y), axis=1)))
return min(direct, flipped)
# create an X to test
X = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9], [11, 12, 13, 14, 15, 16, 17, 18, 19], [21, 22, 23, 24, 25, 26, 27, 28, 29]])
# works as expected
distances = spdist.pdist(X, metric=minimum_average_direct_flip)
# outputs: [ 17.32050808 34.64101615 17.32050808]
print distances
# raises exception, inputs to minimum_average_direct_flip look wrong
# Traceback (most recent call last):
# File ".../test_script.py", line 23, in <module>
# ball_tree.BallTree(X, metric=minimum_average_direct_flip)
# File "sklearn/neighbors/binary_tree.pxi", line 1059, in sklearn.neighbors.ball_tree.BinaryTree.__init__ (sklearn\neighbors\ball_tree.c:8381)
# File "sklearn/neighbors/dist_metrics.pyx", line 262, in sklearn.neighbors.dist_metrics.DistanceMetric.get_metric (sklearn\neighbors\dist_metrics.c:4032)
# File "sklearn/neighbors/dist_metrics.pyx", line 1091, in sklearn.neighbors.dist_metrics.PyFuncDistance.__init__ (sklearn\neighbors\dist_metrics.c:10586)
# File "C:/Users/danrs/Documents/neuro_atlas/test_script.py", line 8, in minimum_average_direct_flip
# x = np.reshape(x, (-1, 3))
# File "C:\Anaconda2\lib\site-packages\numpy\core\fromnumeric.py", line 225, in reshape
# return reshape(newshape, order=order)
# ValueError: total size of new array must be unchanged
ball_tree.BallTree(X, metric=minimum_average_direct_flip)
Im ersten Anruf vom BallTree Code minimum_average_direct_flip, die Eingänge sind:
x = [ 0.4238394 0.55205233 0.04699435 0.19542642 0.20331665 0.44594837 0.35634537 0.8200018 0.28598294 0.34236847]
y = [ 0.4238394 0.55205233 0.04699435 0.19542642 0.20331665 0.44594837 0.35634537 0.8200018 0.28598294 0.34236847]
Diese sieht völlig falsch. Mache ich etwas falsch in der Art, wie ich das nenne oder ist das ein Fehler in Sklearn?