2012-12-16 7 views
13

Gibt es eine eingebaute Funktion in scipy/numpy für das Erhalten der PMF eines Multinomial? Ich bin mir nicht sicher, ob binom in der richtigen Weise verallgemeinert wird, z.B.multinomiale PMF in Python scipy/numpy

# Attempt to define multinomial with n = 10, p = [0.1, 0.1, 0.8] 
rv = scipy.stats.binom(10, [0.1, 0.1, 0.8]) 
# Score the outcome 4, 4, 2 
rv.pmf([4, 4, 2]) 

Was ist der richtige Weg? Vielen Dank.

Antwort

9

Es gibt keine eingebaute Funktion, die ich kenne, und die Binomialwahrscheinlichkeiten verallgemeinern nicht (Sie müssen über eine andere Menge von möglichen Ergebnissen normalisieren, da die Summe aller Zählungen n sein muss, was nicht sein wird wird von unabhängigen Binomen behandelt). Allerdings ist es ziemlich einfach, sich zu implementieren, zum Beispiel:

import math 

class Multinomial(object): 
    def __init__(self, params): 
    self._params = params 

    def pmf(self, counts): 
    if not(len(counts)==len(self._params)): 
     raise ValueError("Dimensionality of count vector is incorrect") 

    prob = 1. 
    for i,c in enumerate(counts): 
     prob *= self._params[i]**counts[i] 

    return prob * math.exp(self._log_multinomial_coeff(counts)) 

    def log_pmf(self,counts): 
    if not(len(counts)==len(self._params)): 
     raise ValueError("Dimensionality of count vector is incorrect") 

    prob = 0. 
    for i,c in enumerate(counts): 
     prob += counts[i]*math.log(self._params[i]) 

    return prob + self._log_multinomial_coeff(counts) 

    def _log_multinomial_coeff(self, counts): 
    return self._log_factorial(sum(counts)) - sum(self._log_factorial(c) 
                for c in counts) 

    def _log_factorial(self, num): 
    if not round(num)==num and num > 0: 
     raise ValueError("Can only compute the factorial of positive ints") 
    return sum(math.log(n) for n in range(1,num+1)) 

m = Multinomial([0.1, 0.1, 0.8]) 
print m.pmf([4,4,2]) 

>>2.016e-05 

Meine Umsetzung des Multinomialkoeffizient etwas naiv ist, und arbeitet in Log-Speicherplatz Überlauf zu verhindern. Beachten Sie auch, dass n als Parameter überflüssig ist, da es durch die Summe der Zählwerte gegeben ist (und der gleiche Parametersatz funktioniert für jedes n). Da dies bei mittleren oder großen Dimensionen schnell unterlaufen wird, arbeiten Sie besser im Protokollbereich (logPMF wird auch hier bereitgestellt!)