vaibhavpandeyvpz commited on
Commit
ce06be0
·
1 Parent(s): 75cf114

Reformat code style

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +166 -95
  2. trellis/models/__init__.py +23 -15
  3. trellis/models/sparse_structure_flow.py +49 -26
  4. trellis/models/sparse_structure_vae.py +50 -42
  5. trellis/models/structured_latent_flow.py +70 -45
  6. trellis/models/structured_latent_vae/base.py +36 -20
  7. trellis/models/structured_latent_vae/decoder_gs.py +54 -23
  8. trellis/models/structured_latent_vae/decoder_mesh.py +43 -26
  9. trellis/models/structured_latent_vae/decoder_rf.py +38 -14
  10. trellis/models/structured_latent_vae/encoder.py +5 -3
  11. trellis/modules/attention/__init__.py +18 -11
  12. trellis/modules/attention/full_attn.py +62 -43
  13. trellis/modules/attention/modules.py +47 -22
  14. trellis/modules/norm.py +5 -5
  15. trellis/modules/sparse/__init__.py +51 -44
  16. trellis/modules/sparse/attention/full_attn.py +149 -66
  17. trellis/modules/sparse/attention/modules.py +44 -17
  18. trellis/modules/sparse/attention/serialized_attn.py +105 -51
  19. trellis/modules/sparse/attention/windowed_attn.py +85 -45
  20. trellis/modules/sparse/basic.py +193 -117
  21. trellis/modules/sparse/conv/__init__.py +12 -7
  22. trellis/modules/sparse/conv/conv_spconv.py +78 -20
  23. trellis/modules/sparse/conv/conv_torchsparse.py +52 -14
  24. trellis/modules/sparse/linear.py +1 -3
  25. trellis/modules/sparse/nonlinearity.py +2 -8
  26. trellis/modules/sparse/norm.py +10 -5
  27. trellis/modules/sparse/spatial.py +49 -29
  28. trellis/modules/sparse/transformer/__init__.py +1 -1
  29. trellis/modules/sparse/transformer/blocks.py +14 -4
  30. trellis/modules/sparse/transformer/modulated.py +34 -15
  31. trellis/modules/spatial.py +24 -8
  32. trellis/modules/transformer/__init__.py +1 -1
  33. trellis/modules/transformer/blocks.py +21 -7
  34. trellis/modules/transformer/modulated.py +22 -11
  35. trellis/modules/utils.py +1 -0
  36. trellis/pipelines/__init__.py +4 -2
  37. trellis/pipelines/base.py +8 -6
  38. trellis/pipelines/samplers/__init__.py +5 -1
  39. trellis/pipelines/samplers/base.py +1 -6
  40. trellis/pipelines/samplers/flow_euler.py +40 -16
  41. trellis/pipelines/samplers/guidance_interval_mixin.py +3 -1
  42. trellis/pipelines/trellis_image_to_3d.py +142 -83
  43. trellis/renderers/__init__.py +6 -5
  44. trellis/renderers/gaussian_render.py +122 -76
  45. trellis/renderers/mesh_renderer.py +69 -42
  46. trellis/renderers/octree_renderer.py +213 -124
  47. trellis/renderers/sh_utils.py +41 -31
  48. trellis/representations/gaussian/__init__.py +1 -1
  49. trellis/representations/gaussian/gaussian_model.py +119 -67
  50. trellis/representations/gaussian/general_utils.py +35 -17
app.py CHANGED
@@ -4,7 +4,8 @@ from gradio_litmodel3d import LitModel3D
4
 
5
  import os
6
  import shutil
7
- os.environ['SPCONV_ALGO'] = 'native'
 
8
  from typing import *
9
  import torch
10
  import numpy as np
@@ -17,15 +18,15 @@ from trellis.utils import render_utils, postprocessing_utils
17
 
18
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
- TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
21
  os.makedirs(TMP_DIR, exist_ok=True)
22
 
23
 
24
  def start_session(req: gr.Request):
25
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
26
  os.makedirs(user_dir, exist_ok=True)
27
-
28
-
29
  def end_session(req: gr.Request):
30
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
31
  if os.path.exists(user_dir):
@@ -35,7 +36,7 @@ def end_session(req: gr.Request):
35
  def preprocess_image(image: Image.Image) -> Image.Image:
36
  """
37
  Preprocess the input image for 3D generation.
38
-
39
  This function is called when a user uploads an image or selects an example.
40
  It applies background removal and other preprocessing steps necessary for
41
  optimal 3D model generation.
@@ -53,13 +54,13 @@ def preprocess_image(image: Image.Image) -> Image.Image:
53
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
54
  """
55
  Preprocess a list of input images for multi-image 3D generation.
56
-
57
  This function is called when users upload multiple images in the gallery.
58
  It processes each image to prepare them for the multi-image 3D generation pipeline.
59
-
60
  Args:
61
  images (List[Tuple[Image.Image, str]]): The input images from the gallery
62
-
63
  Returns:
64
  List[Image.Image]: The preprocessed images ready for 3D generation
65
  """
@@ -70,55 +71,55 @@ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image
70
 
71
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
72
  return {
73
- 'gaussian': {
74
  **gs.init_params,
75
- '_xyz': gs._xyz.cpu().numpy(),
76
- '_features_dc': gs._features_dc.cpu().numpy(),
77
- '_scaling': gs._scaling.cpu().numpy(),
78
- '_rotation': gs._rotation.cpu().numpy(),
79
- '_opacity': gs._opacity.cpu().numpy(),
80
  },
81
- 'mesh': {
82
- 'vertices': mesh.vertices.cpu().numpy(),
83
- 'faces': mesh.faces.cpu().numpy(),
84
  },
85
  }
86
-
87
-
88
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
89
  gs = Gaussian(
90
- aabb=state['gaussian']['aabb'],
91
- sh_degree=state['gaussian']['sh_degree'],
92
- mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
93
- scaling_bias=state['gaussian']['scaling_bias'],
94
- opacity_bias=state['gaussian']['opacity_bias'],
95
- scaling_activation=state['gaussian']['scaling_activation'],
96
  )
97
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
98
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
99
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
100
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
101
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
102
-
103
  mesh = edict(
104
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
105
- faces=torch.tensor(state['mesh']['faces'], device='cuda'),
106
  )
107
-
108
  return gs, mesh
109
 
110
 
111
  def get_seed(randomize_seed: bool, seed: int) -> int:
112
  """
113
  Get the random seed for generation.
114
-
115
  This function is called by the generate button to determine whether to use
116
  a random seed or the user-specified seed value.
117
-
118
  Args:
119
  randomize_seed (bool): Whether to generate a random seed
120
  seed (int): The user-specified seed value
121
-
122
  Returns:
123
  int: The seed to use for generation
124
  """
@@ -163,7 +164,7 @@ def generate_and_extract_glb(
163
  str: The path to the extracted GLB file (for download).
164
  """
165
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
166
-
167
  # Generate 3D model
168
  if not is_multiimage:
169
  outputs = pipeline.run(
@@ -196,24 +197,28 @@ def generate_and_extract_glb(
196
  },
197
  mode=multiimage_algo,
198
  )
199
-
200
  # Render video
201
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
202
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
203
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
204
- video_path = os.path.join(user_dir, 'sample.mp4')
 
 
205
  imageio.mimsave(video_path, video, fps=15)
206
-
207
  # Extract GLB
208
- gs = outputs['gaussian'][0]
209
- mesh = outputs['mesh'][0]
210
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
211
- glb_path = os.path.join(user_dir, 'sample.glb')
 
 
212
  glb.export(glb_path)
213
-
214
  # Pack state for optional Gaussian extraction
215
  state = pack_state(gs, mesh)
216
-
217
  torch.cuda.empty_cache()
218
  return state, video_path, glb_path, glb_path
219
 
@@ -222,7 +227,7 @@ def generate_and_extract_glb(
222
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
223
  """
224
  Extract a Gaussian splatting file from the generated 3D model.
225
-
226
  This function is called when the user clicks "Extract Gaussian" button.
227
  It converts the 3D model state into a .ply file format containing
228
  Gaussian splatting data for advanced 3D applications.
@@ -236,7 +241,7 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
236
  """
237
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
238
  gs, _ = unpack_state(state)
239
- gaussian_path = os.path.join(user_dir, 'sample.ply')
240
  gs.save_ply(gaussian_path)
241
  torch.cuda.empty_cache()
242
  return gaussian_path, gaussian_path
@@ -250,82 +255,124 @@ def prepare_multi_example() -> List[Image.Image]:
250
  def split_image(image: Image.Image) -> List[Image.Image]:
251
  """
252
  Split a multi-view image into separate view images.
253
-
254
  This function is called when users select multi-image examples that contain
255
  multiple views in a single concatenated image. It automatically splits them
256
  based on alpha channel boundaries and preprocesses each view.
257
-
258
  Args:
259
  image (Image.Image): A concatenated image containing multiple views
260
-
261
  Returns:
262
  List[Image.Image]: List of individual preprocessed view images
263
  """
264
  image = np.array(image)
265
  alpha = image[..., 3]
266
- alpha = np.any(alpha>0, axis=0)
267
  start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
268
  end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
269
  images = []
270
  for s, e in zip(start_pos, end_pos):
271
- images.append(Image.fromarray(image[:, s:e+1]))
272
  return [preprocess_image(image) for image in images]
273
 
274
 
275
  with gr.Blocks(delete_cache=(600, 600)) as demo:
276
- gr.Markdown("""
 
277
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
278
  * Upload an image and click "Generate & Extract GLB" to create a 3D asset and automatically extract the GLB file.
279
  * If you want the Gaussian file as well, click "Extract Gaussian" after generation.
280
  * If the image has alpha channel, it will be used as the mask. Otherwise, we use `rembg` to remove the background.
281
 
282
  ✨New: 1) Experimental multi-image support. 2) Gaussian file extraction.
283
- """)
284
-
 
285
  with gr.Row():
286
  with gr.Column():
287
  with gr.Tabs() as input_tabs:
288
  with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
289
- image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
 
 
 
 
 
 
290
  with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
291
- multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
292
- gr.Markdown("""
 
 
 
 
 
 
 
293
  Input different views of the object in separate images.
294
 
295
  *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
296
- """)
297
-
 
298
  with gr.Accordion(label="Generation Settings", open=False):
299
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
300
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
301
  gr.Markdown("Stage 1: Sparse Structure Generation")
302
  with gr.Row():
303
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
304
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
 
 
 
 
305
  gr.Markdown("Stage 2: Structured Latent Generation")
306
  with gr.Row():
307
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
308
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
309
- multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
310
-
 
 
 
 
 
 
 
 
311
  with gr.Accordion(label="GLB Extraction Settings", open=False):
312
- mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
313
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
 
 
 
 
314
 
315
  generate_btn = gr.Button("Generate & Extract GLB", variant="primary")
316
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
317
- gr.Markdown("""
 
318
  *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
319
- """)
 
320
 
321
  with gr.Column():
322
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
323
- model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
324
-
 
 
 
 
325
  with gr.Row():
326
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
327
- download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
328
-
 
 
 
 
329
  is_multiimage = gr.State(False)
330
  output_buf = gr.State()
331
 
@@ -334,9 +381,9 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
334
  if os.path.exists("assets/images"):
335
  examples = gr.Examples(
336
  examples=[
337
- f'assets/images/{image}'
338
  for image in os.listdir("assets/images")
339
- if image.endswith(('.png', '.jpg', '.jpeg', '.webp'))
340
  ],
341
  inputs=[image_prompt],
342
  fn=preprocess_image,
@@ -346,7 +393,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
346
  )
347
  else:
348
  examples = gr.Examples(examples=[], inputs=[image_prompt])
349
-
350
  with gr.Row(visible=False) as multiimage_example:
351
  examples_multi = gr.Examples(
352
  examples=prepare_multi_example(),
@@ -360,16 +407,20 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
360
  # Handlers
361
  demo.load(start_session)
362
  demo.unload(end_session)
363
-
364
  single_image_input_tab.select(
365
- lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]),
366
- outputs=[is_multiimage, single_image_example, multiimage_example]
 
 
367
  )
368
  multiimage_input_tab.select(
369
- lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]),
370
- outputs=[is_multiimage, single_image_example, multiimage_example]
 
 
371
  )
372
-
373
  image_prompt.upload(
374
  preprocess_image,
375
  inputs=[image_prompt],
@@ -387,7 +438,19 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
387
  outputs=[seed],
388
  ).then(
389
  generate_and_extract_glb,
390
- inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo, mesh_simplify, texture_size],
 
 
 
 
 
 
 
 
 
 
 
 
391
  outputs=[output_buf, video_output, model_output, download_glb],
392
  ).then(
393
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
@@ -395,10 +458,16 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
395
  )
396
 
397
  video_output.clear(
398
- lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False), gr.Button(interactive=False)]),
 
 
 
 
 
 
399
  outputs=[extract_gs_btn, download_glb, download_gs],
400
  )
401
-
402
  extract_gs_btn.click(
403
  extract_gaussian,
404
  inputs=[output_buf],
@@ -412,14 +481,16 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
412
  lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
413
  outputs=[download_glb, download_gs],
414
  )
415
-
416
 
417
  # Launch the Gradio app
418
  if __name__ == "__main__":
419
  pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
420
  pipeline.cuda()
421
  try:
422
- pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
 
 
423
  except:
424
  pass
425
  demo.launch()
 
4
 
5
  import os
6
  import shutil
7
+
8
+ os.environ["SPCONV_ALGO"] = "native"
9
  from typing import *
10
  import torch
11
  import numpy as np
 
18
 
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
22
  os.makedirs(TMP_DIR, exist_ok=True)
23
 
24
 
25
  def start_session(req: gr.Request):
26
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
27
  os.makedirs(user_dir, exist_ok=True)
28
+
29
+
30
  def end_session(req: gr.Request):
31
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
32
  if os.path.exists(user_dir):
 
36
  def preprocess_image(image: Image.Image) -> Image.Image:
37
  """
38
  Preprocess the input image for 3D generation.
39
+
40
  This function is called when a user uploads an image or selects an example.
41
  It applies background removal and other preprocessing steps necessary for
42
  optimal 3D model generation.
 
54
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
55
  """
56
  Preprocess a list of input images for multi-image 3D generation.
57
+
58
  This function is called when users upload multiple images in the gallery.
59
  It processes each image to prepare them for the multi-image 3D generation pipeline.
60
+
61
  Args:
62
  images (List[Tuple[Image.Image, str]]): The input images from the gallery
63
+
64
  Returns:
65
  List[Image.Image]: The preprocessed images ready for 3D generation
66
  """
 
71
 
72
  def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
73
  return {
74
+ "gaussian": {
75
  **gs.init_params,
76
+ "_xyz": gs._xyz.cpu().numpy(),
77
+ "_features_dc": gs._features_dc.cpu().numpy(),
78
+ "_scaling": gs._scaling.cpu().numpy(),
79
+ "_rotation": gs._rotation.cpu().numpy(),
80
+ "_opacity": gs._opacity.cpu().numpy(),
81
  },
82
+ "mesh": {
83
+ "vertices": mesh.vertices.cpu().numpy(),
84
+ "faces": mesh.faces.cpu().numpy(),
85
  },
86
  }
87
+
88
+
89
  def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
90
  gs = Gaussian(
91
+ aabb=state["gaussian"]["aabb"],
92
+ sh_degree=state["gaussian"]["sh_degree"],
93
+ mininum_kernel_size=state["gaussian"]["mininum_kernel_size"],
94
+ scaling_bias=state["gaussian"]["scaling_bias"],
95
+ opacity_bias=state["gaussian"]["opacity_bias"],
96
+ scaling_activation=state["gaussian"]["scaling_activation"],
97
  )
98
+ gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device="cuda")
99
+ gs._features_dc = torch.tensor(state["gaussian"]["_features_dc"], device="cuda")
100
+ gs._scaling = torch.tensor(state["gaussian"]["_scaling"], device="cuda")
101
+ gs._rotation = torch.tensor(state["gaussian"]["_rotation"], device="cuda")
102
+ gs._opacity = torch.tensor(state["gaussian"]["_opacity"], device="cuda")
103
+
104
  mesh = edict(
105
+ vertices=torch.tensor(state["mesh"]["vertices"], device="cuda"),
106
+ faces=torch.tensor(state["mesh"]["faces"], device="cuda"),
107
  )
108
+
109
  return gs, mesh
110
 
111
 
112
  def get_seed(randomize_seed: bool, seed: int) -> int:
113
  """
114
  Get the random seed for generation.
115
+
116
  This function is called by the generate button to determine whether to use
117
  a random seed or the user-specified seed value.
118
+
119
  Args:
120
  randomize_seed (bool): Whether to generate a random seed
121
  seed (int): The user-specified seed value
122
+
123
  Returns:
124
  int: The seed to use for generation
125
  """
 
164
  str: The path to the extracted GLB file (for download).
165
  """
166
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
167
+
168
  # Generate 3D model
169
  if not is_multiimage:
170
  outputs = pipeline.run(
 
197
  },
198
  mode=multiimage_algo,
199
  )
200
+
201
  # Render video
202
+ video = render_utils.render_video(outputs["gaussian"][0], num_frames=120)["color"]
203
+ video_geo = render_utils.render_video(outputs["mesh"][0], num_frames=120)["normal"]
204
+ video = [
205
+ np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))
206
+ ]
207
+ video_path = os.path.join(user_dir, "sample.mp4")
208
  imageio.mimsave(video_path, video, fps=15)
209
+
210
  # Extract GLB
211
+ gs = outputs["gaussian"][0]
212
+ mesh = outputs["mesh"][0]
213
+ glb = postprocessing_utils.to_glb(
214
+ gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False
215
+ )
216
+ glb_path = os.path.join(user_dir, "sample.glb")
217
  glb.export(glb_path)
218
+
219
  # Pack state for optional Gaussian extraction
220
  state = pack_state(gs, mesh)
221
+
222
  torch.cuda.empty_cache()
223
  return state, video_path, glb_path, glb_path
224
 
 
227
  def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
228
  """
229
  Extract a Gaussian splatting file from the generated 3D model.
230
+
231
  This function is called when the user clicks "Extract Gaussian" button.
232
  It converts the 3D model state into a .ply file format containing
233
  Gaussian splatting data for advanced 3D applications.
 
241
  """
242
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
243
  gs, _ = unpack_state(state)
244
+ gaussian_path = os.path.join(user_dir, "sample.ply")
245
  gs.save_ply(gaussian_path)
246
  torch.cuda.empty_cache()
247
  return gaussian_path, gaussian_path
 
255
  def split_image(image: Image.Image) -> List[Image.Image]:
256
  """
257
  Split a multi-view image into separate view images.
258
+
259
  This function is called when users select multi-image examples that contain
260
  multiple views in a single concatenated image. It automatically splits them
261
  based on alpha channel boundaries and preprocesses each view.
262
+
263
  Args:
264
  image (Image.Image): A concatenated image containing multiple views
265
+
266
  Returns:
267
  List[Image.Image]: List of individual preprocessed view images
268
  """
269
  image = np.array(image)
270
  alpha = image[..., 3]
271
+ alpha = np.any(alpha > 0, axis=0)
272
  start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
273
  end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
274
  images = []
275
  for s, e in zip(start_pos, end_pos):
276
+ images.append(Image.fromarray(image[:, s : e + 1]))
277
  return [preprocess_image(image) for image in images]
278
 
279
 
280
  with gr.Blocks(delete_cache=(600, 600)) as demo:
281
+ gr.Markdown(
282
+ """
283
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
284
  * Upload an image and click "Generate & Extract GLB" to create a 3D asset and automatically extract the GLB file.
285
  * If you want the Gaussian file as well, click "Extract Gaussian" after generation.
286
  * If the image has alpha channel, it will be used as the mask. Otherwise, we use `rembg` to remove the background.
287
 
288
  ✨New: 1) Experimental multi-image support. 2) Gaussian file extraction.
289
+ """
290
+ )
291
+
292
  with gr.Row():
293
  with gr.Column():
294
  with gr.Tabs() as input_tabs:
295
  with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
296
+ image_prompt = gr.Image(
297
+ label="Image Prompt",
298
+ format="png",
299
+ image_mode="RGBA",
300
+ type="pil",
301
+ height=300,
302
+ )
303
  with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
304
+ multiimage_prompt = gr.Gallery(
305
+ label="Image Prompt",
306
+ format="png",
307
+ type="pil",
308
+ height=300,
309
+ columns=3,
310
+ )
311
+ gr.Markdown(
312
+ """
313
  Input different views of the object in separate images.
314
 
315
  *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
316
+ """
317
+ )
318
+
319
  with gr.Accordion(label="Generation Settings", open=False):
320
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
321
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
322
  gr.Markdown("Stage 1: Sparse Structure Generation")
323
  with gr.Row():
324
+ ss_guidance_strength = gr.Slider(
325
+ 0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1
326
+ )
327
+ ss_sampling_steps = gr.Slider(
328
+ 1, 50, label="Sampling Steps", value=12, step=1
329
+ )
330
  gr.Markdown("Stage 2: Structured Latent Generation")
331
  with gr.Row():
332
+ slat_guidance_strength = gr.Slider(
333
+ 0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1
334
+ )
335
+ slat_sampling_steps = gr.Slider(
336
+ 1, 50, label="Sampling Steps", value=12, step=1
337
+ )
338
+ multiimage_algo = gr.Radio(
339
+ ["stochastic", "multidiffusion"],
340
+ label="Multi-image Algorithm",
341
+ value="stochastic",
342
+ )
343
+
344
  with gr.Accordion(label="GLB Extraction Settings", open=False):
345
+ mesh_simplify = gr.Slider(
346
+ 0.9, 0.98, label="Simplify", value=0.95, step=0.01
347
+ )
348
+ texture_size = gr.Slider(
349
+ 512, 2048, label="Texture Size", value=1024, step=512
350
+ )
351
 
352
  generate_btn = gr.Button("Generate & Extract GLB", variant="primary")
353
  extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
354
+ gr.Markdown(
355
+ """
356
  *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
357
+ """
358
+ )
359
 
360
  with gr.Column():
361
+ video_output = gr.Video(
362
+ label="Generated 3D Asset", autoplay=True, loop=True, height=300
363
+ )
364
+ model_output = LitModel3D(
365
+ label="Extracted GLB/Gaussian", exposure=10.0, height=300
366
+ )
367
+
368
  with gr.Row():
369
+ download_glb = gr.DownloadButton(
370
+ label="Download GLB", interactive=False
371
+ )
372
+ download_gs = gr.DownloadButton(
373
+ label="Download Gaussian", interactive=False
374
+ )
375
+
376
  is_multiimage = gr.State(False)
377
  output_buf = gr.State()
378
 
 
381
  if os.path.exists("assets/images"):
382
  examples = gr.Examples(
383
  examples=[
384
+ f"assets/images/{image}"
385
  for image in os.listdir("assets/images")
386
+ if image.endswith((".png", ".jpg", ".jpeg", ".webp"))
387
  ],
388
  inputs=[image_prompt],
389
  fn=preprocess_image,
 
393
  )
394
  else:
395
  examples = gr.Examples(examples=[], inputs=[image_prompt])
396
+
397
  with gr.Row(visible=False) as multiimage_example:
398
  examples_multi = gr.Examples(
399
  examples=prepare_multi_example(),
 
407
  # Handlers
408
  demo.load(start_session)
409
  demo.unload(end_session)
410
+
411
  single_image_input_tab.select(
412
+ lambda: tuple(
413
+ [False, gr.Row.update(visible=True), gr.Row.update(visible=False)]
414
+ ),
415
+ outputs=[is_multiimage, single_image_example, multiimage_example],
416
  )
417
  multiimage_input_tab.select(
418
+ lambda: tuple(
419
+ [True, gr.Row.update(visible=False), gr.Row.update(visible=True)]
420
+ ),
421
+ outputs=[is_multiimage, single_image_example, multiimage_example],
422
  )
423
+
424
  image_prompt.upload(
425
  preprocess_image,
426
  inputs=[image_prompt],
 
438
  outputs=[seed],
439
  ).then(
440
  generate_and_extract_glb,
441
+ inputs=[
442
+ image_prompt,
443
+ multiimage_prompt,
444
+ is_multiimage,
445
+ seed,
446
+ ss_guidance_strength,
447
+ ss_sampling_steps,
448
+ slat_guidance_strength,
449
+ slat_sampling_steps,
450
+ multiimage_algo,
451
+ mesh_simplify,
452
+ texture_size,
453
+ ],
454
  outputs=[output_buf, video_output, model_output, download_glb],
455
  ).then(
456
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
 
458
  )
459
 
460
  video_output.clear(
461
+ lambda: tuple(
462
+ [
463
+ gr.Button(interactive=False),
464
+ gr.Button(interactive=False),
465
+ gr.Button(interactive=False),
466
+ ]
467
+ ),
468
  outputs=[extract_gs_btn, download_glb, download_gs],
469
  )
470
+
471
  extract_gs_btn.click(
472
  extract_gaussian,
473
  inputs=[output_buf],
 
481
  lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
482
  outputs=[download_glb, download_gs],
483
  )
484
+
485
 
486
  # Launch the Gradio app
487
  if __name__ == "__main__":
488
  pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
489
  pipeline.cuda()
490
  try:
491
+ pipeline.preprocess_image(
492
+ Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))
493
+ ) # Preload rembg
494
  except:
495
  pass
496
  demo.launch()
trellis/models/__init__.py CHANGED
@@ -1,20 +1,21 @@
1
  import importlib
2
 
3
  __attributes = {
4
- 'SparseStructureEncoder': 'sparse_structure_vae',
5
- 'SparseStructureDecoder': 'sparse_structure_vae',
6
- 'SparseStructureFlowModel': 'sparse_structure_flow',
7
- 'SLatEncoder': 'structured_latent_vae',
8
- 'SLatGaussianDecoder': 'structured_latent_vae',
9
- 'SLatRadianceFieldDecoder': 'structured_latent_vae',
10
- 'SLatMeshDecoder': 'structured_latent_vae',
11
- 'SLatFlowModel': 'structured_latent_flow',
12
  }
13
 
14
  __submodules = []
15
 
16
  __all__ = list(__attributes.keys()) + __submodules
17
 
 
18
  def __getattr__(name):
19
  if name not in globals():
20
  if name in __attributes:
@@ -41,6 +42,7 @@ def from_pretrained(path: str, **kwargs):
41
  import os
42
  import json
43
  from safetensors.torch import load_file
 
44
  is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
45
 
46
  if is_local:
@@ -48,23 +50,29 @@ def from_pretrained(path: str, **kwargs):
48
  model_file = f"{path}.safetensors"
49
  else:
50
  from huggingface_hub import hf_hub_download
51
- path_parts = path.split('/')
52
- repo_id = f'{path_parts[0]}/{path_parts[1]}'
53
- model_name = '/'.join(path_parts[2:])
 
54
  config_file = hf_hub_download(repo_id, f"{model_name}.json")
55
  model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
56
 
57
- with open(config_file, 'r') as f:
58
  config = json.load(f)
59
- model = __getattr__(config['name'])(**config['args'], **kwargs)
60
  model.load_state_dict(load_file(model_file))
61
 
62
  return model
63
 
64
 
65
  # For Pylance
66
- if __name__ == '__main__':
67
  from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder
68
  from .sparse_structure_flow import SparseStructureFlowModel
69
- from .structured_latent_vae import SLatEncoder, SLatGaussianDecoder, SLatRadianceFieldDecoder, SLatMeshDecoder
 
 
 
 
 
70
  from .structured_latent_flow import SLatFlowModel
 
1
  import importlib
2
 
3
  __attributes = {
4
+ "SparseStructureEncoder": "sparse_structure_vae",
5
+ "SparseStructureDecoder": "sparse_structure_vae",
6
+ "SparseStructureFlowModel": "sparse_structure_flow",
7
+ "SLatEncoder": "structured_latent_vae",
8
+ "SLatGaussianDecoder": "structured_latent_vae",
9
+ "SLatRadianceFieldDecoder": "structured_latent_vae",
10
+ "SLatMeshDecoder": "structured_latent_vae",
11
+ "SLatFlowModel": "structured_latent_flow",
12
  }
13
 
14
  __submodules = []
15
 
16
  __all__ = list(__attributes.keys()) + __submodules
17
 
18
+
19
  def __getattr__(name):
20
  if name not in globals():
21
  if name in __attributes:
 
42
  import os
43
  import json
44
  from safetensors.torch import load_file
45
+
46
  is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
47
 
48
  if is_local:
 
50
  model_file = f"{path}.safetensors"
51
  else:
52
  from huggingface_hub import hf_hub_download
53
+
54
+ path_parts = path.split("/")
55
+ repo_id = f"{path_parts[0]}/{path_parts[1]}"
56
+ model_name = "/".join(path_parts[2:])
57
  config_file = hf_hub_download(repo_id, f"{model_name}.json")
58
  model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
59
 
60
+ with open(config_file, "r") as f:
61
  config = json.load(f)
62
+ model = __getattr__(config["name"])(**config["args"], **kwargs)
63
  model.load_state_dict(load_file(model_file))
64
 
65
  return model
66
 
67
 
68
  # For Pylance
69
+ if __name__ == "__main__":
70
  from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder
71
  from .sparse_structure_flow import SparseStructureFlowModel
72
+ from .structured_latent_vae import (
73
+ SLatEncoder,
74
+ SLatGaussianDecoder,
75
+ SLatRadianceFieldDecoder,
76
+ SLatMeshDecoder,
77
+ )
78
  from .structured_latent_flow import SLatFlowModel
trellis/models/sparse_structure_flow.py CHANGED
@@ -4,7 +4,10 @@ import torch.nn as nn
4
  import torch.nn.functional as F
5
  import numpy as np
6
  from ..modules.utils import convert_module_to_f16, convert_module_to_f32
7
- from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock
 
 
 
8
  from ..modules.spatial import patchify, unpatchify
9
 
10
 
@@ -12,6 +15,7 @@ class TimestepEmbedder(nn.Module):
12
  """
13
  Embeds scalar timesteps into vector representations.
14
  """
 
15
  def __init__(self, hidden_size, frequency_embedding_size=256):
16
  super().__init__()
17
  self.mlp = nn.Sequential(
@@ -38,12 +42,16 @@ class TimestepEmbedder(nn.Module):
38
  # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
39
  half = dim // 2
40
  freqs = torch.exp(
41
- -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
 
 
42
  ).to(device=t.device)
43
  args = t[:, None].float() * freqs[None]
44
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
45
  if dim % 2:
46
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
 
 
47
  return embedding
48
 
49
  def forward(self, t):
@@ -93,34 +101,41 @@ class SparseStructureFlowModel(nn.Module):
93
  self.t_embedder = TimestepEmbedder(model_channels)
94
  if share_mod:
95
  self.adaLN_modulation = nn.Sequential(
96
- nn.SiLU(),
97
- nn.Linear(model_channels, 6 * model_channels, bias=True)
98
  )
99
 
100
  if pe_mode == "ape":
101
  pos_embedder = AbsolutePositionEmbedder(model_channels, 3)
102
- coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij')
 
 
 
 
 
 
103
  coords = torch.stack(coords, dim=-1).reshape(-1, 3)
104
  pos_emb = pos_embedder(coords)
105
  self.register_buffer("pos_emb", pos_emb)
106
 
107
  self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels)
108
-
109
- self.blocks = nn.ModuleList([
110
- ModulatedTransformerCrossBlock(
111
- model_channels,
112
- cond_channels,
113
- num_heads=self.num_heads,
114
- mlp_ratio=self.mlp_ratio,
115
- attn_mode='full',
116
- use_checkpoint=self.use_checkpoint,
117
- use_rope=(pe_mode == "rope"),
118
- share_mod=share_mod,
119
- qk_rms_norm=self.qk_rms_norm,
120
- qk_rms_norm_cross=self.qk_rms_norm_cross,
121
- )
122
- for _ in range(num_blocks)
123
- ])
 
 
124
 
125
  self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3)
126
 
@@ -154,6 +169,7 @@ class SparseStructureFlowModel(nn.Module):
154
  torch.nn.init.xavier_uniform_(module.weight)
155
  if module.bias is not None:
156
  nn.init.constant_(module.bias, 0)
 
157
  self.apply(_basic_init)
158
 
159
  # Initialize timestep embedding MLP:
@@ -173,9 +189,14 @@ class SparseStructureFlowModel(nn.Module):
173
  nn.init.constant_(self.out_layer.weight, 0)
174
  nn.init.constant_(self.out_layer.bias, 0)
175
 
176
- def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
177
- assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
178
- f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
 
 
 
 
 
179
 
180
  h = patchify(x, self.patch_size)
181
  h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous()
@@ -194,7 +215,9 @@ class SparseStructureFlowModel(nn.Module):
194
  h = F.layer_norm(h, h.shape[-1:])
195
  h = self.out_layer(h)
196
 
197
- h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3)
 
 
198
  h = unpatchify(h, self.patch_size).contiguous()
199
 
200
  return h
 
4
  import torch.nn.functional as F
5
  import numpy as np
6
  from ..modules.utils import convert_module_to_f16, convert_module_to_f32
7
+ from ..modules.transformer import (
8
+ AbsolutePositionEmbedder,
9
+ ModulatedTransformerCrossBlock,
10
+ )
11
  from ..modules.spatial import patchify, unpatchify
12
 
13
 
 
15
  """
16
  Embeds scalar timesteps into vector representations.
17
  """
18
+
19
  def __init__(self, hidden_size, frequency_embedding_size=256):
20
  super().__init__()
21
  self.mlp = nn.Sequential(
 
42
  # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
43
  half = dim // 2
44
  freqs = torch.exp(
45
+ -np.log(max_period)
46
+ * torch.arange(start=0, end=half, dtype=torch.float32)
47
+ / half
48
  ).to(device=t.device)
49
  args = t[:, None].float() * freqs[None]
50
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
51
  if dim % 2:
52
+ embedding = torch.cat(
53
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
54
+ )
55
  return embedding
56
 
57
  def forward(self, t):
 
101
  self.t_embedder = TimestepEmbedder(model_channels)
102
  if share_mod:
103
  self.adaLN_modulation = nn.Sequential(
104
+ nn.SiLU(), nn.Linear(model_channels, 6 * model_channels, bias=True)
 
105
  )
106
 
107
  if pe_mode == "ape":
108
  pos_embedder = AbsolutePositionEmbedder(model_channels, 3)
109
+ coords = torch.meshgrid(
110
+ *[
111
+ torch.arange(res, device=self.device)
112
+ for res in [resolution // patch_size] * 3
113
+ ],
114
+ indexing="ij",
115
+ )
116
  coords = torch.stack(coords, dim=-1).reshape(-1, 3)
117
  pos_emb = pos_embedder(coords)
118
  self.register_buffer("pos_emb", pos_emb)
119
 
120
  self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels)
121
+
122
+ self.blocks = nn.ModuleList(
123
+ [
124
+ ModulatedTransformerCrossBlock(
125
+ model_channels,
126
+ cond_channels,
127
+ num_heads=self.num_heads,
128
+ mlp_ratio=self.mlp_ratio,
129
+ attn_mode="full",
130
+ use_checkpoint=self.use_checkpoint,
131
+ use_rope=(pe_mode == "rope"),
132
+ share_mod=share_mod,
133
+ qk_rms_norm=self.qk_rms_norm,
134
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
135
+ )
136
+ for _ in range(num_blocks)
137
+ ]
138
+ )
139
 
140
  self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3)
141
 
 
169
  torch.nn.init.xavier_uniform_(module.weight)
170
  if module.bias is not None:
171
  nn.init.constant_(module.bias, 0)
172
+
173
  self.apply(_basic_init)
174
 
175
  # Initialize timestep embedding MLP:
 
189
  nn.init.constant_(self.out_layer.weight, 0)
190
  nn.init.constant_(self.out_layer.bias, 0)
191
 
192
+ def forward(
193
+ self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor
194
+ ) -> torch.Tensor:
195
+ assert [*x.shape] == [
196
+ x.shape[0],
197
+ self.in_channels,
198
+ *[self.resolution] * 3,
199
+ ], f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
200
 
201
  h = patchify(x, self.patch_size)
202
  h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous()
 
215
  h = F.layer_norm(h, h.shape[-1:])
216
  h = self.out_layer(h)
217
 
218
+ h = h.permute(0, 2, 1).view(
219
+ h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3
220
+ )
221
  h = unpatchify(h, self.patch_size).contiguous()
222
 
223
  return h
trellis/models/sparse_structure_vae.py CHANGED
@@ -33,9 +33,15 @@ class ResBlock3d(nn.Module):
33
  self.norm1 = norm_layer(norm_type, channels)
34
  self.norm2 = norm_layer(norm_type, self.out_channels)
35
  self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
36
- self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1))
37
- self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
38
-
 
 
 
 
 
 
39
  def forward(self, x: torch.Tensor) -> torch.Tensor:
40
  h = self.norm1(x)
41
  h = F.silu(h)
@@ -63,7 +69,9 @@ class DownsampleBlock3d(nn.Module):
63
  if mode == "conv":
64
  self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2)
65
  elif mode == "avgpool":
66
- assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels"
 
 
67
 
68
  def forward(self, x: torch.Tensor) -> torch.Tensor:
69
  if hasattr(self, "conv"):
@@ -86,9 +94,11 @@ class UpsampleBlock3d(nn.Module):
86
  self.out_channels = out_channels
87
 
88
  if mode == "conv":
89
- self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1)
90
  elif mode == "nearest":
91
- assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels"
 
 
92
 
93
  def forward(self, x: torch.Tensor) -> torch.Tensor:
94
  if hasattr(self, "conv"):
@@ -96,12 +106,12 @@ class UpsampleBlock3d(nn.Module):
96
  return pixel_shuffle_3d(x, 2)
97
  else:
98
  return F.interpolate(x, scale_factor=2, mode="nearest")
99
-
100
 
101
  class SparseStructureEncoder(nn.Module):
102
  """
103
  Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3).
104
-
105
  Args:
106
  in_channels (int): Channels of the input.
107
  latent_channels (int): Channels of the latent representation.
@@ -111,6 +121,7 @@ class SparseStructureEncoder(nn.Module):
111
  norm_type (Literal["group", "layer"]): Type of normalization layer.
112
  use_fp16 (bool): Whether to use FP16.
113
  """
 
114
  def __init__(
115
  self,
116
  in_channels: int,
@@ -135,24 +146,21 @@ class SparseStructureEncoder(nn.Module):
135
 
136
  self.blocks = nn.ModuleList([])
137
  for i, ch in enumerate(channels):
138
- self.blocks.extend([
139
- ResBlock3d(ch, ch)
140
- for _ in range(num_res_blocks)
141
- ])
142
  if i < len(channels) - 1:
143
- self.blocks.append(
144
- DownsampleBlock3d(ch, channels[i+1])
145
- )
146
-
147
- self.middle_block = nn.Sequential(*[
148
- ResBlock3d(channels[-1], channels[-1])
149
- for _ in range(num_res_blocks_middle)
150
- ])
151
 
152
  self.out_layer = nn.Sequential(
153
  norm_layer(norm_type, channels[-1]),
154
  nn.SiLU(),
155
- nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1)
156
  )
157
 
158
  if use_fp16:
@@ -183,7 +191,9 @@ class SparseStructureEncoder(nn.Module):
183
  self.blocks.apply(convert_module_to_f32)
184
  self.middle_block.apply(convert_module_to_f32)
185
 
186
- def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor:
 
 
187
  h = self.input_layer(x)
188
  h = h.type(self.dtype)
189
 
@@ -201,16 +211,16 @@ class SparseStructureEncoder(nn.Module):
201
  z = mean + std * torch.randn_like(std)
202
  else:
203
  z = mean
204
-
205
  if return_raw:
206
  return z, mean, logvar
207
  return z
208
-
209
 
210
  class SparseStructureDecoder(nn.Module):
211
  """
212
  Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3).
213
-
214
  Args:
215
  out_channels (int): Channels of the output.
216
  latent_channels (int): Channels of the latent representation.
@@ -219,7 +229,8 @@ class SparseStructureDecoder(nn.Module):
219
  num_res_blocks_middle (int): Number of residual blocks in the middle.
220
  norm_type (Literal["group", "layer"]): Type of normalization layer.
221
  use_fp16 (bool): Whether to use FP16.
222
- """
 
223
  def __init__(
224
  self,
225
  out_channels: int,
@@ -242,26 +253,23 @@ class SparseStructureDecoder(nn.Module):
242
 
243
  self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
244
 
245
- self.middle_block = nn.Sequential(*[
246
- ResBlock3d(channels[0], channels[0])
247
- for _ in range(num_res_blocks_middle)
248
- ])
 
 
249
 
250
  self.blocks = nn.ModuleList([])
251
  for i, ch in enumerate(channels):
252
- self.blocks.extend([
253
- ResBlock3d(ch, ch)
254
- for _ in range(num_res_blocks)
255
- ])
256
  if i < len(channels) - 1:
257
- self.blocks.append(
258
- UpsampleBlock3d(ch, channels[i+1])
259
- )
260
 
261
  self.out_layer = nn.Sequential(
262
  norm_layer(norm_type, channels[-1]),
263
  nn.SiLU(),
264
- nn.Conv3d(channels[-1], out_channels, 3, padding=1)
265
  )
266
 
267
  if use_fp16:
@@ -273,7 +281,7 @@ class SparseStructureDecoder(nn.Module):
273
  Return the device of the model.
274
  """
275
  return next(self.parameters()).device
276
-
277
  def convert_to_fp16(self) -> None:
278
  """
279
  Convert the torso of the model to float16.
@@ -291,12 +299,12 @@ class SparseStructureDecoder(nn.Module):
291
  self.dtype = torch.float32
292
  self.blocks.apply(convert_module_to_f32)
293
  self.middle_block.apply(convert_module_to_f32)
294
-
295
  def forward(self, x: torch.Tensor) -> torch.Tensor:
296
  h = self.input_layer(x)
297
-
298
  h = h.type(self.dtype)
299
-
300
  h = self.middle_block(h)
301
  for block in self.blocks:
302
  h = block(h)
 
33
  self.norm1 = norm_layer(norm_type, channels)
34
  self.norm2 = norm_layer(norm_type, self.out_channels)
35
  self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
36
+ self.conv2 = zero_module(
37
+ nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)
38
+ )
39
+ self.skip_connection = (
40
+ nn.Conv3d(channels, self.out_channels, 1)
41
+ if channels != self.out_channels
42
+ else nn.Identity()
43
+ )
44
+
45
  def forward(self, x: torch.Tensor) -> torch.Tensor:
46
  h = self.norm1(x)
47
  h = F.silu(h)
 
69
  if mode == "conv":
70
  self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2)
71
  elif mode == "avgpool":
72
+ assert (
73
+ in_channels == out_channels
74
+ ), "Pooling mode requires in_channels to be equal to out_channels"
75
 
76
  def forward(self, x: torch.Tensor) -> torch.Tensor:
77
  if hasattr(self, "conv"):
 
94
  self.out_channels = out_channels
95
 
96
  if mode == "conv":
97
+ self.conv = nn.Conv3d(in_channels, out_channels * 8, 3, padding=1)
98
  elif mode == "nearest":
99
+ assert (
100
+ in_channels == out_channels
101
+ ), "Nearest mode requires in_channels to be equal to out_channels"
102
 
103
  def forward(self, x: torch.Tensor) -> torch.Tensor:
104
  if hasattr(self, "conv"):
 
106
  return pixel_shuffle_3d(x, 2)
107
  else:
108
  return F.interpolate(x, scale_factor=2, mode="nearest")
109
+
110
 
111
  class SparseStructureEncoder(nn.Module):
112
  """
113
  Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3).
114
+
115
  Args:
116
  in_channels (int): Channels of the input.
117
  latent_channels (int): Channels of the latent representation.
 
121
  norm_type (Literal["group", "layer"]): Type of normalization layer.
122
  use_fp16 (bool): Whether to use FP16.
123
  """
124
+
125
  def __init__(
126
  self,
127
  in_channels: int,
 
146
 
147
  self.blocks = nn.ModuleList([])
148
  for i, ch in enumerate(channels):
149
+ self.blocks.extend([ResBlock3d(ch, ch) for _ in range(num_res_blocks)])
 
 
 
150
  if i < len(channels) - 1:
151
+ self.blocks.append(DownsampleBlock3d(ch, channels[i + 1]))
152
+
153
+ self.middle_block = nn.Sequential(
154
+ *[
155
+ ResBlock3d(channels[-1], channels[-1])
156
+ for _ in range(num_res_blocks_middle)
157
+ ]
158
+ )
159
 
160
  self.out_layer = nn.Sequential(
161
  norm_layer(norm_type, channels[-1]),
162
  nn.SiLU(),
163
+ nn.Conv3d(channels[-1], latent_channels * 2, 3, padding=1),
164
  )
165
 
166
  if use_fp16:
 
191
  self.blocks.apply(convert_module_to_f32)
192
  self.middle_block.apply(convert_module_to_f32)
193
 
194
+ def forward(
195
+ self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False
196
+ ) -> torch.Tensor:
197
  h = self.input_layer(x)
198
  h = h.type(self.dtype)
199
 
 
211
  z = mean + std * torch.randn_like(std)
212
  else:
213
  z = mean
214
+
215
  if return_raw:
216
  return z, mean, logvar
217
  return z
218
+
219
 
220
  class SparseStructureDecoder(nn.Module):
221
  """
222
  Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3).
223
+
224
  Args:
225
  out_channels (int): Channels of the output.
226
  latent_channels (int): Channels of the latent representation.
 
229
  num_res_blocks_middle (int): Number of residual blocks in the middle.
230
  norm_type (Literal["group", "layer"]): Type of normalization layer.
231
  use_fp16 (bool): Whether to use FP16.
232
+ """
233
+
234
  def __init__(
235
  self,
236
  out_channels: int,
 
253
 
254
  self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
255
 
256
+ self.middle_block = nn.Sequential(
257
+ *[
258
+ ResBlock3d(channels[0], channels[0])
259
+ for _ in range(num_res_blocks_middle)
260
+ ]
261
+ )
262
 
263
  self.blocks = nn.ModuleList([])
264
  for i, ch in enumerate(channels):
265
+ self.blocks.extend([ResBlock3d(ch, ch) for _ in range(num_res_blocks)])
 
 
 
266
  if i < len(channels) - 1:
267
+ self.blocks.append(UpsampleBlock3d(ch, channels[i + 1]))
 
 
268
 
269
  self.out_layer = nn.Sequential(
270
  norm_layer(norm_type, channels[-1]),
271
  nn.SiLU(),
272
+ nn.Conv3d(channels[-1], out_channels, 3, padding=1),
273
  )
274
 
275
  if use_fp16:
 
281
  Return the device of the model.
282
  """
283
  return next(self.parameters()).device
284
+
285
  def convert_to_fp16(self) -> None:
286
  """
287
  Convert the torso of the model to float16.
 
299
  self.dtype = torch.float32
300
  self.blocks.apply(convert_module_to_f32)
301
  self.middle_block.apply(convert_module_to_f32)
302
+
303
  def forward(self, x: torch.Tensor) -> torch.Tensor:
304
  h = self.input_layer(x)
305
+
306
  h = h.type(self.dtype)
307
+
308
  h = self.middle_block(h)
309
  for block in self.blocks:
310
  h = block(h)
trellis/models/structured_latent_flow.py CHANGED
@@ -26,18 +26,26 @@ class SparseResBlock3d(nn.Module):
26
  self.out_channels = out_channels or channels
27
  self.downsample = downsample
28
  self.upsample = upsample
29
-
30
- assert not (downsample and upsample), "Cannot downsample and upsample at the same time"
 
 
31
 
32
  self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
33
  self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
34
  self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
35
- self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
 
 
36
  self.emb_layers = nn.Sequential(
37
  nn.SiLU(),
38
  nn.Linear(emb_channels, 2 * self.out_channels, bias=True),
39
  )
40
- self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
 
 
 
 
41
  self.updown = None
42
  if self.downsample:
43
  self.updown = sp.SparseDownsample(2)
@@ -63,7 +71,7 @@ class SparseResBlock3d(nn.Module):
63
  h = h + self.skip_connection(x)
64
 
65
  return h
66
-
67
 
68
  class SLatFlowModel(nn.Module):
69
  def __init__(
@@ -109,14 +117,17 @@ class SLatFlowModel(nn.Module):
109
  self.qk_rms_norm_cross = qk_rms_norm_cross
110
  self.dtype = torch.float16 if use_fp16 else torch.float32
111
 
112
- assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2"
113
- assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages"
 
 
 
 
114
 
115
  self.t_embedder = TimestepEmbedder(model_channels)
116
  if share_mod:
117
  self.adaLN_modulation = nn.Sequential(
118
- nn.SiLU(),
119
- nn.Linear(model_channels, 6 * model_channels, bias=True)
120
  )
121
 
122
  if pe_mode == "ape":
@@ -124,15 +135,19 @@ class SLatFlowModel(nn.Module):
124
 
125
  self.input_layer = sp.SparseLinear(in_channels, io_block_channels[0])
126
  self.input_blocks = nn.ModuleList([])
127
- for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]):
128
- self.input_blocks.extend([
129
- SparseResBlock3d(
130
- chs,
131
- model_channels,
132
- out_channels=chs,
133
- )
134
- for _ in range(num_io_res_blocks-1)
135
- ])
 
 
 
 
136
  self.input_blocks.append(
137
  SparseResBlock3d(
138
  chs,
@@ -141,25 +156,30 @@ class SLatFlowModel(nn.Module):
141
  downsample=True,
142
  )
143
  )
144
-
145
- self.blocks = nn.ModuleList([
146
- ModulatedSparseTransformerCrossBlock(
147
- model_channels,
148
- cond_channels,
149
- num_heads=self.num_heads,
150
- mlp_ratio=self.mlp_ratio,
151
- attn_mode='full',
152
- use_checkpoint=self.use_checkpoint,
153
- use_rope=(pe_mode == "rope"),
154
- share_mod=self.share_mod,
155
- qk_rms_norm=self.qk_rms_norm,
156
- qk_rms_norm_cross=self.qk_rms_norm_cross,
157
- )
158
- for _ in range(num_blocks)
159
- ])
 
 
160
 
161
  self.out_blocks = nn.ModuleList([])
162
- for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))):
 
 
 
163
  self.out_blocks.append(
164
  SparseResBlock3d(
165
  prev_chs * 2 if self.use_skip_connection else prev_chs,
@@ -168,14 +188,16 @@ class SLatFlowModel(nn.Module):
168
  upsample=True,
169
  )
170
  )
171
- self.out_blocks.extend([
172
- SparseResBlock3d(
173
- chs * 2 if self.use_skip_connection else chs,
174
- model_channels,
175
- out_channels=chs,
176
- )
177
- for _ in range(num_io_res_blocks-1)
178
- ])
 
 
179
  self.out_layer = sp.SparseLinear(io_block_channels[0], out_channels)
180
 
181
  self.initialize_weights()
@@ -212,6 +234,7 @@ class SLatFlowModel(nn.Module):
212
  torch.nn.init.xavier_uniform_(module.weight)
213
  if module.bias is not None:
214
  nn.init.constant_(module.bias, 0)
 
215
  self.apply(_basic_init)
216
 
217
  # Initialize timestep embedding MLP:
@@ -231,7 +254,9 @@ class SLatFlowModel(nn.Module):
231
  nn.init.constant_(self.out_layer.weight, 0)
232
  nn.init.constant_(self.out_layer.bias, 0)
233
 
234
- def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor) -> sp.SparseTensor:
 
 
235
  h = self.input_layer(x).type(self.dtype)
236
  t_emb = self.t_embedder(t)
237
  if self.share_mod:
@@ -244,7 +269,7 @@ class SLatFlowModel(nn.Module):
244
  for block in self.input_blocks:
245
  h = block(h, t_emb)
246
  skips.append(h.feats)
247
-
248
  if self.pe_mode == "ape":
249
  h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype)
250
  for block in self.blocks:
 
26
  self.out_channels = out_channels or channels
27
  self.downsample = downsample
28
  self.upsample = upsample
29
+
30
+ assert not (
31
+ downsample and upsample
32
+ ), "Cannot downsample and upsample at the same time"
33
 
34
  self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
35
  self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
36
  self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
37
+ self.conv2 = zero_module(
38
+ sp.SparseConv3d(self.out_channels, self.out_channels, 3)
39
+ )
40
  self.emb_layers = nn.Sequential(
41
  nn.SiLU(),
42
  nn.Linear(emb_channels, 2 * self.out_channels, bias=True),
43
  )
44
+ self.skip_connection = (
45
+ sp.SparseLinear(channels, self.out_channels)
46
+ if channels != self.out_channels
47
+ else nn.Identity()
48
+ )
49
  self.updown = None
50
  if self.downsample:
51
  self.updown = sp.SparseDownsample(2)
 
71
  h = h + self.skip_connection(x)
72
 
73
  return h
74
+
75
 
76
  class SLatFlowModel(nn.Module):
77
  def __init__(
 
117
  self.qk_rms_norm_cross = qk_rms_norm_cross
118
  self.dtype = torch.float16 if use_fp16 else torch.float32
119
 
120
+ assert int(np.log2(patch_size)) == np.log2(
121
+ patch_size
122
+ ), "Patch size must be a power of 2"
123
+ assert np.log2(patch_size) == len(
124
+ io_block_channels
125
+ ), "Number of IO ResBlocks must match the number of stages"
126
 
127
  self.t_embedder = TimestepEmbedder(model_channels)
128
  if share_mod:
129
  self.adaLN_modulation = nn.Sequential(
130
+ nn.SiLU(), nn.Linear(model_channels, 6 * model_channels, bias=True)
 
131
  )
132
 
133
  if pe_mode == "ape":
 
135
 
136
  self.input_layer = sp.SparseLinear(in_channels, io_block_channels[0])
137
  self.input_blocks = nn.ModuleList([])
138
+ for chs, next_chs in zip(
139
+ io_block_channels, io_block_channels[1:] + [model_channels]
140
+ ):
141
+ self.input_blocks.extend(
142
+ [
143
+ SparseResBlock3d(
144
+ chs,
145
+ model_channels,
146
+ out_channels=chs,
147
+ )
148
+ for _ in range(num_io_res_blocks - 1)
149
+ ]
150
+ )
151
  self.input_blocks.append(
152
  SparseResBlock3d(
153
  chs,
 
156
  downsample=True,
157
  )
158
  )
159
+
160
+ self.blocks = nn.ModuleList(
161
+ [
162
+ ModulatedSparseTransformerCrossBlock(
163
+ model_channels,
164
+ cond_channels,
165
+ num_heads=self.num_heads,
166
+ mlp_ratio=self.mlp_ratio,
167
+ attn_mode="full",
168
+ use_checkpoint=self.use_checkpoint,
169
+ use_rope=(pe_mode == "rope"),
170
+ share_mod=self.share_mod,
171
+ qk_rms_norm=self.qk_rms_norm,
172
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
173
+ )
174
+ for _ in range(num_blocks)
175
+ ]
176
+ )
177
 
178
  self.out_blocks = nn.ModuleList([])
179
+ for chs, prev_chs in zip(
180
+ reversed(io_block_channels),
181
+ [model_channels] + list(reversed(io_block_channels[1:])),
182
+ ):
183
  self.out_blocks.append(
184
  SparseResBlock3d(
185
  prev_chs * 2 if self.use_skip_connection else prev_chs,
 
188
  upsample=True,
189
  )
190
  )
191
+ self.out_blocks.extend(
192
+ [
193
+ SparseResBlock3d(
194
+ chs * 2 if self.use_skip_connection else chs,
195
+ model_channels,
196
+ out_channels=chs,
197
+ )
198
+ for _ in range(num_io_res_blocks - 1)
199
+ ]
200
+ )
201
  self.out_layer = sp.SparseLinear(io_block_channels[0], out_channels)
202
 
203
  self.initialize_weights()
 
234
  torch.nn.init.xavier_uniform_(module.weight)
235
  if module.bias is not None:
236
  nn.init.constant_(module.bias, 0)
237
+
238
  self.apply(_basic_init)
239
 
240
  # Initialize timestep embedding MLP:
 
254
  nn.init.constant_(self.out_layer.weight, 0)
255
  nn.init.constant_(self.out_layer.bias, 0)
256
 
257
+ def forward(
258
+ self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor
259
+ ) -> sp.SparseTensor:
260
  h = self.input_layer(x).type(self.dtype)
261
  t_emb = self.t_embedder(t)
262
  if self.share_mod:
 
269
  for block in self.input_blocks:
270
  h = block(h, t_emb)
271
  skips.append(h.feats)
272
+
273
  if self.pe_mode == "ape":
274
  h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype)
275
  for block in self.blocks:
trellis/models/structured_latent_vae/base.py CHANGED
@@ -13,15 +13,23 @@ def block_attn_config(self):
13
  """
14
  for i in range(self.num_blocks):
15
  if self.attn_mode == "shift_window":
16
- yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
 
 
17
  elif self.attn_mode == "shift_sequence":
18
- yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
 
 
 
 
19
  elif self.attn_mode == "shift_order":
20
  yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
21
  elif self.attn_mode == "full":
22
  yield "full", None, None, None, None
23
  elif self.attn_mode == "swin":
24
- yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
 
 
25
 
26
 
27
  class SparseTransformerBase(nn.Module):
@@ -29,6 +37,7 @@ class SparseTransformerBase(nn.Module):
29
  Sparse Transformer without output layers.
30
  Serve as the base class for encoder and decoder.
31
  """
 
32
  def __init__(
33
  self,
34
  in_channels: int,
@@ -37,7 +46,9 @@ class SparseTransformerBase(nn.Module):
37
  num_heads: Optional[int] = None,
38
  num_head_channels: Optional[int] = 64,
39
  mlp_ratio: float = 4.0,
40
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
 
 
41
  window_size: Optional[int] = None,
42
  pe_mode: Literal["ape", "rope"] = "ape",
43
  use_fp16: bool = False,
@@ -62,22 +73,26 @@ class SparseTransformerBase(nn.Module):
62
  self.pos_embedder = AbsolutePositionEmbedder(model_channels)
63
 
64
  self.input_layer = sp.SparseLinear(in_channels, model_channels)
65
- self.blocks = nn.ModuleList([
66
- SparseTransformerBlock(
67
- model_channels,
68
- num_heads=self.num_heads,
69
- mlp_ratio=self.mlp_ratio,
70
- attn_mode=attn_mode,
71
- window_size=window_size,
72
- shift_sequence=shift_sequence,
73
- shift_window=shift_window,
74
- serialize_mode=serialize_mode,
75
- use_checkpoint=self.use_checkpoint,
76
- use_rope=(pe_mode == "rope"),
77
- qk_rms_norm=self.qk_rms_norm,
78
- )
79
- for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
80
- ])
 
 
 
 
81
 
82
  @property
83
  def device(self) -> torch.device:
@@ -105,6 +120,7 @@ class SparseTransformerBase(nn.Module):
105
  torch.nn.init.xavier_uniform_(module.weight)
106
  if module.bias is not None:
107
  nn.init.constant_(module.bias, 0)
 
108
  self.apply(_basic_init)
109
 
110
  def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
 
13
  """
14
  for i in range(self.num_blocks):
15
  if self.attn_mode == "shift_window":
16
+ yield "serialized", self.window_size, 0, (
17
+ 16 * (i % 2),
18
+ ) * 3, sp.SerializeMode.Z_ORDER
19
  elif self.attn_mode == "shift_sequence":
20
+ yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (
21
+ 0,
22
+ 0,
23
+ 0,
24
+ ), sp.SerializeMode.Z_ORDER
25
  elif self.attn_mode == "shift_order":
26
  yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
27
  elif self.attn_mode == "full":
28
  yield "full", None, None, None, None
29
  elif self.attn_mode == "swin":
30
+ yield "windowed", self.window_size, None, self.window_size // 2 * (
31
+ i % 2
32
+ ), None
33
 
34
 
35
  class SparseTransformerBase(nn.Module):
 
37
  Sparse Transformer without output layers.
38
  Serve as the base class for encoder and decoder.
39
  """
40
+
41
  def __init__(
42
  self,
43
  in_channels: int,
 
46
  num_heads: Optional[int] = None,
47
  num_head_channels: Optional[int] = 64,
48
  mlp_ratio: float = 4.0,
49
+ attn_mode: Literal[
50
+ "full", "shift_window", "shift_sequence", "shift_order", "swin"
51
+ ] = "full",
52
  window_size: Optional[int] = None,
53
  pe_mode: Literal["ape", "rope"] = "ape",
54
  use_fp16: bool = False,
 
73
  self.pos_embedder = AbsolutePositionEmbedder(model_channels)
74
 
75
  self.input_layer = sp.SparseLinear(in_channels, model_channels)
76
+ self.blocks = nn.ModuleList(
77
+ [
78
+ SparseTransformerBlock(
79
+ model_channels,
80
+ num_heads=self.num_heads,
81
+ mlp_ratio=self.mlp_ratio,
82
+ attn_mode=attn_mode,
83
+ window_size=window_size,
84
+ shift_sequence=shift_sequence,
85
+ shift_window=shift_window,
86
+ serialize_mode=serialize_mode,
87
+ use_checkpoint=self.use_checkpoint,
88
+ use_rope=(pe_mode == "rope"),
89
+ qk_rms_norm=self.qk_rms_norm,
90
+ )
91
+ for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(
92
+ self
93
+ )
94
+ ]
95
+ )
96
 
97
  @property
98
  def device(self) -> torch.device:
 
120
  torch.nn.init.xavier_uniform_(module.weight)
121
  if module.bias is not None:
122
  nn.init.constant_(module.bias, 0)
123
+
124
  self.apply(_basic_init)
125
 
126
  def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
trellis/models/structured_latent_vae/decoder_gs.py CHANGED
@@ -18,7 +18,9 @@ class SLatGaussianDecoder(SparseTransformerBase):
18
  num_heads: Optional[int] = None,
19
  num_head_channels: Optional[int] = 64,
20
  mlp_ratio: float = 4,
21
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
 
 
22
  window_size: int = 8,
23
  pe_mode: Literal["ape", "rope"] = "ape",
24
  use_fp16: bool = False,
@@ -57,26 +59,44 @@ class SLatGaussianDecoder(SparseTransformerBase):
57
  nn.init.constant_(self.out_layer.bias, 0)
58
 
59
  def _build_perturbation(self) -> None:
60
- perturbation = [hammersley_sequence(3, i, self.rep_config['num_gaussians']) for i in range(self.rep_config['num_gaussians'])]
 
 
 
61
  perturbation = torch.tensor(perturbation).float() * 2 - 1
62
- perturbation = perturbation / self.rep_config['voxel_size']
63
  perturbation = torch.atanh(perturbation).to(self.device)
64
- self.register_buffer('offset_perturbation', perturbation)
65
 
66
  def _calc_layout(self) -> None:
67
  self.layout = {
68
- '_xyz' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
69
- '_features_dc' : {'shape': (self.rep_config['num_gaussians'], 1, 3), 'size': self.rep_config['num_gaussians'] * 3},
70
- '_scaling' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
71
- '_rotation' : {'shape': (self.rep_config['num_gaussians'], 4), 'size': self.rep_config['num_gaussians'] * 4},
72
- '_opacity' : {'shape': (self.rep_config['num_gaussians'], 1), 'size': self.rep_config['num_gaussians']},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  }
74
  start = 0
75
  for k, v in self.layout.items():
76
- v['range'] = (start, start + v['size'])
77
- start += v['size']
78
  self.out_channels = start
79
-
80
  def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]:
81
  """
82
  Convert a batch of network outputs to 3D representations.
@@ -92,24 +112,35 @@ class SLatGaussianDecoder(SparseTransformerBase):
92
  representation = Gaussian(
93
  sh_degree=0,
94
  aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
95
- mininum_kernel_size = self.rep_config['3d_filter_kernel_size'],
96
- scaling_bias = self.rep_config['scaling_bias'],
97
- opacity_bias = self.rep_config['opacity_bias'],
98
- scaling_activation = self.rep_config['scaling_activation']
99
  )
100
  xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
101
  for k, v in self.layout.items():
102
- if k == '_xyz':
103
- offset = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape'])
104
- offset = offset * self.rep_config['lr'][k]
105
- if self.rep_config['perturb_offset']:
 
 
106
  offset = offset + self.offset_perturbation
107
- offset = torch.tanh(offset) / self.resolution * 0.5 * self.rep_config['voxel_size']
 
 
 
 
 
108
  _xyz = xyz.unsqueeze(1) + offset
109
  setattr(representation, k, _xyz.flatten(0, 1))
110
  else:
111
- feats = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1)
112
- feats = feats * self.rep_config['lr'][k]
 
 
 
 
113
  setattr(representation, k, feats)
114
  ret.append(representation)
115
  return ret
 
18
  num_heads: Optional[int] = None,
19
  num_head_channels: Optional[int] = 64,
20
  mlp_ratio: float = 4,
21
+ attn_mode: Literal[
22
+ "full", "shift_window", "shift_sequence", "shift_order", "swin"
23
+ ] = "swin",
24
  window_size: int = 8,
25
  pe_mode: Literal["ape", "rope"] = "ape",
26
  use_fp16: bool = False,
 
59
  nn.init.constant_(self.out_layer.bias, 0)
60
 
61
  def _build_perturbation(self) -> None:
62
+ perturbation = [
63
+ hammersley_sequence(3, i, self.rep_config["num_gaussians"])
64
+ for i in range(self.rep_config["num_gaussians"])
65
+ ]
66
  perturbation = torch.tensor(perturbation).float() * 2 - 1
67
+ perturbation = perturbation / self.rep_config["voxel_size"]
68
  perturbation = torch.atanh(perturbation).to(self.device)
69
+ self.register_buffer("offset_perturbation", perturbation)
70
 
71
  def _calc_layout(self) -> None:
72
  self.layout = {
73
+ "_xyz": {
74
+ "shape": (self.rep_config["num_gaussians"], 3),
75
+ "size": self.rep_config["num_gaussians"] * 3,
76
+ },
77
+ "_features_dc": {
78
+ "shape": (self.rep_config["num_gaussians"], 1, 3),
79
+ "size": self.rep_config["num_gaussians"] * 3,
80
+ },
81
+ "_scaling": {
82
+ "shape": (self.rep_config["num_gaussians"], 3),
83
+ "size": self.rep_config["num_gaussians"] * 3,
84
+ },
85
+ "_rotation": {
86
+ "shape": (self.rep_config["num_gaussians"], 4),
87
+ "size": self.rep_config["num_gaussians"] * 4,
88
+ },
89
+ "_opacity": {
90
+ "shape": (self.rep_config["num_gaussians"], 1),
91
+ "size": self.rep_config["num_gaussians"],
92
+ },
93
  }
94
  start = 0
95
  for k, v in self.layout.items():
96
+ v["range"] = (start, start + v["size"])
97
+ start += v["size"]
98
  self.out_channels = start
99
+
100
  def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]:
101
  """
102
  Convert a batch of network outputs to 3D representations.
 
112
  representation = Gaussian(
113
  sh_degree=0,
114
  aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
115
+ mininum_kernel_size=self.rep_config["3d_filter_kernel_size"],
116
+ scaling_bias=self.rep_config["scaling_bias"],
117
+ opacity_bias=self.rep_config["opacity_bias"],
118
+ scaling_activation=self.rep_config["scaling_activation"],
119
  )
120
  xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
121
  for k, v in self.layout.items():
122
+ if k == "_xyz":
123
+ offset = x.feats[x.layout[i]][
124
+ :, v["range"][0] : v["range"][1]
125
+ ].reshape(-1, *v["shape"])
126
+ offset = offset * self.rep_config["lr"][k]
127
+ if self.rep_config["perturb_offset"]:
128
  offset = offset + self.offset_perturbation
129
+ offset = (
130
+ torch.tanh(offset)
131
+ / self.resolution
132
+ * 0.5
133
+ * self.rep_config["voxel_size"]
134
+ )
135
  _xyz = xyz.unsqueeze(1) + offset
136
  setattr(representation, k, _xyz.flatten(0, 1))
137
  else:
138
+ feats = (
139
+ x.feats[x.layout[i]][:, v["range"][0] : v["range"][1]]
140
+ .reshape(-1, *v["shape"])
141
+ .flatten(0, 1)
142
+ )
143
+ feats = feats * self.rep_config["lr"][k]
144
  setattr(representation, k, feats)
145
  ret.append(representation)
146
  return ret
trellis/models/structured_latent_vae/decoder_mesh.py CHANGED
@@ -19,12 +19,13 @@ class SparseSubdivideBlock3d(nn.Module):
19
  out_channels: if specified, the number of output channels.
20
  num_groups: the number of groups for the group norm.
21
  """
 
22
  def __init__(
23
  self,
24
  channels: int,
25
  resolution: int,
26
  out_channels: Optional[int] = None,
27
- num_groups: int = 32
28
  ):
29
  super().__init__()
30
  self.channels = channels
@@ -33,24 +34,34 @@ class SparseSubdivideBlock3d(nn.Module):
33
  self.out_channels = out_channels or channels
34
 
35
  self.act_layers = nn.Sequential(
36
- sp.SparseGroupNorm32(num_groups, channels),
37
- sp.SparseSiLU()
38
  )
39
-
40
  self.sub = sp.SparseSubdivide()
41
-
42
  self.out_layers = nn.Sequential(
43
- sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"),
 
 
44
  sp.SparseGroupNorm32(num_groups, self.out_channels),
45
  sp.SparseSiLU(),
46
- zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")),
 
 
 
 
 
 
 
47
  )
48
-
49
  if self.out_channels == channels:
50
  self.skip_connection = nn.Identity()
51
  else:
52
- self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}")
53
-
 
 
54
  def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
55
  """
56
  Apply the block to a Tensor, conditioned on a timestep embedding.
@@ -78,7 +89,9 @@ class SLatMeshDecoder(SparseTransformerBase):
78
  num_heads: Optional[int] = None,
79
  num_head_channels: Optional[int] = 64,
80
  mlp_ratio: float = 4,
81
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
 
 
82
  window_size: int = 8,
83
  pe_mode: Literal["ape", "rope"] = "ape",
84
  use_fp16: bool = False,
@@ -102,20 +115,24 @@ class SLatMeshDecoder(SparseTransformerBase):
102
  )
103
  self.resolution = resolution
104
  self.rep_config = representation_config
105
- self.mesh_extractor = SparseFeatures2Mesh(res=self.resolution*4, use_color=self.rep_config.get('use_color', False))
 
 
106
  self.out_channels = self.mesh_extractor.feats_channels
107
- self.upsample = nn.ModuleList([
108
- SparseSubdivideBlock3d(
109
- channels=model_channels,
110
- resolution=resolution,
111
- out_channels=model_channels // 4
112
- ),
113
- SparseSubdivideBlock3d(
114
- channels=model_channels // 4,
115
- resolution=resolution * 2,
116
- out_channels=model_channels // 8
117
- )
118
- ])
 
 
119
  self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels)
120
 
121
  self.initialize_weights()
@@ -140,8 +157,8 @@ class SLatMeshDecoder(SparseTransformerBase):
140
  Convert the torso of the model to float32.
141
  """
142
  super().convert_to_fp32()
143
- self.upsample.apply(convert_module_to_f32)
144
-
145
  def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
146
  """
147
  Convert a batch of network outputs to 3D representations.
 
19
  out_channels: if specified, the number of output channels.
20
  num_groups: the number of groups for the group norm.
21
  """
22
+
23
  def __init__(
24
  self,
25
  channels: int,
26
  resolution: int,
27
  out_channels: Optional[int] = None,
28
+ num_groups: int = 32,
29
  ):
30
  super().__init__()
31
  self.channels = channels
 
34
  self.out_channels = out_channels or channels
35
 
36
  self.act_layers = nn.Sequential(
37
+ sp.SparseGroupNorm32(num_groups, channels), sp.SparseSiLU()
 
38
  )
39
+
40
  self.sub = sp.SparseSubdivide()
41
+
42
  self.out_layers = nn.Sequential(
43
+ sp.SparseConv3d(
44
+ channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"
45
+ ),
46
  sp.SparseGroupNorm32(num_groups, self.out_channels),
47
  sp.SparseSiLU(),
48
+ zero_module(
49
+ sp.SparseConv3d(
50
+ self.out_channels,
51
+ self.out_channels,
52
+ 3,
53
+ indice_key=f"res_{self.out_resolution}",
54
+ )
55
+ ),
56
  )
57
+
58
  if self.out_channels == channels:
59
  self.skip_connection = nn.Identity()
60
  else:
61
+ self.skip_connection = sp.SparseConv3d(
62
+ channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}"
63
+ )
64
+
65
  def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
66
  """
67
  Apply the block to a Tensor, conditioned on a timestep embedding.
 
89
  num_heads: Optional[int] = None,
90
  num_head_channels: Optional[int] = 64,
91
  mlp_ratio: float = 4,
92
+ attn_mode: Literal[
93
+ "full", "shift_window", "shift_sequence", "shift_order", "swin"
94
+ ] = "swin",
95
  window_size: int = 8,
96
  pe_mode: Literal["ape", "rope"] = "ape",
97
  use_fp16: bool = False,
 
115
  )
116
  self.resolution = resolution
117
  self.rep_config = representation_config
118
+ self.mesh_extractor = SparseFeatures2Mesh(
119
+ res=self.resolution * 4, use_color=self.rep_config.get("use_color", False)
120
+ )
121
  self.out_channels = self.mesh_extractor.feats_channels
122
+ self.upsample = nn.ModuleList(
123
+ [
124
+ SparseSubdivideBlock3d(
125
+ channels=model_channels,
126
+ resolution=resolution,
127
+ out_channels=model_channels // 4,
128
+ ),
129
+ SparseSubdivideBlock3d(
130
+ channels=model_channels // 4,
131
+ resolution=resolution * 2,
132
+ out_channels=model_channels // 8,
133
+ ),
134
+ ]
135
+ )
136
  self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels)
137
 
138
  self.initialize_weights()
 
157
  Convert the torso of the model to float32.
158
  """
159
  super().convert_to_fp32()
160
+ self.upsample.apply(convert_module_to_f32)
161
+
162
  def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
163
  """
164
  Convert a batch of network outputs to 3D representations.
trellis/models/structured_latent_vae/decoder_rf.py CHANGED
@@ -18,7 +18,9 @@ class SLatRadianceFieldDecoder(SparseTransformerBase):
18
  num_heads: Optional[int] = None,
19
  num_head_channels: Optional[int] = 64,
20
  mlp_ratio: float = 4,
21
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
 
 
22
  window_size: int = 8,
23
  pe_mode: Literal["ape", "rope"] = "ape",
24
  use_fp16: bool = False,
@@ -57,16 +59,25 @@ class SLatRadianceFieldDecoder(SparseTransformerBase):
57
 
58
  def _calc_layout(self) -> None:
59
  self.layout = {
60
- 'trivec': {'shape': (self.rep_config['rank'], 3, self.rep_config['dim']), 'size': self.rep_config['rank'] * 3 * self.rep_config['dim']},
61
- 'density': {'shape': (self.rep_config['rank'],), 'size': self.rep_config['rank']},
62
- 'features_dc': {'shape': (self.rep_config['rank'], 1, 3), 'size': self.rep_config['rank'] * 3},
 
 
 
 
 
 
 
 
 
63
  }
64
  start = 0
65
  for k, v in self.layout.items():
66
- v['range'] = (start, start + v['size'])
67
- start += v['size']
68
- self.out_channels = start
69
-
70
  def to_representation(self, x: sp.SparseTensor) -> List[Strivec]:
71
  """
72
  Convert a batch of network outputs to 3D representations.
@@ -83,15 +94,28 @@ class SLatRadianceFieldDecoder(SparseTransformerBase):
83
  sh_degree=0,
84
  resolution=self.resolution,
85
  aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
86
- rank=self.rep_config['rank'],
87
- dim=self.rep_config['dim'],
88
- device='cuda',
89
  )
90
  representation.density_shift = 0.0
91
- representation.position = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
92
- representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda')
 
 
 
 
 
 
 
93
  for k, v in self.layout.items():
94
- setattr(representation, k, x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']))
 
 
 
 
 
 
95
  representation.trivec = representation.trivec + 1
96
  ret.append(representation)
97
  return ret
 
18
  num_heads: Optional[int] = None,
19
  num_head_channels: Optional[int] = 64,
20
  mlp_ratio: float = 4,
21
+ attn_mode: Literal[
22
+ "full", "shift_window", "shift_sequence", "shift_order", "swin"
23
+ ] = "swin",
24
  window_size: int = 8,
25
  pe_mode: Literal["ape", "rope"] = "ape",
26
  use_fp16: bool = False,
 
59
 
60
  def _calc_layout(self) -> None:
61
  self.layout = {
62
+ "trivec": {
63
+ "shape": (self.rep_config["rank"], 3, self.rep_config["dim"]),
64
+ "size": self.rep_config["rank"] * 3 * self.rep_config["dim"],
65
+ },
66
+ "density": {
67
+ "shape": (self.rep_config["rank"],),
68
+ "size": self.rep_config["rank"],
69
+ },
70
+ "features_dc": {
71
+ "shape": (self.rep_config["rank"], 1, 3),
72
+ "size": self.rep_config["rank"] * 3,
73
+ },
74
  }
75
  start = 0
76
  for k, v in self.layout.items():
77
+ v["range"] = (start, start + v["size"])
78
+ start += v["size"]
79
+ self.out_channels = start
80
+
81
  def to_representation(self, x: sp.SparseTensor) -> List[Strivec]:
82
  """
83
  Convert a batch of network outputs to 3D representations.
 
94
  sh_degree=0,
95
  resolution=self.resolution,
96
  aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
97
+ rank=self.rep_config["rank"],
98
+ dim=self.rep_config["dim"],
99
+ device="cuda",
100
  )
101
  representation.density_shift = 0.0
102
+ representation.position = (
103
+ x.coords[x.layout[i]][:, 1:].float() + 0.5
104
+ ) / self.resolution
105
+ representation.depth = torch.full(
106
+ (representation.position.shape[0], 1),
107
+ int(np.log2(self.resolution)),
108
+ dtype=torch.uint8,
109
+ device="cuda",
110
+ )
111
  for k, v in self.layout.items():
112
+ setattr(
113
+ representation,
114
+ k,
115
+ x.feats[x.layout[i]][:, v["range"][0] : v["range"][1]].reshape(
116
+ -1, *v["shape"]
117
+ ),
118
+ )
119
  representation.trivec = representation.trivec + 1
120
  ret.append(representation)
121
  return ret
trellis/models/structured_latent_vae/encoder.py CHANGED
@@ -17,7 +17,9 @@ class SLatEncoder(SparseTransformerBase):
17
  num_heads: Optional[int] = None,
18
  num_head_channels: Optional[int] = 64,
19
  mlp_ratio: float = 4,
20
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
 
 
21
  window_size: int = 8,
22
  pe_mode: Literal["ape", "rope"] = "ape",
23
  use_fp16: bool = False,
@@ -56,7 +58,7 @@ class SLatEncoder(SparseTransformerBase):
56
  h = h.type(x.dtype)
57
  h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
58
  h = self.out_layer(h)
59
-
60
  # Sample from the posterior distribution
61
  mean, logvar = h.feats.chunk(2, dim=-1)
62
  if sample_posterior:
@@ -65,7 +67,7 @@ class SLatEncoder(SparseTransformerBase):
65
  else:
66
  z = mean
67
  z = h.replace(z)
68
-
69
  if return_raw:
70
  return z, mean, logvar
71
  else:
 
17
  num_heads: Optional[int] = None,
18
  num_head_channels: Optional[int] = 64,
19
  mlp_ratio: float = 4,
20
+ attn_mode: Literal[
21
+ "full", "shift_window", "shift_sequence", "shift_order", "swin"
22
+ ] = "swin",
23
  window_size: int = 8,
24
  pe_mode: Literal["ape", "rope"] = "ape",
25
  use_fp16: bool = False,
 
58
  h = h.type(x.dtype)
59
  h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
60
  h = self.out_layer(h)
61
+
62
  # Sample from the posterior distribution
63
  mean, logvar = h.feats.chunk(2, dim=-1)
64
  if sample_posterior:
 
67
  else:
68
  z = mean
69
  z = h.replace(z)
70
+
71
  if return_raw:
72
  return z, mean, logvar
73
  else:
trellis/modules/attention/__init__.py CHANGED
@@ -1,32 +1,39 @@
1
  from typing import *
2
 
3
- BACKEND = 'flash_attn'
4
  DEBUG = False
5
 
 
6
  def __from_env():
7
  import os
8
-
9
  global BACKEND
10
  global DEBUG
11
-
12
- env_attn_backend = os.environ.get('ATTN_BACKEND')
13
- env_sttn_debug = os.environ.get('ATTN_DEBUG')
14
-
15
- if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']:
 
 
 
 
 
16
  BACKEND = env_attn_backend
17
  if env_sttn_debug is not None:
18
- DEBUG = env_sttn_debug == '1'
19
 
20
  print(f"[ATTENTION] Using backend: {BACKEND}")
21
-
22
 
23
  __from_env()
24
-
25
 
26
- def set_backend(backend: Literal['xformers', 'flash_attn']):
 
27
  global BACKEND
28
  BACKEND = backend
29
 
 
30
  def set_debug(debug: bool):
31
  global DEBUG
32
  DEBUG = debug
 
1
  from typing import *
2
 
3
+ BACKEND = "flash_attn"
4
  DEBUG = False
5
 
6
+
7
  def __from_env():
8
  import os
9
+
10
  global BACKEND
11
  global DEBUG
12
+
13
+ env_attn_backend = os.environ.get("ATTN_BACKEND")
14
+ env_sttn_debug = os.environ.get("ATTN_DEBUG")
15
+
16
+ if env_attn_backend is not None and env_attn_backend in [
17
+ "xformers",
18
+ "flash_attn",
19
+ "sdpa",
20
+ "naive",
21
+ ]:
22
  BACKEND = env_attn_backend
23
  if env_sttn_debug is not None:
24
+ DEBUG = env_sttn_debug == "1"
25
 
26
  print(f"[ATTENTION] Using backend: {BACKEND}")
27
+
28
 
29
  __from_env()
 
30
 
31
+
32
+ def set_backend(backend: Literal["xformers", "flash_attn"]):
33
  global BACKEND
34
  BACKEND = backend
35
 
36
+
37
  def set_debug(debug: bool):
38
  global DEBUG
39
  DEBUG = debug
trellis/modules/attention/full_attn.py CHANGED
@@ -3,20 +3,20 @@ import torch
3
  import math
4
  from . import DEBUG, BACKEND
5
 
6
- if BACKEND == 'xformers':
7
  import xformers.ops as xops
8
- elif BACKEND == 'flash_attn':
9
  import flash_attn
10
- elif BACKEND == 'sdpa':
11
  from torch.nn.functional import scaled_dot_product_attention as sdpa
12
- elif BACKEND == 'naive':
13
  pass
14
  else:
15
  raise ValueError(f"Unknown attention backend: {BACKEND}")
16
 
17
 
18
  __all__ = [
19
- 'scaled_dot_product_attention',
20
  ]
21
 
22
 
@@ -24,14 +24,14 @@ def _naive_sdpa(q, k, v):
24
  """
25
  Naive implementation of scaled dot product attention.
26
  """
27
- q = q.permute(0, 2, 1, 3) # [N, H, L, C]
28
- k = k.permute(0, 2, 1, 3) # [N, H, L, C]
29
- v = v.permute(0, 2, 1, 3) # [N, H, L, C]
30
  scale_factor = 1 / math.sqrt(q.size(-1))
31
  attn_weight = q @ k.transpose(-2, -1) * scale_factor
32
  attn_weight = torch.softmax(attn_weight, dim=-1)
33
  out = attn_weight @ v
34
- out = out.permute(0, 2, 1, 3) # [N, L, H, C]
35
  return out
36
 
37
 
@@ -45,6 +45,7 @@ def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor:
45
  """
46
  ...
47
 
 
48
  @overload
49
  def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
50
  """
@@ -56,8 +57,11 @@ def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Ten
56
  """
57
  ...
58
 
 
59
  @overload
60
- def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
 
 
61
  """
62
  Apply scaled dot product attention.
63
 
@@ -71,64 +75,79 @@ def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tens
71
  """
72
  ...
73
 
 
74
  def scaled_dot_product_attention(*args, **kwargs):
75
- arg_names_dict = {
76
- 1: ['qkv'],
77
- 2: ['q', 'kv'],
78
- 3: ['q', 'k', 'v']
79
- }
80
  num_all_args = len(args) + len(kwargs)
81
- assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
82
- for key in arg_names_dict[num_all_args][len(args):]:
 
 
83
  assert key in kwargs, f"Missing argument {key}"
84
 
85
  if num_all_args == 1:
86
- qkv = args[0] if len(args) > 0 else kwargs['qkv']
87
- assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]"
 
 
88
  device = qkv.device
89
 
90
  elif num_all_args == 2:
91
- q = args[0] if len(args) > 0 else kwargs['q']
92
- kv = args[1] if len(args) > 1 else kwargs['kv']
93
- assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
94
- assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
95
- assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
 
 
 
 
 
 
96
  device = q.device
97
 
98
  elif num_all_args == 3:
99
- q = args[0] if len(args) > 0 else kwargs['q']
100
- k = args[1] if len(args) > 1 else kwargs['k']
101
- v = args[2] if len(args) > 2 else kwargs['v']
102
- assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
103
- assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
104
- assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
105
- assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
106
- device = q.device
107
-
108
- if BACKEND == 'xformers':
 
 
 
 
 
 
 
 
109
  if num_all_args == 1:
110
  q, k, v = qkv.unbind(dim=2)
111
  elif num_all_args == 2:
112
  k, v = kv.unbind(dim=2)
113
  out = xops.memory_efficient_attention(q, k, v)
114
- elif BACKEND == 'flash_attn':
115
  if num_all_args == 1:
116
  out = flash_attn.flash_attn_qkvpacked_func(qkv)
117
  elif num_all_args == 2:
118
  out = flash_attn.flash_attn_kvpacked_func(q, kv)
119
  elif num_all_args == 3:
120
  out = flash_attn.flash_attn_func(q, k, v)
121
- elif BACKEND == 'sdpa':
122
  if num_all_args == 1:
123
  q, k, v = qkv.unbind(dim=2)
124
  elif num_all_args == 2:
125
  k, v = kv.unbind(dim=2)
126
- q = q.permute(0, 2, 1, 3) # [N, H, L, C]
127
- k = k.permute(0, 2, 1, 3) # [N, H, L, C]
128
- v = v.permute(0, 2, 1, 3) # [N, H, L, C]
129
- out = sdpa(q, k, v) # [N, H, L, C]
130
- out = out.permute(0, 2, 1, 3) # [N, L, H, C]
131
- elif BACKEND == 'naive':
132
  if num_all_args == 1:
133
  q, k, v = qkv.unbind(dim=2)
134
  elif num_all_args == 2:
@@ -136,5 +155,5 @@ def scaled_dot_product_attention(*args, **kwargs):
136
  out = _naive_sdpa(q, k, v)
137
  else:
138
  raise ValueError(f"Unknown attention module: {BACKEND}")
139
-
140
  return out
 
3
  import math
4
  from . import DEBUG, BACKEND
5
 
6
+ if BACKEND == "xformers":
7
  import xformers.ops as xops
8
+ elif BACKEND == "flash_attn":
9
  import flash_attn
10
+ elif BACKEND == "sdpa":
11
  from torch.nn.functional import scaled_dot_product_attention as sdpa
12
+ elif BACKEND == "naive":
13
  pass
14
  else:
15
  raise ValueError(f"Unknown attention backend: {BACKEND}")
16
 
17
 
18
  __all__ = [
19
+ "scaled_dot_product_attention",
20
  ]
21
 
22
 
 
24
  """
25
  Naive implementation of scaled dot product attention.
26
  """
27
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
28
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
29
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
30
  scale_factor = 1 / math.sqrt(q.size(-1))
31
  attn_weight = q @ k.transpose(-2, -1) * scale_factor
32
  attn_weight = torch.softmax(attn_weight, dim=-1)
33
  out = attn_weight @ v
34
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
35
  return out
36
 
37
 
 
45
  """
46
  ...
47
 
48
+
49
  @overload
50
  def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
51
  """
 
57
  """
58
  ...
59
 
60
+
61
  @overload
62
+ def scaled_dot_product_attention(
63
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
64
+ ) -> torch.Tensor:
65
  """
66
  Apply scaled dot product attention.
67
 
 
75
  """
76
  ...
77
 
78
+
79
  def scaled_dot_product_attention(*args, **kwargs):
80
+ arg_names_dict = {1: ["qkv"], 2: ["q", "kv"], 3: ["q", "k", "v"]}
 
 
 
 
81
  num_all_args = len(args) + len(kwargs)
82
+ assert (
83
+ num_all_args in arg_names_dict
84
+ ), f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
85
+ for key in arg_names_dict[num_all_args][len(args) :]:
86
  assert key in kwargs, f"Missing argument {key}"
87
 
88
  if num_all_args == 1:
89
+ qkv = args[0] if len(args) > 0 else kwargs["qkv"]
90
+ assert (
91
+ len(qkv.shape) == 5 and qkv.shape[2] == 3
92
+ ), f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]"
93
  device = qkv.device
94
 
95
  elif num_all_args == 2:
96
+ q = args[0] if len(args) > 0 else kwargs["q"]
97
+ kv = args[1] if len(args) > 1 else kwargs["kv"]
98
+ assert (
99
+ q.shape[0] == kv.shape[0]
100
+ ), f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
101
+ assert (
102
+ len(q.shape) == 4
103
+ ), f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
104
+ assert (
105
+ len(kv.shape) == 5
106
+ ), f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
107
  device = q.device
108
 
109
  elif num_all_args == 3:
110
+ q = args[0] if len(args) > 0 else kwargs["q"]
111
+ k = args[1] if len(args) > 1 else kwargs["k"]
112
+ v = args[2] if len(args) > 2 else kwargs["v"]
113
+ assert (
114
+ q.shape[0] == k.shape[0] == v.shape[0]
115
+ ), f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
116
+ assert (
117
+ len(q.shape) == 4
118
+ ), f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
119
+ assert (
120
+ len(k.shape) == 4
121
+ ), f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
122
+ assert (
123
+ len(v.shape) == 4
124
+ ), f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
125
+ device = q.device
126
+
127
+ if BACKEND == "xformers":
128
  if num_all_args == 1:
129
  q, k, v = qkv.unbind(dim=2)
130
  elif num_all_args == 2:
131
  k, v = kv.unbind(dim=2)
132
  out = xops.memory_efficient_attention(q, k, v)
133
+ elif BACKEND == "flash_attn":
134
  if num_all_args == 1:
135
  out = flash_attn.flash_attn_qkvpacked_func(qkv)
136
  elif num_all_args == 2:
137
  out = flash_attn.flash_attn_kvpacked_func(q, kv)
138
  elif num_all_args == 3:
139
  out = flash_attn.flash_attn_func(q, k, v)
140
+ elif BACKEND == "sdpa":
141
  if num_all_args == 1:
142
  q, k, v = qkv.unbind(dim=2)
143
  elif num_all_args == 2:
144
  k, v = kv.unbind(dim=2)
145
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
146
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
147
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
148
+ out = sdpa(q, k, v) # [N, H, L, C]
149
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
150
+ elif BACKEND == "naive":
151
  if num_all_args == 1:
152
  q, k, v = qkv.unbind(dim=2)
153
  elif num_all_args == 2:
 
155
  out = _naive_sdpa(q, k, v)
156
  else:
157
  raise ValueError(f"Unknown attention module: {BACKEND}")
158
+
159
  return out
trellis/modules/attention/modules.py CHANGED
@@ -8,11 +8,11 @@ from .full_attn import scaled_dot_product_attention
8
  class MultiHeadRMSNorm(nn.Module):
9
  def __init__(self, dim: int, heads: int):
10
  super().__init__()
11
- self.scale = dim ** 0.5
12
  self.gamma = nn.Parameter(torch.ones(heads, dim))
13
 
14
  def forward(self, x: torch.Tensor) -> torch.Tensor:
15
- return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
16
 
17
 
18
  class RotaryPositionEmbedder(nn.Module):
@@ -23,21 +23,25 @@ class RotaryPositionEmbedder(nn.Module):
23
  self.in_channels = in_channels
24
  self.freq_dim = hidden_size // in_channels // 2
25
  self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
26
- self.freqs = 1.0 / (10000 ** self.freqs)
27
-
28
  def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
29
  self.freqs = self.freqs.to(indices.device)
30
  phases = torch.outer(indices, self.freqs)
31
  phases = torch.polar(torch.ones_like(phases), phases)
32
  return phases
33
-
34
  def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
35
  x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
36
  x_rotated = x_complex * phases
37
- x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
 
 
38
  return x_embed
39
-
40
- def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
41
  """
42
  Args:
43
  q (sp.SparseTensor): [..., N, D] tensor of queries
@@ -48,24 +52,38 @@ class RotaryPositionEmbedder(nn.Module):
48
  indices = torch.arange(q.shape[-2], device=q.device)
49
  if len(q.shape) > 2:
50
  indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,))
51
-
52
  phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
53
  if phases.shape[1] < self.hidden_size // 2:
54
- phases = torch.cat([phases, torch.polar(
55
- torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device),
56
- torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device)
57
- )], dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  q_embed = self._rotary_embedding(q, phases)
59
  k_embed = self._rotary_embedding(k, phases)
60
  return q_embed, k_embed
61
-
62
 
63
  class MultiHeadAttention(nn.Module):
64
  def __init__(
65
  self,
66
  channels: int,
67
  num_heads: int,
68
- ctx_channels: Optional[int]=None,
69
  type: Literal["self", "cross"] = "self",
70
  attn_mode: Literal["full", "windowed"] = "full",
71
  window_size: Optional[int] = None,
@@ -78,11 +96,13 @@ class MultiHeadAttention(nn.Module):
78
  assert channels % num_heads == 0
79
  assert type in ["self", "cross"], f"Invalid attention type: {type}"
80
  assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
81
- assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
82
-
 
 
83
  if attn_mode == "windowed":
84
  raise NotImplementedError("Windowed attention is not yet implemented")
85
-
86
  self.channels = channels
87
  self.head_dim = channels // num_heads
88
  self.ctx_channels = ctx_channels if ctx_channels is not None else channels
@@ -99,17 +119,22 @@ class MultiHeadAttention(nn.Module):
99
  else:
100
  self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
101
  self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
102
-
103
  if self.qk_rms_norm:
104
  self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
105
  self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
106
-
107
  self.to_out = nn.Linear(channels, channels)
108
 
109
  if use_rope:
110
  self.rope = RotaryPositionEmbedder(channels)
111
-
112
- def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor:
 
 
 
 
 
113
  B, L, C = x.shape
114
  if self._type == "self":
115
  qkv = self.to_qkv(x)
 
8
  class MultiHeadRMSNorm(nn.Module):
9
  def __init__(self, dim: int, heads: int):
10
  super().__init__()
11
+ self.scale = dim**0.5
12
  self.gamma = nn.Parameter(torch.ones(heads, dim))
13
 
14
  def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ return (F.normalize(x.float(), dim=-1) * self.gamma * self.scale).to(x.dtype)
16
 
17
 
18
  class RotaryPositionEmbedder(nn.Module):
 
23
  self.in_channels = in_channels
24
  self.freq_dim = hidden_size // in_channels // 2
25
  self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
26
+ self.freqs = 1.0 / (10000**self.freqs)
27
+
28
  def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
29
  self.freqs = self.freqs.to(indices.device)
30
  phases = torch.outer(indices, self.freqs)
31
  phases = torch.polar(torch.ones_like(phases), phases)
32
  return phases
33
+
34
  def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
35
  x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
36
  x_rotated = x_complex * phases
37
+ x_embed = (
38
+ torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
39
+ )
40
  return x_embed
41
+
42
+ def forward(
43
+ self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None
44
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
45
  """
46
  Args:
47
  q (sp.SparseTensor): [..., N, D] tensor of queries
 
52
  indices = torch.arange(q.shape[-2], device=q.device)
53
  if len(q.shape) > 2:
54
  indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,))
55
+
56
  phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
57
  if phases.shape[1] < self.hidden_size // 2:
58
+ phases = torch.cat(
59
+ [
60
+ phases,
61
+ torch.polar(
62
+ torch.ones(
63
+ *phases.shape[:-1],
64
+ self.hidden_size // 2 - phases.shape[1],
65
+ device=phases.device,
66
+ ),
67
+ torch.zeros(
68
+ *phases.shape[:-1],
69
+ self.hidden_size // 2 - phases.shape[1],
70
+ device=phases.device,
71
+ ),
72
+ ),
73
+ ],
74
+ dim=-1,
75
+ )
76
  q_embed = self._rotary_embedding(q, phases)
77
  k_embed = self._rotary_embedding(k, phases)
78
  return q_embed, k_embed
79
+
80
 
81
  class MultiHeadAttention(nn.Module):
82
  def __init__(
83
  self,
84
  channels: int,
85
  num_heads: int,
86
+ ctx_channels: Optional[int] = None,
87
  type: Literal["self", "cross"] = "self",
88
  attn_mode: Literal["full", "windowed"] = "full",
89
  window_size: Optional[int] = None,
 
96
  assert channels % num_heads == 0
97
  assert type in ["self", "cross"], f"Invalid attention type: {type}"
98
  assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
99
+ assert (
100
+ type == "self" or attn_mode == "full"
101
+ ), "Cross-attention only supports full attention"
102
+
103
  if attn_mode == "windowed":
104
  raise NotImplementedError("Windowed attention is not yet implemented")
105
+
106
  self.channels = channels
107
  self.head_dim = channels // num_heads
108
  self.ctx_channels = ctx_channels if ctx_channels is not None else channels
 
119
  else:
120
  self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
121
  self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
122
+
123
  if self.qk_rms_norm:
124
  self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
125
  self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
126
+
127
  self.to_out = nn.Linear(channels, channels)
128
 
129
  if use_rope:
130
  self.rope = RotaryPositionEmbedder(channels)
131
+
132
+ def forward(
133
+ self,
134
+ x: torch.Tensor,
135
+ context: Optional[torch.Tensor] = None,
136
+ indices: Optional[torch.Tensor] = None,
137
+ ) -> torch.Tensor:
138
  B, L, C = x.shape
139
  if self._type == "self":
140
  qkv = self.to_qkv(x)
trellis/modules/norm.py CHANGED
@@ -5,21 +5,21 @@ import torch.nn as nn
5
  class LayerNorm32(nn.LayerNorm):
6
  def forward(self, x: torch.Tensor) -> torch.Tensor:
7
  return super().forward(x.float()).type(x.dtype)
8
-
9
 
10
  class GroupNorm32(nn.GroupNorm):
11
  """
12
  A GroupNorm layer that converts to float32 before the forward pass.
13
  """
 
14
  def forward(self, x: torch.Tensor) -> torch.Tensor:
15
  return super().forward(x.float()).type(x.dtype)
16
-
17
-
18
  class ChannelLayerNorm32(LayerNorm32):
19
  def forward(self, x: torch.Tensor) -> torch.Tensor:
20
  DIM = x.dim()
21
  x = x.permute(0, *range(2, DIM), 1).contiguous()
22
  x = super().forward(x)
23
- x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
24
  return x
25
-
 
5
  class LayerNorm32(nn.LayerNorm):
6
  def forward(self, x: torch.Tensor) -> torch.Tensor:
7
  return super().forward(x.float()).type(x.dtype)
8
+
9
 
10
  class GroupNorm32(nn.GroupNorm):
11
  """
12
  A GroupNorm layer that converts to float32 before the forward pass.
13
  """
14
+
15
  def forward(self, x: torch.Tensor) -> torch.Tensor:
16
  return super().forward(x.float()).type(x.dtype)
17
+
18
+
19
  class ChannelLayerNorm32(LayerNorm32):
20
  def forward(self, x: torch.Tensor) -> torch.Tensor:
21
  DIM = x.dim()
22
  x = x.permute(0, *range(2, DIM), 1).contiguous()
23
  x = super().forward(x)
24
+ x = x.permute(0, DIM - 1, *range(1, DIM - 1)).contiguous()
25
  return x
 
trellis/modules/sparse/__init__.py CHANGED
@@ -1,81 +1,88 @@
1
  from typing import *
2
 
3
- BACKEND = 'spconv'
4
  DEBUG = False
5
- ATTN = 'flash_attn'
 
6
 
7
  def __from_env():
8
  import os
9
-
10
  global BACKEND
11
  global DEBUG
12
  global ATTN
13
-
14
- env_sparse_backend = os.environ.get('SPARSE_BACKEND')
15
- env_sparse_debug = os.environ.get('SPARSE_DEBUG')
16
- env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
17
  if env_sparse_attn is None:
18
- env_sparse_attn = os.environ.get('ATTN_BACKEND')
19
 
20
- if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
 
 
 
21
  BACKEND = env_sparse_backend
22
  if env_sparse_debug is not None:
23
- DEBUG = env_sparse_debug == '1'
24
- if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
25
  ATTN = env_sparse_attn
26
-
27
  print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
28
-
29
 
30
  __from_env()
31
-
32
 
33
- def set_backend(backend: Literal['spconv', 'torchsparse']):
 
34
  global BACKEND
35
  BACKEND = backend
36
 
 
37
  def set_debug(debug: bool):
38
  global DEBUG
39
  DEBUG = debug
40
 
41
- def set_attn(attn: Literal['xformers', 'flash_attn']):
 
42
  global ATTN
43
  ATTN = attn
44
-
45
-
46
  import importlib
47
 
48
  __attributes = {
49
- 'SparseTensor': 'basic',
50
- 'sparse_batch_broadcast': 'basic',
51
- 'sparse_batch_op': 'basic',
52
- 'sparse_cat': 'basic',
53
- 'sparse_unbind': 'basic',
54
- 'SparseGroupNorm': 'norm',
55
- 'SparseLayerNorm': 'norm',
56
- 'SparseGroupNorm32': 'norm',
57
- 'SparseLayerNorm32': 'norm',
58
- 'SparseReLU': 'nonlinearity',
59
- 'SparseSiLU': 'nonlinearity',
60
- 'SparseGELU': 'nonlinearity',
61
- 'SparseActivation': 'nonlinearity',
62
- 'SparseLinear': 'linear',
63
- 'sparse_scaled_dot_product_attention': 'attention',
64
- 'SerializeMode': 'attention',
65
- 'sparse_serialized_scaled_dot_product_self_attention': 'attention',
66
- 'sparse_windowed_scaled_dot_product_self_attention': 'attention',
67
- 'SparseMultiHeadAttention': 'attention',
68
- 'SparseConv3d': 'conv',
69
- 'SparseInverseConv3d': 'conv',
70
- 'SparseDownsample': 'spatial',
71
- 'SparseUpsample': 'spatial',
72
- 'SparseSubdivide' : 'spatial'
73
  }
74
 
75
- __submodules = ['transformer']
76
 
77
  __all__ = list(__attributes.keys()) + __submodules
78
 
 
79
  def __getattr__(name):
80
  if name not in globals():
81
  if name in __attributes:
@@ -91,7 +98,7 @@ def __getattr__(name):
91
 
92
 
93
  # For Pylance
94
- if __name__ == '__main__':
95
  from .basic import *
96
  from .norm import *
97
  from .nonlinearity import *
 
1
  from typing import *
2
 
3
+ BACKEND = "spconv"
4
  DEBUG = False
5
+ ATTN = "flash_attn"
6
+
7
 
8
  def __from_env():
9
  import os
10
+
11
  global BACKEND
12
  global DEBUG
13
  global ATTN
14
+
15
+ env_sparse_backend = os.environ.get("SPARSE_BACKEND")
16
+ env_sparse_debug = os.environ.get("SPARSE_DEBUG")
17
+ env_sparse_attn = os.environ.get("SPARSE_ATTN_BACKEND")
18
  if env_sparse_attn is None:
19
+ env_sparse_attn = os.environ.get("ATTN_BACKEND")
20
 
21
+ if env_sparse_backend is not None and env_sparse_backend in [
22
+ "spconv",
23
+ "torchsparse",
24
+ ]:
25
  BACKEND = env_sparse_backend
26
  if env_sparse_debug is not None:
27
+ DEBUG = env_sparse_debug == "1"
28
+ if env_sparse_attn is not None and env_sparse_attn in ["xformers", "flash_attn"]:
29
  ATTN = env_sparse_attn
30
+
31
  print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
32
+
33
 
34
  __from_env()
 
35
 
36
+
37
+ def set_backend(backend: Literal["spconv", "torchsparse"]):
38
  global BACKEND
39
  BACKEND = backend
40
 
41
+
42
  def set_debug(debug: bool):
43
  global DEBUG
44
  DEBUG = debug
45
 
46
+
47
+ def set_attn(attn: Literal["xformers", "flash_attn"]):
48
  global ATTN
49
  ATTN = attn
50
+
51
+
52
  import importlib
53
 
54
  __attributes = {
55
+ "SparseTensor": "basic",
56
+ "sparse_batch_broadcast": "basic",
57
+ "sparse_batch_op": "basic",
58
+ "sparse_cat": "basic",
59
+ "sparse_unbind": "basic",
60
+ "SparseGroupNorm": "norm",
61
+ "SparseLayerNorm": "norm",
62
+ "SparseGroupNorm32": "norm",
63
+ "SparseLayerNorm32": "norm",
64
+ "SparseReLU": "nonlinearity",
65
+ "SparseSiLU": "nonlinearity",
66
+ "SparseGELU": "nonlinearity",
67
+ "SparseActivation": "nonlinearity",
68
+ "SparseLinear": "linear",
69
+ "sparse_scaled_dot_product_attention": "attention",
70
+ "SerializeMode": "attention",
71
+ "sparse_serialized_scaled_dot_product_self_attention": "attention",
72
+ "sparse_windowed_scaled_dot_product_self_attention": "attention",
73
+ "SparseMultiHeadAttention": "attention",
74
+ "SparseConv3d": "conv",
75
+ "SparseInverseConv3d": "conv",
76
+ "SparseDownsample": "spatial",
77
+ "SparseUpsample": "spatial",
78
+ "SparseSubdivide": "spatial",
79
  }
80
 
81
+ __submodules = ["transformer"]
82
 
83
  __all__ = list(__attributes.keys()) + __submodules
84
 
85
+
86
  def __getattr__(name):
87
  if name not in globals():
88
  if name in __attributes:
 
98
 
99
 
100
  # For Pylance
101
+ if __name__ == "__main__":
102
  from .basic import *
103
  from .norm import *
104
  from .nonlinearity import *
trellis/modules/sparse/attention/full_attn.py CHANGED
@@ -3,16 +3,16 @@ import torch
3
  from .. import SparseTensor
4
  from .. import DEBUG, ATTN
5
 
6
- if ATTN == 'xformers':
7
  import xformers.ops as xops
8
- elif ATTN == 'flash_attn':
9
  import flash_attn
10
  else:
11
  raise ValueError(f"Unknown attention module: {ATTN}")
12
 
13
 
14
  __all__ = [
15
- 'sparse_scaled_dot_product_attention',
16
  ]
17
 
18
 
@@ -26,8 +26,11 @@ def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor:
26
  """
27
  ...
28
 
 
29
  @overload
30
- def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor:
 
 
31
  """
32
  Apply scaled dot product attention to a sparse tensor.
33
 
@@ -37,8 +40,11 @@ def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor,
37
  """
38
  ...
39
 
 
40
  @overload
41
- def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor:
 
 
42
  """
43
  Apply scaled dot product attention to a sparse tensor.
44
 
@@ -48,8 +54,11 @@ def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> to
48
  """
49
  ...
50
 
 
51
  @overload
52
- def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor:
 
 
53
  """
54
  Apply scaled dot product attention to a sparse tensor.
55
 
@@ -63,8 +72,11 @@ def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: Spa
63
  """
64
  ...
65
 
 
66
  @overload
67
- def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor:
 
 
68
  """
69
  Apply scaled dot product attention to a sparse tensor.
70
 
@@ -75,8 +87,11 @@ def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: tor
75
  """
76
  ...
77
 
 
78
  @overload
79
- def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor:
 
 
80
  """
81
  Apply scaled dot product attention to a sparse tensor.
82
 
@@ -87,106 +102,158 @@ def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: Spa
87
  """
88
  ...
89
 
 
90
  def sparse_scaled_dot_product_attention(*args, **kwargs):
91
- arg_names_dict = {
92
- 1: ['qkv'],
93
- 2: ['q', 'kv'],
94
- 3: ['q', 'k', 'v']
95
- }
96
  num_all_args = len(args) + len(kwargs)
97
- assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
98
- for key in arg_names_dict[num_all_args][len(args):]:
 
 
99
  assert key in kwargs, f"Missing argument {key}"
100
 
101
  if num_all_args == 1:
102
- qkv = args[0] if len(args) > 0 else kwargs['qkv']
103
- assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}"
104
- assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
 
 
 
 
105
  device = qkv.device
106
 
107
  s = qkv
108
- q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
 
 
109
  kv_seqlen = q_seqlen
110
- qkv = qkv.feats # [T, 3, H, C]
111
 
112
  elif num_all_args == 2:
113
- q = args[0] if len(args) > 0 else kwargs['q']
114
- kv = args[1] if len(args) > 1 else kwargs['kv']
115
- assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \
116
- isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \
117
- f"Invalid types, got {type(q)} and {type(kv)}"
118
- assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
 
 
 
 
 
119
  device = q.device
120
 
121
  if isinstance(q, SparseTensor):
122
- assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
 
 
123
  s = q
124
  q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
125
- q = q.feats # [T_Q, H, C]
126
  else:
127
- assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
 
 
128
  s = None
129
  N, L, H, C = q.shape
130
  q_seqlen = [L] * N
131
- q = q.reshape(N * L, H, C) # [T_Q, H, C]
132
 
133
  if isinstance(kv, SparseTensor):
134
- assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
135
- kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
136
- kv = kv.feats # [T_KV, 2, H, C]
 
 
 
 
137
  else:
138
- assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
 
 
139
  N, L, _, H, C = kv.shape
140
  kv_seqlen = [L] * N
141
- kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
142
 
143
  elif num_all_args == 3:
144
- q = args[0] if len(args) > 0 else kwargs['q']
145
- k = args[1] if len(args) > 1 else kwargs['k']
146
- v = args[2] if len(args) > 2 else kwargs['v']
147
- assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \
148
- isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \
149
- f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}"
150
- assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
 
 
 
 
 
 
 
151
  device = q.device
152
 
153
  if isinstance(q, SparseTensor):
154
- assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]"
 
 
155
  s = q
156
  q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
157
- q = q.feats # [T_Q, H, Ci]
158
  else:
159
- assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
 
 
160
  s = None
161
  N, L, H, CI = q.shape
162
  q_seqlen = [L] * N
163
  q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
164
 
165
  if isinstance(k, SparseTensor):
166
- assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]"
167
- assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]"
168
- kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
169
- k = k.feats # [T_KV, H, Ci]
170
- v = v.feats # [T_KV, H, Co]
 
 
 
 
 
 
171
  else:
172
- assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
173
- assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
 
 
 
 
174
  N, L, H, CI, CO = *k.shape, v.shape[-1]
175
  kv_seqlen = [L] * N
176
- k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
177
- v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
178
 
179
  if DEBUG:
180
  if s is not None:
181
  for i in range(s.shape[0]):
182
- assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch"
 
 
183
  if num_all_args in [2, 3]:
184
- assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch"
 
 
 
185
  if num_all_args == 3:
186
- assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch"
187
- assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch"
 
 
 
 
 
 
188
 
189
- if ATTN == 'xformers':
190
  if num_all_args == 1:
191
  q, k, v = qkv.unbind(dim=1)
192
  elif num_all_args == 2:
@@ -196,19 +263,35 @@ def sparse_scaled_dot_product_attention(*args, **kwargs):
196
  v = v.unsqueeze(0)
197
  mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
198
  out = xops.memory_efficient_attention(q, k, v, mask)[0]
199
- elif ATTN == 'flash_attn':
200
- cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
 
 
 
 
201
  if num_all_args in [2, 3]:
202
- cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
 
 
 
 
 
 
203
  if num_all_args == 1:
204
- out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
 
 
205
  elif num_all_args == 2:
206
- out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
 
 
207
  elif num_all_args == 3:
208
- out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
 
 
209
  else:
210
  raise ValueError(f"Unknown attention module: {ATTN}")
211
-
212
  if s is not None:
213
  return s.replace(out)
214
  else:
 
3
  from .. import SparseTensor
4
  from .. import DEBUG, ATTN
5
 
6
+ if ATTN == "xformers":
7
  import xformers.ops as xops
8
+ elif ATTN == "flash_attn":
9
  import flash_attn
10
  else:
11
  raise ValueError(f"Unknown attention module: {ATTN}")
12
 
13
 
14
  __all__ = [
15
+ "sparse_scaled_dot_product_attention",
16
  ]
17
 
18
 
 
26
  """
27
  ...
28
 
29
+
30
  @overload
31
+ def sparse_scaled_dot_product_attention(
32
+ q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]
33
+ ) -> SparseTensor:
34
  """
35
  Apply scaled dot product attention to a sparse tensor.
36
 
 
40
  """
41
  ...
42
 
43
+
44
  @overload
45
+ def sparse_scaled_dot_product_attention(
46
+ q: torch.Tensor, kv: SparseTensor
47
+ ) -> torch.Tensor:
48
  """
49
  Apply scaled dot product attention to a sparse tensor.
50
 
 
54
  """
55
  ...
56
 
57
+
58
  @overload
59
+ def sparse_scaled_dot_product_attention(
60
+ q: SparseTensor, k: SparseTensor, v: SparseTensor
61
+ ) -> SparseTensor:
62
  """
63
  Apply scaled dot product attention to a sparse tensor.
64
 
 
72
  """
73
  ...
74
 
75
+
76
  @overload
77
+ def sparse_scaled_dot_product_attention(
78
+ q: SparseTensor, k: torch.Tensor, v: torch.Tensor
79
+ ) -> SparseTensor:
80
  """
81
  Apply scaled dot product attention to a sparse tensor.
82
 
 
87
  """
88
  ...
89
 
90
+
91
  @overload
92
+ def sparse_scaled_dot_product_attention(
93
+ q: torch.Tensor, k: SparseTensor, v: SparseTensor
94
+ ) -> torch.Tensor:
95
  """
96
  Apply scaled dot product attention to a sparse tensor.
97
 
 
102
  """
103
  ...
104
 
105
+
106
  def sparse_scaled_dot_product_attention(*args, **kwargs):
107
+ arg_names_dict = {1: ["qkv"], 2: ["q", "kv"], 3: ["q", "k", "v"]}
 
 
 
 
108
  num_all_args = len(args) + len(kwargs)
109
+ assert (
110
+ num_all_args in arg_names_dict
111
+ ), f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
112
+ for key in arg_names_dict[num_all_args][len(args) :]:
113
  assert key in kwargs, f"Missing argument {key}"
114
 
115
  if num_all_args == 1:
116
+ qkv = args[0] if len(args) > 0 else kwargs["qkv"]
117
+ assert isinstance(
118
+ qkv, SparseTensor
119
+ ), f"qkv must be a SparseTensor, got {type(qkv)}"
120
+ assert (
121
+ len(qkv.shape) == 4 and qkv.shape[1] == 3
122
+ ), f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
123
  device = qkv.device
124
 
125
  s = qkv
126
+ q_seqlen = [
127
+ qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])
128
+ ]
129
  kv_seqlen = q_seqlen
130
+ qkv = qkv.feats # [T, 3, H, C]
131
 
132
  elif num_all_args == 2:
133
+ q = args[0] if len(args) > 0 else kwargs["q"]
134
+ kv = args[1] if len(args) > 1 else kwargs["kv"]
135
+ assert (
136
+ isinstance(q, SparseTensor)
137
+ and isinstance(kv, (SparseTensor, torch.Tensor))
138
+ or isinstance(q, torch.Tensor)
139
+ and isinstance(kv, SparseTensor)
140
+ ), f"Invalid types, got {type(q)} and {type(kv)}"
141
+ assert (
142
+ q.shape[0] == kv.shape[0]
143
+ ), f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
144
  device = q.device
145
 
146
  if isinstance(q, SparseTensor):
147
+ assert (
148
+ len(q.shape) == 3
149
+ ), f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
150
  s = q
151
  q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
152
+ q = q.feats # [T_Q, H, C]
153
  else:
154
+ assert (
155
+ len(q.shape) == 4
156
+ ), f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
157
  s = None
158
  N, L, H, C = q.shape
159
  q_seqlen = [L] * N
160
+ q = q.reshape(N * L, H, C) # [T_Q, H, C]
161
 
162
  if isinstance(kv, SparseTensor):
163
+ assert (
164
+ len(kv.shape) == 4 and kv.shape[1] == 2
165
+ ), f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
166
+ kv_seqlen = [
167
+ kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])
168
+ ]
169
+ kv = kv.feats # [T_KV, 2, H, C]
170
  else:
171
+ assert (
172
+ len(kv.shape) == 5
173
+ ), f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
174
  N, L, _, H, C = kv.shape
175
  kv_seqlen = [L] * N
176
+ kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
177
 
178
  elif num_all_args == 3:
179
+ q = args[0] if len(args) > 0 else kwargs["q"]
180
+ k = args[1] if len(args) > 1 else kwargs["k"]
181
+ v = args[2] if len(args) > 2 else kwargs["v"]
182
+ assert (
183
+ isinstance(q, SparseTensor)
184
+ and isinstance(k, (SparseTensor, torch.Tensor))
185
+ and type(k) == type(v)
186
+ or isinstance(q, torch.Tensor)
187
+ and isinstance(k, SparseTensor)
188
+ and isinstance(v, SparseTensor)
189
+ ), f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}"
190
+ assert (
191
+ q.shape[0] == k.shape[0] == v.shape[0]
192
+ ), f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
193
  device = q.device
194
 
195
  if isinstance(q, SparseTensor):
196
+ assert (
197
+ len(q.shape) == 3
198
+ ), f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]"
199
  s = q
200
  q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
201
+ q = q.feats # [T_Q, H, Ci]
202
  else:
203
+ assert (
204
+ len(q.shape) == 4
205
+ ), f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
206
  s = None
207
  N, L, H, CI = q.shape
208
  q_seqlen = [L] * N
209
  q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
210
 
211
  if isinstance(k, SparseTensor):
212
+ assert (
213
+ len(k.shape) == 3
214
+ ), f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]"
215
+ assert (
216
+ len(v.shape) == 3
217
+ ), f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]"
218
+ kv_seqlen = [
219
+ k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])
220
+ ]
221
+ k = k.feats # [T_KV, H, Ci]
222
+ v = v.feats # [T_KV, H, Co]
223
  else:
224
+ assert (
225
+ len(k.shape) == 4
226
+ ), f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
227
+ assert (
228
+ len(v.shape) == 4
229
+ ), f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
230
  N, L, H, CI, CO = *k.shape, v.shape[-1]
231
  kv_seqlen = [L] * N
232
+ k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
233
+ v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
234
 
235
  if DEBUG:
236
  if s is not None:
237
  for i in range(s.shape[0]):
238
+ assert (
239
+ s.coords[s.layout[i]] == i
240
+ ).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch"
241
  if num_all_args in [2, 3]:
242
+ assert q.shape[:2] == [
243
+ 1,
244
+ sum(q_seqlen),
245
+ ], f"SparseScaledDotProductSelfAttention: q shape mismatch"
246
  if num_all_args == 3:
247
+ assert k.shape[:2] == [
248
+ 1,
249
+ sum(kv_seqlen),
250
+ ], f"SparseScaledDotProductSelfAttention: k shape mismatch"
251
+ assert v.shape[:2] == [
252
+ 1,
253
+ sum(kv_seqlen),
254
+ ], f"SparseScaledDotProductSelfAttention: v shape mismatch"
255
 
256
+ if ATTN == "xformers":
257
  if num_all_args == 1:
258
  q, k, v = qkv.unbind(dim=1)
259
  elif num_all_args == 2:
 
263
  v = v.unsqueeze(0)
264
  mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
265
  out = xops.memory_efficient_attention(q, k, v, mask)[0]
266
+ elif ATTN == "flash_attn":
267
+ cu_seqlens_q = (
268
+ torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)])
269
+ .int()
270
+ .to(device)
271
+ )
272
  if num_all_args in [2, 3]:
273
+ cu_seqlens_kv = (
274
+ torch.cat(
275
+ [torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]
276
+ )
277
+ .int()
278
+ .to(device)
279
+ )
280
  if num_all_args == 1:
281
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(
282
+ qkv, cu_seqlens_q, max(q_seqlen)
283
+ )
284
  elif num_all_args == 2:
285
+ out = flash_attn.flash_attn_varlen_kvpacked_func(
286
+ q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)
287
+ )
288
  elif num_all_args == 3:
289
+ out = flash_attn.flash_attn_varlen_func(
290
+ q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)
291
+ )
292
  else:
293
  raise ValueError(f"Unknown attention module: {ATTN}")
294
+
295
  if s is not None:
296
  return s.replace(out)
297
  else:
trellis/modules/sparse/attention/modules.py CHANGED
@@ -4,7 +4,10 @@ import torch.nn as nn
4
  import torch.nn.functional as F
5
  from .. import SparseTensor
6
  from .full_attn import sparse_scaled_dot_product_attention
7
- from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention
 
 
 
8
  from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention
9
  from ...attention import RotaryPositionEmbedder
10
 
@@ -12,16 +15,18 @@ from ...attention import RotaryPositionEmbedder
12
  class SparseMultiHeadRMSNorm(nn.Module):
13
  def __init__(self, dim: int, heads: int):
14
  super().__init__()
15
- self.scale = dim ** 0.5
16
  self.gamma = nn.Parameter(torch.ones(heads, dim))
17
 
18
- def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
 
 
19
  x_type = x.dtype
20
  x = x.float()
21
  if isinstance(x, SparseTensor):
22
  x = x.replace(F.normalize(x.feats, dim=-1))
23
  else:
24
- x = F.normalize(x, dim=-1)
25
  return (x * self.gamma * self.scale).to(x_type)
26
 
27
 
@@ -44,9 +49,17 @@ class SparseMultiHeadAttention(nn.Module):
44
  super().__init__()
45
  assert channels % num_heads == 0
46
  assert type in ["self", "cross"], f"Invalid attention type: {type}"
47
- assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}"
48
- assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
49
- assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention"
 
 
 
 
 
 
 
 
50
  self.channels = channels
51
  self.ctx_channels = ctx_channels if ctx_channels is not None else channels
52
  self.num_heads = num_heads
@@ -64,31 +77,37 @@ class SparseMultiHeadAttention(nn.Module):
64
  else:
65
  self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
66
  self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
67
-
68
  if self.qk_rms_norm:
69
  self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
70
  self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
71
-
72
  self.to_out = nn.Linear(channels, channels)
73
 
74
  if use_rope:
75
  self.rope = RotaryPositionEmbedder(channels)
76
 
77
  @staticmethod
78
- def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
 
 
79
  if isinstance(x, SparseTensor):
80
  return x.replace(module(x.feats))
81
  else:
82
  return module(x)
83
 
84
  @staticmethod
85
- def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]:
 
 
86
  if isinstance(x, SparseTensor):
87
  return x.reshape(*shape)
88
  else:
89
  return x.reshape(*x.shape[:2], *shape)
90
 
91
- def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]:
 
 
92
  if isinstance(x, SparseTensor):
93
  x_feats = x.feats.unsqueeze(0)
94
  else:
@@ -97,12 +116,16 @@ class SparseMultiHeadAttention(nn.Module):
97
  return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats
98
 
99
  def _rope(self, qkv: SparseTensor) -> SparseTensor:
100
- q, k, v = qkv.feats.unbind(dim=1) # [T, H, C]
101
  q, k = self.rope(q, k, qkv.coords[:, 1:])
102
- qkv = qkv.replace(torch.stack([q, k, v], dim=1))
103
  return qkv
104
-
105
- def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]:
 
 
 
 
106
  if self._type == "self":
107
  qkv = self._linear(self.to_qkv, x)
108
  qkv = self._fused_pre(qkv, num_fused=3)
@@ -117,7 +140,11 @@ class SparseMultiHeadAttention(nn.Module):
117
  h = sparse_scaled_dot_product_attention(qkv)
118
  elif self.attn_mode == "serialized":
119
  h = sparse_serialized_scaled_dot_product_self_attention(
120
- qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window
 
 
 
 
121
  )
122
  elif self.attn_mode == "windowed":
123
  h = sparse_windowed_scaled_dot_product_self_attention(
 
4
  import torch.nn.functional as F
5
  from .. import SparseTensor
6
  from .full_attn import sparse_scaled_dot_product_attention
7
+ from .serialized_attn import (
8
+ SerializeMode,
9
+ sparse_serialized_scaled_dot_product_self_attention,
10
+ )
11
  from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention
12
  from ...attention import RotaryPositionEmbedder
13
 
 
15
  class SparseMultiHeadRMSNorm(nn.Module):
16
  def __init__(self, dim: int, heads: int):
17
  super().__init__()
18
+ self.scale = dim**0.5
19
  self.gamma = nn.Parameter(torch.ones(heads, dim))
20
 
21
+ def forward(
22
+ self, x: Union[SparseTensor, torch.Tensor]
23
+ ) -> Union[SparseTensor, torch.Tensor]:
24
  x_type = x.dtype
25
  x = x.float()
26
  if isinstance(x, SparseTensor):
27
  x = x.replace(F.normalize(x.feats, dim=-1))
28
  else:
29
+ x = F.normalize(x, dim=-1)
30
  return (x * self.gamma * self.scale).to(x_type)
31
 
32
 
 
49
  super().__init__()
50
  assert channels % num_heads == 0
51
  assert type in ["self", "cross"], f"Invalid attention type: {type}"
52
+ assert attn_mode in [
53
+ "full",
54
+ "serialized",
55
+ "windowed",
56
+ ], f"Invalid attention mode: {attn_mode}"
57
+ assert (
58
+ type == "self" or attn_mode == "full"
59
+ ), "Cross-attention only supports full attention"
60
+ assert (
61
+ type == "self" or use_rope is False
62
+ ), "Rotary position embeddings only supported for self-attention"
63
  self.channels = channels
64
  self.ctx_channels = ctx_channels if ctx_channels is not None else channels
65
  self.num_heads = num_heads
 
77
  else:
78
  self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
79
  self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
80
+
81
  if self.qk_rms_norm:
82
  self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
83
  self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
84
+
85
  self.to_out = nn.Linear(channels, channels)
86
 
87
  if use_rope:
88
  self.rope = RotaryPositionEmbedder(channels)
89
 
90
  @staticmethod
91
+ def _linear(
92
+ module: nn.Linear, x: Union[SparseTensor, torch.Tensor]
93
+ ) -> Union[SparseTensor, torch.Tensor]:
94
  if isinstance(x, SparseTensor):
95
  return x.replace(module(x.feats))
96
  else:
97
  return module(x)
98
 
99
  @staticmethod
100
+ def _reshape_chs(
101
+ x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]
102
+ ) -> Union[SparseTensor, torch.Tensor]:
103
  if isinstance(x, SparseTensor):
104
  return x.reshape(*shape)
105
  else:
106
  return x.reshape(*x.shape[:2], *shape)
107
 
108
+ def _fused_pre(
109
+ self, x: Union[SparseTensor, torch.Tensor], num_fused: int
110
+ ) -> Union[SparseTensor, torch.Tensor]:
111
  if isinstance(x, SparseTensor):
112
  x_feats = x.feats.unsqueeze(0)
113
  else:
 
116
  return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats
117
 
118
  def _rope(self, qkv: SparseTensor) -> SparseTensor:
119
+ q, k, v = qkv.feats.unbind(dim=1) # [T, H, C]
120
  q, k = self.rope(q, k, qkv.coords[:, 1:])
121
+ qkv = qkv.replace(torch.stack([q, k, v], dim=1))
122
  return qkv
123
+
124
+ def forward(
125
+ self,
126
+ x: Union[SparseTensor, torch.Tensor],
127
+ context: Optional[Union[SparseTensor, torch.Tensor]] = None,
128
+ ) -> Union[SparseTensor, torch.Tensor]:
129
  if self._type == "self":
130
  qkv = self._linear(self.to_qkv, x)
131
  qkv = self._fused_pre(qkv, num_fused=3)
 
140
  h = sparse_scaled_dot_product_attention(qkv)
141
  elif self.attn_mode == "serialized":
142
  h = sparse_serialized_scaled_dot_product_self_attention(
143
+ qkv,
144
+ self.window_size,
145
+ serialize_mode=self.serialize_mode,
146
+ shift_sequence=self.shift_sequence,
147
+ shift_window=self.shift_window,
148
  )
149
  elif self.attn_mode == "windowed":
150
  h = sparse_windowed_scaled_dot_product_self_attention(
trellis/modules/sparse/attention/serialized_attn.py CHANGED
@@ -5,16 +5,16 @@ import math
5
  from .. import SparseTensor
6
  from .. import DEBUG, ATTN
7
 
8
- if ATTN == 'xformers':
9
  import xformers.ops as xops
10
- elif ATTN == 'flash_attn':
11
  import flash_attn
12
  else:
13
  raise ValueError(f"Unknown attention module: {ATTN}")
14
 
15
 
16
  __all__ = [
17
- 'sparse_serialized_scaled_dot_product_self_attention',
18
  ]
19
 
20
 
@@ -29,7 +29,7 @@ SerializeModes = [
29
  SerializeMode.Z_ORDER,
30
  SerializeMode.Z_ORDER_TRANSPOSED,
31
  SerializeMode.HILBERT,
32
- SerializeMode.HILBERT_TRANSPOSED
33
  ]
34
 
35
 
@@ -38,7 +38,7 @@ def calc_serialization(
38
  window_size: int,
39
  serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
40
  shift_sequence: int = 0,
41
- shift_window: Tuple[int, int, int] = (0, 0, 0)
42
  ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
43
  """
44
  Calculate serialization and partitioning for a set of coordinates.
@@ -58,32 +58,38 @@ def calc_serialization(
58
  seq_lens = []
59
  seq_batch_indices = []
60
  offsets = [0]
61
-
62
- if 'vox2seq' not in globals():
63
  import vox2seq
64
 
65
  # Serialize the input
66
  serialize_coords = tensor.coords[:, 1:].clone()
67
- serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3)
 
 
68
  if serialize_mode == SerializeMode.Z_ORDER:
69
- code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2])
70
  elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED:
71
- code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2])
72
  elif serialize_mode == SerializeMode.HILBERT:
73
- code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2])
74
  elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED:
75
- code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2])
76
  else:
77
  raise ValueError(f"Unknown serialize mode: {serialize_mode}")
78
-
79
  for bi, s in enumerate(tensor.layout):
80
  num_points = s.stop - s.start
81
  num_windows = (num_points + window_size - 1) // window_size
82
  valid_window_size = num_points / num_windows
83
- to_ordered = torch.argsort(code[s.start:s.stop])
84
  if num_windows == 1:
85
  fwd_indices.append(to_ordered)
86
- bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device)))
 
 
 
 
87
  fwd_indices[-1] += s.start
88
  bwd_indices[-1] += offsets[-1]
89
  seq_lens.append(num_points)
@@ -92,18 +98,39 @@ def calc_serialization(
92
  else:
93
  # Partition the input
94
  offset = 0
95
- mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)]
96
- split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)]
97
- bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device)
 
 
 
 
 
 
 
 
98
  for i in range(num_windows):
99
  mid = mids[i]
100
  valid_start = split[i]
101
  valid_end = split[i + 1]
102
  padded_start = math.floor(mid - 0.5 * window_size)
103
  padded_end = padded_start + window_size
104
- fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points])
 
 
 
 
 
105
  offset += valid_start - padded_start
106
- bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device))
 
 
 
 
 
 
 
 
107
  offset += padded_end - valid_start
108
  fwd_indices[-1] += s.start
109
  seq_lens.extend([window_size] * num_windows)
@@ -115,14 +142,14 @@ def calc_serialization(
115
  bwd_indices = torch.cat(bwd_indices)
116
 
117
  return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
118
-
119
 
120
  def sparse_serialized_scaled_dot_product_self_attention(
121
  qkv: SparseTensor,
122
  window_size: int,
123
  serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
124
  shift_sequence: int = 0,
125
- shift_window: Tuple[int, int, int] = (0, 0, 0)
126
  ) -> SparseTensor:
127
  """
128
  Apply serialized scaled dot product self attention to a sparse tensor.
@@ -135,59 +162,86 @@ def sparse_serialized_scaled_dot_product_self_attention(
135
  shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
136
  shift (int): The shift to use.
137
  """
138
- assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
139
-
140
- serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}'
141
- serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
 
 
 
 
 
 
142
  if serialization_spatial_cache is None:
143
- fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window)
144
- qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
 
 
 
 
 
145
  else:
146
- fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
 
 
147
 
148
  M = fwd_indices.shape[0]
149
  T = qkv.feats.shape[0]
150
  H = qkv.feats.shape[2]
151
  C = qkv.feats.shape[3]
152
-
153
- qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
154
 
155
  if DEBUG:
156
  start = 0
157
  qkv_coords = qkv.coords[fwd_indices]
158
  for i in range(len(seq_lens)):
159
- assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
 
 
 
 
160
  start += seq_lens[i]
161
 
162
  if all([seq_len == window_size for seq_len in seq_lens]):
163
  B = len(seq_lens)
164
  N = window_size
165
  qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
166
- if ATTN == 'xformers':
167
- q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
168
- out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
169
- elif ATTN == 'flash_attn':
170
- out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
171
  else:
172
  raise ValueError(f"Unknown attention module: {ATTN}")
173
- out = out.reshape(B * N, H, C) # [M, H, C]
174
  else:
175
- if ATTN == 'xformers':
176
- q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
177
- q = q.unsqueeze(0) # [1, M, H, C]
178
- k = k.unsqueeze(0) # [1, M, H, C]
179
- v = v.unsqueeze(0) # [1, M, H, C]
180
  mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
181
- out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
182
- elif ATTN == 'flash_attn':
183
- cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
184
- .to(qkv.device).int()
185
- out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
186
-
187
- out = out[bwd_indices] # [T, H, C]
 
 
 
 
 
 
 
 
188
 
189
  if DEBUG:
190
  qkv_coords = qkv_coords[bwd_indices]
191
- assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
 
 
192
 
193
  return qkv.replace(out)
 
5
  from .. import SparseTensor
6
  from .. import DEBUG, ATTN
7
 
8
+ if ATTN == "xformers":
9
  import xformers.ops as xops
10
+ elif ATTN == "flash_attn":
11
  import flash_attn
12
  else:
13
  raise ValueError(f"Unknown attention module: {ATTN}")
14
 
15
 
16
  __all__ = [
17
+ "sparse_serialized_scaled_dot_product_self_attention",
18
  ]
19
 
20
 
 
29
  SerializeMode.Z_ORDER,
30
  SerializeMode.Z_ORDER_TRANSPOSED,
31
  SerializeMode.HILBERT,
32
+ SerializeMode.HILBERT_TRANSPOSED,
33
  ]
34
 
35
 
 
38
  window_size: int,
39
  serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
40
  shift_sequence: int = 0,
41
+ shift_window: Tuple[int, int, int] = (0, 0, 0),
42
  ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
43
  """
44
  Calculate serialization and partitioning for a set of coordinates.
 
58
  seq_lens = []
59
  seq_batch_indices = []
60
  offsets = [0]
61
+
62
+ if "vox2seq" not in globals():
63
  import vox2seq
64
 
65
  # Serialize the input
66
  serialize_coords = tensor.coords[:, 1:].clone()
67
+ serialize_coords += torch.tensor(
68
+ shift_window, dtype=torch.int32, device=tensor.device
69
+ ).reshape(1, 3)
70
  if serialize_mode == SerializeMode.Z_ORDER:
71
+ code = vox2seq.encode(serialize_coords, mode="z_order", permute=[0, 1, 2])
72
  elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED:
73
+ code = vox2seq.encode(serialize_coords, mode="z_order", permute=[1, 0, 2])
74
  elif serialize_mode == SerializeMode.HILBERT:
75
+ code = vox2seq.encode(serialize_coords, mode="hilbert", permute=[0, 1, 2])
76
  elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED:
77
+ code = vox2seq.encode(serialize_coords, mode="hilbert", permute=[1, 0, 2])
78
  else:
79
  raise ValueError(f"Unknown serialize mode: {serialize_mode}")
80
+
81
  for bi, s in enumerate(tensor.layout):
82
  num_points = s.stop - s.start
83
  num_windows = (num_points + window_size - 1) // window_size
84
  valid_window_size = num_points / num_windows
85
+ to_ordered = torch.argsort(code[s.start : s.stop])
86
  if num_windows == 1:
87
  fwd_indices.append(to_ordered)
88
+ bwd_indices.append(
89
+ torch.zeros_like(to_ordered).scatter_(
90
+ 0, to_ordered, torch.arange(num_points, device=tensor.device)
91
+ )
92
+ )
93
  fwd_indices[-1] += s.start
94
  bwd_indices[-1] += offsets[-1]
95
  seq_lens.append(num_points)
 
98
  else:
99
  # Partition the input
100
  offset = 0
101
+ mids = [
102
+ (i + 0.5) * valid_window_size + shift_sequence
103
+ for i in range(num_windows)
104
+ ]
105
+ split = [
106
+ math.floor(i * valid_window_size + shift_sequence)
107
+ for i in range(num_windows + 1)
108
+ ]
109
+ bwd_index = torch.zeros(
110
+ (num_points,), dtype=torch.int64, device=tensor.device
111
+ )
112
  for i in range(num_windows):
113
  mid = mids[i]
114
  valid_start = split[i]
115
  valid_end = split[i + 1]
116
  padded_start = math.floor(mid - 0.5 * window_size)
117
  padded_end = padded_start + window_size
118
+ fwd_indices.append(
119
+ to_ordered[
120
+ torch.arange(padded_start, padded_end, device=tensor.device)
121
+ % num_points
122
+ ]
123
+ )
124
  offset += valid_start - padded_start
125
+ bwd_index.scatter_(
126
+ 0,
127
+ fwd_indices[-1][
128
+ valid_start - padded_start : valid_end - padded_start
129
+ ],
130
+ torch.arange(
131
+ offset, offset + valid_end - valid_start, device=tensor.device
132
+ ),
133
+ )
134
  offset += padded_end - valid_start
135
  fwd_indices[-1] += s.start
136
  seq_lens.extend([window_size] * num_windows)
 
142
  bwd_indices = torch.cat(bwd_indices)
143
 
144
  return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
145
+
146
 
147
  def sparse_serialized_scaled_dot_product_self_attention(
148
  qkv: SparseTensor,
149
  window_size: int,
150
  serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
151
  shift_sequence: int = 0,
152
+ shift_window: Tuple[int, int, int] = (0, 0, 0),
153
  ) -> SparseTensor:
154
  """
155
  Apply serialized scaled dot product self attention to a sparse tensor.
 
162
  shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
163
  shift (int): The shift to use.
164
  """
165
+ assert (
166
+ len(qkv.shape) == 4 and qkv.shape[1] == 3
167
+ ), f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
168
+
169
+ serialization_spatial_cache_name = (
170
+ f"serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}"
171
+ )
172
+ serialization_spatial_cache = qkv.get_spatial_cache(
173
+ serialization_spatial_cache_name
174
+ )
175
  if serialization_spatial_cache is None:
176
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(
177
+ qkv, window_size, serialize_mode, shift_sequence, shift_window
178
+ )
179
+ qkv.register_spatial_cache(
180
+ serialization_spatial_cache_name,
181
+ (fwd_indices, bwd_indices, seq_lens, seq_batch_indices),
182
+ )
183
  else:
184
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = (
185
+ serialization_spatial_cache
186
+ )
187
 
188
  M = fwd_indices.shape[0]
189
  T = qkv.feats.shape[0]
190
  H = qkv.feats.shape[2]
191
  C = qkv.feats.shape[3]
192
+
193
+ qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
194
 
195
  if DEBUG:
196
  start = 0
197
  qkv_coords = qkv.coords[fwd_indices]
198
  for i in range(len(seq_lens)):
199
+ assert (
200
+ qkv_coords[start : start + seq_lens[i], 0] == seq_batch_indices[i]
201
+ ).all(), (
202
+ f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
203
+ )
204
  start += seq_lens[i]
205
 
206
  if all([seq_len == window_size for seq_len in seq_lens]):
207
  B = len(seq_lens)
208
  N = window_size
209
  qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
210
+ if ATTN == "xformers":
211
+ q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
212
+ out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
213
+ elif ATTN == "flash_attn":
214
+ out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
215
  else:
216
  raise ValueError(f"Unknown attention module: {ATTN}")
217
+ out = out.reshape(B * N, H, C) # [M, H, C]
218
  else:
219
+ if ATTN == "xformers":
220
+ q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
221
+ q = q.unsqueeze(0) # [1, M, H, C]
222
+ k = k.unsqueeze(0) # [1, M, H, C]
223
+ v = v.unsqueeze(0) # [1, M, H, C]
224
  mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
225
+ out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
226
+ elif ATTN == "flash_attn":
227
+ cu_seqlens = (
228
+ torch.cat(
229
+ [torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)],
230
+ dim=0,
231
+ )
232
+ .to(qkv.device)
233
+ .int()
234
+ )
235
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(
236
+ qkv_feats, cu_seqlens, max(seq_lens)
237
+ ) # [M, H, C]
238
+
239
+ out = out[bwd_indices] # [T, H, C]
240
 
241
  if DEBUG:
242
  qkv_coords = qkv_coords[bwd_indices]
243
+ assert torch.equal(
244
+ qkv_coords, qkv.coords
245
+ ), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
246
 
247
  return qkv.replace(out)
trellis/modules/sparse/attention/windowed_attn.py CHANGED
@@ -4,23 +4,23 @@ import math
4
  from .. import SparseTensor
5
  from .. import DEBUG, ATTN
6
 
7
- if ATTN == 'xformers':
8
  import xformers.ops as xops
9
- elif ATTN == 'flash_attn':
10
  import flash_attn
11
  else:
12
  raise ValueError(f"Unknown attention module: {ATTN}")
13
 
14
 
15
  __all__ = [
16
- 'sparse_windowed_scaled_dot_product_self_attention',
17
  ]
18
 
19
 
20
  def calc_window_partition(
21
  tensor: SparseTensor,
22
  window_size: Union[int, Tuple[int, ...]],
23
- shift_window: Union[int, Tuple[int, ...]] = 0
24
  ) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
25
  """
26
  Calculate serialization and partitioning for a set of coordinates.
@@ -37,33 +37,43 @@ def calc_window_partition(
37
  (List[int]): Sequence batch indices.
38
  """
39
  DIM = tensor.coords.shape[1] - 1
40
- shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
 
 
41
  window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
42
  shifted_coords = tensor.coords.clone().detach()
43
- shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
 
 
44
 
45
  MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist()
46
  NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
47
  OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
48
 
49
- shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
50
- shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
 
 
 
 
 
51
  fwd_indices = torch.argsort(shifted_indices)
52
  bwd_indices = torch.empty_like(fwd_indices)
53
  bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
54
  seq_lens = torch.bincount(shifted_indices)
55
- seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0]
 
 
 
56
  mask = seq_lens != 0
57
  seq_lens = seq_lens[mask].tolist()
58
  seq_batch_indices = seq_batch_indices[mask].tolist()
59
 
60
  return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
61
-
62
 
63
  def sparse_windowed_scaled_dot_product_self_attention(
64
- qkv: SparseTensor,
65
- window_size: int,
66
- shift_window: Tuple[int, int, int] = (0, 0, 0)
67
  ) -> SparseTensor:
68
  """
69
  Apply windowed scaled dot product self attention to a sparse tensor.
@@ -74,62 +84,92 @@ def sparse_windowed_scaled_dot_product_self_attention(
74
  shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
75
  shift (int): The shift to use.
76
  """
77
- assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
78
-
79
- serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}'
80
- serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
 
 
 
 
81
  if serialization_spatial_cache is None:
82
- fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window)
83
- qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
 
 
 
 
 
84
  else:
85
- fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
 
 
86
 
87
  M = fwd_indices.shape[0]
88
  T = qkv.feats.shape[0]
89
  H = qkv.feats.shape[2]
90
  C = qkv.feats.shape[3]
91
-
92
- qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
93
 
94
  if DEBUG:
95
  start = 0
96
  qkv_coords = qkv.coords[fwd_indices]
97
  for i in range(len(seq_lens)):
98
- seq_coords = qkv_coords[start:start+seq_lens[i]]
99
- assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
100
- assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \
101
- f"SparseWindowedScaledDotProductSelfAttention: window size exceeded"
 
 
 
 
 
 
 
 
 
102
  start += seq_lens[i]
103
 
104
  if all([seq_len == window_size for seq_len in seq_lens]):
105
  B = len(seq_lens)
106
  N = window_size
107
  qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
108
- if ATTN == 'xformers':
109
- q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
110
- out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
111
- elif ATTN == 'flash_attn':
112
- out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
113
  else:
114
  raise ValueError(f"Unknown attention module: {ATTN}")
115
- out = out.reshape(B * N, H, C) # [M, H, C]
116
  else:
117
- if ATTN == 'xformers':
118
- q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
119
- q = q.unsqueeze(0) # [1, M, H, C]
120
- k = k.unsqueeze(0) # [1, M, H, C]
121
- v = v.unsqueeze(0) # [1, M, H, C]
122
  mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
123
- out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
124
- elif ATTN == 'flash_attn':
125
- cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
126
- .to(qkv.device).int()
127
- out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
128
-
129
- out = out[bwd_indices] # [T, H, C]
 
 
 
 
 
 
 
 
130
 
131
  if DEBUG:
132
  qkv_coords = qkv_coords[bwd_indices]
133
- assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
 
 
134
 
135
  return qkv.replace(out)
 
4
  from .. import SparseTensor
5
  from .. import DEBUG, ATTN
6
 
7
+ if ATTN == "xformers":
8
  import xformers.ops as xops
9
+ elif ATTN == "flash_attn":
10
  import flash_attn
11
  else:
12
  raise ValueError(f"Unknown attention module: {ATTN}")
13
 
14
 
15
  __all__ = [
16
+ "sparse_windowed_scaled_dot_product_self_attention",
17
  ]
18
 
19
 
20
  def calc_window_partition(
21
  tensor: SparseTensor,
22
  window_size: Union[int, Tuple[int, ...]],
23
+ shift_window: Union[int, Tuple[int, ...]] = 0,
24
  ) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
25
  """
26
  Calculate serialization and partitioning for a set of coordinates.
 
37
  (List[int]): Sequence batch indices.
38
  """
39
  DIM = tensor.coords.shape[1] - 1
40
+ shift_window = (
41
+ (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
42
+ )
43
  window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
44
  shifted_coords = tensor.coords.clone().detach()
45
+ shifted_coords[:, 1:] += torch.tensor(
46
+ shift_window, device=tensor.device, dtype=torch.int32
47
+ ).unsqueeze(0)
48
 
49
  MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist()
50
  NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
51
  OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
52
 
53
+ shifted_coords[:, 1:] //= torch.tensor(
54
+ window_size, device=tensor.device, dtype=torch.int32
55
+ ).unsqueeze(0)
56
+ shifted_indices = (
57
+ shifted_coords
58
+ * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)
59
+ ).sum(dim=1)
60
  fwd_indices = torch.argsort(shifted_indices)
61
  bwd_indices = torch.empty_like(fwd_indices)
62
  bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
63
  seq_lens = torch.bincount(shifted_indices)
64
+ seq_batch_indices = (
65
+ torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32)
66
+ // OFFSET[0]
67
+ )
68
  mask = seq_lens != 0
69
  seq_lens = seq_lens[mask].tolist()
70
  seq_batch_indices = seq_batch_indices[mask].tolist()
71
 
72
  return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
73
+
74
 
75
  def sparse_windowed_scaled_dot_product_self_attention(
76
+ qkv: SparseTensor, window_size: int, shift_window: Tuple[int, int, int] = (0, 0, 0)
 
 
77
  ) -> SparseTensor:
78
  """
79
  Apply windowed scaled dot product self attention to a sparse tensor.
 
84
  shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
85
  shift (int): The shift to use.
86
  """
87
+ assert (
88
+ len(qkv.shape) == 4 and qkv.shape[1] == 3
89
+ ), f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
90
+
91
+ serialization_spatial_cache_name = f"window_partition_{window_size}_{shift_window}"
92
+ serialization_spatial_cache = qkv.get_spatial_cache(
93
+ serialization_spatial_cache_name
94
+ )
95
  if serialization_spatial_cache is None:
96
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(
97
+ qkv, window_size, shift_window
98
+ )
99
+ qkv.register_spatial_cache(
100
+ serialization_spatial_cache_name,
101
+ (fwd_indices, bwd_indices, seq_lens, seq_batch_indices),
102
+ )
103
  else:
104
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = (
105
+ serialization_spatial_cache
106
+ )
107
 
108
  M = fwd_indices.shape[0]
109
  T = qkv.feats.shape[0]
110
  H = qkv.feats.shape[2]
111
  C = qkv.feats.shape[3]
112
+
113
+ qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
114
 
115
  if DEBUG:
116
  start = 0
117
  qkv_coords = qkv.coords[fwd_indices]
118
  for i in range(len(seq_lens)):
119
+ seq_coords = qkv_coords[start : start + seq_lens[i]]
120
+ assert (
121
+ seq_coords[:, 0] == seq_batch_indices[i]
122
+ ).all(), (
123
+ f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
124
+ )
125
+ assert (
126
+ seq_coords[:, 1:].max(dim=0).values
127
+ - seq_coords[:, 1:].min(dim=0).values
128
+ < window_size
129
+ ).all(), (
130
+ f"SparseWindowedScaledDotProductSelfAttention: window size exceeded"
131
+ )
132
  start += seq_lens[i]
133
 
134
  if all([seq_len == window_size for seq_len in seq_lens]):
135
  B = len(seq_lens)
136
  N = window_size
137
  qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
138
+ if ATTN == "xformers":
139
+ q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
140
+ out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
141
+ elif ATTN == "flash_attn":
142
+ out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
143
  else:
144
  raise ValueError(f"Unknown attention module: {ATTN}")
145
+ out = out.reshape(B * N, H, C) # [M, H, C]
146
  else:
147
+ if ATTN == "xformers":
148
+ q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
149
+ q = q.unsqueeze(0) # [1, M, H, C]
150
+ k = k.unsqueeze(0) # [1, M, H, C]
151
+ v = v.unsqueeze(0) # [1, M, H, C]
152
  mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
153
+ out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
154
+ elif ATTN == "flash_attn":
155
+ cu_seqlens = (
156
+ torch.cat(
157
+ [torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)],
158
+ dim=0,
159
+ )
160
+ .to(qkv.device)
161
+ .int()
162
+ )
163
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(
164
+ qkv_feats, cu_seqlens, max(seq_lens)
165
+ ) # [M, H, C]
166
+
167
+ out = out[bwd_indices] # [T, H, C]
168
 
169
  if DEBUG:
170
  qkv_coords = qkv_coords[bwd_indices]
171
+ assert torch.equal(
172
+ qkv_coords, qkv.coords
173
+ ), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
174
 
175
  return qkv.replace(out)
trellis/modules/sparse/basic.py CHANGED
@@ -2,22 +2,23 @@ from typing import *
2
  import torch
3
  import torch.nn as nn
4
  from . import BACKEND, DEBUG
5
- SparseTensorData = None # Lazy import
 
6
 
7
 
8
  __all__ = [
9
- 'SparseTensor',
10
- 'sparse_batch_broadcast',
11
- 'sparse_batch_op',
12
- 'sparse_cat',
13
- 'sparse_unbind',
14
  ]
15
 
16
 
17
  class SparseTensor:
18
  """
19
  Sparse tensor with support for both torchsparse and spconv backends.
20
-
21
  Parameters:
22
  - feats (torch.Tensor): Features of the sparse tensor.
23
  - coords (torch.Tensor): Coordinates of the sparse tensor.
@@ -29,64 +30,87 @@ class SparseTensor:
29
  - Data corresponding to a same batch should be contiguous.
30
  - Coords should be in [0, 1023]
31
  """
 
32
  @overload
33
- def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
 
 
 
 
 
 
 
34
 
35
  @overload
36
- def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
 
 
 
 
 
 
37
 
38
  def __init__(self, *args, **kwargs):
39
  # Lazy import of sparse tensor backend
40
  global SparseTensorData
41
  if SparseTensorData is None:
42
  import importlib
43
- if BACKEND == 'torchsparse':
44
- SparseTensorData = importlib.import_module('torchsparse').SparseTensor
45
- elif BACKEND == 'spconv':
46
- SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
47
-
 
 
 
48
  method_id = 0
49
  if len(args) != 0:
50
  method_id = 0 if isinstance(args[0], torch.Tensor) else 1
51
  else:
52
- method_id = 1 if 'data' in kwargs else 0
53
 
54
  if method_id == 0:
55
  feats, coords, shape, layout = args + (None,) * (4 - len(args))
56
- if 'feats' in kwargs:
57
- feats = kwargs['feats']
58
- del kwargs['feats']
59
- if 'coords' in kwargs:
60
- coords = kwargs['coords']
61
- del kwargs['coords']
62
- if 'shape' in kwargs:
63
- shape = kwargs['shape']
64
- del kwargs['shape']
65
- if 'layout' in kwargs:
66
- layout = kwargs['layout']
67
- del kwargs['layout']
68
 
69
  if shape is None:
70
  shape = self.__cal_shape(feats, coords)
71
  if layout is None:
72
  layout = self.__cal_layout(coords, shape[0])
73
- if BACKEND == 'torchsparse':
74
  self.data = SparseTensorData(feats, coords, **kwargs)
75
- elif BACKEND == 'spconv':
76
  spatial_shape = list(coords.max(0)[0] + 1)[1:]
77
- self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs)
 
 
 
 
 
 
78
  self.data._features = feats
79
  elif method_id == 1:
80
  data, shape, layout = args + (None,) * (3 - len(args))
81
- if 'data' in kwargs:
82
- data = kwargs['data']
83
- del kwargs['data']
84
- if 'shape' in kwargs:
85
- shape = kwargs['shape']
86
- del kwargs['shape']
87
- if 'layout' in kwargs:
88
- layout = kwargs['layout']
89
- del kwargs['layout']
90
 
91
  self.data = data
92
  if shape is None:
@@ -96,73 +120,84 @@ class SparseTensor:
96
 
97
  self._shape = shape
98
  self._layout = layout
99
- self._scale = kwargs.get('scale', (1, 1, 1))
100
- self._spatial_cache = kwargs.get('spatial_cache', {})
101
 
102
  if DEBUG:
103
  try:
104
- assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}"
105
- assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}"
106
- assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}"
 
 
 
 
 
 
107
  for i in range(self.shape[0]):
108
- assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous"
 
 
109
  except Exception as e:
110
- print('Debugging information:')
111
  print(f"- Shape: {self.shape}")
112
  print(f"- Layout: {self.layout}")
113
  print(f"- Scale: {self._scale}")
114
  print(f"- Coords: {self.coords}")
115
  raise e
116
-
117
  def __cal_shape(self, feats, coords):
118
  shape = []
119
  shape.append(coords[:, 0].max().item() + 1)
120
  shape.extend([*feats.shape[1:]])
121
  return torch.Size(shape)
122
-
123
  def __cal_layout(self, coords, batch_size):
124
  seq_len = torch.bincount(coords[:, 0], minlength=batch_size)
125
- offset = torch.cumsum(seq_len, dim=0)
126
- layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)]
 
 
 
127
  return layout
128
-
129
  @property
130
  def shape(self) -> torch.Size:
131
  return self._shape
132
-
133
  def dim(self) -> int:
134
  return len(self.shape)
135
-
136
  @property
137
  def layout(self) -> List[slice]:
138
  return self._layout
139
 
140
  @property
141
  def feats(self) -> torch.Tensor:
142
- if BACKEND == 'torchsparse':
143
  return self.data.F
144
- elif BACKEND == 'spconv':
145
  return self.data.features
146
-
147
  @feats.setter
148
  def feats(self, value: torch.Tensor):
149
- if BACKEND == 'torchsparse':
150
  self.data.F = value
151
- elif BACKEND == 'spconv':
152
  self.data.features = value
153
 
154
  @property
155
  def coords(self) -> torch.Tensor:
156
- if BACKEND == 'torchsparse':
157
  return self.data.C
158
- elif BACKEND == 'spconv':
159
  return self.data.indices
160
-
161
  @coords.setter
162
  def coords(self, value: torch.Tensor):
163
- if BACKEND == 'torchsparse':
164
  self.data.C = value
165
- elif BACKEND == 'spconv':
166
  self.data.indices = value
167
 
168
  @property
@@ -174,12 +209,16 @@ class SparseTensor:
174
  return self.feats.device
175
 
176
  @overload
177
- def to(self, dtype: torch.dtype) -> 'SparseTensor': ...
178
 
179
  @overload
180
- def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ...
 
 
 
 
181
 
182
- def to(self, *args, **kwargs) -> 'SparseTensor':
183
  device = None
184
  dtype = None
185
  if len(args) == 2:
@@ -189,13 +228,13 @@ class SparseTensor:
189
  dtype = args[0]
190
  else:
191
  device = args[0]
192
- if 'dtype' in kwargs:
193
  assert dtype is None, "to() received multiple values for argument 'dtype'"
194
- dtype = kwargs['dtype']
195
- if 'device' in kwargs:
196
  assert device is None, "to() received multiple values for argument 'device'"
197
- device = kwargs['device']
198
-
199
  new_feats = self.feats.to(device=device, dtype=dtype)
200
  new_coords = self.coords.to(device=device)
201
  return self.replace(new_feats, new_coords)
@@ -204,46 +243,48 @@ class SparseTensor:
204
  new_feats = self.feats.type(dtype)
205
  return self.replace(new_feats)
206
 
207
- def cpu(self) -> 'SparseTensor':
208
  new_feats = self.feats.cpu()
209
  new_coords = self.coords.cpu()
210
  return self.replace(new_feats, new_coords)
211
-
212
- def cuda(self) -> 'SparseTensor':
213
  new_feats = self.feats.cuda()
214
  new_coords = self.coords.cuda()
215
  return self.replace(new_feats, new_coords)
216
 
217
- def half(self) -> 'SparseTensor':
218
  new_feats = self.feats.half()
219
  return self.replace(new_feats)
220
-
221
- def float(self) -> 'SparseTensor':
222
  new_feats = self.feats.float()
223
  return self.replace(new_feats)
224
-
225
- def detach(self) -> 'SparseTensor':
226
  new_coords = self.coords.detach()
227
  new_feats = self.feats.detach()
228
  return self.replace(new_feats, new_coords)
229
 
230
  def dense(self) -> torch.Tensor:
231
- if BACKEND == 'torchsparse':
232
  return self.data.dense()
233
- elif BACKEND == 'spconv':
234
  return self.data.dense()
235
 
236
- def reshape(self, *shape) -> 'SparseTensor':
237
  new_feats = self.feats.reshape(self.feats.shape[0], *shape)
238
  return self.replace(new_feats)
239
-
240
- def unbind(self, dim: int) -> List['SparseTensor']:
241
  return sparse_unbind(self, dim)
242
 
243
- def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor':
 
 
244
  new_shape = [self.shape[0]]
245
  new_shape.extend(feats.shape[1:])
246
- if BACKEND == 'torchsparse':
247
  new_data = SparseTensorData(
248
  feats=feats,
249
  coords=self.data.coords if coords is None else coords,
@@ -251,7 +292,7 @@ class SparseTensor:
251
  spatial_range=self.data.spatial_range,
252
  )
253
  new_data._caches = self.data._caches
254
- elif BACKEND == 'spconv':
255
  new_data = SparseTensorData(
256
  self.data.features.reshape(self.data.features.shape[0], -1),
257
  self.data.indices,
@@ -259,7 +300,7 @@ class SparseTensor:
259
  self.data.batch_size,
260
  self.data.grid,
261
  self.data.voxel_num,
262
- self.data.indice_dict
263
  )
264
  new_data._features = feats
265
  new_data.benchmark = self.data.benchmark
@@ -270,26 +311,39 @@ class SparseTensor:
270
  new_data.int8_scale = self.data.int8_scale
271
  if coords is not None:
272
  new_data.indices = coords
273
- new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache)
 
 
 
 
 
 
274
  return new_tensor
275
 
276
  @staticmethod
277
- def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor':
278
  N, C = dim
279
  x = torch.arange(aabb[0], aabb[3] + 1)
280
  y = torch.arange(aabb[1], aabb[4] + 1)
281
  z = torch.arange(aabb[2], aabb[5] + 1)
282
- coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3)
283
- coords = torch.cat([
284
- torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
285
- coords.repeat(N, 1),
286
- ], dim=1).to(dtype=torch.int32, device=device)
 
 
 
 
 
287
  feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
288
  return SparseTensor(feats=feats, coords=coords)
289
 
290
- def __merge_sparse_cache(self, other: 'SparseTensor') -> dict:
291
  new_cache = {}
292
- for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())):
 
 
293
  if k in self._spatial_cache:
294
  new_cache[k] = self._spatial_cache[k]
295
  if k in other._spatial_cache:
@@ -299,10 +353,12 @@ class SparseTensor:
299
  new_cache[k].update(other._spatial_cache[k])
300
  return new_cache
301
 
302
- def __neg__(self) -> 'SparseTensor':
303
  return self.replace(-self.feats)
304
-
305
- def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor':
 
 
306
  if isinstance(other, torch.Tensor):
307
  try:
308
  other = torch.broadcast_to(other, self.shape)
@@ -317,28 +373,44 @@ class SparseTensor:
317
  new_tensor._spatial_cache = self.__merge_sparse_cache(other)
318
  return new_tensor
319
 
320
- def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
 
 
321
  return self.__elemwise__(other, torch.add)
322
 
323
- def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
 
 
324
  return self.__elemwise__(other, torch.add)
325
-
326
- def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
 
 
327
  return self.__elemwise__(other, torch.sub)
328
-
329
- def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
 
 
330
  return self.__elemwise__(other, lambda x, y: torch.sub(y, x))
331
 
332
- def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
 
 
333
  return self.__elemwise__(other, torch.mul)
334
 
335
- def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
 
 
336
  return self.__elemwise__(other, torch.mul)
337
 
338
- def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
 
 
339
  return self.__elemwise__(other, torch.div)
340
 
341
- def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
 
 
342
  return self.__elemwise__(other, lambda x, y: torch.div(y, x))
343
 
344
  def __getitem__(self, idx):
@@ -348,7 +420,9 @@ class SparseTensor:
348
  idx = range(*idx.indices(self.shape[0]))
349
  elif isinstance(idx, torch.Tensor):
350
  if idx.dtype == torch.bool:
351
- assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
 
 
352
  idx = idx.nonzero().squeeze(1)
353
  elif idx.dtype in [torch.int32, torch.int64]:
354
  assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
@@ -356,7 +430,7 @@ class SparseTensor:
356
  raise ValueError(f"Unknown index type: {idx.dtype}")
357
  else:
358
  raise ValueError(f"Unknown index type: {type(idx)}")
359
-
360
  coords = []
361
  feats = []
362
  for new_idx, old_idx in enumerate(idx):
@@ -392,7 +466,7 @@ class SparseTensor:
392
  def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor:
393
  """
394
  Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
395
-
396
  Args:
397
  input (torch.Tensor): 1D tensor to broadcast.
398
  target (SparseTensor): Sparse tensor to broadcast to.
@@ -405,10 +479,12 @@ def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Te
405
  return broadcasted
406
 
407
 
408
- def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor:
 
 
409
  """
410
  Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
411
-
412
  Args:
413
  input (torch.Tensor): 1D tensor to broadcast.
414
  target (SparseTensor): Sparse tensor to broadcast to.
@@ -420,7 +496,7 @@ def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = tor
420
  def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
421
  """
422
  Concatenate a list of sparse tensors.
423
-
424
  Args:
425
  inputs (List[SparseTensor]): List of sparse tensors to concatenate.
426
  """
@@ -447,7 +523,7 @@ def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
447
  def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
448
  """
449
  Unbind a sparse tensor along a dimension.
450
-
451
  Args:
452
  input (SparseTensor): Sparse tensor to unbind.
453
  dim (int): Dimension to unbind.
 
2
  import torch
3
  import torch.nn as nn
4
  from . import BACKEND, DEBUG
5
+
6
+ SparseTensorData = None # Lazy import
7
 
8
 
9
  __all__ = [
10
+ "SparseTensor",
11
+ "sparse_batch_broadcast",
12
+ "sparse_batch_op",
13
+ "sparse_cat",
14
+ "sparse_unbind",
15
  ]
16
 
17
 
18
  class SparseTensor:
19
  """
20
  Sparse tensor with support for both torchsparse and spconv backends.
21
+
22
  Parameters:
23
  - feats (torch.Tensor): Features of the sparse tensor.
24
  - coords (torch.Tensor): Coordinates of the sparse tensor.
 
30
  - Data corresponding to a same batch should be contiguous.
31
  - Coords should be in [0, 1023]
32
  """
33
+
34
  @overload
35
+ def __init__(
36
+ self,
37
+ feats: torch.Tensor,
38
+ coords: torch.Tensor,
39
+ shape: Optional[torch.Size] = None,
40
+ layout: Optional[List[slice]] = None,
41
+ **kwargs,
42
+ ): ...
43
 
44
  @overload
45
+ def __init__(
46
+ self,
47
+ data,
48
+ shape: Optional[torch.Size] = None,
49
+ layout: Optional[List[slice]] = None,
50
+ **kwargs,
51
+ ): ...
52
 
53
  def __init__(self, *args, **kwargs):
54
  # Lazy import of sparse tensor backend
55
  global SparseTensorData
56
  if SparseTensorData is None:
57
  import importlib
58
+
59
+ if BACKEND == "torchsparse":
60
+ SparseTensorData = importlib.import_module("torchsparse").SparseTensor
61
+ elif BACKEND == "spconv":
62
+ SparseTensorData = importlib.import_module(
63
+ "spconv.pytorch"
64
+ ).SparseConvTensor
65
+
66
  method_id = 0
67
  if len(args) != 0:
68
  method_id = 0 if isinstance(args[0], torch.Tensor) else 1
69
  else:
70
+ method_id = 1 if "data" in kwargs else 0
71
 
72
  if method_id == 0:
73
  feats, coords, shape, layout = args + (None,) * (4 - len(args))
74
+ if "feats" in kwargs:
75
+ feats = kwargs["feats"]
76
+ del kwargs["feats"]
77
+ if "coords" in kwargs:
78
+ coords = kwargs["coords"]
79
+ del kwargs["coords"]
80
+ if "shape" in kwargs:
81
+ shape = kwargs["shape"]
82
+ del kwargs["shape"]
83
+ if "layout" in kwargs:
84
+ layout = kwargs["layout"]
85
+ del kwargs["layout"]
86
 
87
  if shape is None:
88
  shape = self.__cal_shape(feats, coords)
89
  if layout is None:
90
  layout = self.__cal_layout(coords, shape[0])
91
+ if BACKEND == "torchsparse":
92
  self.data = SparseTensorData(feats, coords, **kwargs)
93
+ elif BACKEND == "spconv":
94
  spatial_shape = list(coords.max(0)[0] + 1)[1:]
95
+ self.data = SparseTensorData(
96
+ feats.reshape(feats.shape[0], -1),
97
+ coords,
98
+ spatial_shape,
99
+ shape[0],
100
+ **kwargs,
101
+ )
102
  self.data._features = feats
103
  elif method_id == 1:
104
  data, shape, layout = args + (None,) * (3 - len(args))
105
+ if "data" in kwargs:
106
+ data = kwargs["data"]
107
+ del kwargs["data"]
108
+ if "shape" in kwargs:
109
+ shape = kwargs["shape"]
110
+ del kwargs["shape"]
111
+ if "layout" in kwargs:
112
+ layout = kwargs["layout"]
113
+ del kwargs["layout"]
114
 
115
  self.data = data
116
  if shape is None:
 
120
 
121
  self._shape = shape
122
  self._layout = layout
123
+ self._scale = kwargs.get("scale", (1, 1, 1))
124
+ self._spatial_cache = kwargs.get("spatial_cache", {})
125
 
126
  if DEBUG:
127
  try:
128
+ assert (
129
+ self.feats.shape[0] == self.coords.shape[0]
130
+ ), f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}"
131
+ assert self.shape == self.__cal_shape(
132
+ self.feats, self.coords
133
+ ), f"Invalid shape: {self.shape}"
134
+ assert self.layout == self.__cal_layout(
135
+ self.coords, self.shape[0]
136
+ ), f"Invalid layout: {self.layout}"
137
  for i in range(self.shape[0]):
138
+ assert torch.all(
139
+ self.coords[self.layout[i], 0] == i
140
+ ), f"The data of batch {i} is not contiguous"
141
  except Exception as e:
142
+ print("Debugging information:")
143
  print(f"- Shape: {self.shape}")
144
  print(f"- Layout: {self.layout}")
145
  print(f"- Scale: {self._scale}")
146
  print(f"- Coords: {self.coords}")
147
  raise e
148
+
149
  def __cal_shape(self, feats, coords):
150
  shape = []
151
  shape.append(coords[:, 0].max().item() + 1)
152
  shape.extend([*feats.shape[1:]])
153
  return torch.Size(shape)
154
+
155
  def __cal_layout(self, coords, batch_size):
156
  seq_len = torch.bincount(coords[:, 0], minlength=batch_size)
157
+ offset = torch.cumsum(seq_len, dim=0)
158
+ layout = [
159
+ slice((offset[i] - seq_len[i]).item(), offset[i].item())
160
+ for i in range(batch_size)
161
+ ]
162
  return layout
163
+
164
  @property
165
  def shape(self) -> torch.Size:
166
  return self._shape
167
+
168
  def dim(self) -> int:
169
  return len(self.shape)
170
+
171
  @property
172
  def layout(self) -> List[slice]:
173
  return self._layout
174
 
175
  @property
176
  def feats(self) -> torch.Tensor:
177
+ if BACKEND == "torchsparse":
178
  return self.data.F
179
+ elif BACKEND == "spconv":
180
  return self.data.features
181
+
182
  @feats.setter
183
  def feats(self, value: torch.Tensor):
184
+ if BACKEND == "torchsparse":
185
  self.data.F = value
186
+ elif BACKEND == "spconv":
187
  self.data.features = value
188
 
189
  @property
190
  def coords(self) -> torch.Tensor:
191
+ if BACKEND == "torchsparse":
192
  return self.data.C
193
+ elif BACKEND == "spconv":
194
  return self.data.indices
195
+
196
  @coords.setter
197
  def coords(self, value: torch.Tensor):
198
+ if BACKEND == "torchsparse":
199
  self.data.C = value
200
+ elif BACKEND == "spconv":
201
  self.data.indices = value
202
 
203
  @property
 
209
  return self.feats.device
210
 
211
  @overload
212
+ def to(self, dtype: torch.dtype) -> "SparseTensor": ...
213
 
214
  @overload
215
+ def to(
216
+ self,
217
+ device: Optional[Union[str, torch.device]] = None,
218
+ dtype: Optional[torch.dtype] = None,
219
+ ) -> "SparseTensor": ...
220
 
221
+ def to(self, *args, **kwargs) -> "SparseTensor":
222
  device = None
223
  dtype = None
224
  if len(args) == 2:
 
228
  dtype = args[0]
229
  else:
230
  device = args[0]
231
+ if "dtype" in kwargs:
232
  assert dtype is None, "to() received multiple values for argument 'dtype'"
233
+ dtype = kwargs["dtype"]
234
+ if "device" in kwargs:
235
  assert device is None, "to() received multiple values for argument 'device'"
236
+ device = kwargs["device"]
237
+
238
  new_feats = self.feats.to(device=device, dtype=dtype)
239
  new_coords = self.coords.to(device=device)
240
  return self.replace(new_feats, new_coords)
 
243
  new_feats = self.feats.type(dtype)
244
  return self.replace(new_feats)
245
 
246
+ def cpu(self) -> "SparseTensor":
247
  new_feats = self.feats.cpu()
248
  new_coords = self.coords.cpu()
249
  return self.replace(new_feats, new_coords)
250
+
251
+ def cuda(self) -> "SparseTensor":
252
  new_feats = self.feats.cuda()
253
  new_coords = self.coords.cuda()
254
  return self.replace(new_feats, new_coords)
255
 
256
+ def half(self) -> "SparseTensor":
257
  new_feats = self.feats.half()
258
  return self.replace(new_feats)
259
+
260
+ def float(self) -> "SparseTensor":
261
  new_feats = self.feats.float()
262
  return self.replace(new_feats)
263
+
264
+ def detach(self) -> "SparseTensor":
265
  new_coords = self.coords.detach()
266
  new_feats = self.feats.detach()
267
  return self.replace(new_feats, new_coords)
268
 
269
  def dense(self) -> torch.Tensor:
270
+ if BACKEND == "torchsparse":
271
  return self.data.dense()
272
+ elif BACKEND == "spconv":
273
  return self.data.dense()
274
 
275
+ def reshape(self, *shape) -> "SparseTensor":
276
  new_feats = self.feats.reshape(self.feats.shape[0], *shape)
277
  return self.replace(new_feats)
278
+
279
+ def unbind(self, dim: int) -> List["SparseTensor"]:
280
  return sparse_unbind(self, dim)
281
 
282
+ def replace(
283
+ self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None
284
+ ) -> "SparseTensor":
285
  new_shape = [self.shape[0]]
286
  new_shape.extend(feats.shape[1:])
287
+ if BACKEND == "torchsparse":
288
  new_data = SparseTensorData(
289
  feats=feats,
290
  coords=self.data.coords if coords is None else coords,
 
292
  spatial_range=self.data.spatial_range,
293
  )
294
  new_data._caches = self.data._caches
295
+ elif BACKEND == "spconv":
296
  new_data = SparseTensorData(
297
  self.data.features.reshape(self.data.features.shape[0], -1),
298
  self.data.indices,
 
300
  self.data.batch_size,
301
  self.data.grid,
302
  self.data.voxel_num,
303
+ self.data.indice_dict,
304
  )
305
  new_data._features = feats
306
  new_data.benchmark = self.data.benchmark
 
311
  new_data.int8_scale = self.data.int8_scale
312
  if coords is not None:
313
  new_data.indices = coords
314
+ new_tensor = SparseTensor(
315
+ new_data,
316
+ shape=torch.Size(new_shape),
317
+ layout=self.layout,
318
+ scale=self._scale,
319
+ spatial_cache=self._spatial_cache,
320
+ )
321
  return new_tensor
322
 
323
  @staticmethod
324
+ def full(aabb, dim, value, dtype=torch.float32, device=None) -> "SparseTensor":
325
  N, C = dim
326
  x = torch.arange(aabb[0], aabb[3] + 1)
327
  y = torch.arange(aabb[1], aabb[4] + 1)
328
  z = torch.arange(aabb[2], aabb[5] + 1)
329
+ coords = torch.stack(torch.meshgrid(x, y, z, indexing="ij"), dim=-1).reshape(
330
+ -1, 3
331
+ )
332
+ coords = torch.cat(
333
+ [
334
+ torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
335
+ coords.repeat(N, 1),
336
+ ],
337
+ dim=1,
338
+ ).to(dtype=torch.int32, device=device)
339
  feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
340
  return SparseTensor(feats=feats, coords=coords)
341
 
342
+ def __merge_sparse_cache(self, other: "SparseTensor") -> dict:
343
  new_cache = {}
344
+ for k in set(
345
+ list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())
346
+ ):
347
  if k in self._spatial_cache:
348
  new_cache[k] = self._spatial_cache[k]
349
  if k in other._spatial_cache:
 
353
  new_cache[k].update(other._spatial_cache[k])
354
  return new_cache
355
 
356
+ def __neg__(self) -> "SparseTensor":
357
  return self.replace(-self.feats)
358
+
359
+ def __elemwise__(
360
+ self, other: Union[torch.Tensor, "SparseTensor"], op: callable
361
+ ) -> "SparseTensor":
362
  if isinstance(other, torch.Tensor):
363
  try:
364
  other = torch.broadcast_to(other, self.shape)
 
373
  new_tensor._spatial_cache = self.__merge_sparse_cache(other)
374
  return new_tensor
375
 
376
+ def __add__(
377
+ self, other: Union[torch.Tensor, "SparseTensor", float]
378
+ ) -> "SparseTensor":
379
  return self.__elemwise__(other, torch.add)
380
 
381
+ def __radd__(
382
+ self, other: Union[torch.Tensor, "SparseTensor", float]
383
+ ) -> "SparseTensor":
384
  return self.__elemwise__(other, torch.add)
385
+
386
+ def __sub__(
387
+ self, other: Union[torch.Tensor, "SparseTensor", float]
388
+ ) -> "SparseTensor":
389
  return self.__elemwise__(other, torch.sub)
390
+
391
+ def __rsub__(
392
+ self, other: Union[torch.Tensor, "SparseTensor", float]
393
+ ) -> "SparseTensor":
394
  return self.__elemwise__(other, lambda x, y: torch.sub(y, x))
395
 
396
+ def __mul__(
397
+ self, other: Union[torch.Tensor, "SparseTensor", float]
398
+ ) -> "SparseTensor":
399
  return self.__elemwise__(other, torch.mul)
400
 
401
+ def __rmul__(
402
+ self, other: Union[torch.Tensor, "SparseTensor", float]
403
+ ) -> "SparseTensor":
404
  return self.__elemwise__(other, torch.mul)
405
 
406
+ def __truediv__(
407
+ self, other: Union[torch.Tensor, "SparseTensor", float]
408
+ ) -> "SparseTensor":
409
  return self.__elemwise__(other, torch.div)
410
 
411
+ def __rtruediv__(
412
+ self, other: Union[torch.Tensor, "SparseTensor", float]
413
+ ) -> "SparseTensor":
414
  return self.__elemwise__(other, lambda x, y: torch.div(y, x))
415
 
416
  def __getitem__(self, idx):
 
420
  idx = range(*idx.indices(self.shape[0]))
421
  elif isinstance(idx, torch.Tensor):
422
  if idx.dtype == torch.bool:
423
+ assert idx.shape == (
424
+ self.shape[0],
425
+ ), f"Invalid index shape: {idx.shape}"
426
  idx = idx.nonzero().squeeze(1)
427
  elif idx.dtype in [torch.int32, torch.int64]:
428
  assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
 
430
  raise ValueError(f"Unknown index type: {idx.dtype}")
431
  else:
432
  raise ValueError(f"Unknown index type: {type(idx)}")
433
+
434
  coords = []
435
  feats = []
436
  for new_idx, old_idx in enumerate(idx):
 
466
  def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor:
467
  """
468
  Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
469
+
470
  Args:
471
  input (torch.Tensor): 1D tensor to broadcast.
472
  target (SparseTensor): Sparse tensor to broadcast to.
 
479
  return broadcasted
480
 
481
 
482
+ def sparse_batch_op(
483
+ input: SparseTensor, other: torch.Tensor, op: callable = torch.add
484
+ ) -> SparseTensor:
485
  """
486
  Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
487
+
488
  Args:
489
  input (torch.Tensor): 1D tensor to broadcast.
490
  target (SparseTensor): Sparse tensor to broadcast to.
 
496
  def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
497
  """
498
  Concatenate a list of sparse tensors.
499
+
500
  Args:
501
  inputs (List[SparseTensor]): List of sparse tensors to concatenate.
502
  """
 
523
  def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
524
  """
525
  Unbind a sparse tensor along a dimension.
526
+
527
  Args:
528
  input (SparseTensor): Sparse tensor to unbind.
529
  dim (int): Dimension to unbind.
trellis/modules/sparse/conv/__init__.py CHANGED
@@ -1,21 +1,26 @@
1
  from .. import BACKEND
2
 
3
 
4
- SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native'
 
5
 
6
  def __from_env():
7
  import os
8
-
9
  global SPCONV_ALGO
10
- env_spconv_algo = os.environ.get('SPCONV_ALGO')
11
- if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']:
 
 
 
 
12
  SPCONV_ALGO = env_spconv_algo
13
  print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}")
14
-
15
 
16
  __from_env()
17
 
18
- if BACKEND == 'torchsparse':
19
  from .conv_torchsparse import *
20
- elif BACKEND == 'spconv':
21
  from .conv_spconv import *
 
1
  from .. import BACKEND
2
 
3
 
4
+ SPCONV_ALGO = "auto" # 'auto', 'implicit_gemm', 'native'
5
+
6
 
7
  def __from_env():
8
  import os
9
+
10
  global SPCONV_ALGO
11
+ env_spconv_algo = os.environ.get("SPCONV_ALGO")
12
+ if env_spconv_algo is not None and env_spconv_algo in [
13
+ "auto",
14
+ "implicit_gemm",
15
+ "native",
16
+ ]:
17
  SPCONV_ALGO = env_spconv_algo
18
  print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}")
19
+
20
 
21
  __from_env()
22
 
23
+ if BACKEND == "torchsparse":
24
  from .conv_torchsparse import *
25
+ elif BACKEND == "spconv":
26
  from .conv_spconv import *
trellis/modules/sparse/conv/conv_spconv.py CHANGED
@@ -4,21 +4,54 @@ from .. import SparseTensor
4
  from .. import DEBUG
5
  from . import SPCONV_ALGO
6
 
 
7
  class SparseConv3d(nn.Module):
8
- def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
 
 
 
 
 
 
 
 
 
 
9
  super(SparseConv3d, self).__init__()
10
- if 'spconv' not in globals():
11
  import spconv.pytorch as spconv
12
  algo = None
13
- if SPCONV_ALGO == 'native':
14
  algo = spconv.ConvAlgo.Native
15
- elif SPCONV_ALGO == 'implicit_gemm':
16
  algo = spconv.ConvAlgo.MaskImplicitGemm
17
  if stride == 1 and (padding is None):
18
- self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo)
 
 
 
 
 
 
 
 
19
  else:
20
- self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo)
21
- self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  self.padding = padding
23
 
24
  def forward(self, x: SparseTensor) -> SparseTensor:
@@ -30,42 +63,65 @@ class SparseConv3d(nn.Module):
30
  if spatial_changed and (x.shape[0] != 1):
31
  # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords
32
  fwd = new_data.indices[:, 0].argsort()
33
- bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device))
 
 
34
  sorted_feats = new_data.features[fwd]
35
  sorted_coords = new_data.indices[fwd]
36
  unsorted_data = new_data
37
  new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore
38
 
39
  out = SparseTensor(
40
- new_data, shape=torch.Size(new_shape), layout=new_layout,
 
 
41
  scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]),
42
  spatial_cache=x._spatial_cache,
43
  )
44
 
45
  if spatial_changed and (x.shape[0] != 1):
46
- out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data)
47
- out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd)
48
-
 
 
49
  return out
50
 
51
 
52
  class SparseInverseConv3d(nn.Module):
53
- def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
 
 
 
 
 
 
 
 
 
54
  super(SparseInverseConv3d, self).__init__()
55
- if 'spconv' not in globals():
56
  import spconv.pytorch as spconv
57
- self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key)
58
- self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
 
 
 
 
 
 
59
 
60
  def forward(self, x: SparseTensor) -> SparseTensor:
61
  spatial_changed = any(s != 1 for s in self.stride)
62
  if spatial_changed:
63
  # recover the original spconv order
64
- data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data')
65
- bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd')
66
  data = data.replace_feature(x.feats[bwd])
67
  if DEBUG:
68
- assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed'
 
 
69
  else:
70
  data = x.data
71
 
@@ -73,7 +129,9 @@ class SparseInverseConv3d(nn.Module):
73
  new_shape = [x.shape[0], self.conv.out_channels]
74
  new_layout = None if spatial_changed else x.layout
75
  out = SparseTensor(
76
- new_data, shape=torch.Size(new_shape), layout=new_layout,
 
 
77
  scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]),
78
  spatial_cache=x._spatial_cache,
79
  )
 
4
  from .. import DEBUG
5
  from . import SPCONV_ALGO
6
 
7
+
8
  class SparseConv3d(nn.Module):
9
+ def __init__(
10
+ self,
11
+ in_channels,
12
+ out_channels,
13
+ kernel_size,
14
+ stride=1,
15
+ dilation=1,
16
+ padding=None,
17
+ bias=True,
18
+ indice_key=None,
19
+ ):
20
  super(SparseConv3d, self).__init__()
21
+ if "spconv" not in globals():
22
  import spconv.pytorch as spconv
23
  algo = None
24
+ if SPCONV_ALGO == "native":
25
  algo = spconv.ConvAlgo.Native
26
+ elif SPCONV_ALGO == "implicit_gemm":
27
  algo = spconv.ConvAlgo.MaskImplicitGemm
28
  if stride == 1 and (padding is None):
29
+ self.conv = spconv.SubMConv3d(
30
+ in_channels,
31
+ out_channels,
32
+ kernel_size,
33
+ dilation=dilation,
34
+ bias=bias,
35
+ indice_key=indice_key,
36
+ algo=algo,
37
+ )
38
  else:
39
+ self.conv = spconv.SparseConv3d(
40
+ in_channels,
41
+ out_channels,
42
+ kernel_size,
43
+ stride=stride,
44
+ dilation=dilation,
45
+ padding=padding,
46
+ bias=bias,
47
+ indice_key=indice_key,
48
+ algo=algo,
49
+ )
50
+ self.stride = (
51
+ tuple(stride)
52
+ if isinstance(stride, (list, tuple))
53
+ else (stride, stride, stride)
54
+ )
55
  self.padding = padding
56
 
57
  def forward(self, x: SparseTensor) -> SparseTensor:
 
63
  if spatial_changed and (x.shape[0] != 1):
64
  # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords
65
  fwd = new_data.indices[:, 0].argsort()
66
+ bwd = torch.zeros_like(fwd).scatter_(
67
+ 0, fwd, torch.arange(fwd.shape[0], device=fwd.device)
68
+ )
69
  sorted_feats = new_data.features[fwd]
70
  sorted_coords = new_data.indices[fwd]
71
  unsorted_data = new_data
72
  new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore
73
 
74
  out = SparseTensor(
75
+ new_data,
76
+ shape=torch.Size(new_shape),
77
+ layout=new_layout,
78
  scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]),
79
  spatial_cache=x._spatial_cache,
80
  )
81
 
82
  if spatial_changed and (x.shape[0] != 1):
83
+ out.register_spatial_cache(
84
+ f"conv_{self.stride}_unsorted_data", unsorted_data
85
+ )
86
+ out.register_spatial_cache(f"conv_{self.stride}_sort_bwd", bwd)
87
+
88
  return out
89
 
90
 
91
  class SparseInverseConv3d(nn.Module):
92
+ def __init__(
93
+ self,
94
+ in_channels,
95
+ out_channels,
96
+ kernel_size,
97
+ stride=1,
98
+ dilation=1,
99
+ bias=True,
100
+ indice_key=None,
101
+ ):
102
  super(SparseInverseConv3d, self).__init__()
103
+ if "spconv" not in globals():
104
  import spconv.pytorch as spconv
105
+ self.conv = spconv.SparseInverseConv3d(
106
+ in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key
107
+ )
108
+ self.stride = (
109
+ tuple(stride)
110
+ if isinstance(stride, (list, tuple))
111
+ else (stride, stride, stride)
112
+ )
113
 
114
  def forward(self, x: SparseTensor) -> SparseTensor:
115
  spatial_changed = any(s != 1 for s in self.stride)
116
  if spatial_changed:
117
  # recover the original spconv order
118
+ data = x.get_spatial_cache(f"conv_{self.stride}_unsorted_data")
119
+ bwd = x.get_spatial_cache(f"conv_{self.stride}_sort_bwd")
120
  data = data.replace_feature(x.feats[bwd])
121
  if DEBUG:
122
+ assert torch.equal(
123
+ data.indices, x.coords[bwd]
124
+ ), "Recover the original order failed"
125
  else:
126
  data = x.data
127
 
 
129
  new_shape = [x.shape[0], self.conv.out_channels]
130
  new_layout = None if spatial_changed else x.layout
131
  out = SparseTensor(
132
+ new_data,
133
+ shape=torch.Size(new_shape),
134
+ layout=new_layout,
135
  scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]),
136
  spatial_cache=x._spatial_cache,
137
  )
trellis/modules/sparse/conv/conv_torchsparse.py CHANGED
@@ -4,35 +4,73 @@ from .. import SparseTensor
4
 
5
 
6
  class SparseConv3d(nn.Module):
7
- def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
 
 
 
 
 
 
 
 
 
8
  super(SparseConv3d, self).__init__()
9
- if 'torchsparse' not in globals():
10
  import torchsparse
11
- self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias)
 
 
12
 
13
  def forward(self, x: SparseTensor) -> SparseTensor:
14
  out = self.conv(x.data)
15
  new_shape = [x.shape[0], self.conv.out_channels]
16
- out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
 
 
 
 
17
  out._spatial_cache = x._spatial_cache
18
- out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)])
 
 
19
  return out
20
 
21
 
22
  class SparseInverseConv3d(nn.Module):
23
- def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
 
 
 
 
 
 
 
 
 
24
  super(SparseInverseConv3d, self).__init__()
25
- if 'torchsparse' not in globals():
26
  import torchsparse
27
- self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True)
 
 
 
 
 
 
 
 
 
28
 
29
  def forward(self, x: SparseTensor) -> SparseTensor:
30
- out = self.conv(x.data)
31
  new_shape = [x.shape[0], self.conv.out_channels]
32
- out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
 
 
 
 
33
  out._spatial_cache = x._spatial_cache
34
- out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)])
 
 
35
  return out
36
-
37
-
38
-
 
4
 
5
 
6
  class SparseConv3d(nn.Module):
7
+ def __init__(
8
+ self,
9
+ in_channels,
10
+ out_channels,
11
+ kernel_size,
12
+ stride=1,
13
+ dilation=1,
14
+ bias=True,
15
+ indice_key=None,
16
+ ):
17
  super(SparseConv3d, self).__init__()
18
+ if "torchsparse" not in globals():
19
  import torchsparse
20
+ self.conv = torchsparse.nn.Conv3d(
21
+ in_channels, out_channels, kernel_size, stride, 0, dilation, bias
22
+ )
23
 
24
  def forward(self, x: SparseTensor) -> SparseTensor:
25
  out = self.conv(x.data)
26
  new_shape = [x.shape[0], self.conv.out_channels]
27
+ out = SparseTensor(
28
+ out,
29
+ shape=torch.Size(new_shape),
30
+ layout=x.layout if all(s == 1 for s in self.conv.stride) else None,
31
+ )
32
  out._spatial_cache = x._spatial_cache
33
+ out._scale = tuple(
34
+ [s * stride for s, stride in zip(x._scale, self.conv.stride)]
35
+ )
36
  return out
37
 
38
 
39
  class SparseInverseConv3d(nn.Module):
40
+ def __init__(
41
+ self,
42
+ in_channels,
43
+ out_channels,
44
+ kernel_size,
45
+ stride=1,
46
+ dilation=1,
47
+ bias=True,
48
+ indice_key=None,
49
+ ):
50
  super(SparseInverseConv3d, self).__init__()
51
+ if "torchsparse" not in globals():
52
  import torchsparse
53
+ self.conv = torchsparse.nn.Conv3d(
54
+ in_channels,
55
+ out_channels,
56
+ kernel_size,
57
+ stride,
58
+ 0,
59
+ dilation,
60
+ bias,
61
+ transposed=True,
62
+ )
63
 
64
  def forward(self, x: SparseTensor) -> SparseTensor:
65
+ out = self.conv(x.data)
66
  new_shape = [x.shape[0], self.conv.out_channels]
67
+ out = SparseTensor(
68
+ out,
69
+ shape=torch.Size(new_shape),
70
+ layout=x.layout if all(s == 1 for s in self.conv.stride) else None,
71
+ )
72
  out._spatial_cache = x._spatial_cache
73
+ out._scale = tuple(
74
+ [s // stride for s, stride in zip(x._scale, self.conv.stride)]
75
+ )
76
  return out
 
 
 
trellis/modules/sparse/linear.py CHANGED
@@ -2,9 +2,7 @@ import torch
2
  import torch.nn as nn
3
  from . import SparseTensor
4
 
5
- __all__ = [
6
- 'SparseLinear'
7
- ]
8
 
9
 
10
  class SparseLinear(nn.Linear):
 
2
  import torch.nn as nn
3
  from . import SparseTensor
4
 
5
+ __all__ = ["SparseLinear"]
 
 
6
 
7
 
8
  class SparseLinear(nn.Linear):
trellis/modules/sparse/nonlinearity.py CHANGED
@@ -2,18 +2,13 @@ import torch
2
  import torch.nn as nn
3
  from . import SparseTensor
4
 
5
- __all__ = [
6
- 'SparseReLU',
7
- 'SparseSiLU',
8
- 'SparseGELU',
9
- 'SparseActivation'
10
- ]
11
 
12
 
13
  class SparseReLU(nn.ReLU):
14
  def forward(self, input: SparseTensor) -> SparseTensor:
15
  return input.replace(super().forward(input.feats))
16
-
17
 
18
  class SparseSiLU(nn.SiLU):
19
  def forward(self, input: SparseTensor) -> SparseTensor:
@@ -32,4 +27,3 @@ class SparseActivation(nn.Module):
32
 
33
  def forward(self, input: SparseTensor) -> SparseTensor:
34
  return input.replace(self.activation(input.feats))
35
-
 
2
  import torch.nn as nn
3
  from . import SparseTensor
4
 
5
+ __all__ = ["SparseReLU", "SparseSiLU", "SparseGELU", "SparseActivation"]
 
 
 
 
 
6
 
7
 
8
  class SparseReLU(nn.ReLU):
9
  def forward(self, input: SparseTensor) -> SparseTensor:
10
  return input.replace(super().forward(input.feats))
11
+
12
 
13
  class SparseSiLU(nn.SiLU):
14
  def forward(self, input: SparseTensor) -> SparseTensor:
 
27
 
28
  def forward(self, input: SparseTensor) -> SparseTensor:
29
  return input.replace(self.activation(input.feats))
 
trellis/modules/sparse/norm.py CHANGED
@@ -4,10 +4,10 @@ from . import SparseTensor
4
  from . import DEBUG
5
 
6
  __all__ = [
7
- 'SparseGroupNorm',
8
- 'SparseLayerNorm',
9
- 'SparseGroupNorm32',
10
- 'SparseLayerNorm32',
11
  ]
12
 
13
 
@@ -19,7 +19,9 @@ class SparseGroupNorm(nn.GroupNorm):
19
  nfeats = torch.zeros_like(input.feats)
20
  for k in range(input.shape[0]):
21
  if DEBUG:
22
- assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch"
 
 
23
  bfeats = input.feats[input.layout[k]]
24
  bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
25
  bfeats = super().forward(bfeats)
@@ -47,12 +49,15 @@ class SparseGroupNorm32(SparseGroupNorm):
47
  """
48
  A GroupNorm layer that converts to float32 before the forward pass.
49
  """
 
50
  def forward(self, x: SparseTensor) -> SparseTensor:
51
  return super().forward(x.float()).type(x.dtype)
52
 
 
53
  class SparseLayerNorm32(SparseLayerNorm):
54
  """
55
  A LayerNorm layer that converts to float32 before the forward pass.
56
  """
 
57
  def forward(self, x: SparseTensor) -> SparseTensor:
58
  return super().forward(x.float()).type(x.dtype)
 
4
  from . import DEBUG
5
 
6
  __all__ = [
7
+ "SparseGroupNorm",
8
+ "SparseLayerNorm",
9
+ "SparseGroupNorm32",
10
+ "SparseLayerNorm32",
11
  ]
12
 
13
 
 
19
  nfeats = torch.zeros_like(input.feats)
20
  for k in range(input.shape[0]):
21
  if DEBUG:
22
+ assert (
23
+ input.coords[input.layout[k], 0] == k
24
+ ).all(), f"SparseGroupNorm: batch index mismatch"
25
  bfeats = input.feats[input.layout[k]]
26
  bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
27
  bfeats = super().forward(bfeats)
 
49
  """
50
  A GroupNorm layer that converts to float32 before the forward pass.
51
  """
52
+
53
  def forward(self, x: SparseTensor) -> SparseTensor:
54
  return super().forward(x.float()).type(x.dtype)
55
 
56
+
57
  class SparseLayerNorm32(SparseLayerNorm):
58
  """
59
  A LayerNorm layer that converts to float32 before the forward pass.
60
  """
61
+
62
  def forward(self, x: SparseTensor) -> SparseTensor:
63
  return super().forward(x.float()).type(x.dtype)
trellis/modules/sparse/spatial.py CHANGED
@@ -3,11 +3,7 @@ import torch
3
  import torch.nn as nn
4
  from . import SparseTensor
5
 
6
- __all__ = [
7
- 'SparseDownsample',
8
- 'SparseUpsample',
9
- 'SparseSubdivide'
10
- ]
11
 
12
 
13
  class SparseDownsample(nn.Module):
@@ -15,6 +11,7 @@ class SparseDownsample(nn.Module):
15
  Downsample a sparse tensor by a factor of `factor`.
16
  Implemented as average pooling.
17
  """
 
18
  def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]):
19
  super(SparseDownsample, self).__init__()
20
  self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
@@ -22,36 +19,47 @@ class SparseDownsample(nn.Module):
22
  def forward(self, input: SparseTensor) -> SparseTensor:
23
  DIM = input.coords.shape[-1] - 1
24
  factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
25
- assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.'
 
 
26
 
27
  coord = list(input.coords.unbind(dim=-1))
28
  for i, f in enumerate(factor):
29
- coord[i+1] = coord[i+1] // f
30
 
31
- MAX = [coord[i+1].max().item() + 1 for i in range(DIM)]
32
  OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
33
  code = sum([c * o for c, o in zip(coord, OFFSET)])
34
  code, idx = code.unique(return_inverse=True)
35
 
36
  new_feats = torch.scatter_reduce(
37
- torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype),
 
 
 
 
 
38
  dim=0,
39
  index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]),
40
  src=input.feats,
41
- reduce='mean'
42
  )
43
  new_coords = torch.stack(
44
- [code // OFFSET[0]] +
45
- [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
46
- dim=-1
 
 
 
 
 
47
  )
48
- out = SparseTensor(new_feats, new_coords, input.shape,)
49
  out._scale = tuple([s // f for s, f in zip(input._scale, factor)])
50
  out._spatial_cache = input._spatial_cache
51
 
52
- out.register_spatial_cache(f'upsample_{factor}_coords', input.coords)
53
- out.register_spatial_cache(f'upsample_{factor}_layout', input.layout)
54
- out.register_spatial_cache(f'upsample_{factor}_idx', idx)
55
 
56
  return out
57
 
@@ -61,6 +69,7 @@ class SparseUpsample(nn.Module):
61
  Upsample a sparse tensor by a factor of `factor`.
62
  Implemented as nearest neighbor interpolation.
63
  """
 
64
  def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]):
65
  super(SparseUpsample, self).__init__()
66
  self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
@@ -68,24 +77,30 @@ class SparseUpsample(nn.Module):
68
  def forward(self, input: SparseTensor) -> SparseTensor:
69
  DIM = input.coords.shape[-1] - 1
70
  factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
71
- assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.'
 
 
72
 
73
- new_coords = input.get_spatial_cache(f'upsample_{factor}_coords')
74
- new_layout = input.get_spatial_cache(f'upsample_{factor}_layout')
75
- idx = input.get_spatial_cache(f'upsample_{factor}_idx')
76
  if any([x is None for x in [new_coords, new_layout, idx]]):
77
- raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.')
 
 
78
  new_feats = input.feats[idx]
79
  out = SparseTensor(new_feats, new_coords, input.shape, new_layout)
80
  out._scale = tuple([s * f for s, f in zip(input._scale, factor)])
81
  out._spatial_cache = input._spatial_cache
82
  return out
83
-
 
84
  class SparseSubdivide(nn.Module):
85
  """
86
  Upsample a sparse tensor by a factor of `factor`.
87
  Implemented as nearest neighbor interpolation.
88
  """
 
89
  def __init__(self):
90
  super(SparseSubdivide, self).__init__()
91
 
@@ -96,15 +111,20 @@ class SparseSubdivide(nn.Module):
96
  n_coords = torch.nonzero(n_cube)
97
  n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1)
98
  factor = n_coords.shape[0]
99
- assert factor == 2 ** DIM
100
  # print(n_coords.shape)
101
  new_coords = input.coords.clone()
102
  new_coords[:, 1:] *= 2
103
- new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype)
104
-
105
- new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:])
106
- out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape)
 
 
 
 
 
 
107
  out._scale = input._scale * 2
108
  out._spatial_cache = input._spatial_cache
109
  return out
110
-
 
3
  import torch.nn as nn
4
  from . import SparseTensor
5
 
6
+ __all__ = ["SparseDownsample", "SparseUpsample", "SparseSubdivide"]
 
 
 
 
7
 
8
 
9
  class SparseDownsample(nn.Module):
 
11
  Downsample a sparse tensor by a factor of `factor`.
12
  Implemented as average pooling.
13
  """
14
+
15
  def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]):
16
  super(SparseDownsample, self).__init__()
17
  self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
 
19
  def forward(self, input: SparseTensor) -> SparseTensor:
20
  DIM = input.coords.shape[-1] - 1
21
  factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
22
+ assert DIM == len(
23
+ factor
24
+ ), "Input coordinates must have the same dimension as the downsample factor."
25
 
26
  coord = list(input.coords.unbind(dim=-1))
27
  for i, f in enumerate(factor):
28
+ coord[i + 1] = coord[i + 1] // f
29
 
30
+ MAX = [coord[i + 1].max().item() + 1 for i in range(DIM)]
31
  OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
32
  code = sum([c * o for c, o in zip(coord, OFFSET)])
33
  code, idx = code.unique(return_inverse=True)
34
 
35
  new_feats = torch.scatter_reduce(
36
+ torch.zeros(
37
+ code.shape[0],
38
+ input.feats.shape[1],
39
+ device=input.feats.device,
40
+ dtype=input.feats.dtype,
41
+ ),
42
  dim=0,
43
  index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]),
44
  src=input.feats,
45
+ reduce="mean",
46
  )
47
  new_coords = torch.stack(
48
+ [code // OFFSET[0]]
49
+ + [(code // OFFSET[i + 1]) % MAX[i] for i in range(DIM)],
50
+ dim=-1,
51
+ )
52
+ out = SparseTensor(
53
+ new_feats,
54
+ new_coords,
55
+ input.shape,
56
  )
 
57
  out._scale = tuple([s // f for s, f in zip(input._scale, factor)])
58
  out._spatial_cache = input._spatial_cache
59
 
60
+ out.register_spatial_cache(f"upsample_{factor}_coords", input.coords)
61
+ out.register_spatial_cache(f"upsample_{factor}_layout", input.layout)
62
+ out.register_spatial_cache(f"upsample_{factor}_idx", idx)
63
 
64
  return out
65
 
 
69
  Upsample a sparse tensor by a factor of `factor`.
70
  Implemented as nearest neighbor interpolation.
71
  """
72
+
73
  def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]):
74
  super(SparseUpsample, self).__init__()
75
  self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
 
77
  def forward(self, input: SparseTensor) -> SparseTensor:
78
  DIM = input.coords.shape[-1] - 1
79
  factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
80
+ assert DIM == len(
81
+ factor
82
+ ), "Input coordinates must have the same dimension as the upsample factor."
83
 
84
+ new_coords = input.get_spatial_cache(f"upsample_{factor}_coords")
85
+ new_layout = input.get_spatial_cache(f"upsample_{factor}_layout")
86
+ idx = input.get_spatial_cache(f"upsample_{factor}_idx")
87
  if any([x is None for x in [new_coords, new_layout, idx]]):
88
+ raise ValueError(
89
+ "Upsample cache not found. SparseUpsample must be paired with SparseDownsample."
90
+ )
91
  new_feats = input.feats[idx]
92
  out = SparseTensor(new_feats, new_coords, input.shape, new_layout)
93
  out._scale = tuple([s * f for s, f in zip(input._scale, factor)])
94
  out._spatial_cache = input._spatial_cache
95
  return out
96
+
97
+
98
  class SparseSubdivide(nn.Module):
99
  """
100
  Upsample a sparse tensor by a factor of `factor`.
101
  Implemented as nearest neighbor interpolation.
102
  """
103
+
104
  def __init__(self):
105
  super(SparseSubdivide, self).__init__()
106
 
 
111
  n_coords = torch.nonzero(n_cube)
112
  n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1)
113
  factor = n_coords.shape[0]
114
+ assert factor == 2**DIM
115
  # print(n_coords.shape)
116
  new_coords = input.coords.clone()
117
  new_coords[:, 1:] *= 2
118
+ new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(
119
+ new_coords.dtype
120
+ )
121
+
122
+ new_feats = input.feats.unsqueeze(1).expand(
123
+ input.feats.shape[0], factor, *input.feats.shape[1:]
124
+ )
125
+ out = SparseTensor(
126
+ new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape
127
+ )
128
  out._scale = input._scale * 2
129
  out._spatial_cache = input._spatial_cache
130
  return out
 
trellis/modules/sparse/transformer/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
  from .blocks import *
2
- from .modulated import *
 
1
  from .blocks import *
2
+ from .modulated import *
trellis/modules/sparse/transformer/blocks.py CHANGED
@@ -25,12 +25,15 @@ class SparseTransformerBlock(nn.Module):
25
  """
26
  Sparse Transformer block (MSA + FFN).
27
  """
 
28
  def __init__(
29
  self,
30
  channels: int,
31
  num_heads: int,
32
  mlp_ratio: float = 4.0,
33
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
 
 
34
  window_size: Optional[int] = None,
35
  shift_sequence: Optional[int] = None,
36
  shift_window: Optional[Tuple[int, int, int]] = None,
@@ -73,7 +76,9 @@ class SparseTransformerBlock(nn.Module):
73
 
74
  def forward(self, x: SparseTensor) -> SparseTensor:
75
  if self.use_checkpoint:
76
- return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
 
 
77
  else:
78
  return self._forward(x)
79
 
@@ -82,13 +87,16 @@ class SparseTransformerCrossBlock(nn.Module):
82
  """
83
  Sparse Transformer cross-attention block (MSA + MCA + FFN).
84
  """
 
85
  def __init__(
86
  self,
87
  channels: int,
88
  ctx_channels: int,
89
  num_heads: int,
90
  mlp_ratio: float = 4.0,
91
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
 
 
92
  window_size: Optional[int] = None,
93
  shift_sequence: Optional[int] = None,
94
  shift_window: Optional[Tuple[int, int, int]] = None,
@@ -146,6 +154,8 @@ class SparseTransformerCrossBlock(nn.Module):
146
 
147
  def forward(self, x: SparseTensor, context: torch.Tensor):
148
  if self.use_checkpoint:
149
- return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
 
 
150
  else:
151
  return self._forward(x, context)
 
25
  """
26
  Sparse Transformer block (MSA + FFN).
27
  """
28
+
29
  def __init__(
30
  self,
31
  channels: int,
32
  num_heads: int,
33
  mlp_ratio: float = 4.0,
34
+ attn_mode: Literal[
35
+ "full", "shift_window", "shift_sequence", "shift_order", "swin"
36
+ ] = "full",
37
  window_size: Optional[int] = None,
38
  shift_sequence: Optional[int] = None,
39
  shift_window: Optional[Tuple[int, int, int]] = None,
 
76
 
77
  def forward(self, x: SparseTensor) -> SparseTensor:
78
  if self.use_checkpoint:
79
+ return torch.utils.checkpoint.checkpoint(
80
+ self._forward, x, use_reentrant=False
81
+ )
82
  else:
83
  return self._forward(x)
84
 
 
87
  """
88
  Sparse Transformer cross-attention block (MSA + MCA + FFN).
89
  """
90
+
91
  def __init__(
92
  self,
93
  channels: int,
94
  ctx_channels: int,
95
  num_heads: int,
96
  mlp_ratio: float = 4.0,
97
+ attn_mode: Literal[
98
+ "full", "shift_window", "shift_sequence", "shift_order", "swin"
99
+ ] = "full",
100
  window_size: Optional[int] = None,
101
  shift_sequence: Optional[int] = None,
102
  shift_window: Optional[Tuple[int, int, int]] = None,
 
154
 
155
  def forward(self, x: SparseTensor, context: torch.Tensor):
156
  if self.use_checkpoint:
157
+ return torch.utils.checkpoint.checkpoint(
158
+ self._forward, x, context, use_reentrant=False
159
+ )
160
  else:
161
  return self._forward(x, context)
trellis/modules/sparse/transformer/modulated.py CHANGED
@@ -11,12 +11,15 @@ class ModulatedSparseTransformerBlock(nn.Module):
11
  """
12
  Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
13
  """
 
14
  def __init__(
15
  self,
16
  channels: int,
17
  num_heads: int,
18
  mlp_ratio: float = 4.0,
19
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
 
 
20
  window_size: Optional[int] = None,
21
  shift_sequence: Optional[int] = None,
22
  shift_window: Optional[Tuple[int, int, int]] = None,
@@ -50,15 +53,18 @@ class ModulatedSparseTransformerBlock(nn.Module):
50
  )
51
  if not share_mod:
52
  self.adaLN_modulation = nn.Sequential(
53
- nn.SiLU(),
54
- nn.Linear(channels, 6 * channels, bias=True)
55
  )
56
 
57
  def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
58
  if self.share_mod:
59
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
 
 
60
  else:
61
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
 
 
62
  h = x.replace(self.norm1(x.feats))
63
  h = h * (1 + scale_msa) + shift_msa
64
  h = self.attn(h)
@@ -73,7 +79,9 @@ class ModulatedSparseTransformerBlock(nn.Module):
73
 
74
  def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
75
  if self.use_checkpoint:
76
- return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
 
 
77
  else:
78
  return self._forward(x, mod)
79
 
@@ -82,13 +90,16 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
82
  """
83
  Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
84
  """
 
85
  def __init__(
86
  self,
87
  channels: int,
88
  ctx_channels: int,
89
  num_heads: int,
90
  mlp_ratio: float = 4.0,
91
- attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
 
 
92
  window_size: Optional[int] = None,
93
  shift_sequence: Optional[int] = None,
94
  shift_window: Optional[Tuple[int, int, int]] = None,
@@ -99,7 +110,6 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
99
  qk_rms_norm_cross: bool = False,
100
  qkv_bias: bool = True,
101
  share_mod: bool = False,
102
-
103
  ):
104
  super().__init__()
105
  self.use_checkpoint = use_checkpoint
@@ -135,15 +145,20 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
135
  )
136
  if not share_mod:
137
  self.adaLN_modulation = nn.Sequential(
138
- nn.SiLU(),
139
- nn.Linear(channels, 6 * channels, bias=True)
140
  )
141
 
142
- def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
 
 
143
  if self.share_mod:
144
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
 
 
145
  else:
146
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
 
 
147
  h = x.replace(self.norm1(x.feats))
148
  h = h * (1 + scale_msa) + shift_msa
149
  h = self.self_attn(h)
@@ -159,8 +174,12 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
159
  x = x + h
160
  return x
161
 
162
- def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
 
 
163
  if self.use_checkpoint:
164
- return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
 
 
165
  else:
166
  return self._forward(x, mod, context)
 
11
  """
12
  Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
13
  """
14
+
15
  def __init__(
16
  self,
17
  channels: int,
18
  num_heads: int,
19
  mlp_ratio: float = 4.0,
20
+ attn_mode: Literal[
21
+ "full", "shift_window", "shift_sequence", "shift_order", "swin"
22
+ ] = "full",
23
  window_size: Optional[int] = None,
24
  shift_sequence: Optional[int] = None,
25
  shift_window: Optional[Tuple[int, int, int]] = None,
 
53
  )
54
  if not share_mod:
55
  self.adaLN_modulation = nn.Sequential(
56
+ nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True)
 
57
  )
58
 
59
  def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
60
  if self.share_mod:
61
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(
62
+ 6, dim=1
63
+ )
64
  else:
65
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
66
+ self.adaLN_modulation(mod).chunk(6, dim=1)
67
+ )
68
  h = x.replace(self.norm1(x.feats))
69
  h = h * (1 + scale_msa) + shift_msa
70
  h = self.attn(h)
 
79
 
80
  def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
81
  if self.use_checkpoint:
82
+ return torch.utils.checkpoint.checkpoint(
83
+ self._forward, x, mod, use_reentrant=False
84
+ )
85
  else:
86
  return self._forward(x, mod)
87
 
 
90
  """
91
  Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
92
  """
93
+
94
  def __init__(
95
  self,
96
  channels: int,
97
  ctx_channels: int,
98
  num_heads: int,
99
  mlp_ratio: float = 4.0,
100
+ attn_mode: Literal[
101
+ "full", "shift_window", "shift_sequence", "shift_order", "swin"
102
+ ] = "full",
103
  window_size: Optional[int] = None,
104
  shift_sequence: Optional[int] = None,
105
  shift_window: Optional[Tuple[int, int, int]] = None,
 
110
  qk_rms_norm_cross: bool = False,
111
  qkv_bias: bool = True,
112
  share_mod: bool = False,
 
113
  ):
114
  super().__init__()
115
  self.use_checkpoint = use_checkpoint
 
145
  )
146
  if not share_mod:
147
  self.adaLN_modulation = nn.Sequential(
148
+ nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True)
 
149
  )
150
 
151
+ def _forward(
152
+ self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor
153
+ ) -> SparseTensor:
154
  if self.share_mod:
155
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(
156
+ 6, dim=1
157
+ )
158
  else:
159
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
160
+ self.adaLN_modulation(mod).chunk(6, dim=1)
161
+ )
162
  h = x.replace(self.norm1(x.feats))
163
  h = h * (1 + scale_msa) + shift_msa
164
  h = self.self_attn(h)
 
174
  x = x + h
175
  return x
176
 
177
+ def forward(
178
+ self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor
179
+ ) -> SparseTensor:
180
  if self.use_checkpoint:
181
+ return torch.utils.checkpoint.checkpoint(
182
+ self._forward, x, mod, context, use_reentrant=False
183
+ )
184
  else:
185
  return self._forward(x, mod, context)
trellis/modules/spatial.py CHANGED
@@ -9,7 +9,7 @@ def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
9
  C_ = C // scale_factor**3
10
  x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D)
11
  x = x.permute(0, 1, 5, 2, 6, 3, 7, 4)
12
- x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor)
13
  return x
14
 
15
 
@@ -23,11 +23,18 @@ def patchify(x: torch.Tensor, patch_size: int):
23
  """
24
  DIM = x.dim() - 2
25
  for d in range(2, DIM + 2):
26
- assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}"
 
 
27
 
28
- x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], []))
29
- x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)]))
30
- x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:]))
 
 
 
 
 
31
  return x
32
 
33
 
@@ -40,9 +47,18 @@ def unpatchify(x: torch.Tensor, patch_size: int):
40
  patch_size (int): Patch size
41
  """
42
  DIM = x.dim() - 2
43
- assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}"
 
 
44
 
45
- x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:]))
 
 
 
 
 
46
  x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], [])))
47
- x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)])
 
 
48
  return x
 
9
  C_ = C // scale_factor**3
10
  x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D)
11
  x = x.permute(0, 1, 5, 2, 6, 3, 7, 4)
12
+ x = x.reshape(B, C_, H * scale_factor, W * scale_factor, D * scale_factor)
13
  return x
14
 
15
 
 
23
  """
24
  DIM = x.dim() - 2
25
  for d in range(2, DIM + 2):
26
+ assert (
27
+ x.shape[d] % patch_size == 0
28
+ ), f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}"
29
 
30
+ x = x.reshape(
31
+ *x.shape[:2],
32
+ *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], []),
33
+ )
34
+ x = x.permute(
35
+ 0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)])
36
+ )
37
+ x = x.reshape(x.shape[0], x.shape[1] * (patch_size**DIM), *(x.shape[-DIM:]))
38
  return x
39
 
40
 
 
47
  patch_size (int): Patch size
48
  """
49
  DIM = x.dim() - 2
50
+ assert (
51
+ x.shape[1] % (patch_size**DIM) == 0
52
+ ), f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}"
53
 
54
+ x = x.reshape(
55
+ x.shape[0],
56
+ x.shape[1] // (patch_size**DIM),
57
+ *([patch_size] * DIM),
58
+ *(x.shape[-DIM:]),
59
+ )
60
  x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], [])))
61
+ x = x.reshape(
62
+ x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)]
63
+ )
64
  return x
trellis/modules/transformer/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
  from .blocks import *
2
- from .modulated import *
 
1
  from .blocks import *
2
+ from .modulated import *
trellis/modules/transformer/blocks.py CHANGED
@@ -9,14 +9,15 @@ class AbsolutePositionEmbedder(nn.Module):
9
  """
10
  Embeds spatial positions into vector representations.
11
  """
 
12
  def __init__(self, channels: int, in_channels: int = 3):
13
  super().__init__()
14
  self.channels = channels
15
  self.in_channels = in_channels
16
  self.freq_dim = channels // in_channels // 2
17
  self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
18
- self.freqs = 1.0 / (10000 ** self.freqs)
19
-
20
  def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor:
21
  """
22
  Create sinusoidal position embeddings.
@@ -38,11 +39,19 @@ class AbsolutePositionEmbedder(nn.Module):
38
  x (torch.Tensor): (N, D) tensor of spatial positions
39
  """
40
  N, D = x.shape
41
- assert D == self.in_channels, "Input dimension must match number of input channels"
 
 
42
  embed = self._sin_cos_embedding(x.reshape(-1))
43
  embed = embed.reshape(N, -1)
44
  if embed.shape[1] < self.channels:
45
- embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1)
 
 
 
 
 
 
46
  return embed
47
 
48
 
@@ -63,6 +72,7 @@ class TransformerBlock(nn.Module):
63
  """
64
  Transformer block (MSA + FFN).
65
  """
 
66
  def __init__(
67
  self,
68
  channels: int,
@@ -107,7 +117,9 @@ class TransformerBlock(nn.Module):
107
 
108
  def forward(self, x: torch.Tensor) -> torch.Tensor:
109
  if self.use_checkpoint:
110
- return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
 
 
111
  else:
112
  return self._forward(x)
113
 
@@ -116,6 +128,7 @@ class TransformerCrossBlock(nn.Module):
116
  """
117
  Transformer cross-attention block (MSA + MCA + FFN).
118
  """
 
119
  def __init__(
120
  self,
121
  channels: int,
@@ -176,7 +189,8 @@ class TransformerCrossBlock(nn.Module):
176
 
177
  def forward(self, x: torch.Tensor, context: torch.Tensor):
178
  if self.use_checkpoint:
179
- return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
 
 
180
  else:
181
  return self._forward(x, context)
182
-
 
9
  """
10
  Embeds spatial positions into vector representations.
11
  """
12
+
13
  def __init__(self, channels: int, in_channels: int = 3):
14
  super().__init__()
15
  self.channels = channels
16
  self.in_channels = in_channels
17
  self.freq_dim = channels // in_channels // 2
18
  self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
19
+ self.freqs = 1.0 / (10000**self.freqs)
20
+
21
  def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor:
22
  """
23
  Create sinusoidal position embeddings.
 
39
  x (torch.Tensor): (N, D) tensor of spatial positions
40
  """
41
  N, D = x.shape
42
+ assert (
43
+ D == self.in_channels
44
+ ), "Input dimension must match number of input channels"
45
  embed = self._sin_cos_embedding(x.reshape(-1))
46
  embed = embed.reshape(N, -1)
47
  if embed.shape[1] < self.channels:
48
+ embed = torch.cat(
49
+ [
50
+ embed,
51
+ torch.zeros(N, self.channels - embed.shape[1], device=embed.device),
52
+ ],
53
+ dim=-1,
54
+ )
55
  return embed
56
 
57
 
 
72
  """
73
  Transformer block (MSA + FFN).
74
  """
75
+
76
  def __init__(
77
  self,
78
  channels: int,
 
117
 
118
  def forward(self, x: torch.Tensor) -> torch.Tensor:
119
  if self.use_checkpoint:
120
+ return torch.utils.checkpoint.checkpoint(
121
+ self._forward, x, use_reentrant=False
122
+ )
123
  else:
124
  return self._forward(x)
125
 
 
128
  """
129
  Transformer cross-attention block (MSA + MCA + FFN).
130
  """
131
+
132
  def __init__(
133
  self,
134
  channels: int,
 
189
 
190
  def forward(self, x: torch.Tensor, context: torch.Tensor):
191
  if self.use_checkpoint:
192
+ return torch.utils.checkpoint.checkpoint(
193
+ self._forward, x, context, use_reentrant=False
194
+ )
195
  else:
196
  return self._forward(x, context)
 
trellis/modules/transformer/modulated.py CHANGED
@@ -10,6 +10,7 @@ class ModulatedTransformerBlock(nn.Module):
10
  """
11
  Transformer block (MSA + FFN) with adaptive layer norm conditioning.
12
  """
 
13
  def __init__(
14
  self,
15
  channels: int,
@@ -45,15 +46,18 @@ class ModulatedTransformerBlock(nn.Module):
45
  )
46
  if not share_mod:
47
  self.adaLN_modulation = nn.Sequential(
48
- nn.SiLU(),
49
- nn.Linear(channels, 6 * channels, bias=True)
50
  )
51
 
52
  def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
53
  if self.share_mod:
54
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
 
 
55
  else:
56
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
 
 
57
  h = self.norm1(x)
58
  h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
59
  h = self.attn(h)
@@ -68,7 +72,9 @@ class ModulatedTransformerBlock(nn.Module):
68
 
69
  def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
70
  if self.use_checkpoint:
71
- return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
 
 
72
  else:
73
  return self._forward(x, mod)
74
 
@@ -77,6 +83,7 @@ class ModulatedTransformerCrossBlock(nn.Module):
77
  """
78
  Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
79
  """
 
80
  def __init__(
81
  self,
82
  channels: int,
@@ -125,15 +132,18 @@ class ModulatedTransformerCrossBlock(nn.Module):
125
  )
126
  if not share_mod:
127
  self.adaLN_modulation = nn.Sequential(
128
- nn.SiLU(),
129
- nn.Linear(channels, 6 * channels, bias=True)
130
  )
131
 
132
  def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
133
  if self.share_mod:
134
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
 
 
135
  else:
136
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
 
 
137
  h = self.norm1(x)
138
  h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
139
  h = self.self_attn(h)
@@ -151,7 +161,8 @@ class ModulatedTransformerCrossBlock(nn.Module):
151
 
152
  def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
153
  if self.use_checkpoint:
154
- return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
 
 
155
  else:
156
  return self._forward(x, mod, context)
157
-
 
10
  """
11
  Transformer block (MSA + FFN) with adaptive layer norm conditioning.
12
  """
13
+
14
  def __init__(
15
  self,
16
  channels: int,
 
46
  )
47
  if not share_mod:
48
  self.adaLN_modulation = nn.Sequential(
49
+ nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True)
 
50
  )
51
 
52
  def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
53
  if self.share_mod:
54
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(
55
+ 6, dim=1
56
+ )
57
  else:
58
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
59
+ self.adaLN_modulation(mod).chunk(6, dim=1)
60
+ )
61
  h = self.norm1(x)
62
  h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
63
  h = self.attn(h)
 
72
 
73
  def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
74
  if self.use_checkpoint:
75
+ return torch.utils.checkpoint.checkpoint(
76
+ self._forward, x, mod, use_reentrant=False
77
+ )
78
  else:
79
  return self._forward(x, mod)
80
 
 
83
  """
84
  Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
85
  """
86
+
87
  def __init__(
88
  self,
89
  channels: int,
 
132
  )
133
  if not share_mod:
134
  self.adaLN_modulation = nn.Sequential(
135
+ nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True)
 
136
  )
137
 
138
  def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
139
  if self.share_mod:
140
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(
141
+ 6, dim=1
142
+ )
143
  else:
144
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
145
+ self.adaLN_modulation(mod).chunk(6, dim=1)
146
+ )
147
  h = self.norm1(x)
148
  h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
149
  h = self.self_attn(h)
 
161
 
162
  def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
163
  if self.use_checkpoint:
164
+ return torch.utils.checkpoint.checkpoint(
165
+ self._forward, x, mod, context, use_reentrant=False
166
+ )
167
  else:
168
  return self._forward(x, mod, context)
 
trellis/modules/utils.py CHANGED
@@ -14,6 +14,7 @@ FP16_MODULES = (
14
  sp.SparseLinear,
15
  )
16
 
 
17
  def convert_module_to_f16(l):
18
  """
19
  Convert primitive modules to float16.
 
14
  sp.SparseLinear,
15
  )
16
 
17
+
18
  def convert_module_to_f16(l):
19
  """
20
  Convert primitive modules to float16.
trellis/pipelines/__init__.py CHANGED
@@ -11,14 +11,16 @@ def from_pretrained(path: str):
11
  """
12
  import os
13
  import json
 
14
  is_local = os.path.exists(f"{path}/pipeline.json")
15
 
16
  if is_local:
17
  config_file = f"{path}/pipeline.json"
18
  else:
19
  from huggingface_hub import hf_hub_download
 
20
  config_file = hf_hub_download(path, "pipeline.json")
21
 
22
- with open(config_file, 'r') as f:
23
  config = json.load(f)
24
- return globals()[config['name']].from_pretrained(path)
 
11
  """
12
  import os
13
  import json
14
+
15
  is_local = os.path.exists(f"{path}/pipeline.json")
16
 
17
  if is_local:
18
  config_file = f"{path}/pipeline.json"
19
  else:
20
  from huggingface_hub import hf_hub_download
21
+
22
  config_file = hf_hub_download(path, "pipeline.json")
23
 
24
+ with open(config_file, "r") as f:
25
  config = json.load(f)
26
+ return globals()[config["name"]].from_pretrained(path)
trellis/pipelines/base.py CHANGED
@@ -8,6 +8,7 @@ class Pipeline:
8
  """
9
  A base class for pipelines.
10
  """
 
11
  def __init__(
12
  self,
13
  models: dict[str, nn.Module] = None,
@@ -25,20 +26,21 @@ class Pipeline:
25
  """
26
  import os
27
  import json
 
28
  is_local = os.path.exists(f"{path}/pipeline.json")
29
 
30
  if is_local:
31
  config_file = f"{path}/pipeline.json"
32
  else:
33
  from huggingface_hub import hf_hub_download
 
34
  config_file = hf_hub_download(path, "pipeline.json")
35
 
36
- with open(config_file, 'r') as f:
37
- args = json.load(f)['args']
38
 
39
  _models = {
40
- k: models.from_pretrained(f"{path}/{v}")
41
- for k, v in args['models'].items()
42
  }
43
 
44
  new_pipeline = Pipeline(_models)
@@ -48,10 +50,10 @@ class Pipeline:
48
  @property
49
  def device(self) -> torch.device:
50
  for model in self.models.values():
51
- if hasattr(model, 'device'):
52
  return model.device
53
  for model in self.models.values():
54
- if hasattr(model, 'parameters'):
55
  return next(model.parameters()).device
56
  raise RuntimeError("No device found.")
57
 
 
8
  """
9
  A base class for pipelines.
10
  """
11
+
12
  def __init__(
13
  self,
14
  models: dict[str, nn.Module] = None,
 
26
  """
27
  import os
28
  import json
29
+
30
  is_local = os.path.exists(f"{path}/pipeline.json")
31
 
32
  if is_local:
33
  config_file = f"{path}/pipeline.json"
34
  else:
35
  from huggingface_hub import hf_hub_download
36
+
37
  config_file = hf_hub_download(path, "pipeline.json")
38
 
39
+ with open(config_file, "r") as f:
40
+ args = json.load(f)["args"]
41
 
42
  _models = {
43
+ k: models.from_pretrained(f"{path}/{v}") for k, v in args["models"].items()
 
44
  }
45
 
46
  new_pipeline = Pipeline(_models)
 
50
  @property
51
  def device(self) -> torch.device:
52
  for model in self.models.values():
53
+ if hasattr(model, "device"):
54
  return model.device
55
  for model in self.models.values():
56
+ if hasattr(model, "parameters"):
57
  return next(model.parameters()).device
58
  raise RuntimeError("No device found.")
59
 
trellis/pipelines/samplers/__init__.py CHANGED
@@ -1,2 +1,6 @@
1
  from .base import Sampler
2
- from .flow_euler import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler
 
 
 
 
 
1
  from .base import Sampler
2
+ from .flow_euler import (
3
+ FlowEulerSampler,
4
+ FlowEulerCfgSampler,
5
+ FlowEulerGuidanceIntervalSampler,
6
+ )
trellis/pipelines/samplers/base.py CHANGED
@@ -8,13 +8,8 @@ class Sampler(ABC):
8
  """
9
 
10
  @abstractmethod
11
- def sample(
12
- self,
13
- model,
14
- **kwargs
15
- ):
16
  """
17
  Sample from a model.
18
  """
19
  pass
20
-
 
8
  """
9
 
10
  @abstractmethod
11
+ def sample(self, model, **kwargs):
 
 
 
 
12
  """
13
  Sample from a model.
14
  """
15
  pass
 
trellis/pipelines/samplers/flow_euler.py CHANGED
@@ -15,6 +15,7 @@ class FlowEulerSampler(Sampler):
15
  Args:
16
  sigma_min: The minimum scale of noise in flow.
17
  """
 
18
  def __init__(
19
  self,
20
  sigma_min: float,
@@ -32,11 +33,15 @@ class FlowEulerSampler(Sampler):
32
  def _v_to_xstart_eps(self, x_t, t, v):
33
  assert x_t.shape == v.shape
34
  eps = (1 - t) * v + x_t
35
- x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v
 
 
36
  return x_0, eps
37
 
38
  def _inference_model(self, model, x_t, t, cond=None, **kwargs):
39
- t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32)
 
 
40
  return model(x_t, t, cond, **kwargs)
41
 
42
  def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs):
@@ -46,17 +51,11 @@ class FlowEulerSampler(Sampler):
46
 
47
  @torch.no_grad()
48
  def sample_once(
49
- self,
50
- model,
51
- x_t,
52
- t: float,
53
- t_prev: float,
54
- cond: Optional[Any] = None,
55
- **kwargs
56
  ):
57
  """
58
  Sample x_{t-1} from the model using Euler method.
59
-
60
  Args:
61
  model: The model to sample from.
62
  x_t: The [N x C x ...] tensor of noisy inputs at time t.
@@ -70,7 +69,9 @@ class FlowEulerSampler(Sampler):
70
  - 'pred_x_prev': x_{t-1}.
71
  - 'pred_x_0': a prediction of x_0.
72
  """
73
- pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs)
 
 
74
  pred_x_prev = x_t - (t - t_prev) * pred_v
75
  return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0})
76
 
@@ -87,7 +88,7 @@ class FlowEulerSampler(Sampler):
87
  ):
88
  """
89
  Generate samples from the model using Euler method.
90
-
91
  Args:
92
  model: The model to sample from.
93
  noise: The initial noise tensor.
@@ -121,6 +122,7 @@ class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler):
121
  """
122
  Generate samples from a flow-matching model using Euler sampling with classifier-free guidance.
123
  """
 
124
  @torch.no_grad()
125
  def sample(
126
  self,
@@ -136,7 +138,7 @@ class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler):
136
  ):
137
  """
138
  Generate samples from the model using Euler method.
139
-
140
  Args:
141
  model: The model to sample from.
142
  noise: The initial noise tensor.
@@ -154,13 +156,24 @@ class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler):
154
  - 'pred_x_t': a list of prediction of x_t.
155
  - 'pred_x_0': a list of prediction of x_0.
156
  """
157
- return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, **kwargs)
 
 
 
 
 
 
 
 
 
 
158
 
159
 
160
  class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSampler):
161
  """
162
  Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval.
163
  """
 
164
  @torch.no_grad()
165
  def sample(
166
  self,
@@ -177,7 +190,7 @@ class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSa
177
  ):
178
  """
179
  Generate samples from the model using Euler method.
180
-
181
  Args:
182
  model: The model to sample from.
183
  noise: The initial noise tensor.
@@ -196,4 +209,15 @@ class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSa
196
  - 'pred_x_t': a list of prediction of x_t.
197
  - 'pred_x_0': a list of prediction of x_0.
198
  """
199
- return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
15
  Args:
16
  sigma_min: The minimum scale of noise in flow.
17
  """
18
+
19
  def __init__(
20
  self,
21
  sigma_min: float,
 
33
  def _v_to_xstart_eps(self, x_t, t, v):
34
  assert x_t.shape == v.shape
35
  eps = (1 - t) * v + x_t
36
+ x_0 = (1 - self.sigma_min) * x_t - (
37
+ self.sigma_min + (1 - self.sigma_min) * t
38
+ ) * v
39
  return x_0, eps
40
 
41
  def _inference_model(self, model, x_t, t, cond=None, **kwargs):
42
+ t = torch.tensor(
43
+ [1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32
44
+ )
45
  return model(x_t, t, cond, **kwargs)
46
 
47
  def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs):
 
51
 
52
  @torch.no_grad()
53
  def sample_once(
54
+ self, model, x_t, t: float, t_prev: float, cond: Optional[Any] = None, **kwargs
 
 
 
 
 
 
55
  ):
56
  """
57
  Sample x_{t-1} from the model using Euler method.
58
+
59
  Args:
60
  model: The model to sample from.
61
  x_t: The [N x C x ...] tensor of noisy inputs at time t.
 
69
  - 'pred_x_prev': x_{t-1}.
70
  - 'pred_x_0': a prediction of x_0.
71
  """
72
+ pred_x_0, pred_eps, pred_v = self._get_model_prediction(
73
+ model, x_t, t, cond, **kwargs
74
+ )
75
  pred_x_prev = x_t - (t - t_prev) * pred_v
76
  return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0})
77
 
 
88
  ):
89
  """
90
  Generate samples from the model using Euler method.
91
+
92
  Args:
93
  model: The model to sample from.
94
  noise: The initial noise tensor.
 
122
  """
123
  Generate samples from a flow-matching model using Euler sampling with classifier-free guidance.
124
  """
125
+
126
  @torch.no_grad()
127
  def sample(
128
  self,
 
138
  ):
139
  """
140
  Generate samples from the model using Euler method.
141
+
142
  Args:
143
  model: The model to sample from.
144
  noise: The initial noise tensor.
 
156
  - 'pred_x_t': a list of prediction of x_t.
157
  - 'pred_x_0': a list of prediction of x_0.
158
  """
159
+ return super().sample(
160
+ model,
161
+ noise,
162
+ cond,
163
+ steps,
164
+ rescale_t,
165
+ verbose,
166
+ neg_cond=neg_cond,
167
+ cfg_strength=cfg_strength,
168
+ **kwargs
169
+ )
170
 
171
 
172
  class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSampler):
173
  """
174
  Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval.
175
  """
176
+
177
  @torch.no_grad()
178
  def sample(
179
  self,
 
190
  ):
191
  """
192
  Generate samples from the model using Euler method.
193
+
194
  Args:
195
  model: The model to sample from.
196
  noise: The initial noise tensor.
 
209
  - 'pred_x_t': a list of prediction of x_t.
210
  - 'pred_x_0': a list of prediction of x_0.
211
  """
212
+ return super().sample(
213
+ model,
214
+ noise,
215
+ cond,
216
+ steps,
217
+ rescale_t,
218
+ verbose,
219
+ neg_cond=neg_cond,
220
+ cfg_strength=cfg_strength,
221
+ cfg_interval=cfg_interval,
222
+ **kwargs
223
+ )
trellis/pipelines/samplers/guidance_interval_mixin.py CHANGED
@@ -6,7 +6,9 @@ class GuidanceIntervalSamplerMixin:
6
  A mixin class for samplers that apply classifier-free guidance with interval.
7
  """
8
 
9
- def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs):
 
 
10
  if cfg_interval[0] <= t <= cfg_interval[1]:
11
  pred = super()._inference_model(model, x_t, t, cond, **kwargs)
12
  neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
 
6
  A mixin class for samplers that apply classifier-free guidance with interval.
7
  """
8
 
9
+ def _inference_model(
10
+ self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs
11
+ ):
12
  if cfg_interval[0] <= t <= cfg_interval[1]:
13
  pred = super()._inference_model(model, x_t, t, cond, **kwargs)
14
  neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
trellis/pipelines/trellis_image_to_3d.py CHANGED
@@ -26,6 +26,7 @@ class TrellisImageTo3DPipeline(Pipeline):
26
  slat_normalization (dict): The normalization parameters for the structured latent.
27
  image_cond_model (str): The name of the image conditioning model.
28
  """
 
29
  def __init__(
30
  self,
31
  models: dict[str, nn.Module] = None,
@@ -53,33 +54,45 @@ class TrellisImageTo3DPipeline(Pipeline):
53
  Args:
54
  path (str): The path to the model. Can be either local path or a Hugging Face repository.
55
  """
56
- pipeline = super(TrellisImageTo3DPipeline, TrellisImageTo3DPipeline).from_pretrained(path)
 
 
57
  new_pipeline = TrellisImageTo3DPipeline()
58
  new_pipeline.__dict__ = pipeline.__dict__
59
  args = pipeline._pretrained_args
60
 
61
- new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
62
- new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
 
 
 
 
63
 
64
- new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args'])
65
- new_pipeline.slat_sampler_params = args['slat_sampler']['params']
 
 
66
 
67
- new_pipeline.slat_normalization = args['slat_normalization']
68
 
69
- new_pipeline._init_image_cond_model(args['image_cond_model'])
70
 
71
  return new_pipeline
72
-
73
  def _init_image_cond_model(self, name: str):
74
  """
75
  Initialize the image conditioning model.
76
  """
77
- dinov2_model = torch.hub.load('facebookresearch/dinov2', name, pretrained=True)
78
  dinov2_model.eval()
79
- self.models['image_cond_model'] = dinov2_model
80
- transform = transforms.Compose([
81
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
82
- ])
 
 
 
 
83
  self.image_cond_model_transform = transform
84
 
85
  def preprocess_image(self, input: Image.Image) -> Image.Image:
@@ -88,29 +101,42 @@ class TrellisImageTo3DPipeline(Pipeline):
88
  """
89
  # if has alpha channel, use it directly; otherwise, remove background
90
  has_alpha = False
91
- if input.mode == 'RGBA':
92
  alpha = np.array(input)[:, :, 3]
93
  if not np.all(alpha == 255):
94
  has_alpha = True
95
  if has_alpha:
96
  output = input
97
  else:
98
- input = input.convert('RGB')
99
  max_size = max(input.size)
100
  scale = min(1, 1024 / max_size)
101
  if scale < 1:
102
- input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
103
- if getattr(self, 'rembg_session', None) is None:
104
- self.rembg_session = rembg.new_session('u2net')
 
 
 
105
  output = rembg.remove(input, session=self.rembg_session)
106
  output_np = np.array(output)
107
  alpha = output_np[:, :, 3]
108
  bbox = np.argwhere(alpha > 0.8 * 255)
109
- bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
 
 
 
 
 
110
  center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
111
  size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
112
  size = int(size * 1.2)
113
- bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
 
 
 
 
 
114
  output = output.crop(bbox) # type: ignore
115
  output = output.resize((518, 518), Image.Resampling.LANCZOS)
116
  output = np.array(output).astype(np.float32) / 255
@@ -119,7 +145,9 @@ class TrellisImageTo3DPipeline(Pipeline):
119
  return output
120
 
121
  @torch.no_grad()
122
- def encode_image(self, image: Union[torch.Tensor, list[Image.Image]]) -> torch.Tensor:
 
 
123
  """
124
  Encode the image.
125
 
@@ -132,19 +160,21 @@ class TrellisImageTo3DPipeline(Pipeline):
132
  if isinstance(image, torch.Tensor):
133
  assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
134
  elif isinstance(image, list):
135
- assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
 
 
136
  image = [i.resize((518, 518), Image.LANCZOS) for i in image]
137
- image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
138
  image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
139
  image = torch.stack(image).to(self.device)
140
  else:
141
  raise ValueError(f"Unsupported type of image: {type(image)}")
142
-
143
  image = self.image_cond_model_transform(image).to(self.device)
144
- features = self.models['image_cond_model'](image, is_training=True)['x_prenorm']
145
  patchtokens = F.layer_norm(features, features.shape[-1:])
146
  return patchtokens
147
-
148
  def get_cond(self, image: Union[torch.Tensor, list[Image.Image]]) -> dict:
149
  """
150
  Get the conditioning information for the model.
@@ -158,8 +188,8 @@ class TrellisImageTo3DPipeline(Pipeline):
158
  cond = self.encode_image(image)
159
  neg_cond = torch.zeros_like(cond)
160
  return {
161
- 'cond': cond,
162
- 'neg_cond': neg_cond,
163
  }
164
 
165
  def sample_sparse_structure(
@@ -170,35 +200,33 @@ class TrellisImageTo3DPipeline(Pipeline):
170
  ) -> torch.Tensor:
171
  """
172
  Sample sparse structures with the given conditioning.
173
-
174
  Args:
175
  cond (dict): The conditioning information.
176
  num_samples (int): The number of samples to generate.
177
  sampler_params (dict): Additional parameters for the sampler.
178
  """
179
  # Sample occupancy latent
180
- flow_model = self.models['sparse_structure_flow_model']
181
  reso = flow_model.resolution
182
- noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device)
 
 
183
  sampler_params = {**self.sparse_structure_sampler_params, **sampler_params}
184
  z_s = self.sparse_structure_sampler.sample(
185
- flow_model,
186
- noise,
187
- **cond,
188
- **sampler_params,
189
- verbose=True
190
  ).samples
191
-
192
  # Decode occupancy latent
193
- decoder = self.models['sparse_structure_decoder']
194
- coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int()
195
 
196
  return coords
197
 
198
  def decode_slat(
199
  self,
200
  slat: sp.SparseTensor,
201
- formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
202
  ) -> dict:
203
  """
204
  Decode the structured latent.
@@ -211,14 +239,14 @@ class TrellisImageTo3DPipeline(Pipeline):
211
  dict: The decoded structured latent.
212
  """
213
  ret = {}
214
- if 'mesh' in formats:
215
- ret['mesh'] = self.models['slat_decoder_mesh'](slat)
216
- if 'gaussian' in formats:
217
- ret['gaussian'] = self.models['slat_decoder_gs'](slat)
218
- if 'radiance_field' in formats:
219
- ret['radiance_field'] = self.models['slat_decoder_rf'](slat)
220
  return ret
221
-
222
  def sample_slat(
223
  self,
224
  cond: dict,
@@ -227,31 +255,27 @@ class TrellisImageTo3DPipeline(Pipeline):
227
  ) -> sp.SparseTensor:
228
  """
229
  Sample structured latent with the given conditioning.
230
-
231
  Args:
232
  cond (dict): The conditioning information.
233
  coords (torch.Tensor): The coordinates of the sparse structure.
234
  sampler_params (dict): Additional parameters for the sampler.
235
  """
236
  # Sample structured latent
237
- flow_model = self.models['slat_flow_model']
238
  noise = sp.SparseTensor(
239
  feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
240
  coords=coords,
241
  )
242
  sampler_params = {**self.slat_sampler_params, **sampler_params}
243
  slat = self.slat_sampler.sample(
244
- flow_model,
245
- noise,
246
- **cond,
247
- **sampler_params,
248
- verbose=True
249
  ).samples
250
 
251
- std = torch.tensor(self.slat_normalization['std'])[None].to(slat.device)
252
- mean = torch.tensor(self.slat_normalization['mean'])[None].to(slat.device)
253
  slat = slat * std + mean
254
-
255
  return slat
256
 
257
  @torch.no_grad()
@@ -262,7 +286,7 @@ class TrellisImageTo3DPipeline(Pipeline):
262
  seed: int = 42,
263
  sparse_structure_sampler_params: dict = {},
264
  slat_sampler_params: dict = {},
265
- formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
266
  preprocess_image: bool = True,
267
  ) -> dict:
268
  """
@@ -279,7 +303,9 @@ class TrellisImageTo3DPipeline(Pipeline):
279
  image = self.preprocess_image(image)
280
  cond = self.get_cond([image])
281
  torch.manual_seed(seed)
282
- coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
 
 
283
  slat = self.sample_slat(cond, coords, slat_sampler_params)
284
  return self.decode_slat(slat, formats)
285
 
@@ -289,56 +315,80 @@ class TrellisImageTo3DPipeline(Pipeline):
289
  sampler_name: str,
290
  num_images: int,
291
  num_steps: int,
292
- mode: Literal['stochastic', 'multidiffusion'] = 'stochastic',
293
  ):
294
  """
295
  Inject a sampler with multiple images as condition.
296
-
297
  Args:
298
  sampler_name (str): The name of the sampler to inject.
299
  num_images (int): The number of images to condition on.
300
  num_steps (int): The number of steps to run the sampler for.
301
  """
302
  sampler = getattr(self, sampler_name)
303
- setattr(sampler, f'_old_inference_model', sampler._inference_model)
304
 
305
- if mode == 'stochastic':
306
  if num_images > num_steps:
307
- print(f"\033[93mWarning: number of conditioning images is greater than number of steps for {sampler_name}. "
308
- "This may lead to performance degradation.\033[0m")
 
 
309
 
310
  cond_indices = (np.arange(num_steps) % num_images).tolist()
 
311
  def _new_inference_model(self, model, x_t, t, cond, **kwargs):
312
  cond_idx = cond_indices.pop(0)
313
- cond_i = cond[cond_idx:cond_idx+1]
314
  return self._old_inference_model(model, x_t, t, cond=cond_i, **kwargs)
315
-
316
- elif mode =='multidiffusion':
317
  from .samplers import FlowEulerSampler
318
- def _new_inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
319
  if cfg_interval[0] <= t <= cfg_interval[1]:
320
  preds = []
321
  for i in range(len(cond)):
322
- preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs))
 
 
 
 
323
  pred = sum(preds) / len(preds)
324
- neg_pred = FlowEulerSampler._inference_model(self, model, x_t, t, neg_cond, **kwargs)
 
 
325
  return (1 + cfg_strength) * pred - cfg_strength * neg_pred
326
  else:
327
  preds = []
328
  for i in range(len(cond)):
329
- preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs))
 
 
 
 
330
  pred = sum(preds) / len(preds)
331
  return pred
332
-
333
  else:
334
  raise ValueError(f"Unsupported mode: {mode}")
335
-
336
  sampler._inference_model = _new_inference_model.__get__(sampler, type(sampler))
337
 
338
  yield
339
 
340
  sampler._inference_model = sampler._old_inference_model
341
- delattr(sampler, f'_old_inference_model')
342
 
343
  @torch.no_grad()
344
  def run_multi_image(
@@ -348,9 +398,9 @@ class TrellisImageTo3DPipeline(Pipeline):
348
  seed: int = 42,
349
  sparse_structure_sampler_params: dict = {},
350
  slat_sampler_params: dict = {},
351
- formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
352
  preprocess_image: bool = True,
353
- mode: Literal['stochastic', 'multidiffusion'] = 'stochastic',
354
  ) -> dict:
355
  """
356
  Run the pipeline with multiple images as condition
@@ -365,12 +415,21 @@ class TrellisImageTo3DPipeline(Pipeline):
365
  if preprocess_image:
366
  images = [self.preprocess_image(image) for image in images]
367
  cond = self.get_cond(images)
368
- cond['neg_cond'] = cond['neg_cond'][:1]
369
  torch.manual_seed(seed)
370
- ss_steps = {**self.sparse_structure_sampler_params, **sparse_structure_sampler_params}.get('steps')
371
- with self.inject_sampler_multi_image('sparse_structure_sampler', len(images), ss_steps, mode=mode):
372
- coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
373
- slat_steps = {**self.slat_sampler_params, **slat_sampler_params}.get('steps')
374
- with self.inject_sampler_multi_image('slat_sampler', len(images), slat_steps, mode=mode):
 
 
 
 
 
 
 
 
 
375
  slat = self.sample_slat(cond, coords, slat_sampler_params)
376
  return self.decode_slat(slat, formats)
 
26
  slat_normalization (dict): The normalization parameters for the structured latent.
27
  image_cond_model (str): The name of the image conditioning model.
28
  """
29
+
30
  def __init__(
31
  self,
32
  models: dict[str, nn.Module] = None,
 
54
  Args:
55
  path (str): The path to the model. Can be either local path or a Hugging Face repository.
56
  """
57
+ pipeline = super(
58
+ TrellisImageTo3DPipeline, TrellisImageTo3DPipeline
59
+ ).from_pretrained(path)
60
  new_pipeline = TrellisImageTo3DPipeline()
61
  new_pipeline.__dict__ = pipeline.__dict__
62
  args = pipeline._pretrained_args
63
 
64
+ new_pipeline.sparse_structure_sampler = getattr(
65
+ samplers, args["sparse_structure_sampler"]["name"]
66
+ )(**args["sparse_structure_sampler"]["args"])
67
+ new_pipeline.sparse_structure_sampler_params = args["sparse_structure_sampler"][
68
+ "params"
69
+ ]
70
 
71
+ new_pipeline.slat_sampler = getattr(samplers, args["slat_sampler"]["name"])(
72
+ **args["slat_sampler"]["args"]
73
+ )
74
+ new_pipeline.slat_sampler_params = args["slat_sampler"]["params"]
75
 
76
+ new_pipeline.slat_normalization = args["slat_normalization"]
77
 
78
+ new_pipeline._init_image_cond_model(args["image_cond_model"])
79
 
80
  return new_pipeline
81
+
82
  def _init_image_cond_model(self, name: str):
83
  """
84
  Initialize the image conditioning model.
85
  """
86
+ dinov2_model = torch.hub.load("facebookresearch/dinov2", name, pretrained=True)
87
  dinov2_model.eval()
88
+ self.models["image_cond_model"] = dinov2_model
89
+ transform = transforms.Compose(
90
+ [
91
+ transforms.Normalize(
92
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
93
+ ),
94
+ ]
95
+ )
96
  self.image_cond_model_transform = transform
97
 
98
  def preprocess_image(self, input: Image.Image) -> Image.Image:
 
101
  """
102
  # if has alpha channel, use it directly; otherwise, remove background
103
  has_alpha = False
104
+ if input.mode == "RGBA":
105
  alpha = np.array(input)[:, :, 3]
106
  if not np.all(alpha == 255):
107
  has_alpha = True
108
  if has_alpha:
109
  output = input
110
  else:
111
+ input = input.convert("RGB")
112
  max_size = max(input.size)
113
  scale = min(1, 1024 / max_size)
114
  if scale < 1:
115
+ input = input.resize(
116
+ (int(input.width * scale), int(input.height * scale)),
117
+ Image.Resampling.LANCZOS,
118
+ )
119
+ if getattr(self, "rembg_session", None) is None:
120
+ self.rembg_session = rembg.new_session("u2net")
121
  output = rembg.remove(input, session=self.rembg_session)
122
  output_np = np.array(output)
123
  alpha = output_np[:, :, 3]
124
  bbox = np.argwhere(alpha > 0.8 * 255)
125
+ bbox = (
126
+ np.min(bbox[:, 1]),
127
+ np.min(bbox[:, 0]),
128
+ np.max(bbox[:, 1]),
129
+ np.max(bbox[:, 0]),
130
+ )
131
  center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
132
  size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
133
  size = int(size * 1.2)
134
+ bbox = (
135
+ center[0] - size // 2,
136
+ center[1] - size // 2,
137
+ center[0] + size // 2,
138
+ center[1] + size // 2,
139
+ )
140
  output = output.crop(bbox) # type: ignore
141
  output = output.resize((518, 518), Image.Resampling.LANCZOS)
142
  output = np.array(output).astype(np.float32) / 255
 
145
  return output
146
 
147
  @torch.no_grad()
148
+ def encode_image(
149
+ self, image: Union[torch.Tensor, list[Image.Image]]
150
+ ) -> torch.Tensor:
151
  """
152
  Encode the image.
153
 
 
160
  if isinstance(image, torch.Tensor):
161
  assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
162
  elif isinstance(image, list):
163
+ assert all(
164
+ isinstance(i, Image.Image) for i in image
165
+ ), "Image list should be list of PIL images"
166
  image = [i.resize((518, 518), Image.LANCZOS) for i in image]
167
+ image = [np.array(i.convert("RGB")).astype(np.float32) / 255 for i in image]
168
  image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
169
  image = torch.stack(image).to(self.device)
170
  else:
171
  raise ValueError(f"Unsupported type of image: {type(image)}")
172
+
173
  image = self.image_cond_model_transform(image).to(self.device)
174
+ features = self.models["image_cond_model"](image, is_training=True)["x_prenorm"]
175
  patchtokens = F.layer_norm(features, features.shape[-1:])
176
  return patchtokens
177
+
178
  def get_cond(self, image: Union[torch.Tensor, list[Image.Image]]) -> dict:
179
  """
180
  Get the conditioning information for the model.
 
188
  cond = self.encode_image(image)
189
  neg_cond = torch.zeros_like(cond)
190
  return {
191
+ "cond": cond,
192
+ "neg_cond": neg_cond,
193
  }
194
 
195
  def sample_sparse_structure(
 
200
  ) -> torch.Tensor:
201
  """
202
  Sample sparse structures with the given conditioning.
203
+
204
  Args:
205
  cond (dict): The conditioning information.
206
  num_samples (int): The number of samples to generate.
207
  sampler_params (dict): Additional parameters for the sampler.
208
  """
209
  # Sample occupancy latent
210
+ flow_model = self.models["sparse_structure_flow_model"]
211
  reso = flow_model.resolution
212
+ noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(
213
+ self.device
214
+ )
215
  sampler_params = {**self.sparse_structure_sampler_params, **sampler_params}
216
  z_s = self.sparse_structure_sampler.sample(
217
+ flow_model, noise, **cond, **sampler_params, verbose=True
 
 
 
 
218
  ).samples
219
+
220
  # Decode occupancy latent
221
+ decoder = self.models["sparse_structure_decoder"]
222
+ coords = torch.argwhere(decoder(z_s) > 0)[:, [0, 2, 3, 4]].int()
223
 
224
  return coords
225
 
226
  def decode_slat(
227
  self,
228
  slat: sp.SparseTensor,
229
+ formats: List[str] = ["mesh", "gaussian", "radiance_field"],
230
  ) -> dict:
231
  """
232
  Decode the structured latent.
 
239
  dict: The decoded structured latent.
240
  """
241
  ret = {}
242
+ if "mesh" in formats:
243
+ ret["mesh"] = self.models["slat_decoder_mesh"](slat)
244
+ if "gaussian" in formats:
245
+ ret["gaussian"] = self.models["slat_decoder_gs"](slat)
246
+ if "radiance_field" in formats:
247
+ ret["radiance_field"] = self.models["slat_decoder_rf"](slat)
248
  return ret
249
+
250
  def sample_slat(
251
  self,
252
  cond: dict,
 
255
  ) -> sp.SparseTensor:
256
  """
257
  Sample structured latent with the given conditioning.
258
+
259
  Args:
260
  cond (dict): The conditioning information.
261
  coords (torch.Tensor): The coordinates of the sparse structure.
262
  sampler_params (dict): Additional parameters for the sampler.
263
  """
264
  # Sample structured latent
265
+ flow_model = self.models["slat_flow_model"]
266
  noise = sp.SparseTensor(
267
  feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
268
  coords=coords,
269
  )
270
  sampler_params = {**self.slat_sampler_params, **sampler_params}
271
  slat = self.slat_sampler.sample(
272
+ flow_model, noise, **cond, **sampler_params, verbose=True
 
 
 
 
273
  ).samples
274
 
275
+ std = torch.tensor(self.slat_normalization["std"])[None].to(slat.device)
276
+ mean = torch.tensor(self.slat_normalization["mean"])[None].to(slat.device)
277
  slat = slat * std + mean
278
+
279
  return slat
280
 
281
  @torch.no_grad()
 
286
  seed: int = 42,
287
  sparse_structure_sampler_params: dict = {},
288
  slat_sampler_params: dict = {},
289
+ formats: List[str] = ["mesh", "gaussian", "radiance_field"],
290
  preprocess_image: bool = True,
291
  ) -> dict:
292
  """
 
303
  image = self.preprocess_image(image)
304
  cond = self.get_cond([image])
305
  torch.manual_seed(seed)
306
+ coords = self.sample_sparse_structure(
307
+ cond, num_samples, sparse_structure_sampler_params
308
+ )
309
  slat = self.sample_slat(cond, coords, slat_sampler_params)
310
  return self.decode_slat(slat, formats)
311
 
 
315
  sampler_name: str,
316
  num_images: int,
317
  num_steps: int,
318
+ mode: Literal["stochastic", "multidiffusion"] = "stochastic",
319
  ):
320
  """
321
  Inject a sampler with multiple images as condition.
322
+
323
  Args:
324
  sampler_name (str): The name of the sampler to inject.
325
  num_images (int): The number of images to condition on.
326
  num_steps (int): The number of steps to run the sampler for.
327
  """
328
  sampler = getattr(self, sampler_name)
329
+ setattr(sampler, f"_old_inference_model", sampler._inference_model)
330
 
331
+ if mode == "stochastic":
332
  if num_images > num_steps:
333
+ print(
334
+ f"\033[93mWarning: number of conditioning images is greater than number of steps for {sampler_name}. "
335
+ "This may lead to performance degradation.\033[0m"
336
+ )
337
 
338
  cond_indices = (np.arange(num_steps) % num_images).tolist()
339
+
340
  def _new_inference_model(self, model, x_t, t, cond, **kwargs):
341
  cond_idx = cond_indices.pop(0)
342
+ cond_i = cond[cond_idx : cond_idx + 1]
343
  return self._old_inference_model(model, x_t, t, cond=cond_i, **kwargs)
344
+
345
+ elif mode == "multidiffusion":
346
  from .samplers import FlowEulerSampler
347
+
348
+ def _new_inference_model(
349
+ self,
350
+ model,
351
+ x_t,
352
+ t,
353
+ cond,
354
+ neg_cond,
355
+ cfg_strength,
356
+ cfg_interval,
357
+ **kwargs,
358
+ ):
359
  if cfg_interval[0] <= t <= cfg_interval[1]:
360
  preds = []
361
  for i in range(len(cond)):
362
+ preds.append(
363
+ FlowEulerSampler._inference_model(
364
+ self, model, x_t, t, cond[i : i + 1], **kwargs
365
+ )
366
+ )
367
  pred = sum(preds) / len(preds)
368
+ neg_pred = FlowEulerSampler._inference_model(
369
+ self, model, x_t, t, neg_cond, **kwargs
370
+ )
371
  return (1 + cfg_strength) * pred - cfg_strength * neg_pred
372
  else:
373
  preds = []
374
  for i in range(len(cond)):
375
+ preds.append(
376
+ FlowEulerSampler._inference_model(
377
+ self, model, x_t, t, cond[i : i + 1], **kwargs
378
+ )
379
+ )
380
  pred = sum(preds) / len(preds)
381
  return pred
382
+
383
  else:
384
  raise ValueError(f"Unsupported mode: {mode}")
385
+
386
  sampler._inference_model = _new_inference_model.__get__(sampler, type(sampler))
387
 
388
  yield
389
 
390
  sampler._inference_model = sampler._old_inference_model
391
+ delattr(sampler, f"_old_inference_model")
392
 
393
  @torch.no_grad()
394
  def run_multi_image(
 
398
  seed: int = 42,
399
  sparse_structure_sampler_params: dict = {},
400
  slat_sampler_params: dict = {},
401
+ formats: List[str] = ["mesh", "gaussian", "radiance_field"],
402
  preprocess_image: bool = True,
403
+ mode: Literal["stochastic", "multidiffusion"] = "stochastic",
404
  ) -> dict:
405
  """
406
  Run the pipeline with multiple images as condition
 
415
  if preprocess_image:
416
  images = [self.preprocess_image(image) for image in images]
417
  cond = self.get_cond(images)
418
+ cond["neg_cond"] = cond["neg_cond"][:1]
419
  torch.manual_seed(seed)
420
+ ss_steps = {
421
+ **self.sparse_structure_sampler_params,
422
+ **sparse_structure_sampler_params,
423
+ }.get("steps")
424
+ with self.inject_sampler_multi_image(
425
+ "sparse_structure_sampler", len(images), ss_steps, mode=mode
426
+ ):
427
+ coords = self.sample_sparse_structure(
428
+ cond, num_samples, sparse_structure_sampler_params
429
+ )
430
+ slat_steps = {**self.slat_sampler_params, **slat_sampler_params}.get("steps")
431
+ with self.inject_sampler_multi_image(
432
+ "slat_sampler", len(images), slat_steps, mode=mode
433
+ ):
434
  slat = self.sample_slat(cond, coords, slat_sampler_params)
435
  return self.decode_slat(slat, formats)
trellis/renderers/__init__.py CHANGED
@@ -1,15 +1,16 @@
1
  import importlib
2
 
3
  __attributes = {
4
- 'OctreeRenderer': 'octree_renderer',
5
- 'GaussianRenderer': 'gaussian_render',
6
- 'MeshRenderer': 'mesh_renderer',
7
  }
8
 
9
  __submodules = []
10
 
11
  __all__ = list(__attributes.keys()) + __submodules
12
 
 
13
  def __getattr__(name):
14
  if name not in globals():
15
  if name in __attributes:
@@ -25,7 +26,7 @@ def __getattr__(name):
25
 
26
 
27
  # For Pylance
28
- if __name__ == '__main__':
29
  from .octree_renderer import OctreeRenderer
30
  from .gaussian_render import GaussianRenderer
31
- from .mesh_renderer import MeshRenderer
 
1
  import importlib
2
 
3
  __attributes = {
4
+ "OctreeRenderer": "octree_renderer",
5
+ "GaussianRenderer": "gaussian_render",
6
+ "MeshRenderer": "mesh_renderer",
7
  }
8
 
9
  __submodules = []
10
 
11
  __all__ = list(__attributes.keys()) + __submodules
12
 
13
+
14
  def __getattr__(name):
15
  if name not in globals():
16
  if name in __attributes:
 
26
 
27
 
28
  # For Pylance
29
+ if __name__ == "__main__":
30
  from .octree_renderer import OctreeRenderer
31
  from .gaussian_render import GaussianRenderer
32
+ from .mesh_renderer import MeshRenderer
trellis/renderers/gaussian_render.py CHANGED
@@ -3,7 +3,7 @@
3
  # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
  # All rights reserved.
5
  #
6
- # This software is free for non-commercial, research and evaluation use
7
  # under the terms of the LICENSE.md file.
8
  #
9
  # For inquiries contact [email protected]
@@ -20,10 +20,10 @@ from easydict import EasyDict as edict
20
 
21
 
22
  def intrinsics_to_projection(
23
- intrinsics: torch.Tensor,
24
- near: float,
25
- far: float,
26
- ) -> torch.Tensor:
27
  """
28
  OpenCV intrinsics to OpenGL perspective matrix
29
 
@@ -40,25 +40,40 @@ def intrinsics_to_projection(
40
  ret[0, 0] = 2 * fx
41
  ret[1, 1] = 2 * fy
42
  ret[0, 2] = 2 * cx - 1
43
- ret[1, 2] = - 2 * cy + 1
44
  ret[2, 2] = far / (far - near)
45
  ret[2, 3] = near * far / (near - far)
46
- ret[3, 2] = 1.
47
  return ret
48
 
49
 
50
- def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
 
 
 
 
 
 
 
51
  """
52
- Render the scene.
53
-
54
  Background tensor (bg_color) must be on GPU!
55
  """
56
  # lazy import
57
- if 'GaussianRasterizer' not in globals():
58
- from diff_gaussian_rasterization import GaussianRasterizer, GaussianRasterizationSettings
59
-
 
 
 
60
  # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
61
- screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
 
 
 
 
 
62
  try:
63
  screenspace_points.retain_grad()
64
  except:
@@ -66,9 +81,13 @@ def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scali
66
  # Set up rasterization configuration
67
  tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
68
  tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
69
-
70
  kernel_size = pipe.kernel_size
71
- subpixel_offset = torch.zeros((int(viewpoint_camera.image_height), int(viewpoint_camera.image_width), 2), dtype=torch.float32, device="cuda")
 
 
 
 
72
 
73
  raster_settings = GaussianRasterizationSettings(
74
  image_height=int(viewpoint_camera.image_height),
@@ -84,9 +103,9 @@ def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scali
84
  sh_degree=pc.active_sh_degree,
85
  campos=viewpoint_camera.camera_center,
86
  prefiltered=False,
87
- debug=pipe.debug
88
  )
89
-
90
  rasterizer = GaussianRasterizer(raster_settings=raster_settings)
91
 
92
  means3D = pc.get_xyz
@@ -110,9 +129,13 @@ def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scali
110
  colors_precomp = None
111
  if override_color is None:
112
  if pipe.convert_SHs_python:
113
- shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
114
- dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
115
- dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
 
 
 
 
116
  sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
117
  colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
118
  else:
@@ -120,24 +143,28 @@ def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scali
120
  else:
121
  colors_precomp = override_color
122
 
123
- # Rasterize visible Gaussians to image, obtain their radii (on screen).
124
  rendered_image, radii = rasterizer(
125
- means3D = means3D,
126
- means2D = means2D,
127
- shs = shs,
128
- colors_precomp = colors_precomp,
129
- opacities = opacity,
130
- scales = scales,
131
- rotations = rotations,
132
- cov3D_precomp = cov3D_precomp
133
  )
134
 
135
  # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
136
  # They will be excluded from value updates used in the splitting criteria.
137
- return edict({"render": rendered_image,
 
 
138
  "viewspace_points": screenspace_points,
139
- "visibility_filter" : radii > 0,
140
- "radii": radii})
 
 
141
 
142
 
143
  class GaussianRenderer:
@@ -149,30 +176,34 @@ class GaussianRenderer:
149
  """
150
 
151
  def __init__(self, rendering_options={}) -> None:
152
- self.pipe = edict({
153
- "kernel_size": 0.1,
154
- "convert_SHs_python": False,
155
- "compute_cov3D_python": False,
156
- "scale_modifier": 1.0,
157
- "debug": False
158
- })
159
- self.rendering_options = edict({
160
- "resolution": None,
161
- "near": None,
162
- "far": None,
163
- "ssaa": 1,
164
- "bg_color": 'random',
165
- })
 
 
 
 
166
  self.rendering_options.update(rendering_options)
167
  self.bg_color = None
168
-
169
  def render(
170
- self,
171
- gausssian: Gaussian,
172
- extrinsics: torch.Tensor,
173
- intrinsics: torch.Tensor,
174
- colors_overwrite: torch.Tensor = None
175
- ) -> edict:
176
  """
177
  Render the gausssian.
178
 
@@ -190,13 +221,15 @@ class GaussianRenderer:
190
  near = self.rendering_options["near"]
191
  far = self.rendering_options["far"]
192
  ssaa = self.rendering_options["ssaa"]
193
-
194
- if self.rendering_options["bg_color"] == 'random':
195
  self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
196
  if np.random.rand() < 0.5:
197
  self.bg_color += 1
198
  else:
199
- self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda")
 
 
200
 
201
  view = extrinsics
202
  perspective = intrinsics_to_projection(intrinsics, near, far)
@@ -205,27 +238,40 @@ class GaussianRenderer:
205
  focaly = intrinsics[1, 1]
206
  fovx = 2 * torch.atan(0.5 / focalx)
207
  fovy = 2 * torch.atan(0.5 / focaly)
208
-
209
- camera_dict = edict({
210
- "image_height": resolution * ssaa,
211
- "image_width": resolution * ssaa,
212
- "FoVx": fovx,
213
- "FoVy": fovy,
214
- "znear": near,
215
- "zfar": far,
216
- "world_view_transform": view.T.contiguous(),
217
- "projection_matrix": perspective.T.contiguous(),
218
- "full_proj_transform": (perspective @ view).T.contiguous(),
219
- "camera_center": camera
220
- })
 
 
221
 
222
  # Render
223
- render_ret = render(camera_dict, gausssian, self.pipe, self.bg_color, override_color=colors_overwrite, scaling_modifier=self.pipe.scale_modifier)
 
 
 
 
 
 
 
224
 
225
  if ssaa > 1:
226
- render_ret.render = F.interpolate(render_ret.render[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
227
-
228
- ret = edict({
229
- 'color': render_ret['render']
230
- })
 
 
 
 
231
  return ret
 
3
  # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
  # All rights reserved.
5
  #
6
+ # This software is free for non-commercial, research and evaluation use
7
  # under the terms of the LICENSE.md file.
8
  #
9
  # For inquiries contact [email protected]
 
20
 
21
 
22
  def intrinsics_to_projection(
23
+ intrinsics: torch.Tensor,
24
+ near: float,
25
+ far: float,
26
+ ) -> torch.Tensor:
27
  """
28
  OpenCV intrinsics to OpenGL perspective matrix
29
 
 
40
  ret[0, 0] = 2 * fx
41
  ret[1, 1] = 2 * fy
42
  ret[0, 2] = 2 * cx - 1
43
+ ret[1, 2] = -2 * cy + 1
44
  ret[2, 2] = far / (far - near)
45
  ret[2, 3] = near * far / (near - far)
46
+ ret[3, 2] = 1.0
47
  return ret
48
 
49
 
50
+ def render(
51
+ viewpoint_camera,
52
+ pc: Gaussian,
53
+ pipe,
54
+ bg_color: torch.Tensor,
55
+ scaling_modifier=1.0,
56
+ override_color=None,
57
+ ):
58
  """
59
+ Render the scene.
60
+
61
  Background tensor (bg_color) must be on GPU!
62
  """
63
  # lazy import
64
+ if "GaussianRasterizer" not in globals():
65
+ from diff_gaussian_rasterization import (
66
+ GaussianRasterizer,
67
+ GaussianRasterizationSettings,
68
+ )
69
+
70
  # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
71
+ screenspace_points = (
72
+ torch.zeros_like(
73
+ pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda"
74
+ )
75
+ + 0
76
+ )
77
  try:
78
  screenspace_points.retain_grad()
79
  except:
 
81
  # Set up rasterization configuration
82
  tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
83
  tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
84
+
85
  kernel_size = pipe.kernel_size
86
+ subpixel_offset = torch.zeros(
87
+ (int(viewpoint_camera.image_height), int(viewpoint_camera.image_width), 2),
88
+ dtype=torch.float32,
89
+ device="cuda",
90
+ )
91
 
92
  raster_settings = GaussianRasterizationSettings(
93
  image_height=int(viewpoint_camera.image_height),
 
103
  sh_degree=pc.active_sh_degree,
104
  campos=viewpoint_camera.camera_center,
105
  prefiltered=False,
106
+ debug=pipe.debug,
107
  )
108
+
109
  rasterizer = GaussianRasterizer(raster_settings=raster_settings)
110
 
111
  means3D = pc.get_xyz
 
129
  colors_precomp = None
130
  if override_color is None:
131
  if pipe.convert_SHs_python:
132
+ shs_view = pc.get_features.transpose(1, 2).view(
133
+ -1, 3, (pc.max_sh_degree + 1) ** 2
134
+ )
135
+ dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat(
136
+ pc.get_features.shape[0], 1
137
+ )
138
+ dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
139
  sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
140
  colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
141
  else:
 
143
  else:
144
  colors_precomp = override_color
145
 
146
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
147
  rendered_image, radii = rasterizer(
148
+ means3D=means3D,
149
+ means2D=means2D,
150
+ shs=shs,
151
+ colors_precomp=colors_precomp,
152
+ opacities=opacity,
153
+ scales=scales,
154
+ rotations=rotations,
155
+ cov3D_precomp=cov3D_precomp,
156
  )
157
 
158
  # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
159
  # They will be excluded from value updates used in the splitting criteria.
160
+ return edict(
161
+ {
162
+ "render": rendered_image,
163
  "viewspace_points": screenspace_points,
164
+ "visibility_filter": radii > 0,
165
+ "radii": radii,
166
+ }
167
+ )
168
 
169
 
170
  class GaussianRenderer:
 
176
  """
177
 
178
  def __init__(self, rendering_options={}) -> None:
179
+ self.pipe = edict(
180
+ {
181
+ "kernel_size": 0.1,
182
+ "convert_SHs_python": False,
183
+ "compute_cov3D_python": False,
184
+ "scale_modifier": 1.0,
185
+ "debug": False,
186
+ }
187
+ )
188
+ self.rendering_options = edict(
189
+ {
190
+ "resolution": None,
191
+ "near": None,
192
+ "far": None,
193
+ "ssaa": 1,
194
+ "bg_color": "random",
195
+ }
196
+ )
197
  self.rendering_options.update(rendering_options)
198
  self.bg_color = None
199
+
200
  def render(
201
+ self,
202
+ gausssian: Gaussian,
203
+ extrinsics: torch.Tensor,
204
+ intrinsics: torch.Tensor,
205
+ colors_overwrite: torch.Tensor = None,
206
+ ) -> edict:
207
  """
208
  Render the gausssian.
209
 
 
221
  near = self.rendering_options["near"]
222
  far = self.rendering_options["far"]
223
  ssaa = self.rendering_options["ssaa"]
224
+
225
+ if self.rendering_options["bg_color"] == "random":
226
  self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
227
  if np.random.rand() < 0.5:
228
  self.bg_color += 1
229
  else:
230
+ self.bg_color = torch.tensor(
231
+ self.rendering_options["bg_color"], dtype=torch.float32, device="cuda"
232
+ )
233
 
234
  view = extrinsics
235
  perspective = intrinsics_to_projection(intrinsics, near, far)
 
238
  focaly = intrinsics[1, 1]
239
  fovx = 2 * torch.atan(0.5 / focalx)
240
  fovy = 2 * torch.atan(0.5 / focaly)
241
+
242
+ camera_dict = edict(
243
+ {
244
+ "image_height": resolution * ssaa,
245
+ "image_width": resolution * ssaa,
246
+ "FoVx": fovx,
247
+ "FoVy": fovy,
248
+ "znear": near,
249
+ "zfar": far,
250
+ "world_view_transform": view.T.contiguous(),
251
+ "projection_matrix": perspective.T.contiguous(),
252
+ "full_proj_transform": (perspective @ view).T.contiguous(),
253
+ "camera_center": camera,
254
+ }
255
+ )
256
 
257
  # Render
258
+ render_ret = render(
259
+ camera_dict,
260
+ gausssian,
261
+ self.pipe,
262
+ self.bg_color,
263
+ override_color=colors_overwrite,
264
+ scaling_modifier=self.pipe.scale_modifier,
265
+ )
266
 
267
  if ssaa > 1:
268
+ render_ret.render = F.interpolate(
269
+ render_ret.render[None],
270
+ size=(resolution, resolution),
271
+ mode="bilinear",
272
+ align_corners=False,
273
+ antialias=True,
274
+ ).squeeze()
275
+
276
+ ret = edict({"color": render_ret["render"]})
277
  return ret
trellis/renderers/mesh_renderer.py CHANGED
@@ -13,10 +13,10 @@ import torch.nn.functional as F
13
 
14
 
15
  def intrinsics_to_projection(
16
- intrinsics: torch.Tensor,
17
- near: float,
18
- far: float,
19
- ) -> torch.Tensor:
20
  """
21
  OpenCV intrinsics to OpenGL perspective matrix
22
 
@@ -33,10 +33,10 @@ def intrinsics_to_projection(
33
  ret[0, 0] = 2 * fx
34
  ret[1, 1] = 2 * fy
35
  ret[0, 2] = 2 * cx - 1
36
- ret[1, 2] = - 2 * cy + 1
37
  ret[2, 2] = far / (far - near)
38
  ret[2, 3] = near * far / (near - far)
39
- ret[3, 2] = 1.
40
  return ret
41
 
42
 
@@ -47,25 +47,23 @@ class MeshRenderer:
47
  Args:
48
  rendering_options (dict): Rendering options.
49
  glctx (nvdiffrast.torch.RasterizeGLContext): RasterizeGLContext object for CUDA/OpenGL interop.
50
- """
51
- def __init__(self, rendering_options={}, device='cuda'):
52
- self.rendering_options = edict({
53
- "resolution": None,
54
- "near": None,
55
- "far": None,
56
- "ssaa": 1
57
- })
58
  self.rendering_options.update(rendering_options)
59
  self.glctx = dr.RasterizeCudaContext(device=device)
60
- self.device=device
61
-
62
  def render(
63
- self,
64
- mesh : MeshExtractResult,
65
- extrinsics: torch.Tensor,
66
- intrinsics: torch.Tensor,
67
- return_types = ["mask", "normal", "depth"]
68
- ) -> edict:
69
  """
70
  Render the mesh.
71
 
@@ -87,51 +85,80 @@ class MeshRenderer:
87
  near = self.rendering_options["near"]
88
  far = self.rendering_options["far"]
89
  ssaa = self.rendering_options["ssaa"]
90
-
91
  if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0:
92
- default_img = torch.zeros((1, resolution, resolution, 3), dtype=torch.float32, device=self.device)
93
- ret_dict = {k : default_img if k in ['normal', 'normal_map', 'color'] else default_img[..., :1] for k in return_types}
 
 
 
 
 
 
 
 
 
94
  return ret_dict
95
-
96
  perspective = intrinsics_to_projection(intrinsics, near, far)
97
-
98
  RT = extrinsics.unsqueeze(0)
99
  full_proj = (perspective @ extrinsics).unsqueeze(0)
100
-
101
  vertices = mesh.vertices.unsqueeze(0)
102
 
103
- vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1)
 
 
104
  vertices_camera = torch.bmm(vertices_homo, RT.transpose(-1, -2))
105
  vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2))
106
  faces_int = mesh.faces.int()
107
  rast, _ = dr.rasterize(
108
- self.glctx, vertices_clip, faces_int, (resolution * ssaa, resolution * ssaa))
109
-
 
110
  out_dict = edict()
111
  for type in return_types:
112
  img = None
113
- if type == "mask" :
114
- img = dr.antialias((rast[..., -1:] > 0).float(), rast, vertices_clip, faces_int)
 
 
115
  elif type == "depth":
116
- img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces_int)[0]
 
 
117
  img = dr.antialias(img, rast, vertices_clip, faces_int)
118
- elif type == "normal" :
119
  img = dr.interpolate(
120
- mesh.face_normal.reshape(1, -1, 3), rast,
121
- torch.arange(mesh.faces.shape[0] * 3, device=self.device, dtype=torch.int).reshape(-1, 3)
 
 
 
122
  )[0]
123
  img = dr.antialias(img, rast, vertices_clip, faces_int)
124
  # normalize norm pictures
125
  img = (img + 1) / 2
126
- elif type == "normal_map" :
127
- img = dr.interpolate(mesh.vertex_attrs[:, 3:].contiguous(), rast, faces_int)[0]
 
 
128
  img = dr.antialias(img, rast, vertices_clip, faces_int)
129
- elif type == "color" :
130
- img = dr.interpolate(mesh.vertex_attrs[:, :3].contiguous(), rast, faces_int)[0]
 
 
131
  img = dr.antialias(img, rast, vertices_clip, faces_int)
132
 
133
  if ssaa > 1:
134
- img = F.interpolate(img.permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True)
 
 
 
 
 
 
135
  img = img.squeeze()
136
  else:
137
  img = img.permute(0, 3, 1, 2).squeeze()
 
13
 
14
 
15
  def intrinsics_to_projection(
16
+ intrinsics: torch.Tensor,
17
+ near: float,
18
+ far: float,
19
+ ) -> torch.Tensor:
20
  """
21
  OpenCV intrinsics to OpenGL perspective matrix
22
 
 
33
  ret[0, 0] = 2 * fx
34
  ret[1, 1] = 2 * fy
35
  ret[0, 2] = 2 * cx - 1
36
+ ret[1, 2] = -2 * cy + 1
37
  ret[2, 2] = far / (far - near)
38
  ret[2, 3] = near * far / (near - far)
39
+ ret[3, 2] = 1.0
40
  return ret
41
 
42
 
 
47
  Args:
48
  rendering_options (dict): Rendering options.
49
  glctx (nvdiffrast.torch.RasterizeGLContext): RasterizeGLContext object for CUDA/OpenGL interop.
50
+ """
51
+
52
+ def __init__(self, rendering_options={}, device="cuda"):
53
+ self.rendering_options = edict(
54
+ {"resolution": None, "near": None, "far": None, "ssaa": 1}
55
+ )
 
 
56
  self.rendering_options.update(rendering_options)
57
  self.glctx = dr.RasterizeCudaContext(device=device)
58
+ self.device = device
59
+
60
  def render(
61
+ self,
62
+ mesh: MeshExtractResult,
63
+ extrinsics: torch.Tensor,
64
+ intrinsics: torch.Tensor,
65
+ return_types=["mask", "normal", "depth"],
66
+ ) -> edict:
67
  """
68
  Render the mesh.
69
 
 
85
  near = self.rendering_options["near"]
86
  far = self.rendering_options["far"]
87
  ssaa = self.rendering_options["ssaa"]
88
+
89
  if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0:
90
+ default_img = torch.zeros(
91
+ (1, resolution, resolution, 3), dtype=torch.float32, device=self.device
92
+ )
93
+ ret_dict = {
94
+ k: (
95
+ default_img
96
+ if k in ["normal", "normal_map", "color"]
97
+ else default_img[..., :1]
98
+ )
99
+ for k in return_types
100
+ }
101
  return ret_dict
102
+
103
  perspective = intrinsics_to_projection(intrinsics, near, far)
104
+
105
  RT = extrinsics.unsqueeze(0)
106
  full_proj = (perspective @ extrinsics).unsqueeze(0)
107
+
108
  vertices = mesh.vertices.unsqueeze(0)
109
 
110
+ vertices_homo = torch.cat(
111
+ [vertices, torch.ones_like(vertices[..., :1])], dim=-1
112
+ )
113
  vertices_camera = torch.bmm(vertices_homo, RT.transpose(-1, -2))
114
  vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2))
115
  faces_int = mesh.faces.int()
116
  rast, _ = dr.rasterize(
117
+ self.glctx, vertices_clip, faces_int, (resolution * ssaa, resolution * ssaa)
118
+ )
119
+
120
  out_dict = edict()
121
  for type in return_types:
122
  img = None
123
+ if type == "mask":
124
+ img = dr.antialias(
125
+ (rast[..., -1:] > 0).float(), rast, vertices_clip, faces_int
126
+ )
127
  elif type == "depth":
128
+ img = dr.interpolate(
129
+ vertices_camera[..., 2:3].contiguous(), rast, faces_int
130
+ )[0]
131
  img = dr.antialias(img, rast, vertices_clip, faces_int)
132
+ elif type == "normal":
133
  img = dr.interpolate(
134
+ mesh.face_normal.reshape(1, -1, 3),
135
+ rast,
136
+ torch.arange(
137
+ mesh.faces.shape[0] * 3, device=self.device, dtype=torch.int
138
+ ).reshape(-1, 3),
139
  )[0]
140
  img = dr.antialias(img, rast, vertices_clip, faces_int)
141
  # normalize norm pictures
142
  img = (img + 1) / 2
143
+ elif type == "normal_map":
144
+ img = dr.interpolate(
145
+ mesh.vertex_attrs[:, 3:].contiguous(), rast, faces_int
146
+ )[0]
147
  img = dr.antialias(img, rast, vertices_clip, faces_int)
148
+ elif type == "color":
149
+ img = dr.interpolate(
150
+ mesh.vertex_attrs[:, :3].contiguous(), rast, faces_int
151
+ )[0]
152
  img = dr.antialias(img, rast, vertices_clip, faces_int)
153
 
154
  if ssaa > 1:
155
+ img = F.interpolate(
156
+ img.permute(0, 3, 1, 2),
157
+ (resolution, resolution),
158
+ mode="bilinear",
159
+ align_corners=False,
160
+ antialias=True,
161
+ )
162
  img = img.squeeze()
163
  else:
164
  img = img.permute(0, 3, 1, 2).squeeze()
trellis/renderers/octree_renderer.py CHANGED
@@ -9,10 +9,10 @@ from ..representations.octree import DfsOctree
9
 
10
 
11
  def intrinsics_to_projection(
12
- intrinsics: torch.Tensor,
13
- near: float,
14
- far: float,
15
- ) -> torch.Tensor:
16
  """
17
  OpenCV intrinsics to OpenGL perspective matrix
18
 
@@ -29,23 +29,38 @@ def intrinsics_to_projection(
29
  ret[0, 0] = 2 * fx
30
  ret[1, 1] = 2 * fy
31
  ret[0, 2] = 2 * cx - 1
32
- ret[1, 2] = - 2 * cy + 1
33
  ret[2, 2] = far / (far - near)
34
  ret[2, 3] = near * far / (near - far)
35
- ret[3, 2] = 1.
36
  return ret
37
 
38
 
39
- def render(viewpoint_camera, octree : DfsOctree, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, used_rank = None, colors_overwrite = None, aux=None, halton_sampler=None):
 
 
 
 
 
 
 
 
 
 
40
  """
41
- Render the scene.
42
-
43
  Background tensor (bg_color) must be on GPU!
44
  """
45
  # lazy import
46
- if 'OctreeTrivecRasterizer' not in globals():
47
- from diffoctreerast import OctreeVoxelRasterizer, OctreeGaussianRasterizer, OctreeTrivecRasterizer, OctreeDecoupolyRasterizer
48
-
 
 
 
 
 
49
  # Set up rasterization configuration
50
  tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
51
  tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
@@ -96,69 +111,73 @@ def render(viewpoint_camera, octree : DfsOctree, pipe, bg_color : torch.Tensor,
96
  if octree.primitive == "voxel":
97
  renderer = OctreeVoxelRasterizer(raster_settings=raster_settings)
98
  rgb, depth, alpha, distloss = renderer(
99
- positions = positions,
100
- densities = densities,
101
- shs = shs,
102
- colors_precomp = colors_precomp,
103
- depths = depths,
104
- aabb = octree.aabb,
105
- aux = aux,
106
  )
107
- ret['rgb'] = rgb
108
- ret['depth'] = depth
109
- ret['alpha'] = alpha
110
- ret['distloss'] = distloss
111
  elif octree.primitive == "gaussian":
112
  renderer = OctreeGaussianRasterizer(raster_settings=raster_settings)
113
  rgb, depth, alpha = renderer(
114
- positions = positions,
115
- opacities = opacities,
116
- shs = shs,
117
- colors_precomp = colors_precomp,
118
- depths = depths,
119
- aabb = octree.aabb,
120
- aux = aux,
121
  )
122
- ret['rgb'] = rgb
123
- ret['depth'] = depth
124
- ret['alpha'] = alpha
125
  elif octree.primitive == "trivec":
126
- raster_settings.used_rank = used_rank if used_rank is not None else trivecs.shape[1]
 
 
127
  renderer = OctreeTrivecRasterizer(raster_settings=raster_settings)
128
  rgb, depth, alpha, percent_depth = renderer(
129
- positions = positions,
130
- trivecs = trivecs,
131
- densities = densities,
132
- shs = shs,
133
- colors_precomp = colors_precomp,
134
- colors_overwrite = colors_overwrite,
135
- depths = depths,
136
- aabb = octree.aabb,
137
- aux = aux,
138
- halton_sampler = halton_sampler,
139
  )
140
- ret['percent_depth'] = percent_depth
141
- ret['rgb'] = rgb
142
- ret['depth'] = depth
143
- ret['alpha'] = alpha
144
  elif octree.primitive == "decoupoly":
145
- raster_settings.used_rank = used_rank if used_rank is not None else decoupolys_V.shape[1]
 
 
146
  renderer = OctreeDecoupolyRasterizer(raster_settings=raster_settings)
147
  rgb, depth, alpha = renderer(
148
- positions = positions,
149
- decoupolys_V = decoupolys_V,
150
- decoupolys_g = decoupolys_g,
151
- densities = densities,
152
- shs = shs,
153
- colors_precomp = colors_precomp,
154
- depths = depths,
155
- aabb = octree.aabb,
156
- aux = aux,
157
  )
158
- ret['rgb'] = rgb
159
- ret['depth'] = depth
160
- ret['alpha'] = alpha
161
-
162
  return ret
163
 
164
 
@@ -174,37 +193,43 @@ class OctreeRenderer:
174
  try:
175
  import diffoctreerast
176
  except ImportError:
177
- print("\033[93m[WARNING] diffoctreerast is not installed. The renderer will be disabled.\033[0m")
 
 
178
  self.unsupported = True
179
  else:
180
  self.unsupported = False
181
-
182
- self.pipe = edict({
183
- "with_distloss": False,
184
- "with_aux": False,
185
- "scale_modifier": 1.0,
186
- "used_rank": None,
187
- "jitter": False,
188
- "debug": False,
189
- })
190
- self.rendering_options = edict({
191
- "resolution": None,
192
- "near": None,
193
- "far": None,
194
- "ssaa": 1,
195
- "bg_color": 'random',
196
- })
 
 
 
 
197
  self.halton_sampler = qmc.Halton(2, scramble=False)
198
  self.rendering_options.update(rendering_options)
199
  self.bg_color = None
200
-
201
  def render(
202
- self,
203
- octree: DfsOctree,
204
- extrinsics: torch.Tensor,
205
- intrinsics: torch.Tensor,
206
- colors_overwrite: torch.Tensor = None,
207
- ) -> edict:
208
  """
209
  Render the octree.
210
 
@@ -227,27 +252,53 @@ class OctreeRenderer:
227
  near = self.rendering_options["near"]
228
  far = self.rendering_options["far"]
229
  ssaa = self.rendering_options["ssaa"]
230
-
231
  if self.unsupported:
232
  image = np.zeros((512, 512, 3), dtype=np.uint8)
233
- text_bbox = cv2.getTextSize("Unsupported", cv2.FONT_HERSHEY_SIMPLEX, 2, 3)[0]
 
 
234
  origin = (512 - text_bbox[0]) // 2, (512 - text_bbox[1]) // 2
235
- image = cv2.putText(image, "Unsupported", origin, cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 255, 255), 3, cv2.LINE_AA)
 
 
 
 
 
 
 
 
 
236
  return {
237
- 'color': torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255,
 
238
  }
239
-
240
- if self.rendering_options["bg_color"] == 'random':
241
  self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
242
  if np.random.rand() < 0.5:
243
  self.bg_color += 1
244
  else:
245
- self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda")
 
 
246
 
247
  if self.pipe["with_aux"]:
248
  aux = {
249
- 'grad_color2': torch.zeros((octree.num_leaf_nodes, 3), dtype=torch.float32, requires_grad=True, device="cuda") + 0,
250
- 'contributions': torch.zeros((octree.num_leaf_nodes, 1), dtype=torch.float32, requires_grad=True, device="cuda") + 0,
 
 
 
 
 
 
 
 
 
 
 
 
251
  }
252
  for k in aux.keys():
253
  aux[k].requires_grad_()
@@ -262,39 +313,77 @@ class OctreeRenderer:
262
  focaly = intrinsics[1, 1]
263
  fovx = 2 * torch.atan(0.5 / focalx)
264
  fovy = 2 * torch.atan(0.5 / focaly)
265
-
266
- camera_dict = edict({
267
- "image_height": resolution * ssaa,
268
- "image_width": resolution * ssaa,
269
- "FoVx": fovx,
270
- "FoVy": fovy,
271
- "znear": near,
272
- "zfar": far,
273
- "world_view_transform": view.T.contiguous(),
274
- "projection_matrix": perspective.T.contiguous(),
275
- "full_proj_transform": (perspective @ view).T.contiguous(),
276
- "camera_center": camera
277
- })
 
 
278
 
279
  # Render
280
- render_ret = render(camera_dict, octree, self.pipe, self.bg_color, aux=aux, colors_overwrite=colors_overwrite, scaling_modifier=self.pipe.scale_modifier, used_rank=self.pipe.used_rank, halton_sampler=self.halton_sampler)
 
 
 
 
 
 
 
 
 
 
281
 
282
  if ssaa > 1:
283
- render_ret.rgb = F.interpolate(render_ret.rgb[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
284
- render_ret.depth = F.interpolate(render_ret.depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
285
- render_ret.alpha = F.interpolate(render_ret.alpha[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
286
- if hasattr(render_ret, 'percent_depth'):
287
- render_ret.percent_depth = F.interpolate(render_ret.percent_depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
- ret = edict({
290
- 'color': render_ret.rgb,
291
- 'depth': render_ret.depth,
292
- 'alpha': render_ret.alpha,
293
- })
294
- if self.pipe["with_distloss"] and 'distloss' in render_ret:
295
- ret['distloss'] = render_ret.distloss
 
 
296
  if self.pipe["with_aux"]:
297
- ret['aux'] = aux
298
- if hasattr(render_ret, 'percent_depth'):
299
- ret['percent_depth'] = render_ret.percent_depth
300
  return ret
 
9
 
10
 
11
  def intrinsics_to_projection(
12
+ intrinsics: torch.Tensor,
13
+ near: float,
14
+ far: float,
15
+ ) -> torch.Tensor:
16
  """
17
  OpenCV intrinsics to OpenGL perspective matrix
18
 
 
29
  ret[0, 0] = 2 * fx
30
  ret[1, 1] = 2 * fy
31
  ret[0, 2] = 2 * cx - 1
32
+ ret[1, 2] = -2 * cy + 1
33
  ret[2, 2] = far / (far - near)
34
  ret[2, 3] = near * far / (near - far)
35
+ ret[3, 2] = 1.0
36
  return ret
37
 
38
 
39
+ def render(
40
+ viewpoint_camera,
41
+ octree: DfsOctree,
42
+ pipe,
43
+ bg_color: torch.Tensor,
44
+ scaling_modifier=1.0,
45
+ used_rank=None,
46
+ colors_overwrite=None,
47
+ aux=None,
48
+ halton_sampler=None,
49
+ ):
50
  """
51
+ Render the scene.
52
+
53
  Background tensor (bg_color) must be on GPU!
54
  """
55
  # lazy import
56
+ if "OctreeTrivecRasterizer" not in globals():
57
+ from diffoctreerast import (
58
+ OctreeVoxelRasterizer,
59
+ OctreeGaussianRasterizer,
60
+ OctreeTrivecRasterizer,
61
+ OctreeDecoupolyRasterizer,
62
+ )
63
+
64
  # Set up rasterization configuration
65
  tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
66
  tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
 
111
  if octree.primitive == "voxel":
112
  renderer = OctreeVoxelRasterizer(raster_settings=raster_settings)
113
  rgb, depth, alpha, distloss = renderer(
114
+ positions=positions,
115
+ densities=densities,
116
+ shs=shs,
117
+ colors_precomp=colors_precomp,
118
+ depths=depths,
119
+ aabb=octree.aabb,
120
+ aux=aux,
121
  )
122
+ ret["rgb"] = rgb
123
+ ret["depth"] = depth
124
+ ret["alpha"] = alpha
125
+ ret["distloss"] = distloss
126
  elif octree.primitive == "gaussian":
127
  renderer = OctreeGaussianRasterizer(raster_settings=raster_settings)
128
  rgb, depth, alpha = renderer(
129
+ positions=positions,
130
+ opacities=opacities,
131
+ shs=shs,
132
+ colors_precomp=colors_precomp,
133
+ depths=depths,
134
+ aabb=octree.aabb,
135
+ aux=aux,
136
  )
137
+ ret["rgb"] = rgb
138
+ ret["depth"] = depth
139
+ ret["alpha"] = alpha
140
  elif octree.primitive == "trivec":
141
+ raster_settings.used_rank = (
142
+ used_rank if used_rank is not None else trivecs.shape[1]
143
+ )
144
  renderer = OctreeTrivecRasterizer(raster_settings=raster_settings)
145
  rgb, depth, alpha, percent_depth = renderer(
146
+ positions=positions,
147
+ trivecs=trivecs,
148
+ densities=densities,
149
+ shs=shs,
150
+ colors_precomp=colors_precomp,
151
+ colors_overwrite=colors_overwrite,
152
+ depths=depths,
153
+ aabb=octree.aabb,
154
+ aux=aux,
155
+ halton_sampler=halton_sampler,
156
  )
157
+ ret["percent_depth"] = percent_depth
158
+ ret["rgb"] = rgb
159
+ ret["depth"] = depth
160
+ ret["alpha"] = alpha
161
  elif octree.primitive == "decoupoly":
162
+ raster_settings.used_rank = (
163
+ used_rank if used_rank is not None else decoupolys_V.shape[1]
164
+ )
165
  renderer = OctreeDecoupolyRasterizer(raster_settings=raster_settings)
166
  rgb, depth, alpha = renderer(
167
+ positions=positions,
168
+ decoupolys_V=decoupolys_V,
169
+ decoupolys_g=decoupolys_g,
170
+ densities=densities,
171
+ shs=shs,
172
+ colors_precomp=colors_precomp,
173
+ depths=depths,
174
+ aabb=octree.aabb,
175
+ aux=aux,
176
  )
177
+ ret["rgb"] = rgb
178
+ ret["depth"] = depth
179
+ ret["alpha"] = alpha
180
+
181
  return ret
182
 
183
 
 
193
  try:
194
  import diffoctreerast
195
  except ImportError:
196
+ print(
197
+ "\033[93m[WARNING] diffoctreerast is not installed. The renderer will be disabled.\033[0m"
198
+ )
199
  self.unsupported = True
200
  else:
201
  self.unsupported = False
202
+
203
+ self.pipe = edict(
204
+ {
205
+ "with_distloss": False,
206
+ "with_aux": False,
207
+ "scale_modifier": 1.0,
208
+ "used_rank": None,
209
+ "jitter": False,
210
+ "debug": False,
211
+ }
212
+ )
213
+ self.rendering_options = edict(
214
+ {
215
+ "resolution": None,
216
+ "near": None,
217
+ "far": None,
218
+ "ssaa": 1,
219
+ "bg_color": "random",
220
+ }
221
+ )
222
  self.halton_sampler = qmc.Halton(2, scramble=False)
223
  self.rendering_options.update(rendering_options)
224
  self.bg_color = None
225
+
226
  def render(
227
+ self,
228
+ octree: DfsOctree,
229
+ extrinsics: torch.Tensor,
230
+ intrinsics: torch.Tensor,
231
+ colors_overwrite: torch.Tensor = None,
232
+ ) -> edict:
233
  """
234
  Render the octree.
235
 
 
252
  near = self.rendering_options["near"]
253
  far = self.rendering_options["far"]
254
  ssaa = self.rendering_options["ssaa"]
255
+
256
  if self.unsupported:
257
  image = np.zeros((512, 512, 3), dtype=np.uint8)
258
+ text_bbox = cv2.getTextSize("Unsupported", cv2.FONT_HERSHEY_SIMPLEX, 2, 3)[
259
+ 0
260
+ ]
261
  origin = (512 - text_bbox[0]) // 2, (512 - text_bbox[1]) // 2
262
+ image = cv2.putText(
263
+ image,
264
+ "Unsupported",
265
+ origin,
266
+ cv2.FONT_HERSHEY_SIMPLEX,
267
+ 2,
268
+ (255, 255, 255),
269
+ 3,
270
+ cv2.LINE_AA,
271
+ )
272
  return {
273
+ "color": torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)
274
+ / 255,
275
  }
276
+
277
+ if self.rendering_options["bg_color"] == "random":
278
  self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
279
  if np.random.rand() < 0.5:
280
  self.bg_color += 1
281
  else:
282
+ self.bg_color = torch.tensor(
283
+ self.rendering_options["bg_color"], dtype=torch.float32, device="cuda"
284
+ )
285
 
286
  if self.pipe["with_aux"]:
287
  aux = {
288
+ "grad_color2": torch.zeros(
289
+ (octree.num_leaf_nodes, 3),
290
+ dtype=torch.float32,
291
+ requires_grad=True,
292
+ device="cuda",
293
+ )
294
+ + 0,
295
+ "contributions": torch.zeros(
296
+ (octree.num_leaf_nodes, 1),
297
+ dtype=torch.float32,
298
+ requires_grad=True,
299
+ device="cuda",
300
+ )
301
+ + 0,
302
  }
303
  for k in aux.keys():
304
  aux[k].requires_grad_()
 
313
  focaly = intrinsics[1, 1]
314
  fovx = 2 * torch.atan(0.5 / focalx)
315
  fovy = 2 * torch.atan(0.5 / focaly)
316
+
317
+ camera_dict = edict(
318
+ {
319
+ "image_height": resolution * ssaa,
320
+ "image_width": resolution * ssaa,
321
+ "FoVx": fovx,
322
+ "FoVy": fovy,
323
+ "znear": near,
324
+ "zfar": far,
325
+ "world_view_transform": view.T.contiguous(),
326
+ "projection_matrix": perspective.T.contiguous(),
327
+ "full_proj_transform": (perspective @ view).T.contiguous(),
328
+ "camera_center": camera,
329
+ }
330
+ )
331
 
332
  # Render
333
+ render_ret = render(
334
+ camera_dict,
335
+ octree,
336
+ self.pipe,
337
+ self.bg_color,
338
+ aux=aux,
339
+ colors_overwrite=colors_overwrite,
340
+ scaling_modifier=self.pipe.scale_modifier,
341
+ used_rank=self.pipe.used_rank,
342
+ halton_sampler=self.halton_sampler,
343
+ )
344
 
345
  if ssaa > 1:
346
+ render_ret.rgb = F.interpolate(
347
+ render_ret.rgb[None],
348
+ size=(resolution, resolution),
349
+ mode="bilinear",
350
+ align_corners=False,
351
+ antialias=True,
352
+ ).squeeze()
353
+ render_ret.depth = F.interpolate(
354
+ render_ret.depth[None, None],
355
+ size=(resolution, resolution),
356
+ mode="bilinear",
357
+ align_corners=False,
358
+ antialias=True,
359
+ ).squeeze()
360
+ render_ret.alpha = F.interpolate(
361
+ render_ret.alpha[None, None],
362
+ size=(resolution, resolution),
363
+ mode="bilinear",
364
+ align_corners=False,
365
+ antialias=True,
366
+ ).squeeze()
367
+ if hasattr(render_ret, "percent_depth"):
368
+ render_ret.percent_depth = F.interpolate(
369
+ render_ret.percent_depth[None, None],
370
+ size=(resolution, resolution),
371
+ mode="bilinear",
372
+ align_corners=False,
373
+ antialias=True,
374
+ ).squeeze()
375
 
376
+ ret = edict(
377
+ {
378
+ "color": render_ret.rgb,
379
+ "depth": render_ret.depth,
380
+ "alpha": render_ret.alpha,
381
+ }
382
+ )
383
+ if self.pipe["with_distloss"] and "distloss" in render_ret:
384
+ ret["distloss"] = render_ret.distloss
385
  if self.pipe["with_aux"]:
386
+ ret["aux"] = aux
387
+ if hasattr(render_ret, "percent_depth"):
388
+ ret["percent_depth"] = render_ret.percent_depth
389
  return ret
trellis/renderers/sh_utils.py CHANGED
@@ -30,7 +30,7 @@ C2 = [
30
  -1.0925484305920792,
31
  0.31539156525252005,
32
  -1.0925484305920792,
33
- 0.5462742152960396
34
  ]
35
  C3 = [
36
  -0.5900435899266435,
@@ -39,7 +39,7 @@ C3 = [
39
  0.3731763325901154,
40
  -0.4570457994644658,
41
  1.445305721320277,
42
- -0.5900435899266435
43
  ]
44
  C4 = [
45
  2.5033429417967046,
@@ -51,7 +51,7 @@ C4 = [
51
  0.47308734787878004,
52
  -1.7701307697799304,
53
  0.6258357354491761,
54
- ]
55
 
56
 
57
  def eval_sh(deg, sh, dirs):
@@ -74,45 +74,55 @@ def eval_sh(deg, sh, dirs):
74
  result = C0 * sh[..., 0]
75
  if deg > 0:
76
  x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77
- result = (result -
78
- C1 * y * sh[..., 1] +
79
- C1 * z * sh[..., 2] -
80
- C1 * x * sh[..., 3])
81
 
82
  if deg > 1:
83
  xx, yy, zz = x * x, y * y, z * z
84
  xy, yz, xz = x * y, y * z, x * z
85
- result = (result +
86
- C2[0] * xy * sh[..., 4] +
87
- C2[1] * yz * sh[..., 5] +
88
- C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
89
- C2[3] * xz * sh[..., 7] +
90
- C2[4] * (xx - yy) * sh[..., 8])
 
 
91
 
92
  if deg > 2:
93
- result = (result +
94
- C3[0] * y * (3 * xx - yy) * sh[..., 9] +
95
- C3[1] * xy * z * sh[..., 10] +
96
- C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
97
- C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
98
- C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
99
- C3[5] * z * (xx - yy) * sh[..., 14] +
100
- C3[6] * x * (xx - 3 * yy) * sh[..., 15])
 
 
101
 
102
  if deg > 3:
103
- result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
104
- C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
105
- C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
106
- C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
107
- C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
108
- C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
109
- C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
110
- C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
111
- C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
 
 
 
 
 
112
  return result
113
 
 
114
  def RGB2SH(rgb):
115
  return (rgb - 0.5) / C0
116
 
 
117
  def SH2RGB(sh):
118
- return sh * C0 + 0.5
 
30
  -1.0925484305920792,
31
  0.31539156525252005,
32
  -1.0925484305920792,
33
+ 0.5462742152960396,
34
  ]
35
  C3 = [
36
  -0.5900435899266435,
 
39
  0.3731763325901154,
40
  -0.4570457994644658,
41
  1.445305721320277,
42
+ -0.5900435899266435,
43
  ]
44
  C4 = [
45
  2.5033429417967046,
 
51
  0.47308734787878004,
52
  -1.7701307697799304,
53
  0.6258357354491761,
54
+ ]
55
 
56
 
57
  def eval_sh(deg, sh, dirs):
 
74
  result = C0 * sh[..., 0]
75
  if deg > 0:
76
  x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77
+ result = (
78
+ result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3]
79
+ )
 
80
 
81
  if deg > 1:
82
  xx, yy, zz = x * x, y * y, z * z
83
  xy, yz, xz = x * y, y * z, x * z
84
+ result = (
85
+ result
86
+ + C2[0] * xy * sh[..., 4]
87
+ + C2[1] * yz * sh[..., 5]
88
+ + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6]
89
+ + C2[3] * xz * sh[..., 7]
90
+ + C2[4] * (xx - yy) * sh[..., 8]
91
+ )
92
 
93
  if deg > 2:
94
+ result = (
95
+ result
96
+ + C3[0] * y * (3 * xx - yy) * sh[..., 9]
97
+ + C3[1] * xy * z * sh[..., 10]
98
+ + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11]
99
+ + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12]
100
+ + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13]
101
+ + C3[5] * z * (xx - yy) * sh[..., 14]
102
+ + C3[6] * x * (xx - 3 * yy) * sh[..., 15]
103
+ )
104
 
105
  if deg > 3:
106
+ result = (
107
+ result
108
+ + C4[0] * xy * (xx - yy) * sh[..., 16]
109
+ + C4[1] * yz * (3 * xx - yy) * sh[..., 17]
110
+ + C4[2] * xy * (7 * zz - 1) * sh[..., 18]
111
+ + C4[3] * yz * (7 * zz - 3) * sh[..., 19]
112
+ + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20]
113
+ + C4[5] * xz * (7 * zz - 3) * sh[..., 21]
114
+ + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22]
115
+ + C4[7] * xz * (xx - 3 * yy) * sh[..., 23]
116
+ + C4[8]
117
+ * (xx * (xx - 3 * yy) - yy * (3 * xx - yy))
118
+ * sh[..., 24]
119
+ )
120
  return result
121
 
122
+
123
  def RGB2SH(rgb):
124
  return (rgb - 0.5) / C0
125
 
126
+
127
  def SH2RGB(sh):
128
+ return sh * C0 + 0.5
trellis/representations/gaussian/__init__.py CHANGED
@@ -1 +1 @@
1
- from .gaussian_model import Gaussian
 
1
+ from .gaussian_model import Gaussian
trellis/representations/gaussian/gaussian_model.py CHANGED
@@ -7,27 +7,27 @@ import utils3d
7
 
8
  class Gaussian:
9
  def __init__(
10
- self,
11
- aabb : list,
12
- sh_degree : int = 0,
13
- mininum_kernel_size : float = 0.0,
14
- scaling_bias : float = 0.01,
15
- opacity_bias : float = 0.1,
16
- scaling_activation : str = "exp",
17
- device='cuda'
18
- ):
19
  self.init_params = {
20
- 'aabb': aabb,
21
- 'sh_degree': sh_degree,
22
- 'mininum_kernel_size': mininum_kernel_size,
23
- 'scaling_bias': scaling_bias,
24
- 'opacity_bias': opacity_bias,
25
- 'scaling_activation': scaling_activation,
26
  }
27
-
28
  self.sh_degree = sh_degree
29
  self.active_sh_degree = sh_degree
30
- self.mininum_kernel_size = mininum_kernel_size
31
  self.scaling_bias = scaling_bias
32
  self.opacity_bias = opacity_bias
33
  self.scaling_activation_type = scaling_activation
@@ -48,7 +48,7 @@ class Gaussian:
48
  actual_covariance = L @ L.transpose(1, 2)
49
  symm = strip_symmetric(actual_covariance)
50
  return symm
51
-
52
  if self.scaling_activation_type == "exp":
53
  self.scaling_activation = torch.exp
54
  self.inverse_scaling_activation = torch.log
@@ -62,74 +62,91 @@ class Gaussian:
62
  self.inverse_opacity_activation = inverse_sigmoid
63
 
64
  self.rotation_activation = torch.nn.functional.normalize
65
-
66
- self.scale_bias = self.inverse_scaling_activation(torch.tensor(self.scaling_bias)).cuda()
 
 
67
  self.rots_bias = torch.zeros((4)).cuda()
68
  self.rots_bias[0] = 1
69
- self.opacity_bias = self.inverse_opacity_activation(torch.tensor(self.opacity_bias)).cuda()
 
 
70
 
71
  @property
72
  def get_scaling(self):
73
  scales = self.scaling_activation(self._scaling + self.scale_bias)
74
- scales = torch.square(scales) + self.mininum_kernel_size ** 2
75
  scales = torch.sqrt(scales)
76
  return scales
77
-
78
  @property
79
  def get_rotation(self):
80
  return self.rotation_activation(self._rotation + self.rots_bias[None, :])
81
-
82
  @property
83
  def get_xyz(self):
84
  return self._xyz * self.aabb[None, 3:] + self.aabb[None, :3]
85
-
86
  @property
87
  def get_features(self):
88
- return torch.cat((self._features_dc, self._features_rest), dim=2) if self._features_rest is not None else self._features_dc
89
-
 
 
 
 
90
  @property
91
  def get_opacity(self):
92
  return self.opacity_activation(self._opacity + self.opacity_bias)
93
-
94
- def get_covariance(self, scaling_modifier = 1):
95
- return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation + self.rots_bias[None, :])
96
-
 
 
97
  def from_scaling(self, scales):
98
- scales = torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2)
99
  self._scaling = self.inverse_scaling_activation(scales) - self.scale_bias
100
-
101
  def from_rotation(self, rots):
102
  self._rotation = rots - self.rots_bias[None, :]
103
-
104
  def from_xyz(self, xyz):
105
  self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:]
106
-
107
  def from_features(self, features):
108
  self._features_dc = features
109
-
110
  def from_opacity(self, opacities):
111
  self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias
112
 
113
  def construct_list_of_attributes(self):
114
- l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
115
  # All channels except the 3 DC
116
- for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
117
- l.append('f_dc_{}'.format(i))
118
- l.append('opacity')
119
  for i in range(self._scaling.shape[1]):
120
- l.append('scale_{}'.format(i))
121
  for i in range(self._rotation.shape[1]):
122
- l.append('rot_{}'.format(i))
123
  return l
124
-
125
  def save_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]):
126
  xyz = self.get_xyz.detach().cpu().numpy()
127
  normals = np.zeros_like(xyz)
128
- f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
 
 
 
 
 
 
 
129
  opacities = inverse_sigmoid(self.get_opacity).detach().cpu().numpy()
130
  scale = torch.log(self.get_scaling).detach().cpu().numpy()
131
  rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy()
132
-
133
  if transform is not None:
134
  transform = np.array(transform)
135
  xyz = np.matmul(xyz, transform.T)
@@ -137,20 +154,29 @@ class Gaussian:
137
  rotation = np.matmul(transform, rotation)
138
  rotation = utils3d.numpy.matrix_to_quaternion(rotation)
139
 
140
- dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
 
 
141
 
142
  elements = np.empty(xyz.shape[0], dtype=dtype_full)
143
- attributes = np.concatenate((xyz, normals, f_dc, opacities, scale, rotation), axis=1)
 
 
144
  elements[:] = list(map(tuple, attributes))
145
- el = PlyElement.describe(elements, 'vertex')
146
  PlyData([el]).write(path)
147
 
148
  def load_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]):
149
  plydata = PlyData.read(path)
150
 
151
- xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
152
- np.asarray(plydata.elements[0]["y"]),
153
- np.asarray(plydata.elements[0]["z"])), axis=1)
 
 
 
 
 
154
  opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
155
 
156
  features_dc = np.zeros((xyz.shape[0], 3, 1))
@@ -159,43 +185,65 @@ class Gaussian:
159
  features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
160
 
161
  if self.sh_degree > 0:
162
- extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
163
- extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
164
- assert len(extra_f_names)==3*(self.sh_degree + 1) ** 2 - 3
 
 
 
 
165
  features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
166
  for idx, attr_name in enumerate(extra_f_names):
167
  features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
168
  # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
169
- features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
 
 
170
 
171
- scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
172
- scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
 
 
 
 
173
  scales = np.zeros((xyz.shape[0], len(scale_names)))
174
  for idx, attr_name in enumerate(scale_names):
175
  scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
176
 
177
- rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
178
- rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
 
 
179
  rots = np.zeros((xyz.shape[0], len(rot_names)))
180
  for idx, attr_name in enumerate(rot_names):
181
  rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
182
-
183
  if transform is not None:
184
  transform = np.array(transform)
185
  xyz = np.matmul(xyz, transform)
186
  rotation = utils3d.numpy.quaternion_to_matrix(rotation)
187
  rotation = np.matmul(rotation, transform)
188
  rotation = utils3d.numpy.matrix_to_quaternion(rotation)
189
-
190
  # convert to actual gaussian attributes
191
  xyz = torch.tensor(xyz, dtype=torch.float, device=self.device)
192
- features_dc = torch.tensor(features_dc, dtype=torch.float, device=self.device).transpose(1, 2).contiguous()
 
 
 
 
193
  if self.sh_degree > 0:
194
- features_extra = torch.tensor(features_extra, dtype=torch.float, device=self.device).transpose(1, 2).contiguous()
195
- opacities = torch.sigmoid(torch.tensor(opacities, dtype=torch.float, device=self.device))
 
 
 
 
 
 
196
  scales = torch.exp(torch.tensor(scales, dtype=torch.float, device=self.device))
197
  rots = torch.tensor(rots, dtype=torch.float, device=self.device)
198
-
199
  # convert to _hidden attributes
200
  self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:]
201
  self._features_dc = features_dc
@@ -204,6 +252,10 @@ class Gaussian:
204
  else:
205
  self._features_rest = None
206
  self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias
207
- self._scaling = self.inverse_scaling_activation(torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2)) - self.scale_bias
 
 
 
 
 
208
  self._rotation = rots - self.rots_bias[None, :]
209
-
 
7
 
8
  class Gaussian:
9
  def __init__(
10
+ self,
11
+ aabb: list,
12
+ sh_degree: int = 0,
13
+ mininum_kernel_size: float = 0.0,
14
+ scaling_bias: float = 0.01,
15
+ opacity_bias: float = 0.1,
16
+ scaling_activation: str = "exp",
17
+ device="cuda",
18
+ ):
19
  self.init_params = {
20
+ "aabb": aabb,
21
+ "sh_degree": sh_degree,
22
+ "mininum_kernel_size": mininum_kernel_size,
23
+ "scaling_bias": scaling_bias,
24
+ "opacity_bias": opacity_bias,
25
+ "scaling_activation": scaling_activation,
26
  }
27
+
28
  self.sh_degree = sh_degree
29
  self.active_sh_degree = sh_degree
30
+ self.mininum_kernel_size = mininum_kernel_size
31
  self.scaling_bias = scaling_bias
32
  self.opacity_bias = opacity_bias
33
  self.scaling_activation_type = scaling_activation
 
48
  actual_covariance = L @ L.transpose(1, 2)
49
  symm = strip_symmetric(actual_covariance)
50
  return symm
51
+
52
  if self.scaling_activation_type == "exp":
53
  self.scaling_activation = torch.exp
54
  self.inverse_scaling_activation = torch.log
 
62
  self.inverse_opacity_activation = inverse_sigmoid
63
 
64
  self.rotation_activation = torch.nn.functional.normalize
65
+
66
+ self.scale_bias = self.inverse_scaling_activation(
67
+ torch.tensor(self.scaling_bias)
68
+ ).cuda()
69
  self.rots_bias = torch.zeros((4)).cuda()
70
  self.rots_bias[0] = 1
71
+ self.opacity_bias = self.inverse_opacity_activation(
72
+ torch.tensor(self.opacity_bias)
73
+ ).cuda()
74
 
75
  @property
76
  def get_scaling(self):
77
  scales = self.scaling_activation(self._scaling + self.scale_bias)
78
+ scales = torch.square(scales) + self.mininum_kernel_size**2
79
  scales = torch.sqrt(scales)
80
  return scales
81
+
82
  @property
83
  def get_rotation(self):
84
  return self.rotation_activation(self._rotation + self.rots_bias[None, :])
85
+
86
  @property
87
  def get_xyz(self):
88
  return self._xyz * self.aabb[None, 3:] + self.aabb[None, :3]
89
+
90
  @property
91
  def get_features(self):
92
+ return (
93
+ torch.cat((self._features_dc, self._features_rest), dim=2)
94
+ if self._features_rest is not None
95
+ else self._features_dc
96
+ )
97
+
98
  @property
99
  def get_opacity(self):
100
  return self.opacity_activation(self._opacity + self.opacity_bias)
101
+
102
+ def get_covariance(self, scaling_modifier=1):
103
+ return self.covariance_activation(
104
+ self.get_scaling, scaling_modifier, self._rotation + self.rots_bias[None, :]
105
+ )
106
+
107
  def from_scaling(self, scales):
108
+ scales = torch.sqrt(torch.square(scales) - self.mininum_kernel_size**2)
109
  self._scaling = self.inverse_scaling_activation(scales) - self.scale_bias
110
+
111
  def from_rotation(self, rots):
112
  self._rotation = rots - self.rots_bias[None, :]
113
+
114
  def from_xyz(self, xyz):
115
  self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:]
116
+
117
  def from_features(self, features):
118
  self._features_dc = features
119
+
120
  def from_opacity(self, opacities):
121
  self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias
122
 
123
  def construct_list_of_attributes(self):
124
+ l = ["x", "y", "z", "nx", "ny", "nz"]
125
  # All channels except the 3 DC
126
+ for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]):
127
+ l.append("f_dc_{}".format(i))
128
+ l.append("opacity")
129
  for i in range(self._scaling.shape[1]):
130
+ l.append("scale_{}".format(i))
131
  for i in range(self._rotation.shape[1]):
132
+ l.append("rot_{}".format(i))
133
  return l
134
+
135
  def save_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]):
136
  xyz = self.get_xyz.detach().cpu().numpy()
137
  normals = np.zeros_like(xyz)
138
+ f_dc = (
139
+ self._features_dc.detach()
140
+ .transpose(1, 2)
141
+ .flatten(start_dim=1)
142
+ .contiguous()
143
+ .cpu()
144
+ .numpy()
145
+ )
146
  opacities = inverse_sigmoid(self.get_opacity).detach().cpu().numpy()
147
  scale = torch.log(self.get_scaling).detach().cpu().numpy()
148
  rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy()
149
+
150
  if transform is not None:
151
  transform = np.array(transform)
152
  xyz = np.matmul(xyz, transform.T)
 
154
  rotation = np.matmul(transform, rotation)
155
  rotation = utils3d.numpy.matrix_to_quaternion(rotation)
156
 
157
+ dtype_full = [
158
+ (attribute, "f4") for attribute in self.construct_list_of_attributes()
159
+ ]
160
 
161
  elements = np.empty(xyz.shape[0], dtype=dtype_full)
162
+ attributes = np.concatenate(
163
+ (xyz, normals, f_dc, opacities, scale, rotation), axis=1
164
+ )
165
  elements[:] = list(map(tuple, attributes))
166
+ el = PlyElement.describe(elements, "vertex")
167
  PlyData([el]).write(path)
168
 
169
  def load_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]):
170
  plydata = PlyData.read(path)
171
 
172
+ xyz = np.stack(
173
+ (
174
+ np.asarray(plydata.elements[0]["x"]),
175
+ np.asarray(plydata.elements[0]["y"]),
176
+ np.asarray(plydata.elements[0]["z"]),
177
+ ),
178
+ axis=1,
179
+ )
180
  opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
181
 
182
  features_dc = np.zeros((xyz.shape[0], 3, 1))
 
185
  features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
186
 
187
  if self.sh_degree > 0:
188
+ extra_f_names = [
189
+ p.name
190
+ for p in plydata.elements[0].properties
191
+ if p.name.startswith("f_rest_")
192
+ ]
193
+ extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1]))
194
+ assert len(extra_f_names) == 3 * (self.sh_degree + 1) ** 2 - 3
195
  features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
196
  for idx, attr_name in enumerate(extra_f_names):
197
  features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
198
  # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
199
+ features_extra = features_extra.reshape(
200
+ (features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)
201
+ )
202
 
203
+ scale_names = [
204
+ p.name
205
+ for p in plydata.elements[0].properties
206
+ if p.name.startswith("scale_")
207
+ ]
208
+ scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1]))
209
  scales = np.zeros((xyz.shape[0], len(scale_names)))
210
  for idx, attr_name in enumerate(scale_names):
211
  scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
212
 
213
+ rot_names = [
214
+ p.name for p in plydata.elements[0].properties if p.name.startswith("rot")
215
+ ]
216
+ rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1]))
217
  rots = np.zeros((xyz.shape[0], len(rot_names)))
218
  for idx, attr_name in enumerate(rot_names):
219
  rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
220
+
221
  if transform is not None:
222
  transform = np.array(transform)
223
  xyz = np.matmul(xyz, transform)
224
  rotation = utils3d.numpy.quaternion_to_matrix(rotation)
225
  rotation = np.matmul(rotation, transform)
226
  rotation = utils3d.numpy.matrix_to_quaternion(rotation)
227
+
228
  # convert to actual gaussian attributes
229
  xyz = torch.tensor(xyz, dtype=torch.float, device=self.device)
230
+ features_dc = (
231
+ torch.tensor(features_dc, dtype=torch.float, device=self.device)
232
+ .transpose(1, 2)
233
+ .contiguous()
234
+ )
235
  if self.sh_degree > 0:
236
+ features_extra = (
237
+ torch.tensor(features_extra, dtype=torch.float, device=self.device)
238
+ .transpose(1, 2)
239
+ .contiguous()
240
+ )
241
+ opacities = torch.sigmoid(
242
+ torch.tensor(opacities, dtype=torch.float, device=self.device)
243
+ )
244
  scales = torch.exp(torch.tensor(scales, dtype=torch.float, device=self.device))
245
  rots = torch.tensor(rots, dtype=torch.float, device=self.device)
246
+
247
  # convert to _hidden attributes
248
  self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:]
249
  self._features_dc = features_dc
 
252
  else:
253
  self._features_rest = None
254
  self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias
255
+ self._scaling = (
256
+ self.inverse_scaling_activation(
257
+ torch.sqrt(torch.square(scales) - self.mininum_kernel_size**2)
258
+ )
259
+ - self.scale_bias
260
+ )
261
  self._rotation = rots - self.rots_bias[None, :]
 
trellis/representations/gaussian/general_utils.py CHANGED
@@ -3,7 +3,7 @@
3
  # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
  # All rights reserved.
5
  #
6
- # This software is free for non-commercial, research and evaluation use
7
  # under the terms of the LICENSE.md file.
8
  #
9
  # For inquiries contact [email protected]
@@ -15,8 +15,10 @@ from datetime import datetime
15
  import numpy as np
16
  import random
17
 
 
18
  def inverse_sigmoid(x):
19
- return torch.log(x/(1-x))
 
20
 
21
  def PILtoTorch(pil_image, resolution):
22
  resized_image_PIL = pil_image.resize(resolution)
@@ -26,6 +28,7 @@ def PILtoTorch(pil_image, resolution):
26
  else:
27
  return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
28
 
 
29
  def get_expon_lr_func(
30
  lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
31
  ):
@@ -61,6 +64,7 @@ def get_expon_lr_func(
61
 
62
  return helper
63
 
 
64
  def strip_lowerdiag(L):
65
  uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
66
 
@@ -72,45 +76,52 @@ def strip_lowerdiag(L):
72
  uncertainty[:, 5] = L[:, 2, 2]
73
  return uncertainty
74
 
 
75
  def strip_symmetric(sym):
76
  return strip_lowerdiag(sym)
77
 
 
78
  def build_rotation(r):
79
- norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
 
 
80
 
81
  q = r / norm[:, None]
82
 
83
- R = torch.zeros((q.size(0), 3, 3), device='cuda')
84
 
85
  r = q[:, 0]
86
  x = q[:, 1]
87
  y = q[:, 2]
88
  z = q[:, 3]
89
 
90
- R[:, 0, 0] = 1 - 2 * (y*y + z*z)
91
- R[:, 0, 1] = 2 * (x*y - r*z)
92
- R[:, 0, 2] = 2 * (x*z + r*y)
93
- R[:, 1, 0] = 2 * (x*y + r*z)
94
- R[:, 1, 1] = 1 - 2 * (x*x + z*z)
95
- R[:, 1, 2] = 2 * (y*z - r*x)
96
- R[:, 2, 0] = 2 * (x*z - r*y)
97
- R[:, 2, 1] = 2 * (y*z + r*x)
98
- R[:, 2, 2] = 1 - 2 * (x*x + y*y)
99
  return R
100
 
 
101
  def build_scaling_rotation(s, r):
102
  L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
103
  R = build_rotation(r)
104
 
105
- L[:,0,0] = s[:,0]
106
- L[:,1,1] = s[:,1]
107
- L[:,2,2] = s[:,2]
108
 
109
  L = R @ L
110
  return L
111
 
 
112
  def safe_state(silent):
113
  old_f = sys.stdout
 
114
  class F:
115
  def __init__(self, silent):
116
  self.silent = silent
@@ -118,7 +129,14 @@ def safe_state(silent):
118
  def write(self, x):
119
  if not self.silent:
120
  if x.endswith("\n"):
121
- old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
 
 
 
 
 
 
 
122
  else:
123
  old_f.write(x)
124
 
 
3
  # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
  # All rights reserved.
5
  #
6
+ # This software is free for non-commercial, research and evaluation use
7
  # under the terms of the LICENSE.md file.
8
  #
9
  # For inquiries contact [email protected]
 
15
  import numpy as np
16
  import random
17
 
18
+
19
  def inverse_sigmoid(x):
20
+ return torch.log(x / (1 - x))
21
+
22
 
23
  def PILtoTorch(pil_image, resolution):
24
  resized_image_PIL = pil_image.resize(resolution)
 
28
  else:
29
  return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
30
 
31
+
32
  def get_expon_lr_func(
33
  lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
34
  ):
 
64
 
65
  return helper
66
 
67
+
68
  def strip_lowerdiag(L):
69
  uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
70
 
 
76
  uncertainty[:, 5] = L[:, 2, 2]
77
  return uncertainty
78
 
79
+
80
  def strip_symmetric(sym):
81
  return strip_lowerdiag(sym)
82
 
83
+
84
  def build_rotation(r):
85
+ norm = torch.sqrt(
86
+ r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]
87
+ )
88
 
89
  q = r / norm[:, None]
90
 
91
+ R = torch.zeros((q.size(0), 3, 3), device="cuda")
92
 
93
  r = q[:, 0]
94
  x = q[:, 1]
95
  y = q[:, 2]
96
  z = q[:, 3]
97
 
98
+ R[:, 0, 0] = 1 - 2 * (y * y + z * z)
99
+ R[:, 0, 1] = 2 * (x * y - r * z)
100
+ R[:, 0, 2] = 2 * (x * z + r * y)
101
+ R[:, 1, 0] = 2 * (x * y + r * z)
102
+ R[:, 1, 1] = 1 - 2 * (x * x + z * z)
103
+ R[:, 1, 2] = 2 * (y * z - r * x)
104
+ R[:, 2, 0] = 2 * (x * z - r * y)
105
+ R[:, 2, 1] = 2 * (y * z + r * x)
106
+ R[:, 2, 2] = 1 - 2 * (x * x + y * y)
107
  return R
108
 
109
+
110
  def build_scaling_rotation(s, r):
111
  L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
112
  R = build_rotation(r)
113
 
114
+ L[:, 0, 0] = s[:, 0]
115
+ L[:, 1, 1] = s[:, 1]
116
+ L[:, 2, 2] = s[:, 2]
117
 
118
  L = R @ L
119
  return L
120
 
121
+
122
  def safe_state(silent):
123
  old_f = sys.stdout
124
+
125
  class F:
126
  def __init__(self, silent):
127
  self.silent = silent
 
129
  def write(self, x):
130
  if not self.silent:
131
  if x.endswith("\n"):
132
+ old_f.write(
133
+ x.replace(
134
+ "\n",
135
+ " [{}]\n".format(
136
+ str(datetime.now().strftime("%d/%m %H:%M:%S"))
137
+ ),
138
+ )
139
+ )
140
  else:
141
  old_f.write(x)
142