-
Notifications
You must be signed in to change notification settings - Fork 63
/
Copy pathdemo_loading_utils.py
71 lines (61 loc) · 2.47 KB
/
demo_loading_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import logging
from typing import List
import numpy as np
from rlbench.demo import Demo
def _is_stopped(demo, i, obs, stopped_buffer, delta=0.1):
next_is_not_final = i == (len(demo) - 2)
gripper_state_no_change = (
i < (len(demo) - 2) and
(obs.gripper_open == demo[i + 1].gripper_open and
obs.gripper_open == demo[i - 1].gripper_open and
demo[i - 2].gripper_open == demo[i - 1].gripper_open))
small_delta = np.allclose(obs.joint_velocities, 0, atol=delta)
stopped = (stopped_buffer <= 0 and small_delta and
(not next_is_not_final) and gripper_state_no_change)
return stopped
def keypoint_discovery(demo: Demo,
stopping_delta=0.1,
method='heuristic') -> List[int]:
episode_keypoints = []
if method == 'heuristic':
prev_gripper_open = demo[0].gripper_open
stopped_buffer = 0
for i, obs in enumerate(demo):
stopped = _is_stopped(demo, i, obs, stopped_buffer, stopping_delta)
stopped_buffer = 4 if stopped else stopped_buffer - 1
# If change in gripper, or end of episode.
last = i == (len(demo) - 1)
if i != 0 and (obs.gripper_open != prev_gripper_open or
last or stopped):
episode_keypoints.append(i)
prev_gripper_open = obs.gripper_open
if len(episode_keypoints) > 1 and (episode_keypoints[-1] - 1) == \
episode_keypoints[-2]:
episode_keypoints.pop(-2)
logging.debug('Found %d keypoints.' % len(episode_keypoints),
episode_keypoints)
return episode_keypoints
elif method == 'random':
# Randomly select keypoints.
episode_keypoints = np.random.choice(
range(len(demo)),
size=20,
replace=False)
episode_keypoints.sort()
return episode_keypoints
elif method == 'fixed_interval':
# Fixed interval.
episode_keypoints = []
segment_length = len(demo) // 20
for i in range(0, len(demo), segment_length):
episode_keypoints.append(i)
return episode_keypoints
else:
raise NotImplementedError
# find minimum difference between any two elements in list
def find_minimum_difference(lst):
minimum = lst[-1]
for i in range(1, len(lst)):
if lst[i] - lst[i - 1] < minimum:
minimum = lst[i] - lst[i - 1]
return minimum