jvp

paddle.incubate.autograd. jvp ( func, xs, v=None ) [源代码]

计算函数 funcxs 处的雅可比矩阵与向量 v 的乘积。

警告

该API目前为Beta版本,函数签名在未来版本可能发生变化。

参数

  • func (Callable) - Python函数,输入参数为 xs , 输出为Tensor或Tensor序列。

  • xs (Tensor|Sequence[Tensor]) - 函数 func 的输入参数,数据类型为Tensor或 Tensor序列。

  • v (Tensor|Sequence[Tensor]|None, 可选) - 用于计算 jvp 的输入向量,形状要求 与 xs 一致。默认值为 None , 即相当于形状与 xs 一致,值全为1的Tensor或 Tensor序列。

返回

  • func_out (Tensor|tuple[Tensor]) - 函数 func(xs) 的输出。

  • jvp (Tensor|tuple[Tensor]) - jvp 计算结果。

代码示例

import paddle


def func(x):
    return paddle.matmul(x, x)


x = paddle.ones(shape=[2, 2], dtype='float32')
_, jvp_result = paddle.incubate.autograd.jvp(func, x)
print(jvp_result)
# Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=False,
#        [[4., 4.],
#         [4., 4.]])
v = paddle.to_tensor([[1.0, 0.0], [0.0, 0.0]])
_, jvp_result = paddle.incubate.autograd.jvp(func, x, v)
print(jvp_result)
# Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=False,
#        [[2., 1.],
#         [1., 0.]])