PyTorch Neural Network Interface (leaptorch)

class leaptorch.BackProjectorFunctionCPU(*args, **kwargs)
static backward(ctx, grad_output)

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, input, proj, vol, param_id)

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class leaptorch.BackProjectorFunctionGPU(*args, **kwargs)
static backward(ctx, grad_output)

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, input, proj, vol, param_id)

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class leaptorch.BaseProjector(use_static=False, use_gpu=False, gpu_device=None, batch_size=1)
allocate_batch_data()

Allocates the projection and volume batch data which is data that is used within this class and should be considered a private member variable

forward(input)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_projection_dim()

Returns the shape of the CT projection dimensions

get_volume_dim()

Returns the shape of the CT volume dimensions

load_param(param_fn, param_type=0)

Loads the LEAP parameters from file; same as leapct.tomographicModels.load_param

print_parameters()

Prints the CT geometry and CT volume parameters to the screen

save_param(param_fn)

Saves the LEAP parameters to file; same as leapct.tomographicModels.save_param

set_conebeam(numAngles, numRows, numCols, pixelHeight, pixelWidth, centerRow, centerCol, phis, sod, sdd, tau=0.0, helicalPitch=0.0)

Sets the parameters for a cone-beam CT geometry

The origin of the coordinate system is always at the center of rotation. This function is the same as leapct.tomographicModels.set_conebeam, except that it also allocates the batch data for the projections (see also allocate_batch_data)

Parameters:
  • numAngles (int) – number of projection angles

  • numRows (int) – number of rows in the x-ray detector

  • numCols (int) – number of columns in the x-ray detector

  • pixelHeight (float) – the detector pixel pitch (i.e., pixel size) between detector rows, measured in mm

  • pixelWidth (float) – the detector pixel pitch (i.e., pixel size) between detector columns, measured in mm

  • centerRow (float) – the detector pixel row index for the ray that passes from the source, through the origin, and hits the detector

  • centerCol (float) – the detector pixel column index for the ray that passes from the source, through the origin, and hits the detector

  • phis (float32 numpy array) – a numpy array for specifying the angles of each projection, measured in degrees

  • sod (float) – source to object distance, measured in mm; this can also be viewed as the source to center of rotation distance

  • sdd (float) – source to detector distance, measured in mm

  • tau (float) – center of rotation offset

  • helicalPitch (float) – the helical pitch (mm/radians)

Returns:

True if the parameters were valid, false otherwise

set_default_volume(scale=1.0)

Sets the default volume parameters

The default volume parameters are those that fill the field of view of the CT system and use the native voxel sizes. This function is the same as leapct.tomographicModels.set_default_volume, except that it also allocates the batch data for the volume (see also allocate_batch_data)

Parameters:

scale (float) – this value scales the voxel size by this value to create denser or sparser voxel representations (not recommended for fast reconstruction)

Returns:

True if the operation was successful, false otherwise (this usually happens if the CT geometry has not yet been set)

set_fanbeam(numAngles, numRows, numCols, pixelHeight, pixelWidth, centerRow, centerCol, phis, sod, sdd, tau=0.0)

Sets the parameters for a fan-beam CT geometry

The origin of the coordinate system is always at the center of rotation. This function is the same as leapct.tomographicModels.set_fanbeam, except that it also allocates the batch data for the projections (see also allocate_batch_data)

Parameters:
  • numAngles (int) – number of projection angles

  • numRows (int) – number of rows in the x-ray detector

  • numCols (int) – number of columns in the x-ray detector

  • pixelHeight (float) – the detector pixel pitch (i.e., pixel size) between detector rows, measured in mm

  • pixelWidth (float) – the detector pixel pitch (i.e., pixel size) between detector columns, measured in mm

  • centerRow (float) – the detector pixel row index for the ray that passes from the source, through the origin, and hits the detector

  • centerCol (float) – the detector pixel column index for the ray that passes from the source, through the origin, and hits the detector

  • phis (float32 numpy array) – a numpy array for specifying the angles of each projection, measured in degrees

  • sod (float) – source to object distance, measured in mm; this can also be viewed as the source to center of rotation distance

  • sdd (float) – source to detector distance, measured in mm

  • tau (float) – center of rotation offset

Returns:

True if the parameters were valid, false otherwise

set_gpu(which)

Sets the primary GPU number to be used by LEAP

set_gpus(listofgpus)

Sets all list of GPUs (by number) to be used by LEAP

set_modularbeam(numAngles, numRows, numCols, pixelHeight, pixelWidth, sourcePositions, detectorCenters, rowVec, colVec)

Sets the parameters for a modular-beam CT geometry

The origin of the coordinate system is always at the center of rotation. This function is the same as leapct.tomographicModels.set_modularbeam, except that it also allocates the batch data for the projections (see also allocate_batch_data)

