2016-04-20 17 views
5

Ich möchte alle Beobachtungen untersuchen, die einen Knoten in einem Rpart-Entscheidungsbaum erreicht haben. Zum Beispiel in dem folgenden Code:Abrufen der Beobachtungen in einem Rpart-Knoten (z. B. CART)

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis) 
fit 

n= 81 

node), split, n, loss, yval, (yprob) 
     * denotes terminal node 

1) root 81 17 absent (0.79.20987654) 
    2) Start>=8.5 62 6 absent (0.90322581 0.09677419) 
    4) Start>=14.5 29 0 absent (1.00000000 0.00000000) * 
    5) Start< 14.5 33 6 absent (0.81818182 0.18181818) 
     10) Age< 55 12 0 absent (1.00000000 0.00000000) * 
     11) Age>=55 21 6 absent (0.71428571 0.28571429) 
     22) Age>=111 14 2 absent (0.85714286 0.14285714) * 
     23) Age< 111 7 3 present (0.42857143 0.57142857) * 
    3) Start< 8.5 19 8 present (0.42105263 0.57894737) * 

Ich möchte alle Beobachtungen in Knoten sehen (5) (d.h .: die 33 Beobachtungen für die Startseite> = 8,5 & starten < 14,5). Offensichtlich konnte ich manuell zu ihnen kommen. Aber ich hätte gerne eine Funktion wie (zB) "get_node_date". Für die ich einfach get_node_date (5) ausführen könnte - und die relevanten Beobachtungen erhalten.

Irgendwelche Vorschläge, wie man das macht?

Antwort

1

Es scheint keine solche Funktion zu sein, die eine Extraktion der Beobachtungen von einem bestimmten Knoten ermöglicht. Ich würde es wie folgt lösen: Bestimmen Sie zuerst, welche Regel/en für den Knoten verwendet werden, für den Sie interessiert sind. Sie können path.rpart dafür verwenden. Dann könnten Sie die Regel (n) nacheinander anwenden, um die Beobachtungen zu extrahieren.

Dieser Ansatz als Funktion:

get_node_date <- function(tree = fit, node = 5){ 
    rule <- path.rpart(tree, node) 
    rule_2 <- sapply(rule[[1]][-1], function(x) strsplit(x, '(?<=[><=])(?=[^><=])|(?<=[^><=])(?=[><=])', perl = TRUE)) 
    ind <- apply(do.call(cbind, lapply(rule_2, function(x) eval(call(x[2], kyphosis[,x[1]], as.numeric(x[3]))))), 1, all) 
    kyphosis[ind,] 
    } 

Für Knoten 5 erhalten Sie:

get_node_date() 

node number: 5 
    root 
    Start>=8.5 
    Start< 14.5 
    Kyphosis Age Number Start 
2 absent 158  3 14 
10 present 59  6 12 
11 present 82  5 14 
14 absent 1  4 12 
18 absent 175  5 13 
20 absent 27  4  9 
23 present 96  3 12 
26 absent 9  5 13 
28 absent 100  3 14 
32 absent 125  2 11 
33 absent 130  5 13 
35 absent 140  5 11 
37 absent 1  3  9 
39 absent 20  6  9 
40 present 91  5 12 
42 absent 35  3 13 
46 present 139  3 10 
48 absent 131  5 13 
50 absent 177  2 14 
51 absent 68  5 10 
57 absent 2  3 13 
59 absent 51  7  9 
60 absent 102  3 13 
66 absent 17  4 10 
68 absent 159  4 13 
69 absent 18  4 11 
71 absent 158  5 14 
72 absent 127  4 12 
74 absent 206  4 10 
77 present 157  3 13 
78 absent 26  7 13 
79 absent 120  2 13 
81 absent 36  4 13 
1

rpart gibt rpart.object Element, das die Informationen enthält, die Sie brauchen:

require(rpart) 
fit2 <- rpart(Kyphosis ~ Age + Start, data = kyphosis) 
fit2 

get_node_date <-function(nodeId,fit) 
{ 
    fit$frame[toString(nodeId),"n"] 
} 


