Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 91 additions & 14 deletions source/isaaclab/isaaclab/sim/prims/xform_prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,24 @@ def _to_numpy(self, value: torch.Tensor | np.ndarray | Sequence[float] | None) -
else:
return np.array(value)

@staticmethod
def _world_to_local_tq(prim: Usd.Prim, world_t: Gf.Vec3d, world_q: Gf.Quatd) -> tuple[Gf.Vec3d, Gf.Quatd]:
"""Convert desired world (t,q) into local (t,q) wrt prim's parent."""
time = Usd.TimeCode.Default()
parent = prim.GetParent()
if parent and parent.IsValid():
parent_w = UsdGeom.Xformable(parent).ComputeLocalToWorldTransform(time)
else:
parent_w = Gf.Matrix4d(1.0)

world_m = Gf.Matrix4d(1.0)
world_m.SetRotate(world_q)
world_m.SetTranslateOnly(world_t)

local_m = parent_w.GetInverse() * world_m
local_tf = Gf.Transform(local_m)
return local_tf.GetTranslation(), local_tf.GetRotation().GetQuat()

def set_world_poses(
self,
positions: torch.Tensor | np.ndarray | Sequence[float] | None = None,
Expand Down Expand Up @@ -282,19 +300,26 @@ def set_world_poses(
orient_op = xformable.AddXformOp(UsdGeom.XformOp.TypeOrient, UsdGeom.XformOp.PrecisionDouble)

# Set position
if pos_np is not None:
# Convert numpy values to Python floats for USD
translate_op.Set(Gf.Vec3d(float(pos_np[idx, 0]), float(pos_np[idx, 1]), float(pos_np[idx, 2])))

# Set orientation
if orient_np is not None:
# Convert numpy values to Python floats for USD
w = float(orient_np[idx, 0])
x = float(orient_np[idx, 1])
y = float(orient_np[idx, 2])
z = float(orient_np[idx, 3])
quat = Gf.Quatd(w, Gf.Vec3d(x, y, z))
orient_op.Set(quat)
if pos_np is not None or orient_np is not None:
current_world = xformable.ComputeLocalToWorldTransform(Usd.TimeCode.Default())

if pos_np is not None:
world_t = Gf.Vec3d(float(pos_np[idx, 0]), float(pos_np[idx, 1]), float(pos_np[idx, 2]))
else:
world_t = current_world.ExtractTranslation()

if orient_np is not None:
w = float(orient_np[idx, 0])
x = float(orient_np[idx, 1])
y = float(orient_np[idx, 2])
z = float(orient_np[idx, 3])
world_q = Gf.Quatd(w, Gf.Vec3d(x, y, z))
else:
world_q = current_world.ExtractRotation().GetQuat()

local_t, local_q = self._world_to_local_tq(prim, world_t, world_q)
translate_op.Set(local_t)
orient_op.Set(local_q)

def set_local_poses(
self,
Expand All @@ -310,7 +335,59 @@ def set_local_poses(
indices: Indices of prims to update. If None, all prims are updated.
"""
# For local poses, we use the same method since USD xform ops are inherently local
self.set_world_poses(positions=translations, orientations=orientations, indices=indices)
# Convert to numpy
trans_np = self._to_numpy(translations)
orient_np = self._to_numpy(orientations)
indices_np = self._to_numpy(indices)

# Determine which prims to update
if indices_np is None:
prim_indices = range(self._count)
else:
prim_indices = indices_np.astype(int)

# Broadcast if needed
if trans_np is not None:
if trans_np.ndim == 1:
trans_np = np.tile(trans_np, (len(prim_indices), 1))

if orient_np is not None:
if orient_np.ndim == 1:
orient_np = np.tile(orient_np, (len(prim_indices), 1))

# Update each prim
for idx, prim_idx in enumerate(prim_indices):
prim = self._prims[prim_idx]
xformable = UsdGeom.Xformable(prim)

# Get or create the translate op
translate_attr = prim.GetAttribute("xformOp:translate")
if translate_attr:
translate_op = UsdGeom.XformOp(translate_attr)
else:
translate_op = xformable.AddXformOp(UsdGeom.XformOp.TypeTranslate, UsdGeom.XformOp.PrecisionDouble)

# Get or create the orient op
orient_attr = prim.GetAttribute("xformOp:orient")
if orient_attr:
orient_op = UsdGeom.XformOp(orient_attr)
else:
orient_op = xformable.AddXformOp(UsdGeom.XformOp.TypeOrient, UsdGeom.XformOp.PrecisionDouble)

# Set translation
if trans_np is not None:
# Convert numpy values to Python floats for USD
translate_op.Set(Gf.Vec3d(float(trans_np[idx, 0]), float(trans_np[idx, 1]), float(trans_np[idx, 2])))

# Set orientation
if orient_np is not None:
# Convert numpy values to Python floats for USD
w = float(orient_np[idx, 0])
x = float(orient_np[idx, 1])
y = float(orient_np[idx, 2])
z = float(orient_np[idx, 3])
quat = Gf.Quatd(w, Gf.Vec3d(x, y, z))
orient_op.Set(quat)

def set_local_scales(
self,
Expand Down
Loading