2015-10-10 3 views
10

Plotten 2 distplots oder Scatterplots in einem subplot funktioniert super:Wie man 2 seaborn lmplots Seite-an-Seite plant?

import matplotlib.pyplot as plt 
import numpy as np 
import seaborn as sns 
import pandas as pd 
%matplotlib inline 

# create df 
x = np.linspace(0, 2 * np.pi, 400) 
df = pd.DataFrame({'x': x, 'y': np.sin(x ** 2)}) 

# Two subplots 
f, (ax1, ax2) = plt.subplots(1, 2, sharey=True) 
ax1.plot(df.x, df.y) 
ax1.set_title('Sharing Y axis') 
ax2.scatter(df.x, df.y) 

plt.show() 

Subplot example

Aber wenn ich mit einem lmplot das gleiche zu tun, anstatt eine der beiden anderen Arten von Diagrammen bekomme ich einen Fehler:

Gibt es eine Möglichkeit, diese Diagrammtypen nebeneinander darzustellen?

+0

BTW: Ihr Beispiel nicht ausgeführt. Die Variable "x" ist in der Definition der Spalte "y" nicht definiert. –

+0

Danke, dass Sie @PaulH bemerkt haben. Korrigiert. – samthebrand

Antwort

24

Sie erhalten diesen Fehler, weil Matplotlib und seine Objekte Seaborn-Funktionen vollständig nicht kennen.

Ihre Achsen Objekte Pass (dh ax1 und ax2) zu seaborn.regplot oder Sie können wie folgt seaborn.lmplot

Mit Ihrer gleichen Importe, Pre-Definition Ihrer Achsen und mit regplot sieht überspringen definieren diese und verwenden Sie die col kwarg von :

# create df 
x = np.linspace(0, 2 * np.pi, 400) 
df = pd.DataFrame({'x': x, 'y': np.sin(x ** 2)}) 
df.index.names = ['obs'] 
df.columns.names = ['vars'] 

idx = np.array(df.index.tolist(), dtype='float') # make an array of x-values 

# call regplot on each axes 
fig, (ax1, ax2) = plt.subplots(ncols=2, sharey=True) 
sns.regplot(x=idx, y=df['x'], ax=ax1) 
sns.regplot(x=idx, y=df['y'], ax=ax2) 

enter image description here

lmplot Verwendung erfordert Ihre dataframe to be tidy. Fortsetzung aus dem Code oben:

tidy = (
    df.stack() # pull the columns into row variables 
     .to_frame() # convert the resulting Series to a DataFrame 
     .reset_index() # pull the resulting MultiIndex into the columns 
     .rename(columns={0: 'val'}) # rename the unnamed column 
) 
sns.lmplot(x='obs', y='val', col='vars', hue='vars', data=tidy) 

enter image description here