Skip to content

GPT4TS fails on time series imputation #710

@zltututu

Description

@zltututu

1. System Info

pypots v0.15

2. Information

  • The official example scripts
  • My own created scripts

3. Reproduction

import numpy as np
from sklearn.preprocessing import StandardScaler
from pygrinder import mcar, calc_missing_rate
from benchpots.datasets import preprocess_physionet2012
from pypots.imputation import SAITS, GPT4TS 
from pypots.nn.functional import calc_mae

data = preprocess_physionet2012(subset='set-a',rate=0.1)
train_X, val_X, test_X = data["train_X"], data["val_X"], data["test_X"]
print(train_X.shape) 
print(val_X.shape)
print(f"We have {calc_missing_rate(train_X):.1%} values missing in train_X")  
train_set = {"X": train_X} 
val_set = {
    "X": val_X,
    "X_ori": data["val_X_ori"], 
}
test_set = {"X": test_X}  
test_X_ori = data["test_X_ori"] 
indicating_mask = np.isnan(test_X) ^ np.isnan(test_X_ori) 

model = GPT4TS(n_steps=train_X.shape[1], 
               n_features=train_X.shape[2], 
               patch_size=1, 
               patch_stride=1,
               n_layers=2, 
               train_gpt_mlp=True,
               d_ffn=768, 
               dropout=0.1, 
               epochs=5, 
               device='cuda:0')
model.fit(train_set, val_set)  
imputation = model.impute(test_set) 
mae = calc_mae(imputation, np.nan_to_num(test_X_ori), indicating_mask)  

4. Expected behavior

2025-04-05 17:21:32 [ERROR]: ❌ Exception: `predictions` mustn't contain NaN values, but detected NaN in it
Traceback (most recent call last):
  File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/pypots/base.py", line 737, in _train_model
    results = self.model(inputs, calc_criterion=True)
  File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/pypots/imputation/gpt4ts/core.py", line 78, in forward
    loss = self.training_loss(reconstruction, X_ori, indicating_mask)
  File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/pypots/nn/modules/loss.py", line 81, in forward
    value = calc_mae(logits, targets, masks)
  File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/pypots/nn/functional/error.py", line 98, in calc_mae
    lib = _check_inputs(predictions, targets, masks)
  File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/pypots/nn/functional/error.py", line 34, in _check_inputs
    assert not lib.isnan(predictions).any(), "`predictions` mustn't contain NaN values, but detected NaN in it"
AssertionError: `predictions` mustn't contain NaN values, but detected NaN in it

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "GPT4TS.py", line 36, in <module>
    model.fit(train_set, val_set)  # train the model on the dataset
  File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/pypots/imputation/gpt4ts/model.py", line 259, in fit
    self._train_model(train_dataloader, val_dataloader)
  File "/home/zli/data/miniconda3/envs/pypots/lib/python3.8/site-packages/pypots/base.py", line 814, in _train_model
    raise RuntimeError(
RuntimeError: Training got interrupted. Model was not trained. Please investigate the error printed above.

5. Your contribution

I have starred this repo.

Metadata

Metadata

Assignees

Labels

discussionkeepKeep this issue away from being stale.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions