Skip to content

Commit

Permalink
Initial draft
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Jan 8, 2025
1 parent ada7871 commit 219b11f
Show file tree
Hide file tree
Showing 11 changed files with 794 additions and 0 deletions.
1 change: 1 addition & 0 deletions sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from enum import Enum

from . import scheduler # noqa: F401
from ._version import __version__, __version_tuple__ # noqa: F401

__array_api_version__ = "2022.12"
Expand Down
40 changes: 40 additions & 0 deletions sparse/scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from .finch_logic import (
Aggregate,
Alias,
Deferred,
Field,
Immediate,
MapJoin,
Plan,
Produces,
Query,
Reformat,
Relabel,
Reorder,
Subquery,
Table,
)
from .optimize import optimize, propagate_map_queries
from .rewrite_tools import PostOrderDFS, PostWalk, PreWalk

__all__ = [
"Aggregate",
"Alias",
"Deferred",
"Field",
"Immediate",
"MapJoin",
"Plan",
"Produces",
"Query",
"Reformat",
"Relabel",
"Reorder",
"Subquery",
"Table",
"optimize",
"propagate_map_queries",
"PostOrderDFS",
"PostWalk",
"PreWalk",
]
132 changes: 132 additions & 0 deletions sparse/scheduler/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from collections.abc import Hashable
from textwrap import dedent
from typing import Any

from .finch_logic import (
Alias,
Deferred,
Field,
Immediate,
LogicNode,
MapJoin,
Query,
Reformat,
Relabel,
Reorder,
Subquery,
Table,
)


def get_or_insert(dictionary: dict[Hashable, Any], key: Hashable, default: Any) -> Any:
if key in dictionary:
return dictionary[key]
dictionary[key] = default
return default

Check warning on line 25 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L22-L25

Added lines #L22 - L25 were not covered by tests


def get_structure(node: LogicNode, fields: dict[str, LogicNode], aliases: dict[str, LogicNode]) -> LogicNode:
match node:
case Field(name) | Alias(name):
return get_or_insert(fields, name, Immediate(len(fields) + len(aliases)))
case Subquery(Alias(name) as lhs, arg):
if name in aliases:
return aliases[name]
return Subquery(get_structure(lhs, fields, aliases), get_structure(arg, fields, aliases))
case Table(tns, idxs):
return Table(Immediate(type(tns.val)), tuple(get_structure(idx, fields, aliases) for idx in idxs))
case any if any.is_tree():
return any.from_arguments(*[get_structure(arg, fields, aliases) for arg in any.get_arguments()])
case _:
return node

Check warning on line 41 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L29-L41

Added lines #L29 - L41 were not covered by tests


class PointwiseLowerer:
def __init__(self):
self.bound_idxs = []

Check warning on line 46 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L46

Added line #L46 was not covered by tests

def __call__(self, ex):
match ex:
case MapJoin(Immediate(val), args):
return f":({val}({','.join([self(arg) for arg in args])}))"
case Reorder(Relabel(Alias(name), idxs_1), idxs_2):
self.bound_idxs.append(idxs_1)
return f":({name}[{','.join([idx.name if idx in idxs_2 else 1 for idx in idxs_1])}])"
case Reorder(Immediate(val), _):
return val
case Immediate(val):
return val
case _:
raise Exception(f"Unrecognized logic: {ex}")

Check warning on line 60 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L49-L60

Added lines #L49 - L60 were not covered by tests


def compile_pointwise_logic(ex: LogicNode) -> tuple:
ctx = PointwiseLowerer()
code = ctx(ex)
return (code, ctx.bound_idxs)

Check warning on line 66 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L64-L66

Added lines #L64 - L66 were not covered by tests


def compile_logic_constant(ex):
match ex:
case Immediate(val):
return val
case Deferred(ex, type_):
return f":({ex}::{type_})"
case _:
raise Exception(f"Invalid constant: {ex}")

