2016-04-13 5 views
3

Ich wechselte vor kurzem Form Tensorflow zu Skflow. Im Tensorflow würden wir unserem Verlust Lambda * tf.nn.l2_loss (Gewichte) hinzufügen. Jetzt habe ich den folgenden Code in skflow:Hinzufügen von Regularizer zu Skflow

def deep_psi(X, y): 
    layers = skflow.ops.dnn(X, [5, 10, 20, 10, 5], keep_prob=0.5) 
    preds, loss = skflow.models.logistic_regression(layers, y) 
    return preds, loss 

def exp_decay(global_step): 
    return tf.train.exponential_decay(learning_rate=0.01, 
             global_step=global_step, 
             decay_steps=1000, 
             decay_rate=0.005) 

deep_cd = skflow.TensorFlowEstimator(model_fn=deep_psi, 
            n_classes=2, 
            steps=10000, 
            batch_size=10, 
            learning_rate=exp_decay, 
            verbose=True,) 

Wie und wo füge ich einen Regularizer hier? Illia deutet etwas an here, aber ich konnte es nicht herausfinden.

Antwort

3

Sie können noch zusätzliche Komponenten hinzufügen Verlust, müssen Sie nur Gewichte von dnn/logistic_regression abzurufen und sie auf den Verlust hinzu:

def regularize_loss(loss, weights, lambda): 
    for weight in weights: 
     loss = loss + lambda * tf.nn.l2_loss(weight) 
    return loss  


def deep_psi(X, y): 
    layers = skflow.ops.dnn(X, [5, 10, 20, 10, 5], keep_prob=0.5) 
    preds, loss = skflow.models.logistic_regression(layers, y) 

    weights = [] 
    for layer in range(5): # n layers you passed to dnn 
     weights.append(tf.get_variable("dnn/layer%d/linear/Matrix" % layer)) 
     # biases are also available at dnn/layer%d/linear/Bias 
    weights.append(tf.get_variable('logistic_regression/weights')) 

    return preds, regularize_loss(loss, weights, lambda) 

`` `

Hinweis, den Pfad zu Variablen kann found here sein.

Auch wollen wir Regularisator Unterstützung für alle Schichten mit Variablen (wie dnn, conv2d oder fully_connected), so kann nächste Woche sein Nacht-Build von Tensorflow sollte so etwas wie dieses dnn(.., regularize=tf.contrib.layers.l2_regularizer(lambda)) haben hinzuzufügen. Ich werde diese Antwort aktualisieren, wenn dies passiert.

+0

dnn (.., regularize = tf.contrib.layers.l2_regularizer (Lambda)) wird Liebe sein. – plumSemPy