Skip to content

Commit

Permalink
Update abstract_interp_lib_gen.py
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 committed Oct 15, 2024
1 parent a1f5a69 commit 9652494
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 79 deletions.
51 changes: 9 additions & 42 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13813,63 +13813,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %5 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.lerp.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %0#0, %1#0, %none : (!torch.int, !torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %3 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n"
" %4 = torch.prim.ListConstruct %0#1, %1#1, %3 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %4) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %5 : !torch.int\n"
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int11 = torch.constant.int 11\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %3 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.aten.ne.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = torch.aten.ne.int %2#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %5 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %6 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %7 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %8 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %8 : !torch.int\n"
" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %5 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.addcdiv\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%5) : (!torch.int) -> !torch.bool\n"
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %5 : !torch.int\n"
" }\n"
" return %7 : !torch.int\n"
" return %5 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.add.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4349,18 +4349,7 @@ def aten〇addmm〡dtype(self_rank_dtype: Tuple[int, int], mat1_rank_dtype: Tupl
return promote_dtypes(ranks, dtypes)

@check_dtype_function(
# _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) +
# Different width
[Invocation(TensorOfShape(4, 3, dtype=torch.float32),
TensorOfShape(4, 3, dtype=torch.float64),
TensorOfShape(4, 3, dtype=torch.float32)),
# Different type
Invocation(TensorOfShape(4, 3, dtype=torch.float32),
TensorOfShape(4, 3, dtype=torch.float32),
TensorOfShape(4, 3, dtype=torch.int32)),
Invocation(TensorOfShape(4, 3, dtype=torch.int32),
TensorOfShape(4, 3, dtype=torch.float32),
TensorOfShape(4, 3, dtype=torch.float32))])
_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]))
def aten〇lerp〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
end_rank, end_dtype = end_rank_dtype
Expand All @@ -4371,28 +4360,17 @@ def aten〇lerp〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtyp
return promote_dtypes(ranks, dtypes)

@check_dtype_function(
_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1)], weight=0.5) +
# Different width
[Invocation(TensorOfShape(4, 3, dtype=torch.float32),
TensorOfShape(4, 3, dtype=torch.float64),
weight=0.5),
# Different type
Invocation(TensorOfShape(4, 3, dtype=torch.int32),
TensorOfShape(4, 3, dtype=torch.float32),
weight=0.5),
Invocation(TensorOfShape(4, 3, dtype=torch.float32),
TensorOfShape(4, 3, dtype=torch.float32),
weight=2)])
_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1)], weight=0.5))
def aten〇lerp〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtype: Tuple[int, int], weight: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
end_rank, end_dtype = end_rank_dtype

ranks: List[Optional[int]] = [self_rank, end_rank, None]
dtypes = [self_dtype, end_dtype, get_dtype_of_scalar(weight)]
ranks: List[Optional[int]] = [self_rank, end_rank]
dtypes = [self_dtype, end_dtype]
return promote_dtypes(ranks, dtypes)

@check_dtype_function(
_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)], error_types={torch.bool}) +
_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) +
# Different width
[Invocation(TensorOfShape(3, 3, dtype=torch.float32),
TensorOfShape(3, 3, dtype=torch.float64),
Expand All @@ -4409,16 +4387,11 @@ def aten〇addcmul〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype:
tensor1_rank, tensor1_dtype = tensor1_rank_dtype
tensor2_rank, tensor2_dtype = tensor2_rank_dtype

assert self_dtype != torch.bool
assert tensor1_dtype != torch.bool
assert tensor2_dtype != torch.bool

ranks: List[Optional[int]] = [self_rank, tensor1_rank, tensor2_rank]
dtypes = [self_dtype, tensor1_dtype, tensor2_dtype]
return promote_dtypes(ranks, dtypes)

@check_dtype_function(
_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) +
# Different width
[Invocation(TensorOfShape(3, 3, dtype=torch.float32),
TensorOfShape(3, 3, dtype=torch.float64),
Expand All @@ -4438,8 +4411,6 @@ def aten〇addcdiv〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype:
ranks: List[Optional[int]] = [self_rank, tensor1_rank, tensor2_rank]
dtypes = [self_dtype, tensor1_dtype, tensor2_dtype]
result = promote_dtypes(ranks, dtypes)
if is_integer_dtype(result):
return torch.float32
return result

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) +
Expand Down
2 changes: 1 addition & 1 deletion pytorch-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2fcfb44b7e518b35c3de74bdd85fe7c836c81d4b
ec8499a174317b85b6c6fe98eb99a266b590cef8
2 changes: 1 addition & 1 deletion pytorch-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch/
--pre
torch==2.6.0.dev20241014
torch==2.6.0.dev20241015
2 changes: 1 addition & 1 deletion torchvision-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torchvision/
--pre
torchvision==0.20.0.dev20241014
torchvision==0.20.0.dev20241015

0 comments on commit 9652494

Please sign in to comment.