diff --git a/jax/lax_linalg.py b/jax/lax_linalg.py index e782ced27fb8..fb9893e3a9c0 100644 --- a/jax/lax_linalg.py +++ b/jax/lax_linalg.py @@ -66,10 +66,18 @@ def qr(x, full_matrices=True): return q, r def svd(x, full_matrices=True, compute_uv=True): - s, u, v = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv) + """Singular value decomposition. + + Returns the singular values if compute_uv is False, otherwise returns a triple + containing the left singular vectors, the singular values and the adjoint of + the right singular vectors. + """ + result = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv) if compute_uv: + s, u, v = result return u, s, v else: + s, = result return s def triangular_solve(a, b, left_side=False, lower=False, transpose_a=False, @@ -847,9 +855,8 @@ def _qr_cpu_gpu_translation_rule(geqrf_impl, orgqr_impl, c, operand, # Singular value decomposition def svd_impl(operand, full_matrices, compute_uv): - s, u, vt = xla.apply_primitive(svd_p, operand, full_matrices=full_matrices, - compute_uv=compute_uv) - return s, u, vt + return xla.apply_primitive(svd_p, operand, full_matrices=full_matrices, + compute_uv=compute_uv) def svd_translation_rule(c, operand, full_matrices, compute_uv): shape = c.get_shape(operand).dimensions() @@ -861,7 +868,12 @@ def svd_translation_rule(c, operand, full_matrices, compute_uv): if not full_matrices and m != n: u = xops.SliceInDim(u, 0, min(m, n), stride=1, dimno=len(shape) - 1) vt = xops.SliceInDim(vt, 0, min(m, n), stride=1, dimno=len(shape) - 2) - return xops.Tuple(c, [s, u, vt]) + + if not compute_uv: + return xops.Tuple(c, [s]) + else: + return xops.Tuple(c, [s, u, vt]) + def svd_abstract_eval(operand, full_matrices, compute_uv): if isinstance(operand, ShapedArray): @@ -872,11 +884,14 @@ def svd_abstract_eval(operand, full_matrices, compute_uv): m = operand.shape[-2] n = operand.shape[-1] s = ShapedArray(batch_dims + (min(m, n),), lax.lax._complex_basetype(operand.dtype)) - u = ShapedArray(batch_dims + (m, m if full_matrices else min(m, n)), operand.dtype) - vt = ShapedArray(batch_dims + (n if full_matrices else min(m, n), n), operand.dtype) + if compute_uv: + u = ShapedArray(batch_dims + (m, m if full_matrices else min(m, n)), operand.dtype) + vt = ShapedArray(batch_dims + (n if full_matrices else min(m, n), n), operand.dtype) + return s, u, vt + else: + return s, else: raise NotImplementedError - return s, u, vt def svd_jvp_rule(primals, tangents, full_matrices, compute_uv): A, = primals @@ -893,6 +908,10 @@ def svd_jvp_rule(primals, tangents, full_matrices, compute_uv): s_dim = s[..., None, :] dS = jnp.matmul(jnp.matmul(Ut, dA), V) ds = jnp.real(jnp.diagonal(dS, 0, -2, -1)) + + if not compute_uv: + return (s,), (ds,) + F = 1 / (jnp.square(s_dim) - jnp.square(_T(s_dim)) + jnp.eye(k, dtype=A.dtype)) F = F - jnp.eye(k, dtype=A.dtype) dSS = s_dim * dS @@ -917,18 +936,28 @@ def _svd_cpu_gpu_translation_rule(gesvd_impl, c, operand, full_matrices, compute ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32))) s = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1,)), s, _nan_like(c, s)) - u = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), u, - _nan_like(c, u)) - vt = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vt, - _nan_like(c, vt)) - return xops.Tuple(c, [s, u, vt]) + + result = [s] + + if compute_uv: + u = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), u, + _nan_like(c, u)) + vt = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), vt, + _nan_like(c, vt)) + result += [u, vt] + + return xops.Tuple(c, result) def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) outs = svd_p.bind(x, full_matrices=full_matrices, compute_uv=compute_uv) - return outs, (0, 0, 0) + + if compute_uv: + return outs, (0, 0, 0) + else: + return outs, (0,) svd_p = Primitive('svd') svd_p.multiple_results = True