Skip to content

nabu.cuda.kernel

[docs] module nabu.cuda.kernel

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
from os import linesep
import numpy as np
from cupy import RawModule
from ..processing.kernel_base import KernelBase
from ..utils import catch_warnings  # TODO use warnings.catch_warnings once python < 3.11 is dropped


class CudaKernel(KernelBase):
    """
    Helper class that wraps CUDA kernel through cupy SourceModule.

    Parameters
    -----------
    kernel_name: str
        Name of the CUDA kernel.
    filename: str, optional
        Path to the file name containing kernels definitions
    src: str, optional
        Source code of kernels definitions
    automation_params: dict, optional
        Automation parameters, see below
    sourcemodule_kwargs: optional
        Extra arguments to provide to cupy.RawModule(),

    """

    def __init__(
        self,
        kernel_name,
        filename=None,
        src=None,
        automation_params=None,
        silent_compilation_warnings=False,
        extern_c=True,
        **sourcemodule_kwargs,
    ):
        super().__init__(
            kernel_name,
            filename=filename,
            src=src,
            automation_params=automation_params,
            silent_compilation_warnings=silent_compilation_warnings,
        )

        if extern_c:
            # pycuda/pyopencl do that automatically, not cupy
            self.src = patch_sourcecode_add_externC(self.src, filename=filename)
        self.compile_kernel_source(kernel_name, sourcemodule_kwargs)

    def compile_kernel_source(self, kernel_name, sourcemodule_kwargs):
        self.sourcemodule_kwargs = sourcemodule_kwargs
        # Use NVCC by default
        if self.sourcemodule_kwargs.get("backend", None) is None:
            self.sourcemodule_kwargs["backend"] = "nvcc"
        #
        self.kernel_name = kernel_name
        with catch_warnings(action=("ignore" if self.silent_compilation_warnings else None)):  # pylint: disable=E1123
            self.module = RawModule(code=self.src, **self.sourcemodule_kwargs)
            self.module.compile()
        self.func = self.module.get_function(kernel_name)

    def follow_device_arr(self, args):
        return args

    def call(self, *args, **kwargs):
        grid, block, args, kwargs = self._prepare_call(*args, **kwargs)
        self.func(grid, block, args)

    __call__ = call


def patch_sourcecode_add_externC(src_code, filename=None):
    """
    Patch a source code to surround the relevant parts with 'extern C {}' directive
    The NVCC compiler needs this to avoid name mangling.
    """
    lines = src_code.split(linesep)
    incl_idx = []
    incl_idx_nows = []

    i = 0
    for i0, line in enumerate(lines):
        line = line.strip()
        if line.startswith(("//", "/*")):
            continue
        if line.startswith("#include"):
            incl_idx.append(i0)
            incl_idx_nows.append(i)
            i += 1
    if len(incl_idx) == 0:
        insertion_idx = 0
    else:
        if np.any(np.diff(incl_idx_nows) > 1):
            raise ValueError(
                f"'#include' should be grouped on top of the file, not separated by anything else - found #include at lines {incl_idx} in file {filename}"
            )
        else:
            insertion_idx = incl_idx[-1] + 1
    lines.insert(insertion_idx, 'extern "C" {')
    lines.append("}")
    modified_src = linesep.join(lines)
    return modified_src