jax.numpy.linalg.matrix_norm# jax.numpy.linalg.matrix_norm(x, /, *, keepdims=False, ord='fro')[source]# Computes the matrix norm of a matrix (or a stack of matrices) x. Parameters: x (Union[Array, ndarray, bool_, number, bool, int, float, complex]) – keepdims (bool) – ord (str) – Return type: Array