diff --git a/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h b/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h --- a/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h +++ b/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h @@ -16,6 +16,7 @@ struct LogicalResult; class MLIRContext; class Value; +class Operation; class RewritePatternSet; using OwningRewritePatternList = RewritePatternSet; @@ -49,6 +50,9 @@ /// are not rewritten by the provided patterns are legal. void configureParallelLoopToGPULegality(ConversionTarget &target); +/// Clean up after applyPartialConversion/applyFullConversion call. +void finalizeParallelLoopToGPUConversion(Operation *op); + } // namespace mlir #endif // MLIR_CONVERSION_SCFTOGPU_SCFTOGPU_H_ diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -37,6 +37,19 @@ using namespace mlir; using namespace mlir::scf; +// Name of internal attribute to mark visited operations during conversion. +// The conversion uses the following legality criteria: +// * `!parallelOp->getAttr(gpu::getMappingAttrName())` +// But the provided pattern might reject some cases based on more detailed +// analysis of the `mapping` attribute. +// To avoid dialect conversion failure due to non-converted illegal operation +// we use this extra Unit attribute as a marker, that the operation was checked +// by the pattern and is should be considered as legal in the following legality +// checks. The `finalizeParallelLoopToGPUConversion` function performs clean up +// of this extra attributes ans is supposed to be called after the dialect +// conversion. +static constexpr StringLiteral kVisitedAttrName = "SCFToGPU_visited"; + // Extract an indexed value from KernelDim3. static Value getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) { switch (pos) { @@ -567,6 +580,9 @@ LogicalResult ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp, PatternRewriter &rewriter) const { + // Mark the operation as visited for recursive legality check. + parallelOp->setAttr(kVisitedAttrName, rewriter.getUnitAttr()); + // We can only transform starting at the outer-most loop. Launches inside of // parallel loops are not supported. if (auto parentLoop = parallelOp->getParentOfType()) @@ -649,6 +665,13 @@ void mlir::configureParallelLoopToGPULegality(ConversionTarget &target) { target.addLegalDialect(); target.addDynamicallyLegalOp([](scf::ParallelOp parallelOp) { - return !parallelOp->getAttr(gpu::getMappingAttrName()); + return !parallelOp->getAttr(gpu::getMappingAttrName()) || + parallelOp->getAttr(kVisitedAttrName); + }); +} + +void mlir::finalizeParallelLoopToGPUConversion(Operation *op) { + op->walk([](scf::ParallelOp parallelOp) { + parallelOp->removeAttr(kVisitedAttrName); }); } diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp @@ -55,6 +55,7 @@ if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); + finalizeParallelLoopToGPUConversion(getOperation()); } }; diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -175,6 +175,7 @@ target.addLegalDialect(); target.addDynamicallyLegalDialect( [&](Operation *op) { return typeConverter.isLegal(op); }); + target.addLegalOp(); target.addLegalDialect(); if (failed( diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1650,7 +1650,13 @@ bool OperationLegalizer::isIllegal(Operation *op) const { // Check if the target explicitly marked this operation as illegal. - return target.getOpAction(op->getName()) == LegalizationAction::Illegal; + if (auto info = target.getOpAction(op->getName())) { + if (*info == LegalizationAction::Dynamic) + return !target.isLegal(op); + return *info == LegalizationAction::Illegal; + } + + return false; } LogicalResult diff --git a/mlir/test/Transforms/test-legalizer-analysis.mlir b/mlir/test/Transforms/test-legalizer-analysis.mlir --- a/mlir/test/Transforms/test-legalizer-analysis.mlir +++ b/mlir/test/Transforms/test-legalizer-analysis.mlir @@ -1,5 +1,8 @@ // RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -verify-diagnostics -test-legalize-mode=analysis %s | FileCheck %s -// expected-remark@-2 {{op 'builtin.module' is legalizable}} + +// expected-remark@below {{applyAnalysisConversion succeeded}} +// expected-remark@below {{op 'builtin.module' is legalizable}} +builtin.module { // expected-remark@+1 {{op 'builtin.func' is legalizable}} func @test(%arg0: f32) { @@ -16,3 +19,5 @@ // CHECK-LABEL: func @test // CHECK-NEXT: "test.illegal_op_a" // CHECK: "test.invalid" + +} diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir --- a/mlir/test/Transforms/test-legalizer-full.mlir +++ b/mlir/test/Transforms/test-legalizer-full.mlir @@ -1,5 +1,8 @@ // RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -test-legalize-mode=full -split-input-file -verify-diagnostics %s | FileCheck %s +// expected-remark@+1 {{applyFullConversion succeeded}} +builtin.module { + // CHECK-LABEL: func @multi_level_mapping func @multi_level_mapping() { // CHECK: "test.type_producer"() : () -> f64 @@ -27,12 +30,17 @@ "test.return"() : () -> () } +} + // ----- +// expected-remark@+1 {{applyFullConversion succeeded}} +builtin.module { + // Test that children of recursively legal operations are ignored. func @recursively_legal_invalid_op() { /// Operation that is statically legal. - module attributes {test.recursively_legal} { + builtin.module attributes {test.recursively_legal} { %ignored = "test.illegal_op_f"() : () -> (i32) } /// Operation that is dynamically legal, i.e. the function has a pattern @@ -45,8 +53,13 @@ "test.return"() : () -> () } +} + // ----- +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { + // Test that region cloning can be properly undone. func @test_undo_region_clone() { "test.region"() ({ @@ -59,8 +72,13 @@ "test.return"() : () -> () } +} + // ----- +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { + // Test that unknown operations can be dynamically legal. func @test_unknown_dynamically_legal() { "foo.unknown_op"() {test.dynamically_legal} : () -> () @@ -70,8 +88,13 @@ "test.return"() : () -> () } +} + // ----- +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { + // Test that region inlining can be properly undone. func @test_undo_region_inline() { "test.region"() ({ @@ -85,8 +108,13 @@ "test.return"() : () -> () } +} + // ----- +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { + // Test that multiple block erases can be properly undone. func @test_undo_block_erase() { // expected-error@+1 {{failed to legalize operation 'test.region'}} @@ -99,3 +127,18 @@ "test.return"() : () -> () } + +} + +// ----- + +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { + +func @create_unregistered_op_in_pattern() -> i32 { + // expected-error@+1 {{failed to legalize operation 'test.illegal_op_g'}} + %0 = "test.illegal_op_g"() : () -> (i32) + "test.return"(%0) : (i32) -> () +} + +} diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -1,5 +1,8 @@ // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns -verify-diagnostics %s | FileCheck %s +// expected-remark@+1 {{applyPartialConversion succeeded}} +builtin.module { + // CHECK-LABEL: verifyDirectPattern func @verifyDirectPattern() -> i32 { // CHECK-NEXT: "test.legal_op_a"() {status = "Success"} @@ -171,30 +174,44 @@ return } +} + // ----- +// expected-remark@+1 {{applyPartialConversion failed}} +builtin.module { + func @fail_to_convert_illegal_op() -> i32 { // expected-error@+1 {{failed to legalize operation 'test.illegal_op_f'}} %result = "test.illegal_op_f"() : () -> (i32) return %result : i32 } +} + // ----- +// expected-remark@+1 {{applyPartialConversion failed}} +builtin.module { + func @fail_to_convert_illegal_op_in_region() { // expected-error@+1 {{failed to legalize operation 'test.region_builder'}} "test.region_builder"() : () -> () return } +} + // ----- // Check that the entry block arguments of a region are untouched in the case // of failure. -// CHECK-LABEL: func @fail_to_convert_region +// expected-remark@+1 {{applyPartialConversion failed}} +builtin.module { + func @fail_to_convert_region() { - // CHECK-NEXT: "test.region" + // CHECK: "test.region" // CHECK-NEXT: ^bb{{.*}}(%{{.*}}: i64): "test.region"() ({ ^bb1(%i0: i64): @@ -205,8 +222,13 @@ return } +} + // ----- +// expected-remark@+1 {{applyPartialConversion succeeded}} +builtin.module { + // CHECK-LABEL: @create_illegal_block func @create_illegal_block() { // Check that we can undo block creation, i.e. that the block was removed. @@ -219,8 +241,13 @@ return } +} + // ----- +// expected-remark@+1 {{applyPartialConversion succeeded}} +builtin.module { + // CHECK-LABEL: @undo_block_arg_replace func @undo_block_arg_replace() { // expected-remark@+1 {{op 'test.undo_block_arg_replace' is not legalizable}} @@ -235,8 +262,13 @@ return } +} + // ----- +// expected-remark@+1 {{applyPartialConversion succeeded}} +builtin.module { + // The op in this function is rewritten to itself (and thus remains illegal) by // a pattern that removes its second block after adding an operation into it. // Check that we can undo block removal successfully. @@ -256,8 +288,13 @@ }) : () -> () } +} + // ----- +// expected-remark@+1 {{applyPartialConversion succeeded}} +builtin.module { + // The op in this function is attempted to be rewritten to another illegal op // with an attached region containing an invalid terminator. The terminator is // created before the parent op. The deletion should not crash when deleting @@ -271,9 +308,12 @@ return } +} // ----- +// expected-remark@+1 {{applyPartialConversion succeeded}} +builtin.module { // Check that a conversion pattern on `test.blackhole` can mark the producer // for deletion. @@ -284,3 +324,18 @@ // expected-remark@+1 {{op 'std.return' is not legalizable}} return } + +} + +// ----- + +// expected-remark@+1 {{applyPartialConversion failed}} +builtin.module { + +func @create_unregistered_op_in_pattern() -> i32 { + // expected-error@+1 {{failed to legalize operation 'test.illegal_op_g'}} + %0 = "test.illegal_op_g"() : () -> (i32) + "test.return"(%0) : (i32) -> () +} + +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1317,9 +1317,12 @@ def ILLegalOpD : TEST_Op<"illegal_op_d">, Results<(outs I32)>; def ILLegalOpE : TEST_Op<"illegal_op_e">, Results<(outs I32)>; def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32)>; +def ILLegalOpG : TEST_Op<"illegal_op_g">, Results<(outs I32)>; def LegalOpA : TEST_Op<"legal_op_a">, Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32)>; def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>; +def LegalOpC : TEST_Op<"legal_op_c">, + Arguments<(ins I32)>, Results<(outs I32)>; // Check that the conversion infrastructure can properly undo the creation of // operations where an operation was created before its parent, in this case, diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -537,6 +537,20 @@ return success(); }; }; + +// This pattern replaces explicitly illegal op with explicitly legal op, +// but in addition creates unregistered operation. +struct TestCreateUnregisteredOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ILLegalOpG op, + PatternRewriter &rewriter) const final { + IntegerAttr attr = rewriter.getI32IntegerAttr(0); + Value val = rewriter.create(op->getLoc(), attr); + rewriter.replaceOpWithNewOp(op, val); + return success(); + }; +}; } // namespace namespace { @@ -607,6 +621,10 @@ TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { TestTypeConverter converter; mlir::RewritePatternSet patterns(&getContext()); @@ -618,8 +636,8 @@ TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, TestNonRootReplacement, TestBoundedRecursiveRewrite, - TestNestedOpCreationUndoRewrite, TestReplaceEraseOp>( - &getContext()); + TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, + TestCreateUnregisteredOp>(&getContext()); patterns.add(&getContext(), converter); mlir::populateFuncOpTypeConversionPattern(patterns, converter); mlir::populateCallOpTypeConversionPattern(patterns, converter); @@ -627,7 +645,7 @@ // Define the conversion target used for the test. ConversionTarget target(getContext()); target.addLegalOp(); - target.addLegalOp(); target .addIllegalOp(); @@ -641,6 +659,11 @@ converter.isLegal(&op.getBody()); }); + // TestCreateUnregisteredOp creates `std.constant` operation, + // which was not added to target intentionally to test + // correct error code from conversion driver. + target.addDynamicallyLegalOp([](ILLegalOpG) { return false; }); + // Expect the type_producer/type_consumer operations to only operate on f64. target.addDynamicallyLegalOp( [](TestTypeProducerOp op) { return op.getType().isF64(); }); @@ -661,8 +684,12 @@ // Handle a partial conversion. if (mode == ConversionMode::Partial) { DenseSet unlegalizedOps; - (void)applyPartialConversion(getOperation(), target, std::move(patterns), - &unlegalizedOps); + if (succeeded(applyPartialConversion( + getOperation(), target, std::move(patterns), &unlegalizedOps))) { + getOperation()->emitRemark() << "applyPartialConversion succeeded"; + } else { + getOperation()->emitRemark() << "applyPartialConversion failed"; + } // Emit remarks for each legalizable operation. for (auto *op : unlegalizedOps) op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; @@ -676,7 +703,12 @@ return (bool)op->getAttrOfType("test.dynamically_legal"); }); - (void)applyFullConversion(getOperation(), target, std::move(patterns)); + if (succeeded(applyFullConversion(getOperation(), target, + std::move(patterns)))) { + getOperation()->emitRemark() << "applyFullConversion succeeded"; + } else { + getOperation()->emitRemark() << "applyFullConversion failed"; + } return; } @@ -685,9 +717,13 @@ // Analyze the convertible operations. DenseSet legalizedOps; - if (failed(applyAnalysisConversion(getOperation(), target, - std::move(patterns), legalizedOps))) + if (succeeded(applyAnalysisConversion(getOperation(), target, + std::move(patterns), legalizedOps))) { + getOperation()->emitRemark() << "applyAnalysisConversion succeeded"; + } else { + getOperation()->emitRemark() << "applyAnalysisConversion failed"; return signalPassFailure(); + } // Emit remarks for each legalizable operation. for (auto *op : legalizedOps)