-
-
Notifications
You must be signed in to change notification settings - Fork 130
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
794 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from .finch_logic import ( | ||
Aggregate, | ||
Alias, | ||
Deferred, | ||
Field, | ||
Immediate, | ||
MapJoin, | ||
Plan, | ||
Produces, | ||
Query, | ||
Reformat, | ||
Relabel, | ||
Reorder, | ||
Subquery, | ||
Table, | ||
) | ||
from .optimize import optimize, propagate_map_queries | ||
from .rewrite_tools import PostOrderDFS, PostWalk, PreWalk | ||
|
||
__all__ = [ | ||
"Aggregate", | ||
"Alias", | ||
"Deferred", | ||
"Field", | ||
"Immediate", | ||
"MapJoin", | ||
"Plan", | ||
"Produces", | ||
"Query", | ||
"Reformat", | ||
"Relabel", | ||
"Reorder", | ||
"Subquery", | ||
"Table", | ||
"optimize", | ||
"propagate_map_queries", | ||
"PostOrderDFS", | ||
"PostWalk", | ||
"PreWalk", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
from collections.abc import Hashable | ||
from textwrap import dedent | ||
from typing import Any | ||
|
||
from .finch_logic import ( | ||
Alias, | ||
Deferred, | ||
Field, | ||
Immediate, | ||
LogicNode, | ||
MapJoin, | ||
Query, | ||
Reformat, | ||
Relabel, | ||
Reorder, | ||
Subquery, | ||
Table, | ||
) | ||
|
||
|
||
def get_or_insert(dictionary: dict[Hashable, Any], key: Hashable, default: Any) -> Any: | ||
if key in dictionary: | ||
return dictionary[key] | ||
dictionary[key] = default | ||
return default | ||
|
||
|
||
def get_structure(node: LogicNode, fields: dict[str, LogicNode], aliases: dict[str, LogicNode]) -> LogicNode: | ||
match node: | ||
case Field(name) | Alias(name): | ||
return get_or_insert(fields, name, Immediate(len(fields) + len(aliases))) | ||
case Subquery(Alias(name) as lhs, arg): | ||
if name in aliases: | ||
return aliases[name] | ||
return Subquery(get_structure(lhs, fields, aliases), get_structure(arg, fields, aliases)) | ||
case Table(tns, idxs): | ||
return Table(Immediate(type(tns.val)), tuple(get_structure(idx, fields, aliases) for idx in idxs)) | ||
case any if any.is_tree(): | ||
return any.from_arguments(*[get_structure(arg, fields, aliases) for arg in any.get_arguments()]) | ||
case _: | ||
return node | ||
|
||
|
||
class PointwiseLowerer: | ||
def __init__(self): | ||
self.bound_idxs = [] | ||
|
||
def __call__(self, ex): | ||
match ex: | ||
case MapJoin(Immediate(val), args): | ||
return f":({val}({','.join([self(arg) for arg in args])}))" | ||
case Reorder(Relabel(Alias(name), idxs_1), idxs_2): | ||
self.bound_idxs.append(idxs_1) | ||
return f":({name}[{','.join([idx.name if idx in idxs_2 else 1 for idx in idxs_1])}])" | ||
case Reorder(Immediate(val), _): | ||
return val | ||
case Immediate(val): | ||
return val | ||
case _: | ||
raise Exception(f"Unrecognized logic: {ex}") | ||
|
||
|
||
def compile_pointwise_logic(ex: LogicNode) -> tuple: | ||
ctx = PointwiseLowerer() | ||
code = ctx(ex) | ||
return (code, ctx.bound_idxs) | ||
|
||
|
||
def compile_logic_constant(ex): | ||
match ex: | ||
case Immediate(val): | ||
return val | ||
case Deferred(ex, type_): | ||
return f":({ex}::{type_})" | ||
case _: | ||
raise Exception(f"Invalid constant: {ex}") | ||
|
||
|
||
def intersect(x1: tuple, x2: tuple) -> tuple: | ||
return tuple(x for x in x1 if x in x2) | ||
|
||
|
||
def with_subsequence(x1: tuple, x2: tuple) -> tuple: | ||
res = list(x2) | ||
indices = [idx for idx, val in enumerate(x2) if val in x1] | ||
for idx, i in enumerate(indices): | ||
res[i] = x1[idx] | ||
return tuple(res) | ||
|
||
|
||
class LogicLowerer: | ||
def __init__(self, mode: str = "fast"): | ||
self.mode = mode | ||
|
||
def __call__(self, ex): | ||
match ex: | ||
case Query(Alias(name), Table(tns, _)): | ||
return f":({name} = {compile_logic_constant(tns)})" | ||
|
||
case Query(Alias(_) as lhs, Reformat(tns, Reorder(Relabel(Alias(_) as arg, idxs_1), idxs_2))): | ||
loop_idxs = [idx.name for idx in with_subsequence(intersect(idxs_1, idxs_2), idxs_2)] | ||
lhs_idxs = [idx.name for idx in idxs_2] | ||
(rhs, rhs_idxs) = compile_pointwise_logic(Reorder(Relabel(arg, idxs_1), idxs_2)) | ||
body = f":({lhs.name}[{','.join(lhs_idxs)}] = {rhs})" | ||
for idx in loop_idxs: | ||
if Field(idx) in rhs_idxs: | ||
body = f":(for {idx} = _ \n {body} end)" | ||
elif idx in lhs_idxs: | ||
body = f":(for {idx} = 1:1 \n {body} end)" | ||
|
||
result = f"""\ | ||
quote | ||
{lhs.name} = {compile_logic_constant(tns)} | ||
@finch mode = {self.mode} begin | ||
{lhs.name} .= {tns.fill_value} | ||
{body} | ||
return {lhs.name} | ||
end | ||
end | ||
""" | ||
return dedent(result) | ||
|
||
# TODO: ... | ||
|
||
case _: | ||
raise Exception(f"Unrecognized logic: {ex}") | ||
|
||
|
||
class LogicCompiler: | ||
def __call__(self, prgm): | ||
prgm = format_queries(prgm, True) # noqa: F821 | ||
return LogicLowerer()(prgm) | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from .compiler import LogicCompiler | ||
|
||
|
||
class LogicExecutor: | ||
def __init__(self, ctx, verbose=False): | ||
self.ctx: LogicCompiler = ctx | ||
self.codes = {} | ||
self.verbose = verbose | ||
|
||
def __call__(self, prgm): | ||
prgm_structure = prgm | ||
if prgm_structure not in self.codes: | ||
thunk = logic_executor_code(self.ctx, prgm) | ||
self.codes[prgm_structure] = eval(thunk), thunk | ||
|
||
f, code = self.codes[prgm_structure] | ||
if self.verbose: | ||
print(code) | ||
return f(prgm) | ||
|
||
|
||
def logic_executor_code(ctx, prgm): | ||
pass | ||
Oops, something went wrong.