2016-05-29 7 views
0

Ich habe einen Kernel für meine benutzerdefinierte Op implementiert, und steckte es in /tensorflow/core/user_ops als custom_op.cc. Innerhalb des Ops mache ich alle registrierenden Sachen, wie REGISTER_OP und .Wie kann man benutzerdefinierte Op in TensorFlow in Python importieren?

Dann habe ich Gradienten für diese Op in Python implementiert, und ich legte es in den gleichen Ordner wie custom_op_grad.py. Ich habe auch die Registrierung hier gemacht (@ops.RegisterGradient).

ich die BUILD-Datei erstellt haben, mit folgendem Inhalt:

load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") 
tf_custom_op_library(
     name = "custom_op.so", 
     srcs = ["custom_op.cc"], 
) 

py_library(
     name = "custom_op_grad", 
     srcs = ["custom_op_grad.py"], 
     srcs_version = "PY2", 
     deps = [ 
     ":custom_op_grad", 
     "//tensorflow:tensorflow_py", 
     ], 
) 

Danach habe ich Tensorflow wieder aufbauen:

pip uninstall tensorflow 
bazel clean 
bazel build -c opt //tensorflow/tools/pip_package:build_pip_package 
cp -r bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/__main__/* bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/ 
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg 
pip install /tmp/tensorflow_pkg/tensorflow-0.8.0-py2-none-any.whl 

Wenn ich versuche, meine Op nach all dies zu nutzen, durch den Aufruf tf.user_ops.custom_op es sagt mir, dass Modul es nicht hat.

Vielleicht gibt es einige zusätzliche Schritte, die ich tun muss? Oder mache ich etwas falsch mit der Datei BUILD?

Antwort

0

Ok, ich habe die Lösung gefunden. Ich habe nur die BUILD Datei entfernt, und meine benutzerdefinierte Op wurde erfolgreich erstellt und konnte in Python mit tensorflow.user_ops.custom_op() importiert werden.

Um den Gradienten zu verwenden, musste ich seinen Code direkt in die tensorflow/python/user_ops/user_ops.py eingeben. Nicht die eleganteste Lösung, aber für jetzt arbeiten.