From 040aec90557a2ef649e8f79244a1aa0a91736922 Mon Sep 17 00:00:00 2001 From: Sagar Shelke Date: Tue, 14 Jan 2025 09:45:36 -0800 Subject: [PATCH] =?UTF-8?q?[lib/conversion]=20Create=20seed=20only=20if=20?= =?UTF-8?q?needed=20in=20`convert-torch-convers=E2=80=A6=20(#3926)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ion-to-mlprogram` pass This PR changes `convert-torch-conversion-to-mlprogram` pass implementation by moving seed generation inside `ConvertGetNextSeedOp` pattern. Previously, global seed was being created by this pass, even when its only consumer `torch_c.get_next_seed` op is not present in the IR. This pass is part of Torch->Linalg conversion pipeline. Always creating global seed created an issue for the case when downstream compiler doesn't expect/support `ml_program` dialect in linalg on tensor IR format. However, when starting torch IR has `torch_c.get_next_seed` op, `ml_program` will still be present and will need to be handled by downstream compilers. --- .../TorchConversionToMLProgram.cpp | 12 +++++++----- .../TorchConversionToMLProgram/basic.mlir | 13 +++++++++++++ .../multiple_functions.mlir | 2 +- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp index ddb6e5a5fdac..ddcfab78ac8f 100644 --- a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp +++ b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp @@ -59,6 +59,13 @@ class ConvertGetNextSeedOp : public OpConversionPattern { matchAndRewrite(GetNextSeedOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); + + // Check for global seed and create if it doesn't exist. + auto module = op->getParentOfType(); + OpBuilder b(module.getBodyRegion()); + if (failed(getOrCreateGlobalVariableForSeed(b, module))) + return failure(); + // Generate sequence for getting the next seed with LCG step: // nextSeed = (multiplier * currentSeed + incrementStep) mod 2^64. // Refer to https://en.wikipedia.org/wiki/Linear_congruential_generator. @@ -115,11 +122,6 @@ class ConvertTorchConversionToMLProgram typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); - auto module = getOperation(); - OpBuilder b(module.getBodyRegion()); - if (failed(getOrCreateGlobalVariableForSeed(b, module))) - signalPassFailure(); - RewritePatternSet patterns(context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/test/Conversion/TorchConversionToMLProgram/basic.mlir b/test/Conversion/TorchConversionToMLProgram/basic.mlir index c7fb38e1c5b0..262ada6f283d 100644 --- a/test/Conversion/TorchConversionToMLProgram/basic.mlir +++ b/test/Conversion/TorchConversionToMLProgram/basic.mlir @@ -17,3 +17,16 @@ module { return %seed : i64 } } + +// ----- + +module { + func.func @no_seed_needed(%arg0: tensor<2x3xf32>) -> !torch.vtensor<[2,3],f32> { + %0 = torch_c.from_builtin_tensor %arg0 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> + return %0 : !torch.vtensor<[2,3],f32> + } +} + +// CHECK-NOT: ml_program.global +// CHECK-LABEL: @no_seed_needed +// CHECK-NEXT: torch_c.from_builtin_tensor diff --git a/test/Conversion/TorchConversionToMLProgram/multiple_functions.mlir b/test/Conversion/TorchConversionToMLProgram/multiple_functions.mlir index 8ef04d95166e..da2424fc3ba2 100644 --- a/test/Conversion/TorchConversionToMLProgram/multiple_functions.mlir +++ b/test/Conversion/TorchConversionToMLProgram/multiple_functions.mlir @@ -11,5 +11,5 @@ module { func.func private @f7() -> i64 } -// CHECK: ml_program.global private mutable @global_seed(dense<0> : tensor) : tensor +// CHECK-NOT: ml_program.global private mutable @global_seed(dense<0> : tensor) : tensor // CHECK-NOT: @global_seed