diff --git a/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h b/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h --- a/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h +++ b/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h @@ -16,6 +16,7 @@ struct LogicalResult; class MLIRContext; class Value; +class Operation; class RewritePatternSet; using OwningRewritePatternList = RewritePatternSet; @@ -49,6 +50,9 @@ /// are not rewritten by the provided patterns are legal. void configureParallelLoopToGPULegality(ConversionTarget &target); +/// Clean up after applyPartialConversion/applyFullConversion call. +void finalizeParallelLoopToGPUConversion(Operation *op); + } // namespace mlir #endif // MLIR_CONVERSION_SCFTOGPU_SCFTOGPU_H_ diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -37,6 +37,24 @@ using namespace mlir; using namespace mlir::scf; +// Name of internal attribute to mark visited operations during conversion. +// +// NOTE: The conversion originally used the following legality criteria: +// `!parallelOp->hasAttr(gpu::getMappingAttrName())` +// But the provided pattern might reject some cases based on more detailed +// analysis of the `mapping` attribute. +// To avoid dialect conversion failure due to non-converted illegal operation +// we use this extra Unit attribute as a marker, that the operation was checked +// by the pattern and is should be considered as legal in the following legality +// checks. The `finalizeParallelLoopToGPUConversion` function performs clean up +// of this extra attributes ans is supposed to be called after the dialect +// conversion. +// +// TODO: Implement a cleaner solution, factoring out the "matching" logic +// from the pattern and its callees into a separate function that can be called +// from both the pattern and the op legality check. +static constexpr StringLiteral kVisitedAttrName = "SCFToGPU_visited"; + // Extract an indexed value from KernelDim3. static Value getDim3Value(const gpu::KernelDim3 &dim3, unsigned pos) { switch (pos) { @@ -567,6 +585,9 @@ LogicalResult ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp, PatternRewriter &rewriter) const { + // Mark the operation as visited for recursive legality check. + parallelOp->setAttr(kVisitedAttrName, rewriter.getUnitAttr()); + // We can only transform starting at the outer-most loop. Launches inside of // parallel loops are not supported. if (auto parentLoop = parallelOp->getParentOfType()) @@ -649,6 +670,13 @@ void mlir::configureParallelLoopToGPULegality(ConversionTarget &target) { target.addLegalDialect(); target.addDynamicallyLegalOp([](scf::ParallelOp parallelOp) { - return !parallelOp->getAttr(gpu::getMappingAttrName()); + return !parallelOp->hasAttr(gpu::getMappingAttrName()) || + parallelOp->hasAttr(kVisitedAttrName); + }); +} + +void mlir::finalizeParallelLoopToGPUConversion(Operation *op) { + op->walk([](scf::ParallelOp parallelOp) { + parallelOp->removeAttr(kVisitedAttrName); }); } diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp @@ -55,6 +55,7 @@ if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); + finalizeParallelLoopToGPUConversion(getOperation()); } }; diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -175,6 +175,7 @@ target.addLegalDialect(); target.addDynamicallyLegalDialect( [&](Operation *op) { return typeConverter.isLegal(op); }); + target.addLegalOp(); target.addLegalDialect(); if (failed( diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1650,7 +1650,13 @@ bool OperationLegalizer::isIllegal(Operation *op) const { // Check if the target explicitly marked this operation as illegal. - return target.getOpAction(op->getName()) == LegalizationAction::Illegal; + if (auto info = target.getOpAction(op->getName())) { + if (*info == LegalizationAction::Dynamic) + return !target.isLegal(op); + return *info == LegalizationAction::Illegal; + } + + return false; } LogicalResult diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir --- a/mlir/test/Transforms/test-legalizer-full.mlir +++ b/mlir/test/Transforms/test-legalizer-full.mlir @@ -47,55 +47,88 @@ // ----- -// Test that region cloning can be properly undone. -func @test_undo_region_clone() { - "test.region"() ({ - ^bb1(%i0: i64): - "test.invalid"(%i0) : (i64) -> () - }) {legalizer.should_clone} : () -> () - - // expected-error@+1 {{failed to legalize operation 'test.illegal_op_f'}} - %ignored = "test.illegal_op_f"() : () -> (i32) - "test.return"() : () -> () +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { + + // Test that region cloning can be properly undone. + func @test_undo_region_clone() { + "test.region"() ({ + ^bb1(%i0: i64): + "test.invalid"(%i0) : (i64) -> () + }) {legalizer.should_clone} : () -> () + + // expected-error@+1 {{failed to legalize operation 'test.illegal_op_f'}} + %ignored = "test.illegal_op_f"() : () -> (i32) + "test.return"() : () -> () + } + } // ----- -// Test that unknown operations can be dynamically legal. -func @test_unknown_dynamically_legal() { - "foo.unknown_op"() {test.dynamically_legal} : () -> () +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { + + // Test that unknown operations can be dynamically legal. + func @test_unknown_dynamically_legal() { + "foo.unknown_op"() {test.dynamically_legal} : () -> () + + // expected-error@+1 {{failed to legalize operation 'foo.unknown_op'}} + "foo.unknown_op"() {} : () -> () + "test.return"() : () -> () + } - // expected-error@+1 {{failed to legalize operation 'foo.unknown_op'}} - "foo.unknown_op"() {} : () -> () - "test.return"() : () -> () } // ----- -// Test that region inlining can be properly undone. -func @test_undo_region_inline() { - "test.region"() ({ - ^bb1(%i0: i64): - // expected-error@+1 {{failed to legalize operation 'std.br'}} - br ^bb2(%i0 : i64) - ^bb2(%i1: i64): - "test.invalid"(%i1) : (i64) -> () - }) {} : () -> () +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { + + // Test that region inlining can be properly undone. + func @test_undo_region_inline() { + "test.region"() ({ + ^bb1(%i0: i64): + // expected-error@+1 {{failed to legalize operation 'std.br'}} + br ^bb2(%i0 : i64) + ^bb2(%i1: i64): + "test.invalid"(%i1) : (i64) -> () + }) {} : () -> () + + "test.return"() : () -> () + } - "test.return"() : () -> () } // ----- -// Test that multiple block erases can be properly undone. -func @test_undo_block_erase() { - // expected-error@+1 {{failed to legalize operation 'test.region'}} - "test.region"() ({ - ^bb1(%i0: i64): - br ^bb2(%i0 : i64) - ^bb2(%i1: i64): - "test.invalid"(%i1) : (i64) -> () - }) {legalizer.should_clone, legalizer.erase_old_blocks} : () -> () +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { + + // Test that multiple block erases can be properly undone. + func @test_undo_block_erase() { + // expected-error@+1 {{failed to legalize operation 'test.region'}} + "test.region"() ({ + ^bb1(%i0: i64): + br ^bb2(%i0 : i64) + ^bb2(%i1: i64): + "test.invalid"(%i1) : (i64) -> () + }) {legalizer.should_clone, legalizer.erase_old_blocks} : () -> () + + "test.return"() : () -> () + } + +} + +// ----- + +// expected-remark@+1 {{applyFullConversion failed}} +builtin.module { + + func @create_unregistered_op_in_pattern() -> i32 { + // expected-error@+1 {{failed to legalize operation 'test.illegal_op_g'}} + %0 = "test.illegal_op_g"() : () -> (i32) + "test.return"(%0) : (i32) -> () + } - "test.return"() : () -> () } diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -173,18 +173,28 @@ // ----- -func @fail_to_convert_illegal_op() -> i32 { - // expected-error@+1 {{failed to legalize operation 'test.illegal_op_f'}} - %result = "test.illegal_op_f"() : () -> (i32) - return %result : i32 +// expected-remark@+1 {{applyPartialConversion failed}} +builtin.module { + + func @fail_to_convert_illegal_op() -> i32 { + // expected-error@+1 {{failed to legalize operation 'test.illegal_op_f'}} + %result = "test.illegal_op_f"() : () -> (i32) + return %result : i32 + } + } // ----- -func @fail_to_convert_illegal_op_in_region() { - // expected-error@+1 {{failed to legalize operation 'test.region_builder'}} - "test.region_builder"() : () -> () - return +// expected-remark@+1 {{applyPartialConversion failed}} +builtin.module { + + func @fail_to_convert_illegal_op_in_region() { + // expected-error@+1 {{failed to legalize operation 'test.region_builder'}} + "test.region_builder"() : () -> () + return + } + } // ----- @@ -192,17 +202,21 @@ // Check that the entry block arguments of a region are untouched in the case // of failure. -// CHECK-LABEL: func @fail_to_convert_region -func @fail_to_convert_region() { - // CHECK-NEXT: "test.region" - // CHECK-NEXT: ^bb{{.*}}(%{{.*}}: i64): - "test.region"() ({ - ^bb1(%i0: i64): - // expected-error@+1 {{failed to legalize operation 'test.region_builder'}} - "test.region_builder"() : () -> () - "test.valid"() : () -> () - }) : () -> () - return +// expected-remark@+1 {{applyPartialConversion failed}} +builtin.module { + + func @fail_to_convert_region() { + // CHECK: "test.region" + // CHECK-NEXT: ^bb{{.*}}(%{{.*}}: i64): + "test.region"() ({ + ^bb1(%i0: i64): + // expected-error@+1 {{failed to legalize operation 'test.region_builder'}} + "test.region_builder"() : () -> () + "test.valid"() : () -> () + }) : () -> () + return + } + } // ----- @@ -271,10 +285,8 @@ return } - // ----- - // Check that a conversion pattern on `test.blackhole` can mark the producer // for deletion. // CHECK-LABEL: @blackhole @@ -284,3 +296,16 @@ // expected-remark@+1 {{op 'std.return' is not legalizable}} return } + +// ----- + +// expected-remark@+1 {{applyPartialConversion failed}} +builtin.module { + + func @create_unregistered_op_in_pattern() -> i32 { + // expected-error@+1 {{failed to legalize operation 'test.illegal_op_g'}} + %0 = "test.illegal_op_g"() : () -> (i32) + "test.return"(%0) : (i32) -> () + } + +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1415,9 +1415,12 @@ def ILLegalOpD : TEST_Op<"illegal_op_d">, Results<(outs I32)>; def ILLegalOpE : TEST_Op<"illegal_op_e">, Results<(outs I32)>; def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32)>; +def ILLegalOpG : TEST_Op<"illegal_op_g">, Results<(outs I32)>; def LegalOpA : TEST_Op<"legal_op_a">, Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32)>; def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>; +def LegalOpC : TEST_Op<"legal_op_c">, + Arguments<(ins I32)>, Results<(outs I32)>; // Check that the conversion infrastructure can properly undo the creation of // operations where an operation was created before its parent, in this case, diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -562,6 +562,20 @@ return success(); }; }; + +// This pattern replaces explicitly illegal op with explicitly legal op, +// but in addition creates unregistered operation. +struct TestCreateUnregisteredOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ILLegalOpG op, + PatternRewriter &rewriter) const final { + IntegerAttr attr = rewriter.getI32IntegerAttr(0); + Value val = rewriter.create(op->getLoc(), attr); + rewriter.replaceOpWithNewOp(op, val); + return success(); + }; +}; } // namespace namespace { @@ -632,6 +646,10 @@ TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { TestTypeConverter converter; mlir::RewritePatternSet patterns(&getContext()); @@ -643,8 +661,8 @@ TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, TestNonRootReplacement, TestBoundedRecursiveRewrite, - TestNestedOpCreationUndoRewrite, TestReplaceEraseOp>( - &getContext()); + TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, + TestCreateUnregisteredOp>(&getContext()); patterns.add(&getContext(), converter); mlir::populateFuncOpTypeConversionPattern(patterns, converter); mlir::populateCallOpTypeConversionPattern(patterns, converter); @@ -652,7 +670,7 @@ // Define the conversion target used for the test. ConversionTarget target(getContext()); target.addLegalOp(); - target.addLegalOp(); target .addIllegalOp(); @@ -666,6 +684,11 @@ converter.isLegal(&op.getBody()); }); + // TestCreateUnregisteredOp creates `std.constant` operation, + // which was not added to target intentionally to test + // correct error code from conversion driver. + target.addDynamicallyLegalOp([](ILLegalOpG) { return false; }); + // Expect the type_producer/type_consumer operations to only operate on f64. target.addDynamicallyLegalOp( [](TestTypeProducerOp op) { return op.getType().isF64(); }); @@ -686,8 +709,10 @@ // Handle a partial conversion. if (mode == ConversionMode::Partial) { DenseSet unlegalizedOps; - (void)applyPartialConversion(getOperation(), target, std::move(patterns), - &unlegalizedOps); + if (failed(applyPartialConversion( + getOperation(), target, std::move(patterns), &unlegalizedOps))) { + getOperation()->emitRemark() << "applyPartialConversion failed"; + } // Emit remarks for each legalizable operation. for (auto *op : unlegalizedOps) op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; @@ -701,7 +726,10 @@ return (bool)op->getAttrOfType("test.dynamically_legal"); }); - (void)applyFullConversion(getOperation(), target, std::move(patterns)); + if (failed(applyFullConversion(getOperation(), target, + std::move(patterns)))) { + getOperation()->emitRemark() << "applyFullConversion failed"; + } return; }