Prune

Motivation

We want to support running inference, training and checkpointing in one ProgramDesc. We implement void Prune(const ProgramDesc* input, ProgramDesc* output) function, which takes a ProgramDesc and generate a pruned ProgramDesc.

Challenge

Pruning need to support both variables and operators being evaluation targets. Consider the following different situations.

# Case 1: run foward pass.
cost_np = session.run(target=cost)
# Case 2: run backward passing.
opts_np, _ = session.run(target=[cost, opt])
# Case 3: run checkpointing
_ = session.run(target=checkpoint)

Solution

To support evaluation of operators, we add is_target field in the OpDesc.

message OpDesc {
  required string type = 3;
  repeated Var inputs = 1;
  repeated Var outputs = 2;
  repeated Attr attrs = 4;
  optional bool is_target = 5 [ default = false ];
};

To support evaluation of variables, we add fetch_op. For each variable in the target, we insert a fetch_op into the ProgramDesc with variable being fetch_op's input. Then we also set fetch_op is a target.

Algorithm

If an operator needs to be run, it must fall into one of the following cases:

  1. It is the target.
  2. It is depended by some other ops, meaning its output is some other op's input.

The first case can be checked by op_desc.is_traget() . The second case can be implement as

bool HasDependentVar(const OpDesc& op_desc, const std::set<string>& dependent_vars) {
  for (auto& var : op_desc.outputs()) {
    for (auto& argu : var.arguments()) {
      if (dependent_vars.count(argu) != 0) {
        return true;
      }
    }
  }
  return false;
}

Then the whole algorithm can be implemented as the following code.