for (i in c(1,2,4,5,10,11,22,23,3)) 
    cat(get_node_date(i,fit2),"\n") 
+1

Sie haben nicht die Beobachtungen durchkommen, aber nur die Anzahl der abservations, die in eine Kategorie fallen – DatamineR

+1

, Sie haben Recht, die Frage falsch verstanden –

1

Das partykit Paket bietet auch eine vorgefertigte Lösung. Sie müssen nur das Objekt rpart in die Klasse party konvertieren, um die einheitliche Oberfläche für den Umgang mit Bäumen zu verwenden. Und dann können Sie die data_party() Funktion verwenden.

Mit dem fit von der Frage und library("partykit") geladen haben, können Sie zunächst die nötigen rpart Baum party:

pfit <- as.party(fit) 
plot(pfit) 

full pfit tree

Es gibt nur zwei kleine Belästigungen für die Daten in der Art und Weise Extrahieren Sie wollen: (1) Die model.frame() von der ursprünglichen Passform wird immer in den Zwang fallen gelassen und muss manuell wieder angebracht werden. (2) Für die Knoten wird ein anderes Nummerierungsschema verwendet. Sie wollen jetzt Knoten 4 (statt 5).

pfit$data <- model.frame(fit) 
data4 <- data_party(pfit, 4) 
dim(data4) 
## [1] 33 5 
head(data4) 
## Kyphosis Age Start (fitted) (response) 
## 2 absent 158 14  7  absent 
## 10 present 59 12  8 present 
## 11 present 82 14  8 present 
## 14 absent 1 12  5  absent 
## 18 absent 175 13  7  absent 
## 20 absent 27  9  5  absent 

Ein anderer Weg ist die Unterstruktur ausgehend von Knoten 4 und dann die Daten aus, dass die Einnahme der Teilmenge:

pfit4 <- pfit[4] 
plot(pfit4) 

subtree of pfit from node 4

Dann data_party(pfit4) gibt Ihnen die gleichen wie data4 oben. Und pfit4$data gibt Ihnen die Daten ohne den Knoten (fitted) und die vorhergesagten (response).

+0

wenn Sie 'ptree $ data verwendet <- model.frame (eval (Baum $ call $ data)) 'die Variablen, die nicht in der Formel verwendet werden, würden nicht gelöscht werden – rawr

+0

Wahr ... aber nur, wenn' data' alle Variablen in der 'formula' enthält, was nicht unbedingt der Fall ist. Mit dem 'model.frame()' erhalten Sie auch transformierte Variablen, z. B. 'log()', 'Surv()' oder 'factor()' Versionen von Variablen, die oft im laufenden Betrieb erzeugt werden. –

+0

BTW: Die 'as.party()' Zwangs für 'rpart' Objekte jetzt _heeps die Daten_ standardmäßig! So können Sie 'as.party (fit, data = TRUE)' (was der neue Standardwert ist) oder 'as.party (fit, data = FALSE)' (was dem alten Verhalten entspricht) tun. –

1

Noch eine andere Möglichkeit, das funktioniert, indem Sie alle Endknoten eines bestimmten Knotens finden und die Teilmenge der Daten, die im Aufruf verwendet werden, zurückgeben.

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis) 

head(subset.rpart(fit, 5)) 
# Kyphosis Age Number Start 
# 2 absent 158  3 14 
# 10 present 59  6 12 
# 11 present 82  5 14 
# 14 absent 1  4 12 
# 18 absent 175  5 13 
# 20 absent 27  4  9 


subset.rpart <- function(tree, node = 1L) { 
    data <- eval(tree$call$data, parent.frame(1L)) 
    wh <- sapply(as.integer(rownames(tree$frame)), parent) 
    wh <- unique(unlist(wh[sapply(wh, function(x) node %in% x)])) 
    data[rownames(tree$frame)[tree$where] %in% wh[wh >= node], ] 
} 

parent <- function(x) { 
    if (x[1] != 1) 
    c(Recall(if (x %% 2 == 0L) x/2 else (x - 1)/2), x) else x 
}