2016-02-02 5 views

Antwort

9

Es gibt drei relevante ops für die Umsetzung Theano des dimshuffle in TensorFlow:

  • tf.transpose() verwendet wird, um die Abmessungen eines Tensor permutieren. Wenn das in den Argumenten zu dimshuffle angegebene Muster eine Permutation der Abmessungen des Eingangstensors ist (d. H. Es gibt keine 'x' oder fehlende Dimension), können Sie verwenden, um dimshuffle() zu implementieren.

  • tf.expand_dims() wird verwendet, um einem Tensor eine oder mehrere Dimensionen der Größe 1 hinzuzufügen. Dies behandelt den Fall, in dem 'x' als Teil des dimshuffle()-Musters angegeben ist, aber die vorhandenen Dimensionen nicht neu angeordnet werden.

  • tf.squeeze() wird verwendet, um eine oder mehrere Dimensionen der Größe 1 aus einem Tensor zu entfernen. Dies behandelt den Fall, bei dem eine Dimension in einem dimshuffle() Muster weggelassen wird, aber die vorhandenen Dimensionen werden nicht neu angeordnet.

Unter der Annahme, dass die Eingabe ein Vektor ist, Ihr Beispiel (dimshuffle(0, 'x')) kann mit tf.expand_dims() nur ausgedrückt werden:

input = tf.placeholder(tf.float32, [None]) # Defines an arbitrary-sized vector. 
result = tf.expand_dims(input, 1) 

print result.get_shape() # ==> TensorShape([Dimension(None), Dimension(1)]) 

ein komplizierteres Beispiel nehmen, dimshuffle(1, 'x', 0) auf eine Matrix aufgebracht wäre:

input = tf.placeholder(tf.float32, [128, 32]) # Defines a matrix. 
output = tf.expand_dims(tf.transpose(input, [1, 0]), 1) 

print output.get_shape() 
# ==> TensorShape([Dimension(32), Dimension(1), Dimension(128)]) 
0

I umgesetzt dimshuffle für TensorFlow in our framework Returnn (here). Der Code ist dies:

def expand_multiple_dims(x, axes, name="expand_multiple_dims"): 
    """ 
    :param tf.Tensor x: 
    :param list[int]|tuple[int] axes: after completion, tf.shape(y)[axis] == 1 for axis in axes 
    :param str name: scope name 
    :return: y where we have a new broadcast axis for each axis in axes 
    :rtype: tf.Tensor 
    """ 
    with tf.name_scope(name): 
    for i in sorted(axes): 
     x = tf.expand_dims(x, axis=i, name="expand_axis_%i" % i) 
    return x 


def dimshuffle(x, axes, name="dimshuffle"): 
    """ 
    Like Theanos dimshuffle. 
    Combines tf.transpose, tf.expand_dims and tf.squeeze. 

    :param tf.Tensor x: 
    :param list[int|str]|tuple[int|str] axes: 
    :param str name: scope name 
    :rtype: tf.Tensor 
    """ 
    with tf.name_scope(name): 
    assert all([i == "x" or isinstance(i, int) for i in axes]) 
    real_axes = [i for i in axes if isinstance(i, int)] 
    bc_axes = [i for (i, j) in enumerate(axes) if j == "x"] 
    if x.get_shape().ndims is None: 
     x_shape = tf.shape(x) 
     x = tf.reshape(x, [x_shape[i] for i in range(max(real_axes) + 1)]) # will have static ndims 
    assert x.get_shape().ndims is not None 

    # First squeeze missing axes. 
    i = 0 
    while i < x.get_shape().ndims: 
     if i not in real_axes: 
     x = tf.squeeze(x, axis=i) 
     real_axes = [(j if (j < i) else (j - 1)) for j in real_axes] 
     else: 
     i += 1 

    # Now permute. 
    assert list(sorted(real_axes)) == list(range(x.get_shape().ndims)) 
    if real_axes != list(range(x.get_shape().ndims)): 
     x = tf.transpose(x, real_axes) 

    # Now add broadcast dimensions. 
    if bc_axes: 
     x = expand_multiple_dims(x, bc_axes) 
    assert len(axes) == x.get_shape().ndims 
    return x 
0

Wenn tensorflow Backend ist

from keras import baskend as K 
K.permute_dimension should do