Skip to content

Commit

Permalink
[lib/conversion] Create seed only if needed in `convert-torch-convers… (
Browse files Browse the repository at this point in the history
#3926)

…ion-to-mlprogram` pass

This PR changes `convert-torch-conversion-to-mlprogram` pass
implementation by moving seed generation inside `ConvertGetNextSeedOp`
pattern.
Previously, global seed was being created by this pass, even when its
only consumer `torch_c.get_next_seed` op is not present in the IR. This
pass is part of Torch->Linalg conversion pipeline. Always creating
global seed created an issue for the case when downstream compiler
doesn't expect/support `ml_program` dialect in linalg on tensor IR
format. However, when starting torch IR has `torch_c.get_next_seed` op,
`ml_program` will still be present and will need to be handled by
downstream compilers.
  • Loading branch information
shelkesagar29 authored Jan 14, 2025
1 parent 62eb38b commit 040aec9
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ class ConvertGetNextSeedOp : public OpConversionPattern<GetNextSeedOp> {
matchAndRewrite(GetNextSeedOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();

// Check for global seed and create if it doesn't exist.
auto module = op->getParentOfType<ModuleOp>();
OpBuilder b(module.getBodyRegion());
if (failed(getOrCreateGlobalVariableForSeed(b, module)))
return failure();

// Generate sequence for getting the next seed with LCG step:
// nextSeed = (multiplier * currentSeed + incrementStep) mod 2^64.
// Refer to https://en.wikipedia.org/wiki/Linear_congruential_generator.
Expand Down Expand Up @@ -115,11 +122,6 @@ class ConvertTorchConversionToMLProgram
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);

auto module = getOperation();
OpBuilder b(module.getBodyRegion());
if (failed(getOrCreateGlobalVariableForSeed(b, module)))
signalPassFailure();

RewritePatternSet patterns(context);
target.addIllegalOp<GetNextSeedOp>();
patterns.add<ConvertGetNextSeedOp>(typeConverter, context);
Expand Down
13 changes: 13 additions & 0 deletions test/Conversion/TorchConversionToMLProgram/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,16 @@ module {
return %seed : i64
}
}

// -----

module {
func.func @no_seed_needed(%arg0: tensor<2x3xf32>) -> !torch.vtensor<[2,3],f32> {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
return %0 : !torch.vtensor<[2,3],f32>
}
}

// CHECK-NOT: ml_program.global
// CHECK-LABEL: @no_seed_needed
// CHECK-NEXT: torch_c.from_builtin_tensor
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ module {
func.func private @f7() -> i64
}

// CHECK: ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
// CHECK-NOT: ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
// CHECK-NOT: @global_seed

0 comments on commit 040aec9

Please sign in to comment.