What is PyTorch .detach()
method?
PyTorch's detach method works on the tensor class.
tensor.detach()
creates a tensor that shares storage with tensor that does not require gradient. tensor.clone()
creates a copy of tensor that imitates the original tensor's requires_grad
field.
You should use detach()
when attempting to remove a tensor from a computation graph, and clone as a way to copy the tensor while still keeping the copy as a part of the computation graph it came from.
Let's see that in an example here
X = torch.ones((28, 28), dtype=torch.float32, requires_grad=True)
y = X**2
z = X**2
result = (y+z).sum()
torchviz.make_dot(result).render('Attached', format='png')
And now one with the detach.
X = torch.ones((28, 28), dtype=torch.float32, requires_grad=True)
y = X**2
z = X.detach()**2
result = (y+z).sum()
torchviz.make_dot(result).render('Attached', format='png')
As you can see now that the branch of computation with x**2
is no longer tracked. This is reflected in the gradient of the result which no longer records the contribution of this branch