Check warning on line 76 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L70-L76

Added lines #L70 - L76 were not covered by tests


def intersect(x1: tuple, x2: tuple) -> tuple:
return tuple(x for x in x1 if x in x2)

Check warning on line 80 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L80

Added line #L80 was not covered by tests


def with_subsequence(x1: tuple, x2: tuple) -> tuple:
res = list(x2)
indices = [idx for idx, val in enumerate(x2) if val in x1]
for idx, i in enumerate(indices):
res[i] = x1[idx]
return tuple(res)

Check warning on line 88 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L84-L88

Added lines #L84 - L88 were not covered by tests


class LogicLowerer:
def __init__(self, mode: str = "fast"):
self.mode = mode

Check warning on line 93 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L93

Added line #L93 was not covered by tests

def __call__(self, ex):
match ex:
case Query(Alias(name), Table(tns, _)):
return f":({name} = {compile_logic_constant(tns)})"

Check warning on line 98 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L96-L98

Added lines #L96 - L98 were not covered by tests

case Query(Alias(_) as lhs, Reformat(tns, Reorder(Relabel(Alias(_) as arg, idxs_1), idxs_2))):
loop_idxs = [idx.name for idx in with_subsequence(intersect(idxs_1, idxs_2), idxs_2)]
lhs_idxs = [idx.name for idx in idxs_2]
(rhs, rhs_idxs) = compile_pointwise_logic(Reorder(Relabel(arg, idxs_1), idxs_2))
body = f":({lhs.name}[{','.join(lhs_idxs)}] = {rhs})"
for idx in loop_idxs:
if Field(idx) in rhs_idxs:
body = f":(for {idx} = _ \n {body} end)"
elif idx in lhs_idxs:
body = f":(for {idx} = 1:1 \n {body} end)"

Check warning on line 109 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L100-L109

Added lines #L100 - L109 were not covered by tests

result = f"""\

Check warning on line 111 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L111

Added line #L111 was not covered by tests
quote
{lhs.name} = {compile_logic_constant(tns)}
@finch mode = {self.mode} begin
{lhs.name} .= {tns.fill_value}
{body}
return {lhs.name}
end
end
"""
return dedent(result)

Check warning on line 121 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L121

Added line #L121 was not covered by tests

# TODO: ...

case _:
raise Exception(f"Unrecognized logic: {ex}")

Check warning on line 126 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L125-L126

Added lines #L125 - L126 were not covered by tests


class LogicCompiler:
def __call__(self, prgm):
prgm = format_queries(prgm, True) # noqa: F821
return LogicLowerer()(prgm)

Check warning on line 132 in sparse/scheduler/compiler.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/compiler.py#L131-L132

Added lines #L131 - L132 were not covered by tests
23 changes: 23 additions & 0 deletions sparse/scheduler/executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from .compiler import LogicCompiler


class LogicExecutor:
def __init__(self, ctx, verbose=False):
self.ctx: LogicCompiler = ctx
self.codes = {}
self.verbose = verbose

def __call__(self, prgm):
prgm_structure = prgm
if prgm_structure not in self.codes:
thunk = logic_executor_code(self.ctx, prgm)
self.codes[prgm_structure] = eval(thunk), thunk

Check warning on line 14 in sparse/scheduler/executor.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/executor.py#L11-L14

Added lines #L11 - L14 were not covered by tests

f, code = self.codes[prgm_structure]
if self.verbose:
print(code)
return f(prgm)

Check warning on line 19 in sparse/scheduler/executor.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/executor.py#L16-L19

Added lines #L16 - L19 were not covered by tests


def logic_executor_code(ctx, prgm):
pass

Check warning on line 23 in sparse/scheduler/executor.py

View check run for this annotation

Codecov / codecov/patch

sparse/scheduler/executor.py#L23

Added line #L23 was not covered by tests
Loading

0 comments on commit 219b11f

Please sign in to comment.