Skip to content

Commit 2b6e162

Browse files
authored
KeyError bugfix for matplotlib bar chart (#391)
* matplotlib bar chart bugfix * test fix
1 parent 29e0c70 commit 2b6e162

File tree

4 files changed

+18
-20
lines changed

4 files changed

+18
-20
lines changed

lux/vislib/matplotlib/BarChart.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -79,17 +79,16 @@ def initialize_chart(self):
7979
)
8080

8181
df = self.data
82-
83-
bars = df[bar_attr].apply(lambda x: str(x))
84-
measurements = df[measure_attr]
82+
bar = df[bar_attr].apply(lambda x: str(x))
83+
bars = list(bar)
84+
measurements = list(df[measure_attr])
8585

8686
plot_code = ""
8787

8888
color_attr = self.vis.get_attr_by_channel("color")
8989
if len(color_attr) == 1:
9090
self.fig, self.ax = matplotlib_setup(6, 4)
9191
color_attr_name = color_attr[0].attribute
92-
color_attr_type = color_attr[0].data_type
9392
colors = df[color_attr_name].values
9493
unique = list(set(colors))
9594
d_x = {}
@@ -101,22 +100,22 @@ def initialize_chart(self):
101100
d_x[colors[i]].append(bars[i])
102101
d_y[colors[i]].append(measurements[i])
103102
for i in range(len(unique)):
104-
self.ax.barh(d_x[unique[i]], d_y[unique[i]], label=unique[i])
105-
plot_code += (
106-
f"ax.barh({d_x}[{unique}[{i}]], {d_y}[{unique}[{i}]], label={unique}[{i}])\n"
107-
)
103+
xval = d_x[unique[i]]
104+
yval = d_y[unique[i]]
105+
l = unique[i]
106+
self.ax.barh(xval, yval, label=l)
107+
plot_code += f"ax.barh({xval},{yval}, label='{l}')\n"
108108
self.ax.legend(
109109
title=color_attr_name, bbox_to_anchor=(1.05, 1), loc="upper left", ncol=1, frameon=False
110110
)
111-
plot_code += f"""ax.legend(
112-
title='{color_attr_name}',
113-
bbox_to_anchor=(1.05, 1),
114-
loc='upper left',
115-
ncol=1,
116-
frameon=False,)\n"""
111+
plot_code += f"""ax.legend(title='{color_attr_name}',
112+
bbox_to_anchor=(1.05, 1),
113+
loc='upper left',
114+
ncol=1,
115+
frameon=False)\n"""
117116
else:
118-
self.ax.barh(bars, measurements, align="center")
119-
plot_code += f"ax.barh(bars, measurements, align='center')\n"
117+
self.ax.barh(bar, df[measure_attr], align="center")
118+
plot_code += f"ax.barh({bar}, {df[measure_attr]}, align='center')\n"
120119

121120
y_ticks_abbev = df[bar_attr].apply(lambda x: str(x)[:10] + "..." if len(str(x)) > 10 else str(x))
122121
self.ax.set_yticks(bars)
@@ -128,7 +127,7 @@ def initialize_chart(self):
128127

129128
self.code += "import numpy as np\n"
130129
self.code += "from math import nan\n"
131-
130+
self.code += f"df = pd.DataFrame({str(self.data.to_dict())})\n"
132131
self.code += f"fig, ax = plt.subplots()\n"
133132
self.code += f"bars = df['{bar_attr}']\n"
134133
self.code += f"measurements = df['{measure_attr}']\n"

lux/vislib/matplotlib/LineChart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def initialize_chart(self):
110110

111111
self.code += "import numpy as np\n"
112112
self.code += "from math import nan\n"
113-
113+
self.code += f"df = pd.DataFrame({str(self.data.to_dict())})\n"
114114
self.code += f"fig, ax = plt.subplots()\n"
115115
self.code += f"x_pts = df['{x_attr.attribute}']\n"
116116
self.code += f"y_pts = df['{y_attr.attribute}']\n"

lux/vislib/matplotlib/ScatterChart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def initialize_chart(self):
140140
self.code += "import numpy as np\n"
141141
self.code += "from math import nan\n"
142142
self.code += "from matplotlib.cm import ScalarMappable\n"
143-
143+
self.code += f"df = pd.DataFrame({str(self.data.to_dict())})\n"
144144
self.code += set_fig_code
145145
self.code += f"x_pts = df['{x_attr.attribute}']\n"
146146
self.code += f"y_pts = df['{y_attr.attribute}']\n"

tests/test_vis.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,6 @@ def test_bar_chart(global_var):
218218
lux.config.plotting_backend = "matplotlib"
219219
vis = Vis(["Origin", "Acceleration"], df)
220220
vis_code = vis.to_matplotlib()
221-
assert "ax.barh(bars, measurements, align='center')" in vis_code
222221
assert "ax.set_xlabel('Acceleration')" in vis_code
223222
assert "ax.set_ylabel('Origin')" in vis_code
224223

0 commit comments

Comments
 (0)