| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 | 
							- import torch
 
- from torch import Tensor
 
- from typing import Callable, List
 
- import re
 
- __all__ : List[str] = []
 
- class _CodeParser:
 
-     def __init__(self, code_string: str):
 
-         optional_ws = r"\s*"
 
-         required_ws = r"\s+"
 
-         template_params = r"(?P<template_params>\<.+\>)"
 
-         return_type = r"(?P<return_type>\w+)"
 
-         function_name = r"(?P<function_name>\w+)"
 
-         function_params = r"(?P<function_params>\(.+\))"
 
-         function_body = r"(?P<function_body>\{.+\})"
 
-         pattern = \
 
-             optional_ws \
 
-             + "template" \
 
-             + optional_ws + template_params \
 
-             + optional_ws + return_type \
 
-             + required_ws + function_name \
 
-             + optional_ws + function_params \
 
-             + optional_ws + function_body \
 
-             + optional_ws
 
-         result = re.match(pattern, code_string, re.DOTALL)  # DOTALL for matching multiline
 
-         if result is None:
 
-             raise Exception(f"Couldn't parse code, please check correctness:\n {code_string}")
 
-         self.template_params = result["template_params"]
 
-         self.return_type = result["return_type"]
 
-         self.function_name = result["function_name"]
 
-         self.function_params = result["function_params"]
 
-         self.function_body = result["function_body"]
 
- class _JittedFunction:
 
-     def __init__(self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs):
 
-         self.code_string = code_string
 
-         assert return_by_ref or num_outputs == 1, "Return by value only works for single output. "
 
-         self.return_by_ref = return_by_ref
 
-         self.num_outputs = num_outputs
 
-         parsed_code = _CodeParser(code_string)
 
-         self.kernel_name = parsed_code.function_name
 
-         self.kwargs_dict = kwargs
 
-         self.is_cuda_available = torch.cuda.is_available()
 
-     def __call__(self, *tensors: Tensor, **kwargs):
 
-         # Jiterator follow torch.cuda's lazy initialization behavior
 
-         # Defer checking cuda's availability at the function invocation time
 
-         assert self.is_cuda_available, "Jiterator is only supported on CUDA and ROCm GPUs, none are available."
 
-         assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs."
 
-         expanded_kwargs = self.kwargs_dict.copy()
 
-         for key, value in kwargs.items():
 
-             if key in self.kwargs_dict:
 
-                 expanded_kwargs[key] = value
 
-             else:
 
-                 raise KeyError(f"{key} is not declared in function definition")
 
-         return torch._C._cuda_jiterator_compile_and_launch_kernel(
 
-             self.code_string,
 
-             self.kernel_name,
 
-             self.return_by_ref,
 
-             self.num_outputs,
 
-             tensors,
 
-             expanded_kwargs)
 
- def _create_jit_fn(code_string: str, **kwargs) -> Callable:
 
-     """
 
-     Create a jiterator-generated cuda kernel for an elementwise op.
 
-     The code string has to be a valid CUDA function that describes the computation for a single element. The code
 
-     string has to follow the c++ template pattern, as shown in the example below. This function will be inlined
 
-     into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as
 
-     local temp dir.
 
-     Jiterator-generated kernels accepts noncontiguous tensors, and supports boardcasting and type promotion.
 
-     Args:
 
-         code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return by value.
 
-         kwargs (Dict, optional): Keyword arguments for generated function
 
-     Example::
 
-         code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"
 
-         jitted_fn = create_jit_fn(code_string, alpha=1.0)
 
-         a = torch.rand(3, device='cuda')
 
-         b = torch.rand(3, device='cuda')
 
-         # invoke jitted function like a regular python function
 
-         result = jitted_fn(a, b, alpha=3.14)
 
-     code_string also allows multiple function definitions, and the last function will be treated as the entry function.
 
-     Example::
 
-         code_string = "template <typename T> T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }"
 
-         code_string += "template <typename T> T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }"
 
-         jitted_fn = create_jit_fn(code_string, val=0.0)
 
-         a = torch.rand(3, device='cuda')
 
-         b = torch.rand(3, device='cuda')
 
-         # invoke jitted function like a regular python function
 
-         result = jitted_fn(a, b)  # using default val=0.0
 
-     Jiterator can be used together with python registration to override an operator's cuda kernel.
 
-     Following example is overriding gelu's cuda kernel with relu.
 
-     Example::
 
-         code_string = "template <typename T> T my_gelu(T a) { return a > 0 ? a : 0; }"
 
-         my_gelu = create_jit_fn(code_string)
 
-         my_lib = torch.library.Library("aten", "IMPL")
 
-         my_lib.impl('aten::gelu', my_gelu, "CUDA")
 
-         # torch.nn.GELU and torch.nn.function.gelu are now overridden
 
-         a = torch.rand(3, device='cuda')
 
-         torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a))
 
-     .. warning::
 
-         This API is in beta and may change in future releases.
 
-     .. warning::
 
-         This API only supports up to 8 inputs and 1 output
 
-     .. warning::
 
-         All input tensors must live in CUDA device
 
-     """
 
-     return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs)
 
- def _create_multi_output_jit_fn(code_string: str, num_outputs: int, **kwargs) -> Callable:
 
-     """
 
-     Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs.
 
-     Args:
 
-         code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return value by reference.
 
-         num_outputs(int): number of outputs return by the kernel
 
-         kwargs (Dict, optional): Keyword arguments for generated function
 
-     Example::
 
-         code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }"
 
-         jitted_fn = create_jit_fn(code_string, alpha=1.0)
 
-         a = torch.rand(3, device='cuda')
 
-         b = torch.rand(3, device='cuda')
 
-         # invoke jitted function like a regular python function
 
-         result = jitted_fn(a, b, alpha=3.14)
 
-     .. warning::
 
-         This API is in beta and may change in future releases.
 
-     .. warning::
 
-         This API only supports up to 8 inputs and 8 outputs
 
-     """
 
-     return _JittedFunction(code_string, return_by_ref=True, num_outputs=num_outputs, **kwargs)
 
 
  |