Skip to content

Conversation

@benbennett
Copy link
Owner

Summary

  • enable optional automatic mixed precision in the ATME model using GradScaler to speed up training
  • add a command-line flag to toggle AMP support for ATME runs

Testing

  • python -m compileall models/atme_model.py options/atme_options.py

https://chatgpt.com/codex/tasks/task_e_68d35ae1eeec83319cf61e547d8fb9ae

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Codex Review: Here are some suggestions.

Reply with @codex fix comments to fix any unresolved comments.

About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you open a pull request for review, mark a draft as ready, or comment "@codex review". If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex fix this CI failure" or "@codex address that feedback".

Comment on lines +105 to +106
self.use_amp = bool(getattr(opt, 'use_amp', False)) and self.isTrain and torch.cuda.is_available()
self.scaler = GradScaler(enabled=self.use_amp)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] Enable AMP only when model runs on CUDA device

self.use_amp is set to True whenever torch.cuda.is_available() is true, regardless of whether the model is actually placed on CUDA. When the user selects CPU execution (e.g. --gpu_ids -1) on a machine that still has CUDA available, optimize_parameters enters a CUDA autocast context and uses GradScaler while all tensors stay on CPU. torch.cuda.amp raises at runtime in that configuration, so CPU training now crashes unless AMP is manually disabled. The previous implementation worked on CPU; this change regresses that path. The flag should also check the selected device (e.g. non-empty self.gpu_ids or self.device.type == 'cuda') before enabling AMP.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants