Skip to content

Commit 6604567

Browse files
committed
har_trees: Generalize dataset handling during training a bit
1 parent b7fff87 commit 6604567

File tree

1 file changed

+28
-9
lines changed

1 file changed

+28
-9
lines changed

examples/har_trees/har_train.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def extract_windows(sensordata : pandas.DataFrame,
6565
window_length : int,
6666
window_hop : int,
6767
groupby : list[str],
68+
time_column = 'time',
6869
):
6970

7071
groups = sensordata.groupby(groupby, observed=True)
@@ -75,7 +76,7 @@ def extract_windows(sensordata : pandas.DataFrame,
7576
windows = []
7677

7778
# make sure order is correct
78-
group_df = group_df.reset_index().set_index('time').sort_index()
79+
group_df = group_df.reset_index().set_index(time_column).sort_index()
7980

8081
# create windows
8182
win_start = 0
@@ -167,6 +168,7 @@ def extract_features(sensordata : pandas.DataFrame,
167168
quant_div = 4,
168169
quant_depth = 6,
169170
label_column='activity',
171+
time_column='time',
170172
) -> pandas.DataFrame:
171173
"""
172174
Convert sensor data into fixed-sized time windows and extact features
@@ -181,7 +183,7 @@ def extract_features(sensordata : pandas.DataFrame,
181183

182184
# Split into fixed-length windows
183185
features_values = []
184-
generator = extract_windows(sensordata, window_length, window_hop, groupby=groupby)
186+
generator = extract_windows(sensordata, window_length, window_hop, groupby=groupby, time_column=time_column)
185187
for windows in generator:
186188

187189
# drop invalid data
@@ -262,8 +264,21 @@ def run_pipeline(run, hyperparameters, dataset,
262264
'squat', 'jumpingjack', 'lunge', 'other',
263265
],
264266
),
267+
'toothbrush_hussain2021': dict(
268+
groups=['subject'],
269+
label_column = 'is_brushing',
270+
time_column = 'elapsed',
271+
data_columns = ['acc_x', 'acc_y', 'acc_z'],
272+
classes = [
273+
#'mixed',
274+
'True', 'False',
275+
],
276+
),
265277
}
266278

279+
if not dataset in dataset_config.keys():
280+
raise ValueError(f"Unknown dataset {dataset}")
281+
267282
if not os.path.exists(out_dir):
268283
os.makedirs(out_dir)
269284

@@ -278,21 +293,25 @@ def run_pipeline(run, hyperparameters, dataset,
278293
groups = dataset_config[dataset]['groups']
279294
data_columns = dataset_config[dataset]['data_columns']
280295
enabled_classes = dataset_config[dataset]['classes']
296+
label_column = dataset_config[dataset].get('label_column', 'activity')
297+
time_column = dataset_config[dataset].get('time_column', 'time')
298+
299+
data[label_column] = data[label_column].astype(str)
281300

282301
data_load_duration = time.time() - data_load_start
283302
log.info('data-loaded', dataset=dataset, samples=len(data), duration=data_load_duration)
284303

285-
286-
287304
feature_extraction_start = time.time()
288305
features = extract_features(data,
289306
columns=data_columns,
290307
groupby=groups,
291308
features=features,
292309
window_length=model_settings['window_length'],
293310
window_hop=model_settings['window_hop'],
311+
label_column=label_column,
312+
time_column=time_column,
294313
)
295-
labeled = numpy.count_nonzero(features['activity'].notna())
314+
labeled = numpy.count_nonzero(features[label_column].notna())
296315

297316
feature_extraction_duration = time.time() - feature_extraction_start
298317
log.info('feature-extraction-done',
@@ -303,19 +322,20 @@ def run_pipeline(run, hyperparameters, dataset,
303322
)
304323

305324
# Drop windows without labels
306-
features = features[features.activity.notna()]
325+
features = features[features[label_column].notna()]
307326

308327
# Keep only windows with enabled classes
309-
features = features[features.activity.isin(enabled_classes)]
328+
features = features[features[label_column].isin(enabled_classes)]
310329

311-
print('Class distribution\n', features['activity'].value_counts(dropna=False))
330+
print('Class distribution\n', features[label_column].value_counts(dropna=False))
312331

313332
# Run train-evaluate
314333
evaluate_groupby = groups[0]
315334
results, estimator = evaluate(features,
316335
hyperparameters=hyperparameters,
317336
groupby=evaluate_groupby,
318337
n_splits=n_splits,
338+
label_column=label_column,
319339
)
320340

321341
# Save a model
@@ -328,7 +348,6 @@ def run_pipeline(run, hyperparameters, dataset,
328348
export_model(estimator_path, model_path)
329349

330350
# Save testdata
331-
label_column = 'activity'
332351
classes = estimator.classes_
333352
class_mapping = dict(zip(classes, range(len(classes))))
334353
meta_path = os.path.join(out_dir, f'{dataset}.meta.json')

0 commit comments

Comments
 (0)