tumourkit.classification.infer.run_inference

tumourkit.classification.infer.run_inference(model: Module, loader: GraphDataLoader, device: str, num_classes: int, enable_background: bool) Dict[str, ndarray]

Runs inference using the specified model on the provided data loader.

Parameters:
  • model (nn.Module) – The model used for inference.

  • loader (GraphDataLoader) – The graph data loader.

  • device (str) – The device used for inference (e.g., ‘cpu’ or ‘cuda’).

  • num_classes (int) – The number of classes.

  • enable_background (bool) – Enable when model has extra head to correct extra cells.

Returns:

The probabilities for all the nodes.

Return type:

Dict[str, np.ndarray]