Parameters:
  • numAngles (int) – number of projection angles

  • numRows (int) – number of rows in the x-ray detector

  • numCols (int) – number of columns in the x-ray detector

  • pixelHeight (float) – the detector pixel pitch (i.e., pixel size) between detector rows, measured in mm

  • pixelWidth (float) – the detector pixel pitch (i.e., pixel size) between detector columns, measured in mm

  • sourcePositions ((numAngles X 3) numpy array) – the (x,y,z) position of each x-ray source

  • moduleCenters ((numAngles X 3) numpy array) – the (x,y,z) position of the center of the front face of the detectors

  • rowVectors ((numAngles X 3) numpy array) – the (x,y,z) unit vector point along the positive detector row direction

  • colVectors ((numAngles X 3) numpy array) – the (x,y,z) unit vector point along the positive detector column direction

Returns:

True if the parameters were valid, false otherwise

set_parallelbeam(numAngles, numRows, numCols, pixelHeight, pixelWidth, centerRow, centerCol, phis)

Sets the parameters for a parallel-beam CT geometry

The origin of the coordinate system is always at the center of rotation. This function is the same as leapct.tomographicModels.set_parallelbeam, except that it also allocates the batch data for the projections (see also allocate_batch_data)

Parameters:
  • numAngles (int) – number of projection angles

  • numRows (int) – number of rows in the x-ray detector

  • numCols (int) – number of columns in the x-ray detector

  • pixelHeight (float) – the detector pixel pitch (i.e., pixel size) between detector rows, measured in mm

  • pixelWidth (float) – the detector pixel pitch (i.e., pixel size) between detector columns, measured in mm

  • centerRow (float) – the detector pixel row index for the ray that passes from the source, through the origin, and hits the detector

  • centerCol (float) – the detector pixel column index for the ray that passes from the source, through the origin, and hits the detector

  • phis (float32 numpy array) – a numpy array for specifying the angles of each projection, measured in degrees

Returns:

True if the parameters were valid, false otherwise

set_volume(numX, numY, numZ, voxelWidth, voxelHeight, offsetX=0.0, offsetY=0.0, offsetZ=0.0)

Set the CT volume parameters

This function is the same as leapct.tomographicModels.set_volume, except that it also allocates the batch data for the volume (see also allocate_batch_data)

Parameters:
  • numX (int) – number of voxels in the x-dimension

  • numY (int) – number of voxels in the y-dimension

  • numZ (int) – number of voxels in the z-dimension

  • voxelWidth (float) – voxel pitch (size) in the x and y dimensions

  • voxelHeight (float) – voxel pitch (size) in the z dimension

  • offsetX (float) – shift the volume in the x-dimension, measured in mm

  • offsetY (float) – shift the volume in the y-dimension, measured in mm

  • offsetZ (float) – shift the volume in the z-dimension, measured in mm

Returns:

True if the parameters were valid, false otherwise

class leaptorch.FBP(forward_FBP=True, use_static=False, use_gpu=False, gpu_device=None, batch_size=1)
forward(input)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class leaptorch.FBPFunctionCPU(*args, **kwargs)
static backward(ctx, grad_output)

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, input, proj, vol, param_id)

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class leaptorch.FBPFunctionGPU(*args, **kwargs)
static backward(ctx, grad_output)

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, input, proj, vol, param_id)

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class leaptorch.FBPReverseFunctionCPU(*args, **kwargs)
static backward(ctx, grad_output)

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, input, proj, vol, param_id)

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class leaptorch.FBPReverseFunctionGPU(*args, **kwargs)
static backward(ctx, grad_output)

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, input, proj, vol, param_id)

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class leaptorch.Projector(forward_project=True, use_static=False, use_gpu=False, gpu_device=None, batch_size=1)

Python class for PyTorch binding of LEAP

Note that leapct is a member variable of this class which is an object of the leapctype.tomographicModels class.

Thus all tomography functions can be access by (object of this class).leapct.XXX

Usage Example:

from leaptorch import Projector

proj = Projector(forward_project=True, use_static=True, use_gpu=use_cuda, gpu_device=device)

proj.set_conebeam(…)

proj.set_default_volume(…) …

fbp(input)

Performs Filtered Backprojection (FBP) reconstruction of any CT geometry on the batch data

forward(input)

Performs the forward model on the batch data (forward projection if forward_project=True, backprojection otherwise)

class leaptorch.ProjectorFunctionCPU(*args, **kwargs)
static backward(ctx, grad_output)

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, input, proj, vol, param_id)

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

class leaptorch.ProjectorFunctionGPU(*args, **kwargs)
static backward(ctx, grad_output)

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

static forward(ctx, input, proj, vol, param_id)

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

  • See combining-forward-context for more details

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See extending-autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.