diff --git a/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h b/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h --- a/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h +++ b/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h @@ -16,6 +16,7 @@ struct LogicalResult; class MLIRContext; class Value; +class Operation; class RewritePatternSet; using OwningRewritePatternList = RewritePatternSet; @@ -49,6 +50,9 @@ /// are not rewritten by the provided patterns are legal. void configureParallelLoopToGPULegality(ConversionTarget &target); +/// Clean up after applyPartialConversion/applyFullConversion call. +void finalizeParallelLoopToGPUConversion(Operation *op); + } // namespace mlir #endif // MLIR_CONVERSION_SCFTOGPU_SCFTOGPU_H_ diff --git a/mlir/include/mlir/Dialect/GPU/ParallelLoopMapper.h b/mlir/include/mlir/Dialect/GPU/ParallelLoopMapper.h --- a/mlir/include/mlir/Dialect/GPU/ParallelLoopMapper.h +++ b/mlir/include/mlir/Dialect/GPU/ParallelLoopMapper.h @@ -41,6 +41,9 @@ /// Name of the mapping attribute produced by loop mappers. StringRef getMappingAttrName(); +/// Name of internal attribute to mark visited operations during conversion. +StringRef getVisitedAttrName(); + /// Get the value of the processor in the ParallelLoopDimMapping attribute. inline Processor getProcessor(ParallelLoopDimMapping attr) { return static_cast(attr.processor().getInt()); diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -567,6 +567,9 @@ LogicalResult ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp, PatternRewriter &rewriter) const { + // Mark the operation as visited for recursive legality check. + parallelOp->setAttr(gpu::getVisitedAttrName(), UnitAttr::get(getContext())); + // We can only transform starting at the outer-most loop. Launches inside of // parallel loops are not supported. if (auto parentLoop = parallelOp->getParentOfType()) @@ -649,6 +652,13 @@ void mlir::configureParallelLoopToGPULegality(ConversionTarget &target) { target.addLegalDialect(); target.addDynamicallyLegalOp([](scf::ParallelOp parallelOp) { - return !parallelOp->getAttr(gpu::getMappingAttrName()); + return !parallelOp->getAttr(gpu::getMappingAttrName()) || + parallelOp->getAttr(gpu::getVisitedAttrName()); + }); +} + +void mlir::finalizeParallelLoopToGPUConversion(Operation *op) { + op->walk([](scf::ParallelOp parallelOp) { + parallelOp->removeAttr(gpu::getVisitedAttrName()); }); } diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp @@ -55,6 +55,7 @@ if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); + finalizeParallelLoopToGPUConversion(getOperation()); } }; diff --git a/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp b/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp --- a/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp @@ -29,6 +29,7 @@ namespace gpu { StringRef getMappingAttrName() { return "mapping"; } +StringRef getVisitedAttrName() { return "GPU._visited_"; } ParallelLoopDimMapping getParallelLoopDimMappingAttr(Processor processor, AffineMap map, diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -175,6 +175,7 @@ target.addLegalDialect(); target.addDynamicallyLegalDialect( [&](Operation *op) { return typeConverter.isLegal(op); }); + target.addLegalOp(); target.addLegalDialect(); if (failed( diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1544,7 +1544,7 @@ const FrozenRewritePatternSet &patterns); /// Returns true if the given operation is known to be illegal on the target. - bool isIllegal(Operation *op) const; + bool isKnownIllegal(Operation *op) const; /// Attempt to legalize the given operation. Returns success if the operation /// was legalized, failure otherwise. @@ -1648,9 +1648,15 @@ computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns); } -bool OperationLegalizer::isIllegal(Operation *op) const { +bool OperationLegalizer::isKnownIllegal(Operation *op) const { // Check if the target explicitly marked this operation as illegal. - return target.getOpAction(op->getName()) == LegalizationAction::Illegal; + if (auto info = target.getOpAction(op->getName())) { + if (*info == LegalizationAction::Dynamic) + return !target.isLegal(op); + return *info == LegalizationAction::Illegal; + } + + return false; } LogicalResult @@ -2222,7 +2228,7 @@ // explicitly marked as illegal. If the user provided a nonlegalizableOps // set, non-legalizable ops are included. if (mode == OpConversionMode::Partial) { - if (opLegalizer.isIllegal(op)) + if (opLegalizer.isKnownIllegal(op)) return op->emitError() << "failed to legalize operation '" << op->getName() << "' that was explicitly marked illegal"; diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir --- a/mlir/test/Transforms/test-legalizer-full.mlir +++ b/mlir/test/Transforms/test-legalizer-full.mlir @@ -47,6 +47,9 @@ // ----- +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { + // Test that region cloning can be properly undone. func @test_undo_region_clone() { "test.region"() ({ @@ -59,8 +62,13 @@ "test.return"() : () -> () } +} + // ----- +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { + // Test that unknown operations can be dynamically legal. func @test_unknown_dynamically_legal() { "foo.unknown_op"() {test.dynamically_legal} : () -> () @@ -70,8 +78,13 @@ "test.return"() : () -> () } +} + // ----- +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { + // Test that region inlining can be properly undone. func @test_undo_region_inline() { "test.region"() ({ @@ -85,8 +98,13 @@ "test.return"() : () -> () } +} + // ----- +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { + // Test that multiple block erases can be properly undone. func @test_undo_block_erase() { // expected-error@+1 {{failed to legalize operation 'test.region'}} @@ -99,3 +117,18 @@ "test.return"() : () -> () } + +} + +// ----- + +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { + +func @create_unregistered_op_in_pattern() -> i32 { + // expected-error@+1 {{failed to legalize operation 'test.illegal_op_g'}} + %0 = "test.illegal_op_g"() : () -> (i32) + "test.return"(%0) : (i32) -> () +} + +} diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -173,28 +173,40 @@ // ----- +// expected-remark@+1 {{applyPartialConversion failed}} +builtin.module { + func @fail_to_convert_illegal_op() -> i32 { // expected-error@+1 {{failed to legalize operation 'test.illegal_op_f'}} %result = "test.illegal_op_f"() : () -> (i32) return %result : i32 } +} + // ----- +// expected-remark@+1 {{applyPartialConversion failed}} +builtin.module { + func @fail_to_convert_illegal_op_in_region() { // expected-error@+1 {{failed to legalize operation 'test.region_builder'}} "test.region_builder"() : () -> () return } +} + // ----- // Check that the entry block arguments of a region are untouched in the case // of failure. -// CHECK-LABEL: func @fail_to_convert_region +// expected-remark@+1 {{applyPartialConversion failed}} +builtin.module { + func @fail_to_convert_region() { - // CHECK-NEXT: "test.region" + // CHECK: "test.region" // CHECK-NEXT: ^bb{{.*}}(%{{.*}}: i64): "test.region"() ({ ^bb1(%i0: i64): @@ -205,6 +217,8 @@ return } +} + // ----- // CHECK-LABEL: @create_illegal_block @@ -271,10 +285,8 @@ return } - // ----- - // Check that a conversion pattern on `test.blackhole` can mark the producer // for deletion. // CHECK-LABEL: @blackhole @@ -284,3 +296,16 @@ // expected-remark@+1 {{op 'std.return' is not legalizable}} return } + +// ----- + +// expected-remark@+1 {{applyPartialConversion failed}} +builtin.module { + +func @create_unregistered_op_in_pattern() -> i32 { + // expected-error@+1 {{failed to legalize operation 'test.illegal_op_g'}} + %0 = "test.illegal_op_g"() : () -> (i32) + "test.return"(%0) : (i32) -> () +} + +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1317,9 +1317,13 @@ def ILLegalOpD : TEST_Op<"illegal_op_d">, Results<(outs I32)>; def ILLegalOpE : TEST_Op<"illegal_op_e">, Results<(outs I32)>; def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32)>; +def ILLegalOpG : TEST_Op<"illegal_op_g">, Results<(outs I32)>; def LegalOpA : TEST_Op<"legal_op_a">, Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32)>; def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>; +def LegalOpC : TEST_Op<"legal_op_c">, + Arguments<(ins I32)>, Results<(outs I32)>; +def UnregisteredOp : TEST_Op<"unregistered_op">, Results<(outs I32)>; // Check that the conversion infrastructure can properly undo the creation of // operations where an operation was created before its parent, in this case, diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -537,6 +537,19 @@ return success(); }; }; + +// This pattern replaces explicitly illegal op with explicitly legal op, +// but in addition creates unregistered operation. +struct TestCreateUnregisteredOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ILLegalOpG op, + PatternRewriter &rewriter) const final { + Value val = rewriter.create(op->getLoc()); + rewriter.replaceOpWithNewOp(op, val); + return success(); + }; +}; } // namespace namespace { @@ -618,8 +631,8 @@ TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, TestNonRootReplacement, TestBoundedRecursiveRewrite, - TestNestedOpCreationUndoRewrite, TestReplaceEraseOp>( - &getContext()); + TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, + TestCreateUnregisteredOp>(&getContext()); patterns.add(&getContext(), converter); mlir::populateFuncOpTypeConversionPattern(patterns, converter); mlir::populateCallOpTypeConversionPattern(patterns, converter); @@ -627,7 +640,7 @@ // Define the conversion target used for the test. ConversionTarget target(getContext()); target.addLegalOp(); - target.addLegalOp(); target .addIllegalOp(); @@ -640,6 +653,7 @@ return converter.isSignatureLegal(op.getType()) && converter.isLegal(&op.getBody()); }); + target.addDynamicallyLegalOp([](ILLegalOpG) { return false; }); // Expect the type_producer/type_consumer operations to only operate on f64. target.addDynamicallyLegalOp( @@ -661,8 +675,10 @@ // Handle a partial conversion. if (mode == ConversionMode::Partial) { DenseSet unlegalizedOps; - (void)applyPartialConversion(getOperation(), target, std::move(patterns), - &unlegalizedOps); + if (failed(applyPartialConversion( + getOperation(), target, std::move(patterns), &unlegalizedOps))) { + getOperation()->emitRemark() << "applyPartialConversion failed"; + } // Emit remarks for each legalizable operation. for (auto *op : unlegalizedOps) op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; @@ -676,7 +692,10 @@ return (bool)op->getAttrOfType("test.dynamically_legal"); }); - (void)applyFullConversion(getOperation(), target, std::move(patterns)); + if (failed(applyFullConversion(getOperation(), target, + std::move(patterns)))) { + getOperation()->emitRemark() << "applyFullConversion failed"; + } return; }