Skip to content

Commit 6aef3d2

Browse files
authored
Cleanup unit test.
1 parent 2c0ff39 commit 6aef3d2

File tree

1 file changed

+10
-55
lines changed

1 file changed

+10
-55
lines changed

tests/test_rigid_physics.py

Lines changed: 10 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3395,10 +3395,8 @@ def test_reset_control(robot_path, tol):
33953395

33963396

33973397
@pytest.mark.required
3398-
@pytest.mark.parametrize("n_envs", [0, 3])
3399-
@pytest.mark.parametrize("backend", [gs.cpu])
3400-
def test_joint_get_anchor_pos_and_axis(n_envs, backend, tol):
3401-
"""Test that get_anchor_pos() and get_anchor_axis() work correctly."""
3398+
@pytest.mark.parametrize("n_envs", [0, 2])
3399+
def test_joint_get_anchor_pos_and_axis(n_envs):
34023400
scene = gs.Scene(
34033401
show_viewer=False,
34043402
show_FPS=False,
@@ -3409,58 +3407,15 @@ def test_joint_get_anchor_pos_and_axis(n_envs, backend, tol):
34093407
),
34103408
)
34113409
scene.build(n_envs=n_envs)
3410+
batch_shape = (n_envs,) if n_envs > 0 else ()
34123411

3413-
# Get a non-fixed joint (skip the root joint which is FREE type)
3414-
joint = None
3415-
for j in robot.joints:
3416-
if j.type != gs.JOINT_TYPE.FIXED and j.type != gs.JOINT_TYPE.FREE:
3417-
joint = j
3418-
break
3419-
3420-
assert joint is not None, "No suitable joint found for testing"
3421-
3422-
# Step the simulation to update joint states
3423-
scene.step()
3424-
3425-
# Test get_anchor_pos()
3412+
joint = robot.joints[1]
34263413
anchor_pos = joint.get_anchor_pos()
3414+
assert anchor_pos.shape == (*batch_shape, 3)
3415+
expected_pos = scene.rigid_solver.joints_state.xanchor.to_numpy()
3416+
assert_allclose(anchor_pos, expected_pos[joint.idx], tol=gs.EPS)
34273417

3428-
# Verify shape
3429-
if n_envs == 0:
3430-
assert anchor_pos.shape == (3,), f"Expected shape (3,), got {anchor_pos.shape}"
3431-
else:
3432-
assert anchor_pos.shape == (n_envs, 3), f"Expected shape ({n_envs}, 3), got {anchor_pos.shape}"
3433-
3434-
# Verify values match internal state
3435-
solver = scene.sim.rigid_solver
3436-
if n_envs == 0:
3437-
expected_pos = solver.joints_state.xanchor[joint.idx, 0].to_numpy()
3438-
assert_allclose(anchor_pos.cpu().numpy(), expected_pos, tol=tol)
3439-
else:
3440-
for i_env in range(n_envs):
3441-
expected_pos = solver.joints_state.xanchor[joint.idx, i_env].to_numpy()
3442-
assert_allclose(anchor_pos[i_env].cpu().numpy(), expected_pos, tol=tol)
3443-
3444-
# Test get_anchor_axis()
34453418
anchor_axis = joint.get_anchor_axis()
3446-
3447-
# Verify shape
3448-
if n_envs == 0:
3449-
assert anchor_axis.shape == (3,), f"Expected shape (3,), got {anchor_axis.shape}"
3450-
else:
3451-
assert anchor_axis.shape == (n_envs, 3), f"Expected shape ({n_envs}, 3), got {anchor_axis.shape}"
3452-
3453-
# Verify values match internal state
3454-
if n_envs == 0:
3455-
expected_axis = solver.joints_state.xaxis[joint.idx, 0].to_numpy()
3456-
assert_allclose(anchor_axis.cpu().numpy(), expected_axis, tol=tol)
3457-
else:
3458-
for i_env in range(n_envs):
3459-
expected_axis = solver.joints_state.xaxis[joint.idx, i_env].to_numpy()
3460-
assert_allclose(anchor_axis[i_env].cpu().numpy(), expected_axis, tol=tol)
3461-
3462-
# Verify that calling these methods multiple times doesn't crash
3463-
anchor_pos2 = joint.get_anchor_pos()
3464-
anchor_axis2 = joint.get_anchor_axis()
3465-
assert_allclose(anchor_pos, anchor_pos2, tol=gs.EPS)
3466-
assert_allclose(anchor_axis, anchor_axis2, tol=gs.EPS)
3419+
assert anchor_axis.shape == (*batch_shape, 3)
3420+
expected_axis = scene.rigid_solver.joints_state.xaxis.to_numpy()
3421+
assert_allclose(anchor_axis, expected_axis[joint.idx], tol=gs.EPS)

0 commit comments

Comments
 (0)