@@ -3395,10 +3395,8 @@ def test_reset_control(robot_path, tol):
33953395
33963396
33973397@pytest .mark .required
3398- @pytest .mark .parametrize ("n_envs" , [0 , 3 ])
3399- @pytest .mark .parametrize ("backend" , [gs .cpu ])
3400- def test_joint_get_anchor_pos_and_axis (n_envs , backend , tol ):
3401- """Test that get_anchor_pos() and get_anchor_axis() work correctly."""
3398+ @pytest .mark .parametrize ("n_envs" , [0 , 2 ])
3399+ def test_joint_get_anchor_pos_and_axis (n_envs ):
34023400 scene = gs .Scene (
34033401 show_viewer = False ,
34043402 show_FPS = False ,
@@ -3409,58 +3407,15 @@ def test_joint_get_anchor_pos_and_axis(n_envs, backend, tol):
34093407 ),
34103408 )
34113409 scene .build (n_envs = n_envs )
3410+ batch_shape = (n_envs ,) if n_envs > 0 else ()
34123411
3413- # Get a non-fixed joint (skip the root joint which is FREE type)
3414- joint = None
3415- for j in robot .joints :
3416- if j .type != gs .JOINT_TYPE .FIXED and j .type != gs .JOINT_TYPE .FREE :
3417- joint = j
3418- break
3419-
3420- assert joint is not None , "No suitable joint found for testing"
3421-
3422- # Step the simulation to update joint states
3423- scene .step ()
3424-
3425- # Test get_anchor_pos()
3412+ joint = robot .joints [1 ]
34263413 anchor_pos = joint .get_anchor_pos ()
3414+ assert anchor_pos .shape == (* batch_shape , 3 )
3415+ expected_pos = scene .rigid_solver .joints_state .xanchor .to_numpy ()
3416+ assert_allclose (anchor_pos , expected_pos [joint .idx ], tol = gs .EPS )
34273417
3428- # Verify shape
3429- if n_envs == 0 :
3430- assert anchor_pos .shape == (3 ,), f"Expected shape (3,), got { anchor_pos .shape } "
3431- else :
3432- assert anchor_pos .shape == (n_envs , 3 ), f"Expected shape ({ n_envs } , 3), got { anchor_pos .shape } "
3433-
3434- # Verify values match internal state
3435- solver = scene .sim .rigid_solver
3436- if n_envs == 0 :
3437- expected_pos = solver .joints_state .xanchor [joint .idx , 0 ].to_numpy ()
3438- assert_allclose (anchor_pos .cpu ().numpy (), expected_pos , tol = tol )
3439- else :
3440- for i_env in range (n_envs ):
3441- expected_pos = solver .joints_state .xanchor [joint .idx , i_env ].to_numpy ()
3442- assert_allclose (anchor_pos [i_env ].cpu ().numpy (), expected_pos , tol = tol )
3443-
3444- # Test get_anchor_axis()
34453418 anchor_axis = joint .get_anchor_axis ()
3446-
3447- # Verify shape
3448- if n_envs == 0 :
3449- assert anchor_axis .shape == (3 ,), f"Expected shape (3,), got { anchor_axis .shape } "
3450- else :
3451- assert anchor_axis .shape == (n_envs , 3 ), f"Expected shape ({ n_envs } , 3), got { anchor_axis .shape } "
3452-
3453- # Verify values match internal state
3454- if n_envs == 0 :
3455- expected_axis = solver .joints_state .xaxis [joint .idx , 0 ].to_numpy ()
3456- assert_allclose (anchor_axis .cpu ().numpy (), expected_axis , tol = tol )
3457- else :
3458- for i_env in range (n_envs ):
3459- expected_axis = solver .joints_state .xaxis [joint .idx , i_env ].to_numpy ()
3460- assert_allclose (anchor_axis [i_env ].cpu ().numpy (), expected_axis , tol = tol )
3461-
3462- # Verify that calling these methods multiple times doesn't crash
3463- anchor_pos2 = joint .get_anchor_pos ()
3464- anchor_axis2 = joint .get_anchor_axis ()
3465- assert_allclose (anchor_pos , anchor_pos2 , tol = gs .EPS )
3466- assert_allclose (anchor_axis , anchor_axis2 , tol = gs .EPS )
3419+ assert anchor_axis .shape == (* batch_shape , 3 )
3420+ expected_axis = scene .rigid_solver .joints_state .xaxis .to_numpy ()
3421+ assert_allclose (anchor_axis , expected_axis [joint .idx ], tol = gs .EPS )
0 commit comments