tumourkit.classification.infer.run_inference
- tumourkit.classification.infer.run_inference(model: Module, loader: GraphDataLoader, device: str, num_classes: int, enable_background: bool) Dict[str, ndarray]
Runs inference using the specified model on the provided data loader.
- Parameters:
model (nn.Module) – The model used for inference.
loader (GraphDataLoader) – The graph data loader.
device (str) – The device used for inference (e.g., ‘cpu’ or ‘cuda’).
num_classes (int) – The number of classes.
enable_background (bool) – Enable when model has extra head to correct extra cells.
- Returns:
The probabilities for all the nodes.
- Return type:
Dict[str, np.ndarray]