Commit
·
91656f9
1
Parent(s):
19e1efc
Update checkpoint configuration.
Browse files
configs/example.yaml
CHANGED
|
@@ -105,11 +105,12 @@ TimeSeriesModel:
|
|
| 105 |
encoder_config:
|
| 106 |
attn_mode: chunk
|
| 107 |
num_heads: 4
|
| 108 |
-
expand_v: 1.0
|
| 109 |
-
use_gate: false
|
| 110 |
use_short_conv: true
|
| 111 |
-
conv_size:
|
| 112 |
-
allow_neg_eigval: true
|
|
|
|
|
|
|
| 113 |
use_forget_gate: true
|
| 114 |
num_householder: 4
|
| 115 |
weaving: true
|
|
|
|
| 105 |
encoder_config:
|
| 106 |
attn_mode: chunk
|
| 107 |
num_heads: 4
|
| 108 |
+
expand_v: 1.0
|
|
|
|
| 109 |
use_short_conv: true
|
| 110 |
+
conv_size: 32
|
| 111 |
+
allow_neg_eigval: true
|
| 112 |
+
hidden_ratio: 1.0
|
| 113 |
+
use_gate: true
|
| 114 |
use_forget_gate: true
|
| 115 |
num_householder: 4
|
| 116 |
weaving: true
|
examples/quick_start_tempo_pfn.ipynb
CHANGED
|
@@ -30,7 +30,6 @@
|
|
| 30 |
"metadata": {},
|
| 31 |
"outputs": [],
|
| 32 |
"source": [
|
| 33 |
-
"import os\n",
|
| 34 |
"import urllib.request\n",
|
| 35 |
"import torch\n",
|
| 36 |
"import numpy as np\n",
|
|
@@ -66,7 +65,8 @@
|
|
| 66 |
"metadata": {},
|
| 67 |
"outputs": [],
|
| 68 |
"source": [
|
| 69 |
-
"DROPBOX_URL = \"https://www.dropbox.com/scl/fi/
|
|
|
|
| 70 |
"CHECKPOINT_DIR = repo_root / \"models\"\n",
|
| 71 |
"CHECKPOINT_PATH = CHECKPOINT_DIR / \"checkpoint.pth\"\n",
|
| 72 |
"\n",
|
|
@@ -202,7 +202,7 @@
|
|
| 202 |
"source": [
|
| 203 |
"import matplotlib.pyplot as plt\n",
|
| 204 |
"\n",
|
| 205 |
-
"plt.set_loglevel(
|
| 206 |
"\n",
|
| 207 |
"# preds: [B, P, N, Q] for quantiles (univariate -> N=1)\n",
|
| 208 |
"preds_np = preds.cpu().numpy()\n",
|
|
|
|
| 30 |
"metadata": {},
|
| 31 |
"outputs": [],
|
| 32 |
"source": [
|
|
|
|
| 33 |
"import urllib.request\n",
|
| 34 |
"import torch\n",
|
| 35 |
"import numpy as np\n",
|
|
|
|
| 65 |
"metadata": {},
|
| 66 |
"outputs": [],
|
| 67 |
"source": [
|
| 68 |
+
"DROPBOX_URL = \"https://www.dropbox.com/scl/fi/mqsni5lehooyaw93y3uzq/checkpoint_38M.pth?rlkey=3uyehvmtted02xkha24zgpzb6&st=seevsbkn&dl=1\"\n",
|
| 69 |
+
"\n",
|
| 70 |
"CHECKPOINT_DIR = repo_root / \"models\"\n",
|
| 71 |
"CHECKPOINT_PATH = CHECKPOINT_DIR / \"checkpoint.pth\"\n",
|
| 72 |
"\n",
|
|
|
|
| 202 |
"source": [
|
| 203 |
"import matplotlib.pyplot as plt\n",
|
| 204 |
"\n",
|
| 205 |
+
"plt.set_loglevel(\"error\")\n",
|
| 206 |
"\n",
|
| 207 |
"# preds: [B, P, N, Q] for quantiles (univariate -> N=1)\n",
|
| 208 |
"preds_np = preds.cpu().numpy()\n",
|
examples/quick_start_tempo_pfn.py
CHANGED
|
@@ -51,7 +51,7 @@ def main():
|
|
| 51 |
if args.checkpoint:
|
| 52 |
model_path = args.checkpoint
|
| 53 |
else:
|
| 54 |
-
dropbox_url = "https://www.dropbox.com/scl/fi/
|
| 55 |
model_path = download_checkpoint_if_needed(dropbox_url, target_dir="models")
|
| 56 |
|
| 57 |
logger.info("=== Time Series Model Demo (Univariate Quantile) ===")
|
|
|
|
| 51 |
if args.checkpoint:
|
| 52 |
model_path = args.checkpoint
|
| 53 |
else:
|
| 54 |
+
dropbox_url = "https://www.dropbox.com/scl/fi/mqsni5lehooyaw93y3uzq/checkpoint_38M.pth?rlkey=3uyehvmtted02xkha24zgpzb6&st=seevsbkn&dl=0"
|
| 55 |
model_path = download_checkpoint_if_needed(dropbox_url, target_dir="models")
|
| 56 |
|
| 57 |
logger.info("=== Time Series Model Demo (Univariate Quantile) ===")
|
src/models/blocks.py
CHANGED
|
@@ -22,6 +22,7 @@ class GatedDeltaProductEncoder(nn.Module):
|
|
| 22 |
use_gate: bool = False,
|
| 23 |
use_short_conv: bool = True,
|
| 24 |
conv_size: int = 4,
|
|
|
|
| 25 |
allow_neg_eigval: bool = True,
|
| 26 |
use_forget_gate: bool = True,
|
| 27 |
num_householder: int = 1,
|
|
@@ -36,7 +37,7 @@ class GatedDeltaProductEncoder(nn.Module):
|
|
| 36 |
use_short_conv=use_short_conv,
|
| 37 |
conv_size=conv_size,
|
| 38 |
head_dim=token_embed_dim // num_heads,
|
| 39 |
-
hidden_ratio=
|
| 40 |
num_heads=num_heads,
|
| 41 |
allow_neg_eigval=allow_neg_eigval,
|
| 42 |
use_forget_gate=use_forget_gate,
|
|
|
|
| 22 |
use_gate: bool = False,
|
| 23 |
use_short_conv: bool = True,
|
| 24 |
conv_size: int = 4,
|
| 25 |
+
hidden_ratio: int = 1.0,
|
| 26 |
allow_neg_eigval: bool = True,
|
| 27 |
use_forget_gate: bool = True,
|
| 28 |
num_householder: int = 1,
|
|
|
|
| 37 |
use_short_conv=use_short_conv,
|
| 38 |
conv_size=conv_size,
|
| 39 |
head_dim=token_embed_dim // num_heads,
|
| 40 |
+
hidden_ratio=hidden_ratio,
|
| 41 |
num_heads=num_heads,
|
| 42 |
allow_neg_eigval=allow_neg_eigval,
|
| 43 |
use_forget_gate=use_forget_gate,
|