Upload folder using huggingface_hub
Browse files- geocalib/extractor.py +2 -7
- gradio_app.py +8 -5
- siclib/models/extractor.py +2 -7
geocalib/extractor.py
CHANGED
|
@@ -22,14 +22,9 @@ class GeoCalib(nn.Module):
|
|
| 22 |
weights (str): trained variant, "pinhole" (default) or "distorted".
|
| 23 |
"""
|
| 24 |
super().__init__()
|
| 25 |
-
if weights
|
| 26 |
-
url = "https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-pinhole.tar"
|
| 27 |
-
elif weights == "distorted":
|
| 28 |
-
url = (
|
| 29 |
-
"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-simple_radial.tar"
|
| 30 |
-
)
|
| 31 |
-
else:
|
| 32 |
raise ValueError(f"Unknown weights: {weights}")
|
|
|
|
| 33 |
|
| 34 |
# load checkpoint
|
| 35 |
model_dir = f"{torch.hub.get_dir()}/geocalib"
|
|
|
|
| 22 |
weights (str): trained variant, "pinhole" (default) or "distorted".
|
| 23 |
"""
|
| 24 |
super().__init__()
|
| 25 |
+
if weights not in {"pinhole", "distorted"}:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
raise ValueError(f"Unknown weights: {weights}")
|
| 27 |
+
url = f"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-{weights}.tar"
|
| 28 |
|
| 29 |
# load checkpoint
|
| 30 |
model_dir = f"{torch.hub.get_dir()}/geocalib"
|
gradio_app.py
CHANGED
|
@@ -8,7 +8,7 @@ import numpy as np
|
|
| 8 |
import spaces
|
| 9 |
import torch
|
| 10 |
|
| 11 |
-
from geocalib import viz2d
|
| 12 |
from geocalib.camera import camera_models
|
| 13 |
from geocalib.extractor import GeoCalib
|
| 14 |
from geocalib.perspective_fields import get_perspective_field
|
|
@@ -77,7 +77,9 @@ def format_output(results):
|
|
| 77 |
@spaces.GPU(duration=10)
|
| 78 |
def inference(img, camera_model):
|
| 79 |
out = model.calibrate(img.to(device), camera_model=camera_model)
|
| 80 |
-
save_keys = ["camera", "gravity"] + [
|
|
|
|
|
|
|
| 81 |
res = {k: v.cpu() for k, v in out.items() if k in save_keys}
|
| 82 |
# not converting to numpy results in gpu abort
|
| 83 |
res["up_confidence"] = out["up_confidence"].cpu().numpy()
|
|
@@ -100,10 +102,9 @@ def process_results(
|
|
| 100 |
raise gr.Error("Please upload an image first.")
|
| 101 |
|
| 102 |
img = model.load_image(image_path)
|
| 103 |
-
print("Running inference...")
|
| 104 |
start = time()
|
| 105 |
inference_result = inference(img, camera_model)
|
| 106 |
-
|
| 107 |
inference_result["image"] = img.cpu()
|
| 108 |
|
| 109 |
if inference_result is None:
|
|
@@ -158,7 +159,9 @@ def update_plot(
|
|
| 158 |
viz2d.plot_confidences([torch.tensor(inference_result["up_confidence"][0])], axes=[ax[0]])
|
| 159 |
|
| 160 |
if plot_latitude_confidence:
|
| 161 |
-
viz2d.plot_confidences(
|
|
|
|
|
|
|
| 162 |
|
| 163 |
fig.canvas.draw()
|
| 164 |
img = np.array(fig.canvas.renderer.buffer_rgba())
|
|
|
|
| 8 |
import spaces
|
| 9 |
import torch
|
| 10 |
|
| 11 |
+
from geocalib import logger, viz2d
|
| 12 |
from geocalib.camera import camera_models
|
| 13 |
from geocalib.extractor import GeoCalib
|
| 14 |
from geocalib.perspective_fields import get_perspective_field
|
|
|
|
| 77 |
@spaces.GPU(duration=10)
|
| 78 |
def inference(img, camera_model):
|
| 79 |
out = model.calibrate(img.to(device), camera_model=camera_model)
|
| 80 |
+
save_keys = ["camera", "gravity"] + [
|
| 81 |
+
f"{k}_uncertainty" for k in ["roll", "pitch", "vfov", "focal"]
|
| 82 |
+
]
|
| 83 |
res = {k: v.cpu() for k, v in out.items() if k in save_keys}
|
| 84 |
# not converting to numpy results in gpu abort
|
| 85 |
res["up_confidence"] = out["up_confidence"].cpu().numpy()
|
|
|
|
| 102 |
raise gr.Error("Please upload an image first.")
|
| 103 |
|
| 104 |
img = model.load_image(image_path)
|
|
|
|
| 105 |
start = time()
|
| 106 |
inference_result = inference(img, camera_model)
|
| 107 |
+
logger.info(f"Calibration took {time() - start:.2f} sec. ({camera_model})")
|
| 108 |
inference_result["image"] = img.cpu()
|
| 109 |
|
| 110 |
if inference_result is None:
|
|
|
|
| 159 |
viz2d.plot_confidences([torch.tensor(inference_result["up_confidence"][0])], axes=[ax[0]])
|
| 160 |
|
| 161 |
if plot_latitude_confidence:
|
| 162 |
+
viz2d.plot_confidences(
|
| 163 |
+
[torch.tensor(inference_result["latitude_confidence"][0])], axes=[ax[0]]
|
| 164 |
+
)
|
| 165 |
|
| 166 |
fig.canvas.draw()
|
| 167 |
img = np.array(fig.canvas.renderer.buffer_rgba())
|
siclib/models/extractor.py
CHANGED
|
@@ -22,14 +22,9 @@ class GeoCalib(nn.Module):
|
|
| 22 |
weights (str, optional): Weights to load. Defaults to "pinhole".
|
| 23 |
"""
|
| 24 |
super().__init__()
|
| 25 |
-
if weights
|
| 26 |
-
url = "https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-pinhole.tar"
|
| 27 |
-
elif weights == "distorted":
|
| 28 |
-
url = (
|
| 29 |
-
"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-simple_radial.tar"
|
| 30 |
-
)
|
| 31 |
-
else:
|
| 32 |
raise ValueError(f"Unknown weights: {weights}")
|
|
|
|
| 33 |
|
| 34 |
# load checkpoint
|
| 35 |
model_dir = f"{torch.hub.get_dir()}/geocalib"
|
|
|
|
| 22 |
weights (str, optional): Weights to load. Defaults to "pinhole".
|
| 23 |
"""
|
| 24 |
super().__init__()
|
| 25 |
+
if weights not in {"pinhole", "distorted"}:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
raise ValueError(f"Unknown weights: {weights}")
|
| 27 |
+
url = f"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-{weights}.tar"
|
| 28 |
|
| 29 |
# load checkpoint
|
| 30 |
model_dir = f"{torch.hub.get_dir()}/geocalib"
|