Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@ jobs:
- uses: actions/checkout@v4
- name: "Main Script"
run: |
EXTRA_INSTALL="mypy pytest types-colorama types-Pygments"
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0

build_py_project_in_conda_env
python -m pip install mypy
./run-mypy.sh

pytest:
Expand Down
3 changes: 1 addition & 2 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,10 @@ Ruff:

Mypy:
script: |
EXTRA_INSTALL="pybind11 numpy"
EXTRA_INSTALL="mypy pybind11 numpy types-colorama types-Pygments"
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
build_py_project_in_venv
python -m pip install mypy
./run-mypy.sh
tags:
- python3
Expand Down
2 changes: 1 addition & 1 deletion loopy/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def __init__(
kwargs = _apply_legacy_map(self._legacy_options_map, kwargs)

try:
import colorama # noqa
import colorama # noqa: F401
except ImportError:
allow_terminal_colors_def = False
else:
Expand Down
23 changes: 16 additions & 7 deletions loopy/target/c/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,9 +795,12 @@ def known_callables(self):
# {{{ code generation

def get_function_definition(
self, codegen_state: CodeGenerationState,
self,
codegen_state: CodeGenerationState,
codegen_result: CodeGenerationResult,
schedule_index: int, function_decl: Generable, function_body: Generable
schedule_index: int,
function_decl: Generable,
function_body: Generable
) -> Generable:
kernel = codegen_state.kernel
assert kernel.linearization is not None
Expand Down Expand Up @@ -825,16 +828,23 @@ def get_function_definition(
tv.initializer is not None):
assert tv.read_only

decl: Generable = self.wrap_global_constant(
decl = self.wrap_global_constant(
self.get_temporary_var_declarator(codegen_state, tv))

if tv.initializer is not None:
decl = Initializer(decl, generate_array_literal(
init_decl = Initializer(decl, generate_array_literal(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is necessary because Initializer takes a Declarator, not a Generable (like the forced annotation above).

codegen_state, tv, tv.initializer))
else:
init_decl = decl

result.append(decl)
result.append(init_decl)

assert isinstance(function_decl, FunctionDeclarationWrapper)
if not isinstance(function_body, Block):
function_body = Block([function_body])

fbody = FunctionBody(function_decl, function_body)
Comment on lines +842 to 846
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if could be in cgen, but the current docs say that FunctionBody takes a block:
https://github.com/inducer/cgen/blob/738bdb6ea330de63dc1e476025b8ca693b535443/cgen/__init__.py#L992-L1006


if not result:
return fbody
else:
Expand Down Expand Up @@ -1338,8 +1348,7 @@ def map_expression(self, expr):

def map_function_decl_wrapper(self, node):
self.decls.append(node.subdecl)
return super()\
.map_function_decl_wrapper(node)
return super().map_function_decl_wrapper(node)


def generate_header(kernel, codegen_result=None):
Expand Down
9 changes: 3 additions & 6 deletions loopy/target/opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
THE SOFTWARE.
"""

from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Literal, Sequence

import numpy as np

Expand Down Expand Up @@ -766,12 +766,9 @@ def get_constant_arg_declarator(self, arg: ConstantArg) -> Declarator:

def get_image_arg_declarator(
self, arg: ImageArg, is_written: bool) -> Declarator:
if is_written:
mode = "w"
else:
mode = "r"

from cgen.opencl import CLImage

mode: Literal["r", "w"] = "w" if is_written else "r"
return CLImage(arg.num_target_axes(), mode, arg.name)

# }}}
Expand Down
21 changes: 16 additions & 5 deletions loopy/target/pyopencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,9 +1026,12 @@ class PyOpenCLCASTBuilder(OpenCLCASTBuilder):
# {{{ function decl/def, with arg overflow handling

def get_function_definition(
self, codegen_state: CodeGenerationState,
self,
codegen_state: CodeGenerationState,
codegen_result: CodeGenerationResult,
schedule_index: int, function_decl: Generable, function_body: Generable,
schedule_index: int,
function_decl: Generable,
function_body: Generable,
) -> Generable:
assert isinstance(function_body, Block)
kernel = codegen_state.kernel
Expand Down Expand Up @@ -1057,15 +1060,17 @@ def get_function_definition(
tv.initializer is not None):
assert tv.read_only

decl: Generable = self.wrap_global_constant(
decl = self.wrap_global_constant(
self.get_temporary_var_declarator(codegen_state, tv))

if tv.initializer is not None:
from loopy.target.c import generate_array_literal
decl = Initializer(decl, generate_array_literal(
init_decl = Initializer(decl, generate_array_literal(
codegen_state, tv, tv.initializer))
else:
init_decl = decl

result.append(decl)
result.append(init_decl)

# {{{ unpack overflow args

Expand All @@ -1091,6 +1096,12 @@ def get_function_definition(

# }}}

from loopy.target.c import FunctionDeclarationWrapper

assert isinstance(function_decl, FunctionDeclarationWrapper)
if not isinstance(function_body, Block):
function_body = Block([function_body])

fbody = FunctionBody(function_decl, function_body)
if not result:
return fbody
Expand Down
11 changes: 3 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -168,17 +168,12 @@ module = [

[[tool.mypy.overrides]]
module = [
"IPython.*",
"fparser.*",
"islpy.*",
"pymbolic.*",
"genpy.*",
"pyopencl.*",
"colorama.*",
"codepy.*",
"mako.*",
"fparser.*",
"ply.*",
"pygments.*",
"IPython.*",
"pyopencl.*",
]
ignore_missing_imports = true

Expand Down
Loading