Skip to content

Commit

Permalink
Fix fx_importer tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 committed Oct 16, 2024
1 parent 64f965a commit 3319f4d
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 18 deletions.
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6317,6 +6317,30 @@ def Torch_AtenDotOp : Torch_Op<"aten.dot", [
let hasCanonicalizer = 1;
}

def Torch_AtenOuterOp : Torch_Op<"aten.outer", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::outer : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$vec2
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenOuterOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenOuterOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenCosineSimilarityOp : Torch_Op<"aten.cosine_similarity", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
15 changes: 15 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7601,6 +7601,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" } : (!torch.int, !torch.bool) -> ()\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.outer\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %1 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %2 = torch.prim.ListConstruct %0, %1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.dot\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -13403,6 +13410,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.outer\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.mm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %false = torch.constant.bool false\n"
" %int5 = torch.constant.int 5\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,9 @@ def aten〇numpy_T〡shape(self: List[int]) -> List[int]:
result_shape.insert(0, i)
return result_shape

def aten〇outer〡shape(self: List[int], vec2: List[int]) -> List[int]:
return [self[0], vec2[0]]

@check_shape_function([Invocation(TensorOfShape(3), TensorOfShape(3))])
def aten〇dot〡shape(self: List[int], tensor: List[int]) -> List[int]:
return []
Expand Down Expand Up @@ -4025,6 +4028,14 @@ def aten〇fmin〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tupl
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(3,), (4,)]))
def aten〇outer〡dtype(self_rank_dtype: Tuple[int, int], vec2_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
vec2_rank, vec2_dtype = vec2_rank_dtype
ranks: List[Optional[int]] = [self_rank, vec2_rank]
dtypes = [self_dtype, vec2_dtype]
return promote_dtypes(ranks, dtypes)

@check_dtype_function(
_check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4), (4, 3)]) +
# Different width
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::matmul : (Tensor, Tensor) -> (Tensor)")
emit("aten::mv : (Tensor, Tensor) -> (Tensor)")
emit("aten::dot : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
emit("aten::outer : (Tensor, Tensor) -> (Tensor)")
emit("aten::cosine_similarity : (Tensor, Tensor, int, float) -> (Tensor)")
emit(
"aten::conv3d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)"
Expand Down
27 changes: 13 additions & 14 deletions test/python/fx_importer/symbolic_shape_expr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,20 +125,19 @@ def forward(self, x, y):


@run
# TODO: Enable these checks once the IR generated is same for both nightly and stable Torch version.
# C_HECK-LABEL: test_outer_with_squared_shape
# C_HECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
# C_HECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
# C_HECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
# C_HECK: %[[I0:.+]] = torch.constant.int 0
# C_HECK: %[[SIZE:.+]] = torch.aten.size.int %[[ARG0]], %[[I0]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.int
# C_HECK: %[[OUTER:.+]] = torch.operator "torch.aten.outer"(%[[ARG0]], %[[ARG0]]) : (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32>
# C_HECK: torch.bind_symbolic_shape %[[OUTER]], [%[[S0]]], affine_map<()[s0] -> (s0, s0)> : !torch.vtensor<[?,?],f32>
# C_HECK: %[[MUL:.+]] = torch.aten.mul.int %[[SIZE]], %[[SIZE]] : !torch.int, !torch.int -> !torch.int
# C_HECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[MUL]] : (!torch.int) -> !torch.list<int>
# C_HECK: %[[VIEW:.+]] = torch.aten.view %[[OUTER]], %[[LIST]] : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
# C_HECK: torch.bind_symbolic_shape %[[VIEW]], [%[[S0]]], affine_map<()[s0] -> (s0 * s0)> : !torch.vtensor<[?],f32>
# C_HECK: return %[[VIEW]] : !torch.vtensor<[?],f32>
# CHECK-LABEL: test_outer_with_squared_shape
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
# CHECK: %[[I0:.+]] = torch.constant.int 0
# CHECK: %[[SIZE:.+]] = torch.aten.size.int %[[ARG0]], %[[I0]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.int
# COM: %[[OUTER:.+]] = torch.aten.outer %[[ARG0]], %[[ARG0]] : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32>
# CHECK: torch.bind_symbolic_shape %{{.*}}, [%[[S0]]], affine_map<()[s0] -> (s0, s0)> : !torch.vtensor<[?,?],f32>
# CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[SIZE]], %[[SIZE]] : !torch.int, !torch.int -> !torch.int
# CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[MUL]] : (!torch.int) -> !torch.list<int>
# CHECK: %[[VIEW:.+]] = torch.aten.view %{{.*}}, %[[LIST]] : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
# CHECK: torch.bind_symbolic_shape %[[VIEW]], [%[[S0]]], affine_map<()[s0] -> (s0 * s0)> : !torch.vtensor<[?],f32>
# CHECK: return %[[VIEW]] : !torch.vtensor<[?],f32>
def test_outer_with_squared_shape():
class OuterWithSquaredShape(torch.nn.Module):
def __init__(self):
Expand Down
6 changes: 2 additions & 4 deletions test/python/fx_importer/v2.3/mutation_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,9 @@ def forward(self, x):
# CHECK: func.func @main(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.tensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
# CHECK-DAG: %[[arg1_copy:.+]] = torch.copy.to_vtensor %arg1 : !torch.vtensor<[3,4],f32>
# CHECK-DAG: %[[arg1_mul:.+]] = torch.aten.mul.Tensor %[[arg1_copy]], %arg0
# COM: %{{.*}} = torch.aten.copy %[[arg1_copy]], %[[arg1_mul]], %false : !torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>, !torch.bool -> !torch.vtensor<[3,4],f32>
# CHECK-DAG: torch.overwrite.tensor.contents %{{.*}} overwrites %arg1
# CHECK-DAG: %[[arg0_mul:.+]] = torch.aten.mul.Tensor %arg0, %[[arg1_mul]]
# TODO: Enable these checks once the IR generated is same for both nightly and stable Torch version.
# C_HECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
# C_HECK-DAG: %[[COPY:.+]] = torch.aten.copy %[[arg1_copy]], %[[arg1_mul]], %[[FALSE]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>, !torch.bool -> !torch.vtensor<[3,4],f32>
# C_HECK-DAG: torch.overwrite.tensor.contents %[[COPY]] overwrites %arg1
# CHECK: return %[[arg0_mul]]
def test_user_input_mutate():
class Basic(nn.Module):
Expand Down

0 comments on commit 3319f4d

Please sign in to comment.