Skip to content

feature(pu): add rope that use the true timestep index as pos_index #266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Mar 29, 2025

Conversation

puyuan1996
Copy link
Collaborator

@puyuan1996 puyuan1996 commented Aug 15, 2024

  • add rope that use the true timestep index as pos_index in unizero's transformer

  • Current performance on Pong:
    image

  • Current performance on cartpole-swingup:
    image

@puyuan1996 puyuan1996 added the enhancement New feature or request label Feb 13, 2025
@PaParaZz1 PaParaZz1 added the algorithm New algorithm label Mar 27, 2025
return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps
return_result = obs_act_embeddings
if not self.config.rotary_emb:
return_result += self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need to add this if branch here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The polishing of pos_emb will be addressed in a future PR.

@puyuan1996 puyuan1996 merged commit d9957ef into main Mar 29, 2025
3 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
algorithm New algorithm enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants