2016-08-07 48 views
0

Ich versuche, eine benutzerdefinierte Metrik mit sklearn.neighbors.BallTree zu verwenden, aber wenn sie meine Metrik aufruft, sehen die Eingaben nicht korrekt aus. Wenn ich scipy.spatial.distance.pdist mit derselben benutzerdefinierten Metrik verwende, funktioniert es wie erwartet. Wenn ich versuche, einen BallTree instanziieren, wird eine Ausnahme ausgelöst, wenn ich versuche, die Eingabe umzuformen. Wenn ich mir die tatsächlichen Eingaben ansehe, sehen die Form und die Werte nicht korrekt aus.Verwenden einer benutzerdefinierten Metrik mit sklearn.neighbors.BallTree gibt falsche Eingabe?

import numpy as np 
import scipy.spatial.distance as spdist 
import sklearn.neighbors.ball_tree as ball_tree 


# custom metric 
def minimum_average_direct_flip(x, y): 
    x = np.reshape(x, (-1, 3)) 
    y = np.reshape(y, (-1, 3)) 
    direct = np.mean(np.sqrt(np.sum(np.square(x - y), axis=1))) 
    flipped = np.mean(np.sqrt(np.sum(np.square(np.flipud(x) - y), axis=1))) 
    return min(direct, flipped) 

# create an X to test 
X = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9], [11, 12, 13, 14, 15, 16, 17, 18, 19], [21, 22, 23, 24, 25, 26, 27, 28, 29]]) 

# works as expected 
distances = spdist.pdist(X, metric=minimum_average_direct_flip) 

# outputs: [ 17.32050808 34.64101615 17.32050808] 
print distances 

# raises exception, inputs to minimum_average_direct_flip look wrong 
# Traceback (most recent call last): 
# File ".../test_script.py", line 23, in <module> 
#  ball_tree.BallTree(X, metric=minimum_average_direct_flip) 
# File "sklearn/neighbors/binary_tree.pxi", line 1059, in sklearn.neighbors.ball_tree.BinaryTree.__init__ (sklearn\neighbors\ball_tree.c:8381) 
# File "sklearn/neighbors/dist_metrics.pyx", line 262, in sklearn.neighbors.dist_metrics.DistanceMetric.get_metric (sklearn\neighbors\dist_metrics.c:4032) 
# File "sklearn/neighbors/dist_metrics.pyx", line 1091, in sklearn.neighbors.dist_metrics.PyFuncDistance.__init__ (sklearn\neighbors\dist_metrics.c:10586) 
# File "C:/Users/danrs/Documents/neuro_atlas/test_script.py", line 8, in minimum_average_direct_flip 
#  x = np.reshape(x, (-1, 3)) 
# File "C:\Anaconda2\lib\site-packages\numpy\core\fromnumeric.py", line 225, in reshape 
#  return reshape(newshape, order=order) 
# ValueError: total size of new array must be unchanged 
ball_tree.BallTree(X, metric=minimum_average_direct_flip) 

Im ersten Anruf vom BallTree Code minimum_average_direct_flip, die Eingänge sind:

x = [ 0.4238394 0.55205233 0.04699435 0.19542642 0.20331665 0.44594837 0.35634537 0.8200018 0.28598294 0.34236847] 
y = [ 0.4238394 0.55205233 0.04699435 0.19542642 0.20331665 0.44594837 0.35634537 0.8200018 0.28598294 0.34236847] 

Diese sieht völlig falsch. Mache ich etwas falsch in der Art, wie ich das nenne oder ist das ein Fehler in Sklearn?

Antwort

0

Es scheint, dass dies ein bekanntes Problem ist: https://github.com/scikit-learn/scikit-learn/issues/6287

Sie irgendeine Art von Validierungsschritt tun, die problematisch ist. Als Problemumgehung kann ich eine Überprüfung der Eingabegröße hinzufügen, aber das Problem stellt fest, dass dies unerwünscht ist, da ich selbst keine eigentlichen Validierungsprüfungen durchführen kann.