Source code for nabu.opencl.kernel

import pyopencl.array as parray
from pyopencl import Program, CommandQueue, kernel_work_group_info
from ..utils import deprecation_warning
from ..processing.kernel_base import KernelBase


[docs] class OpenCLKernel(KernelBase): """ Helper class that wraps OpenCL kernel through pyopencl. Parameters ----------- kernel_name: str Name of the OpenCL kernel. ctx: pyopencl.Context OpenCL context to use. queue: pyopencl.CommandQueue OpenCL queue to use. If provided, will use this queue's context instead of 'ctx' filename: str, optional Path to the file name containing kernels definitions src: str, optional Source code of kernels definitions automation_params: dict, optional Automation parameters, see below build_kwargs: optional Extra arguments to provide to pyopencl.Program.build(), """ def __init__( self, kernel_name, ctx, queue=None, filename=None, src=None, automation_params=None, **build_kwargs, ): super().__init__(kernel_name, filename=filename, src=src, automation_params=automation_params) if queue is not None: self.ctx = queue.context self.queue = queue else: self.ctx = ctx self.queue = None self.compile_kernel_source(kernel_name, build_kwargs) self.get_kernel()
[docs] def compile_kernel_source(self, kernel_name, build_kwargs): self.build_kwargs = build_kwargs self.kernel_name = kernel_name self.program = Program(self.ctx, self.src).build(**self.build_kwargs)
[docs] def get_kernel(self): self.kernel = None for kern in self.program.all_kernels(): if kern.function_name == self.kernel_name: self.kernel = kern if self.kernel is None: raise ValueError( "Could not find a kernel with function name '%s'. Available are: %s" % (self.kernel_name, self.program.kernel_names) )
# overwrite parent method
[docs] def guess_block_size(self, shape): device = self.ctx.devices[0] wg_max = device.max_work_group_size wg_multiple = self.kernel.get_work_group_info(kernel_work_group_info.PREFERRED_WORK_GROUP_SIZE_MULTIPLE, device) ndim = len(shape) # Try to have workgroup relatively well-balanced in all dimensions, # with more work items in x > y > z if ndim == 1: wg = (wg_max, 1, 1) else: w = (wg_max // wg_multiple, wg_multiple) wg = w if w[0] > w[1] else w[::-1] wg = wg + (1,) if ndim == 3: (wg[0] // 2, wg[1] // 4, 8) return wg
[docs] def get_block_grid(self, *args, **kwargs): local_size = None global_size = block = None # COMPAT. block = kwargs.pop("block", None) if block is not None: deprecation_warning("Please use 'local_size' instead of 'block'") grid = kwargs.pop("grid", None) if grid is not None: deprecation_warning("Please use 'global_size' instead of 'grid'") global_size = tuple(g * b for g, b in zip(grid, block)) # global_size = kwargs.pop("global_size", global_size) local_size = kwargs.pop("local_size", block) if global_size is None: raise ValueError("Need to define global_size for kernel '%s'" % self.kernel_name) if len(global_size) == 2 and local_size is not None and len(local_size) == 3: local_size = local_size[:-1] # TODO check that last dim is 1 self.last_block_size = local_size self.last_grid_size = global_size return local_size, global_size
[docs] def follow_device_arr(self, args): args = list(args) for i, arg in enumerate(args): if isinstance(arg, parray.Array): args[i] = arg.data return tuple(args)
[docs] def call(self, *args, **kwargs): if not isinstance(args[0], CommandQueue): queue = self.queue if queue is None: raise ValueError( "First argument must be a pyopencl queue - otherwise provide OpenCLKernel(..., queue=queue)" ) else: queue = args[0] args = args[1:] global_size, local_size, args, kwargs = self._prepare_call(*args, **kwargs) kwargs.pop("global_size", None) kwargs.pop("local_size", None) kwargs.pop("grid", None) kwargs.pop("block", None) return self.kernel(queue, global_size, local_size, *args, **kwargs)
__call__ = call