-
Notifications
You must be signed in to change notification settings - Fork 42
Ktezcan/dev/iss941 encode targets sepfstep #1019
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Ktezcan/dev/iss941 encode targets sepfstep #1019
Conversation
…941_encode_targets_sepfstep
| time_win: tuple, | ||
| normalizer, # dataset | ||
| normalizer, # dataset, | ||
| use_normalizer: str, # "source_normalizer" or "target_normalizer" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename use_normalizer to channel_to_normalize. Even though the type and possible values are clearly documented use_normalizer indicates for a boolean value.
Another option is to rename normalizer to normaliser_datasetor normaliser_dsso you can use normalizer instead of use_normalizer
| ) | ||
| for stl_b in batch | ||
| ] | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use less lines, because it looks more complex than it actually is.
target_source_like_tokens_lens = torch.stack([
torch.stack([
torch.stack([
s.target_source_like_tokens_lens[fstep]
if len(s.target_source_like_tokens_lens[fstep]) > 0
else torch.tensor([])
for fstep in range(len(s.target_source_like_tokens_lens))
]) for s in stl_b
]) for stl_b in batch
])
If this was caused by ruff then just forget about this comment...
| for ib, sb in enumerate(batch): | ||
| for itype, s in enumerate(sb): | ||
| for fstep in range(offsets.shape[0]): | ||
| if target_source_like_tokens_lens[ib, itype, fstep].sum() != 0: # if not empty |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace with if target_source_like_tokens_lens[ib, itype, fstep].sum() != 0: with if any(target_source_like_tokens_lens[ib, type, fstep]): for better efficiency.
| # batch sample list when non-empty | ||
| for fstep in range(len(self.target_source_like_tokens_cells)): | ||
| if ( | ||
| torch.tensor([len(s) for s in self.target_source_like_tokens_cells[fstep]]).sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace
if (
torch.tensor([len(s) for s in self.target_source_like_tokens_cells[fstep]]).sum()
> 0
):
with
if any(len(s) > 0 for s in self.target_source_like_tokens_cells[fstep]):
for slightly better efficiency.
Maybe you can find a way to replace len(s) with a way to do the check in constant time without having to write multiple lines of code.
| times: np.array, | ||
| time_win: tuple, | ||
| normalizer, # dataset | ||
| use_normalizer: str, # "source_normalizer" or "target_normalizer" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename use_tokenizer as you did in tokeniser_forecast.py(see first comment)
| tokens_target_det = tokens_target.detach() # explicitly detach as well | ||
| tokens_targets.append(tokens_target_det) | ||
|
|
||
| return_dict = {"preds_all": preds_all, "posteriors": posteriors} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initiate return_dictabove first if check on "encode_targets_latent".
Move the key accesses on return_dict at the end of the first if check on "encode_targets_latent".
Remove the second if check on "encode_targets_latent".
| # # we don't append an empty tensor for the source | ||
| # tokens_all.append(torch.tensor([], dtype=self.dtype, device="cuda")) | ||
| # el | ||
| if source_tokens_lens.sum() != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace if source_tokens_lens.sum() != 0: with if source_tokens_lens.any(): for better efficiency
Description
This is an dditional PR over a previous PR: #961
The previous one introduces a new function to embed cells for the targets. This PR uses the existing
embed_cells()function to embed the target tokens. The purpose is to reduce duplicated code and prevent potential "code rot" etc..I have tested both training and inference with this.
Issue Number
Ref #941
Refs #941
Closes #941
Closes #941
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60