vlad-moroshan commited on
Commit
91656f9
·
1 Parent(s): 19e1efc

Update checkpoint configuration.

Browse files
configs/example.yaml CHANGED
@@ -105,11 +105,12 @@ TimeSeriesModel:
105
  encoder_config:
106
  attn_mode: chunk
107
  num_heads: 4
108
- expand_v: 1.0
109
- use_gate: false
110
  use_short_conv: true
111
- conv_size: 16
112
- allow_neg_eigval: true
 
 
113
  use_forget_gate: true
114
  num_householder: 4
115
  weaving: true
 
105
  encoder_config:
106
  attn_mode: chunk
107
  num_heads: 4
108
+ expand_v: 1.0
 
109
  use_short_conv: true
110
+ conv_size: 32
111
+ allow_neg_eigval: true
112
+ hidden_ratio: 1.0
113
+ use_gate: true
114
  use_forget_gate: true
115
  num_householder: 4
116
  weaving: true
examples/quick_start_tempo_pfn.ipynb CHANGED
@@ -30,7 +30,6 @@
30
  "metadata": {},
31
  "outputs": [],
32
  "source": [
33
- "import os\n",
34
  "import urllib.request\n",
35
  "import torch\n",
36
  "import numpy as np\n",
@@ -66,7 +65,8 @@
66
  "metadata": {},
67
  "outputs": [],
68
  "source": [
69
- "DROPBOX_URL = \"https://www.dropbox.com/scl/fi/5vmjr7nx9wj9w1vl2giuv/checkpoint.pth?rlkey=qmk08ojp7wj0l6kpm8hzgbzju&st=dyr07d00&dl=1\"\n",
 
70
  "CHECKPOINT_DIR = repo_root / \"models\"\n",
71
  "CHECKPOINT_PATH = CHECKPOINT_DIR / \"checkpoint.pth\"\n",
72
  "\n",
@@ -202,7 +202,7 @@
202
  "source": [
203
  "import matplotlib.pyplot as plt\n",
204
  "\n",
205
- "plt.set_loglevel('error') \n",
206
  "\n",
207
  "# preds: [B, P, N, Q] for quantiles (univariate -> N=1)\n",
208
  "preds_np = preds.cpu().numpy()\n",
 
30
  "metadata": {},
31
  "outputs": [],
32
  "source": [
 
33
  "import urllib.request\n",
34
  "import torch\n",
35
  "import numpy as np\n",
 
65
  "metadata": {},
66
  "outputs": [],
67
  "source": [
68
+ "DROPBOX_URL = \"https://www.dropbox.com/scl/fi/mqsni5lehooyaw93y3uzq/checkpoint_38M.pth?rlkey=3uyehvmtted02xkha24zgpzb6&st=seevsbkn&dl=1\"\n",
69
+ "\n",
70
  "CHECKPOINT_DIR = repo_root / \"models\"\n",
71
  "CHECKPOINT_PATH = CHECKPOINT_DIR / \"checkpoint.pth\"\n",
72
  "\n",
 
202
  "source": [
203
  "import matplotlib.pyplot as plt\n",
204
  "\n",
205
+ "plt.set_loglevel(\"error\")\n",
206
  "\n",
207
  "# preds: [B, P, N, Q] for quantiles (univariate -> N=1)\n",
208
  "preds_np = preds.cpu().numpy()\n",
examples/quick_start_tempo_pfn.py CHANGED
@@ -51,7 +51,7 @@ def main():
51
  if args.checkpoint:
52
  model_path = args.checkpoint
53
  else:
54
- dropbox_url = "https://www.dropbox.com/scl/fi/5vmjr7nx9wj9w1vl2giuv/checkpoint.pth?rlkey=qmk08ojp7wj0l6kpm8hzgbzju&st=dyr07d00&dl=0"
55
  model_path = download_checkpoint_if_needed(dropbox_url, target_dir="models")
56
 
57
  logger.info("=== Time Series Model Demo (Univariate Quantile) ===")
 
51
  if args.checkpoint:
52
  model_path = args.checkpoint
53
  else:
54
+ dropbox_url = "https://www.dropbox.com/scl/fi/mqsni5lehooyaw93y3uzq/checkpoint_38M.pth?rlkey=3uyehvmtted02xkha24zgpzb6&st=seevsbkn&dl=0"
55
  model_path = download_checkpoint_if_needed(dropbox_url, target_dir="models")
56
 
57
  logger.info("=== Time Series Model Demo (Univariate Quantile) ===")
src/models/blocks.py CHANGED
@@ -22,6 +22,7 @@ class GatedDeltaProductEncoder(nn.Module):
22
  use_gate: bool = False,
23
  use_short_conv: bool = True,
24
  conv_size: int = 4,
 
25
  allow_neg_eigval: bool = True,
26
  use_forget_gate: bool = True,
27
  num_householder: int = 1,
@@ -36,7 +37,7 @@ class GatedDeltaProductEncoder(nn.Module):
36
  use_short_conv=use_short_conv,
37
  conv_size=conv_size,
38
  head_dim=token_embed_dim // num_heads,
39
- hidden_ratio=0.5,
40
  num_heads=num_heads,
41
  allow_neg_eigval=allow_neg_eigval,
42
  use_forget_gate=use_forget_gate,
 
22
  use_gate: bool = False,
23
  use_short_conv: bool = True,
24
  conv_size: int = 4,
25
+ hidden_ratio: int = 1.0,
26
  allow_neg_eigval: bool = True,
27
  use_forget_gate: bool = True,
28
  num_householder: int = 1,
 
37
  use_short_conv=use_short_conv,
38
  conv_size=conv_size,
39
  head_dim=token_embed_dim // num_heads,
40
+ hidden_ratio=hidden_ratio,
41
  num_heads=num_heads,
42
  allow_neg_eigval=allow_neg_eigval,
43
  use_forget_gate=use_forget_gate,