Assign

class paddle.nn.initializer. Assign ( value, name=None ) [源代码]

该接口为参数初始化函数,使用 Numpy 数组、Python 列表、Tensor 来初始化参数。

参数

  • value (Tensor|numpy.ndarray|list) - 用于初始化参数的一个 Numpy 数组、Python 列表、Tensor。

  • name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值为 None。

返回

由 Numpy 数组、Python 列表、Tensor 初始化的参数。

代码示例

>>> import paddle
>>> import numpy as np

>>> # numpy array
>>> data_1 = paddle.ones(shape=[1, 2], dtype='float32')
>>> weight_attr_1 = paddle.framework.ParamAttr(
...     name="linear_weight_1",
...     initializer=paddle.nn.initializer.Assign(np.array([2, 2])))
>>> bias_attr_1 = paddle.framework.ParamAttr(
...     name="linear_bias_1",
...     initializer=paddle.nn.initializer.Assign(np.array([2])))
>>> linear_1 = paddle.nn.Linear(2, 2, weight_attr=weight_attr_1, bias_attr=bias_attr_1)
>>> print(linear_1.weight.numpy())
[2. 2.]
>>> print(linear_1.bias.numpy())
[2.]

>>> res_1 = linear_1(data_1)
>>> print(res_1.numpy())
[6.]

>>> # python list
>>> data_2 = paddle.ones(shape=[1, 2], dtype='float32')
>>> weight_attr_2 = paddle.framework.ParamAttr(
...     name="linear_weight_2",
...     initializer=paddle.nn.initializer.Assign([2, 2]))
>>> bias_attr_2 = paddle.framework.ParamAttr(
...     name="linear_bias_2",
...     initializer=paddle.nn.initializer.Assign([2]))
>>> linear_2 = paddle.nn.Linear(2, 2, weight_attr=weight_attr_2, bias_attr=bias_attr_2)
>>> print(linear_2.weight.numpy())
[2. 2.]
>>> print(linear_2.bias.numpy())
[2.]

>>> res_2 = linear_2(data_2)
>>> print(res_2.numpy())
[6.]

>>> # tensor
>>> data_3 = paddle.ones(shape=[1, 2], dtype='float32')
>>> weight_attr_3 = paddle.framework.ParamAttr(
...     name="linear_weight_3",
...     initializer=paddle.nn.initializer.Assign(paddle.full([2], 2)))
>>> bias_attr_3 = paddle.framework.ParamAttr(
...     name="linear_bias_3",
...     initializer=paddle.nn.initializer.Assign(paddle.full([1], 2)))
>>> linear_3 = paddle.nn.Linear(2, 2, weight_attr=weight_attr_3, bias_attr=bias_attr_3)
>>> print(linear_3.weight.numpy())
[2. 2.]
>>> print(linear_3.bias.numpy())
[2.]

>>> res_3 = linear_3(data_3)
>>> print(res_3.numpy())
[6.]