Skip to content

Commit d5fdc27

Browse files
committed
fix more warnings
1 parent 5273678 commit d5fdc27

File tree

4 files changed

+23
-18
lines changed

4 files changed

+23
-18
lines changed

python/dalex/dalex/fairness/_group_fairness/plot.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -328,20 +328,20 @@ def plot_metric_scores(fobject,
328328
for metric in data.metric.unique():
329329
for label in data.label.unique():
330330
x = float(privileged_data.loc[(privileged_data.metric == metric) &
331-
(privileged_data.label == label), :].score)
331+
(privileged_data.label == label), :].score.iloc[0])
332332
if np.isnan(x):
333333
lines_nan = True
334334
continue
335335
# lines
336336
for subgroup in data.subgroup.unique():
337337
y = float(data.loc[(data.metric == metric) &
338338
(data.label == label) &
339-
(data.subgroup == subgroup)].subgroup_numeric)
339+
(data.subgroup == subgroup)].subgroup_numeric.iloc[0])
340340
# horizontal
341341

342342
x0 = float(data.loc[(data.metric == metric) &
343343
(data.label == label) &
344-
(data.subgroup == subgroup)].score)
344+
(data.subgroup == subgroup)].score.iloc[0])
345345

346346
if np.isnan(x0):
347347
lines_nan = True
@@ -691,7 +691,7 @@ def plot_ceteris_paribus_cutoff(fobject,
691691
def plot_density(fobject,
692692
other_objects,
693693
title):
694-
data = pd.DataFrame(columns=['y', 'y_hat', 'subgroup', 'model'])
694+
data_list = []
695695
objects = [fobject]
696696
if other_objects is not None:
697697
for other_obj in other_objects:
@@ -703,7 +703,9 @@ def plot_density(fobject,
703703
'y_hat': y_hat,
704704
'subgroup': np.repeat(subgroup, len(y)),
705705
'model': np.repeat(obj.label, len(y))})
706-
data = pd.concat([data, data_to_append])
706+
data_list.append(data_to_append)
707+
708+
data = pd.concat(data_list)
707709

708710
fig = go.Figure()
709711

python/dalex/dalex/fairness/_group_fairness/utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,14 @@ def __init__(self, sub_confusion_matrix):
9696

9797
def to_vertical_DataFrame(self) -> pd.DataFrame:
9898

99-
columns = ['metric', 'subgroup', 'score']
100-
data = pd.DataFrame(columns=columns)
99+
df_list = []
101100
metrics = self.subgroup_confusion_matrix_metrics
102101
for subgroup in metrics.keys():
103102
metric = metrics.get(subgroup)
104103
subgroup_vec = np.repeat(subgroup, len(metric))
105104
sub_df = pd.DataFrame({'metric': metric.keys(), 'subgroup': subgroup_vec, 'score': metric.values()})
106-
data = pd.concat([data, sub_df])
105+
df_list.append(sub_df)
106+
data = pd.concat(df_list, ignore_index=True)
107107
return data
108108

109109
def to_horizontal_DataFrame(self) -> pd.DataFrame:
@@ -286,8 +286,7 @@ def calculate_regression_measures(y, y_hat, protected, privileged):
286286
unique_protected = np.unique(protected)
287287
unique_unprivileged = unique_protected[unique_protected != privileged]
288288

289-
data = pd.DataFrame(columns=['subgroup', 'independence', 'separation', 'sufficiency'])
290-
289+
data_list = []
291290
for unprivileged in unique_unprivileged:
292291
# filter elements
293292
array_elements = np.isin(protected, [privileged, unprivileged])
@@ -319,8 +318,12 @@ def calculate_regression_measures(y, y_hat, protected, privileged):
319318
'independence': [r_ind],
320319
'separation': [r_sep],
321320
'sufficiency': [r_suf]})
321+
data_list.append(to_append)
322322

323-
data = pd.concat([data, to_append])
323+
if data_list:
324+
data = pd.concat(data_list, ignore_index=True)
325+
else:
326+
data = pd.DataFrame(columns=['subgroup', 'independence', 'separation', 'sufficiency'])
324327

325328
# append the scale
326329
to_append = pd.DataFrame({'subgroup': [privileged],

python/dalex/dalex/model_explanations/_variable_importance/checks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def check_variable_groups(variable_groups, explainer):
3737
if not isinstance(variable_groups[key][0], str):
3838
raise TypeError("variable_groups' is a dict of lists of variables")
3939

40-
wrong_names[i] = np.in1d(variable_groups[key], explainer.data.columns).all()
40+
wrong_names[i] = np.isin(variable_groups[key], explainer.data.columns).all()
4141

4242
wrong_names = not wrong_names.all()
4343

python/dalex/test/test_arena_classification.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def setUp(self):
6060
FairnessCheckContainer, ShapleyValuesDependenceContainer, ShapleyValuesVariableImportanceContainer,
6161
VariableAgainstAnotherContainer, VariableDistributionContainer]
6262

63-
@unittest.skipIf(sys.platform.startswith("win"), "requires Windows")
63+
6464
def test_supported_plots(self):
6565
arena = dx.Arena()
6666
arena.push_model(self.exp)
@@ -74,15 +74,15 @@ def test_supported_plots(self):
7474
except Exception:
7575
pass
7676

77-
@unittest.skipIf(sys.platform.startswith("win"), "requires Windows")
77+
@unittest.skipUnless(sys.platform.startswith("ubuntu"), "requires Ubuntu")
7878
def test_server(self):
7979
arena = dx.Arena()
8080
arena.push_model(self.exp)
8181
arena.push_model(self.exp2)
8282
port = get_free_port()
8383
try:
8484
arena.run_server(port=port)
85-
time.sleep(2)
85+
time.sleep(10)
8686
self.assertFalse(try_port(port))
8787
arena.stop_server()
8888
except AssertionError as e:
@@ -93,7 +93,7 @@ def test_server(self):
9393
except Exception:
9494
pass
9595

96-
@unittest.skipIf(sys.platform.startswith("win"), "requires Windows")
96+
@unittest.skipUnless(sys.platform.startswith("ubuntu"), "requires Ubuntu")
9797
def test_plots(self):
9898
arena = dx.Arena()
9999
arena.push_model(self.exp)
@@ -110,7 +110,7 @@ def test_plots(self):
110110
except Exception:
111111
pass
112112

113-
@unittest.skipIf(sys.platform.startswith("win"), "requires Windows")
113+
@unittest.skipUnless(sys.platform.startswith("ubuntu"), "requires Ubuntu")
114114
def test_observation_attributes(self):
115115
arena = dx.Arena()
116116
arena.push_model(self.exp)
@@ -128,7 +128,7 @@ def test_observation_attributes(self):
128128
except Exception:
129129
pass
130130

131-
@unittest.skipIf(sys.platform.startswith("win"), "requires Windows")
131+
@unittest.skipUnless(sys.platform.startswith("ubuntu"), "requires Ubuntu")
132132
def test_variable_attributes(self):
133133
arena = dx.Arena()
134134
arena.push_model(self.exp)

0 commit comments

Comments
 (0)