-
-
Notifications
You must be signed in to change notification settings - Fork 172
Open
Labels
Description
1. System Info
pypots v0.15
2. Information
- The official example scripts
- My own created scripts
3. Reproduction
import numpy as np
from sklearn.preprocessing import StandardScaler
from pygrinder import mcar, calc_missing_rate
from benchpots.datasets import preprocess_physionet2012
from pypots.imputation import SAITS, GPT4TS
from pypots.nn.functional import calc_mae
data = preprocess_physionet2012(subset='set-a',rate=0.1)
train_X, val_X, test_X = data["train_X"], data["val_X"], data["test_X"]
print(train_X.shape)
print(val_X.shape)
print(f"We have {calc_missing_rate(train_X):.1%} values missing in train_X")
train_set = {"X": train_X}
val_set = {
"X": val_X,
"X_ori": data["val_X_ori"],
}
test_set = {"X": test_X}
test_X_ori = data["test_X_ori"]
indicating_mask = np.isnan(test_X) ^ np.isnan(test_X_ori)
model = GPT4TS(n_steps=train_X.shape[1],
n_features=train_X.shape[2],
patch_size=1,
patch_stride=1,
n_layers=2,
train_gpt_mlp=True,
d_ffn=768,
dropout=0.1,
epochs=5,
device='cuda:0')
model.fit(train_set, val_set)
imputation = model.impute(test_set)
mae = calc_mae(imputation, np.nan_to_num(test_X_ori), indicating_mask) 4. Expected behavior
2025-04-05 17:21:32 [ERROR]: ❌ Exception: `predictions` mustn't contain NaN values, but detected NaN in it
Traceback (most recent call last):
File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/pypots/base.py", line 737, in _train_model
results = self.model(inputs, calc_criterion=True)
File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/pypots/imputation/gpt4ts/core.py", line 78, in forward
loss = self.training_loss(reconstruction, X_ori, indicating_mask)
File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/pypots/nn/modules/loss.py", line 81, in forward
value = calc_mae(logits, targets, masks)
File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/pypots/nn/functional/error.py", line 98, in calc_mae
lib = _check_inputs(predictions, targets, masks)
File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/pypots/nn/functional/error.py", line 34, in _check_inputs
assert not lib.isnan(predictions).any(), "`predictions` mustn't contain NaN values, but detected NaN in it"
AssertionError: `predictions` mustn't contain NaN values, but detected NaN in it
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "GPT4TS.py", line 36, in <module>
model.fit(train_set, val_set) # train the model on the dataset
File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/pypots/imputation/gpt4ts/model.py", line 259, in fit
self._train_model(train_dataloader, val_dataloader)
File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/pypots/base.py", line 814, in _train_model
raise RuntimeError(
RuntimeError: Training got interrupted. Model was not trained. Please investigate the error printed above.
5. Your contribution
I have starred this repo.