2016-08-08 51 views
3

Ich versuche, gewichtetes Mittel in pyspark zu berechnen, aber nicht viele FortschritteBerechnung gewichteten Mittelwert in PySpark

# Example data 
df = sc.parallelize([ 
    ("a", 7, 1), ("a", 5, 2), ("a", 4, 3), 
    ("b", 2, 2), ("b", 5, 4), ("c", 1, -1) 
]).toDF(["k", "v1", "v2"]) 
df.show() 

import numpy as np 
def weighted_mean(workclass, final_weight): 
    return np.average(workclass, weights=final_weight) 

weighted_mean_udaf = pyspark.sql.functions.udf(weighted_mean, 
    pyspark.sql.types.IntegerType()) 

machen, aber wenn ich versuche, diesen Code

df.groupby('k').agg(weighted_mean_udaf(df.v1,df.v2)).show() 

Ich erhalte die auszuführen Fehler

Meine Frage ist, kann ich eine benutzerdefinierte Funktion (mit mehreren Argumenten) als Argument angeben zu agg? Wenn nicht, gibt es eine Alternative, um Operationen wie das gewichtete Mittel nach der Gruppierung mit einem Schlüssel durchzuführen?

+1

Meinen Sie die 'weighted_mean' Funktion außer Kraft zu setzen? –

+0

Was ich tun möchte, ist a) groupby b) führen Sie eine Operation abhängig von mehreren Spalten des Datenrahmens. Der gewichtete Mittelwert ist nur ein Beispiel. – MARK

+0

Ich denke, was @ cricket_007 bedeutete ist, Sie absichtlich überschreiben 'weighted_mean' durch diese Zeile' weighted_mean = pyspark.sql.functions.udf (weighted_mean, 'oder es ist ein Tippfehler? – akarilimano

Antwort

3

Benutzerdefinierte Aggregationsfunktion (UDAF, die auf pyspark.sql.GroupedData funktioniert, aber nicht in PYSPARK unterstützt) ist keine benutzerdefinierte Funktion (UDF, die auf pyspark.sql.DataFrame funktioniert).

Da in pyspark können Sie nicht Ihre eigene UDAF erstellen, und die gelieferten UDAFs kann das Problem nicht lösen, können Sie zu RDD Welt gehen müssen zurück:

from numpy import sum 

def weighted_mean(vals): 
    vals = list(vals) # save the values from the iterator 
    sum_of_weights = sum(tup[1] for tup in vals) 
    return sum(1. * tup[0] * tup[1]/sum_of_weights for tup in vals) 

df.map(
    lambda x: (x[0], tuple(x[1:])) # reshape to (key, val) so grouping could work 
).groupByKey().mapValues(
    weighted_mean 
).collect()