diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -660,20 +660,25 @@ /// ConversionPatternRewriter, to see what additional constraints are imposed on /// the use of the PatternRewriter. -/// Apply a partial conversion on the given operations, and all nested +/// Apply a partial conversion on the given operations and all nested /// operations. This method converts as many operations to the target as /// possible, ignoring operations that failed to legalize. This method only -/// returns failure if there are unreachable blocks in any of the regions nested -/// within 'ops'. If 'converter' is provided, the signatures of blocks and -/// regions are also converted. +/// returns failure if there ops explicitly marked as illegal. If 'converter' is +/// provided, the signatures of blocks and regions are also converted. +/// If an 'unconvertedOps' set is provided, all operations that are found not +/// to be legalizable to the given 'target' are placed within that set. (Note +/// that if there is an op explicitly marked as illegal, the conversion +/// terminates and the 'unconvertedOps' set will not necessarily be complete.) LLVM_NODISCARD LogicalResult applyPartialConversion(ArrayRef ops, ConversionTarget &target, const OwningRewritePatternList &patterns, - TypeConverter *converter = nullptr); + TypeConverter *converter = nullptr, + DenseSet *unconvertedOps = nullptr); LLVM_NODISCARD LogicalResult applyPartialConversion(Operation *op, ConversionTarget &target, const OwningRewritePatternList &patterns, - TypeConverter *converter = nullptr); + TypeConverter *converter = nullptr, + DenseSet *unconvertedOps = nullptr); /// Apply a complete conversion on the given operations, and all nested /// operations. This method returns failure if the conversion of any operation diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -1541,9 +1541,8 @@ explicit OperationConverter(ConversionTarget &target, const OwningRewritePatternList &patterns, OpConversionMode mode, - DenseSet *legalizableOps = nullptr) - : opLegalizer(target, patterns), mode(mode), - legalizableOps(legalizableOps) {} + DenseSet *trackedOps = nullptr) + : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {} /// Converts the given operations to the conversion target. LogicalResult convertOperations(ArrayRef ops, @@ -1563,9 +1562,11 @@ /// The conversion mode to use when legalizing operations. OpConversionMode mode; - /// A set of pre-existing operations that were found to be legalizable to the - /// target. This field is only used when mode == OpConversionMode::Analysis. - DenseSet *legalizableOps; + /// A set of pre-existing operations. When mode == OpConversionMode::Analysis, + /// this is populated with ops found to be legalizable to the target. + /// When mode == OpConversionMode::Partial, this is populated with ops found + /// *not* to be legalizable to the target. + DenseSet *trackedOps; }; } // end anonymous namespace @@ -1594,17 +1595,22 @@ return op->emitError() << "failed to legalize operation '" << op->getName() << "'"; /// Partial conversions allow conversions to fail iff the operation was not - /// explicitly marked as illegal. - if (mode == OpConversionMode::Partial && opLegalizer.isIllegal(op)) - return op->emitError() - << "failed to legalize operation '" << op->getName() - << "' that was explicitly marked illegal"; + /// explicitly marked as illegal. If the user provided a nonlegalizableOps + /// set, non-legalizable ops are included. + if (mode == OpConversionMode::Partial) { + if (opLegalizer.isIllegal(op)) + return op->emitError() + << "failed to legalize operation '" << op->getName() + << "' that was explicitly marked illegal"; + if (trackedOps) + trackedOps->insert(op); + } } else { /// Analysis conversions don't fail if any operations fail to legalize, /// they are only interested in the operations that were successfully /// legalized. if (mode == OpConversionMode::Analysis) - legalizableOps->insert(op); + trackedOps->insert(op); // If legalization succeeded, convert the types any of the blocks within // this operation. @@ -1932,21 +1938,30 @@ // Op Conversion Entry Points //===----------------------------------------------------------------------===// -/// Apply a partial conversion on the given operations, and all nested +/// Apply a partial conversion on the given operations and all nested /// operations. This method converts as many operations to the target as -/// possible, ignoring operations that failed to legalize. +/// possible, ignoring operations that failed to legalize. This method only +/// returns failure if there ops explicitly marked as illegal. If 'converter' is +/// provided, the signatures of blocks and regions are also converted. +/// If an 'unconvertedOps' set is provided, all operations that are found not +/// to be legalizable to the given 'target' are placed within that set. (Note +/// that if there is an op explicitly marked as illegal, the conversion +/// terminates and the 'unconvertedOps' set will not necessarily be complete.) LogicalResult mlir::applyPartialConversion( ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns, TypeConverter *converter) { - OperationConverter opConverter(target, patterns, OpConversionMode::Partial); + const OwningRewritePatternList &patterns, TypeConverter *converter, + DenseSet *unconvertedOps) { + OperationConverter opConverter(target, patterns, OpConversionMode::Partial, + unconvertedOps); return opConverter.convertOperations(ops, converter); } LogicalResult mlir::applyPartialConversion(Operation *op, ConversionTarget &target, const OwningRewritePatternList &patterns, - TypeConverter *converter) { + TypeConverter *converter, + DenseSet *unconvertedOps) { return applyPartialConversion(llvm::makeArrayRef(op), target, patterns, - converter); + converter, unconvertedOps); } /// Apply a complete conversion on the given operations, and all nested diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -4,6 +4,7 @@ func @verifyDirectPattern() -> i32 { // CHECK-NEXT: "test.legal_op_a"() {status = "Success"} %result = "test.illegal_op_a"() : () -> (i32) + // expected-remark@+1 {{op 'std.return' is not legalizable}} return %result : i32 } @@ -11,6 +12,7 @@ func @verifyLargerBenefit() -> i32 { // CHECK-NEXT: "test.legal_op_a"() {status = "Success"} %result = "test.illegal_op_c"() : () -> (i32) + // expected-remark@+1 {{op 'std.return' is not legalizable}} return %result : i32 } @@ -26,7 +28,9 @@ // CHECK-LABEL: func @remap_call_1_to_1(%arg0: f64) func @remap_call_1_to_1(%arg0: i64) { // CHECK-NEXT: call @remap_input_1_to_1(%arg0) : (f64) -> () + // expected-remark@+1 {{op 'std.call' is not legalizable}} call @remap_input_1_to_1(%arg0) : (i64) -> () + // expected-remark@+1 {{op 'std.return' is not legalizable}} return } @@ -40,6 +44,7 @@ func @remap_input_1_to_N_remaining_use(%arg0: f32) { // CHECK-NEXT: [[CAST:%.*]] = "test.cast"(%arg0, %arg1) : (f16, f16) -> f32 // CHECK-NEXT: "work"([[CAST]]) : (f32) -> () + // expected-remark@+1 {{op 'work' is not legalizable}} "work"(%arg0) : (f32) -> () } @@ -47,6 +52,7 @@ func @remap_input_to_self(%arg0: index) { // CHECK-NOT: test.cast // CHECK: "work" + // expected-remark@+1 {{op 'work' is not legalizable}} "work"(%arg0) : (index) -> () } @@ -59,12 +65,14 @@ // CHECK-LABEL: func @no_remap_nested func @no_remap_nested() { // CHECK-NEXT: "foo.region" + // expected-remark@+1 {{op 'foo.region' is not legalizable}} "foo.region"() ({ // CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64): ^bb0(%i0: i64, %unused: i16, %i1: i64): // CHECK-NEXT: "test.valid"{{.*}} : (i64, i64) "test.invalid"(%i0, %i1) : (i64, i64) -> () }) : () -> () + // expected-remark@+1 {{op 'std.return' is not legalizable}} return } @@ -78,6 +86,7 @@ ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32): "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> () }) : () -> () + // expected-remark@+1 {{op 'std.return' is not legalizable}} return } @@ -91,6 +100,7 @@ ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32): "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> () }) {legalizer.should_clone} : () -> () + // expected-remark@+1 {{op 'std.return' is not legalizable}} return } @@ -102,6 +112,7 @@ ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32): "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> () }) : () -> () + // expected-remark@+1 {{op 'std.return' is not legalizable}} return } @@ -109,6 +120,7 @@ func @dropped_input_in_use(%arg: i16, %arg2: i64) { // CHECK-NEXT: "test.cast"{{.*}} : () -> i16 // CHECK-NEXT: "work"{{.*}} : (i16) + // expected-remark@+1 {{op 'work' is not legalizable}} "work"(%arg) : (i16) -> () } @@ -117,6 +129,7 @@ // CHECK-NEXT: return %repl_1 = "test.rewrite"(%arg) : (i8) -> i8 %repl_2 = "test.rewrite"(%repl_1) : (i8) -> i8 + // expected-remark@+1 {{op 'std.return' is not legalizable}} return %repl_2 : i8 } @@ -127,11 +140,13 @@ %0 = "test.op_with_region_fold"(%arg0) ({ "foo.op_with_region_terminator"() : () -> () }) : (i32) -> (i32) + // expected-remark@+1 {{op 'std.return' is not legalizable}} return %0 : i32 } // CHECK-LABEL: @create_block func @create_block() { + // expected-remark@+1 {{op 'test.container' is not legalizable}} "test.container"() ({ // Check that we created a block with arguments. // CHECK-NOT: test.create_block @@ -140,6 +155,7 @@ "test.create_block"() : () -> () "test.finish"() : () -> () }) : () -> () + // expected-remark@+1 {{op 'std.return' is not legalizable}} return } @@ -147,6 +163,7 @@ func @bounded_recursion() { // CHECK: test.recursive_rewrite 0 test.recursive_rewrite 3 + // expected-remark@+1 {{op 'std.return' is not legalizable}} return } @@ -188,13 +205,16 @@ // CHECK-LABEL: @create_illegal_block func @create_illegal_block() { + // expected-remark@+1 {{op 'test.container' is not legalizable}} "test.container"() ({ // Check that we can undo block creation, i.e. that the block was removed. // CHECK: test.create_illegal_block // CHECK-NOT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32): + // expected-remark@+1 {{op 'test.create_illegal_block' is not legalizable}} "test.create_illegal_block"() : () -> () "test.finish"() : () -> () }) : () -> () + // expected-remark@+1 {{op 'std.return' is not legalizable}} return } @@ -202,6 +222,7 @@ // CHECK-LABEL: @undo_block_arg_replace func @undo_block_arg_replace() { + // expected-remark@+1 {{op 'test.undo_block_arg_replace' is not legalizable}} "test.undo_block_arg_replace"() ({ ^bb0(%arg0: i32): // CHECK: ^bb0(%[[ARG:.*]]: i32): @@ -209,5 +230,6 @@ "test.return"(%arg0) : (i32) -> () }) : () -> () + // expected-remark@+1 {{op 'std.return' is not legalizable}} return } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -515,8 +515,12 @@ // Handle a partial conversion. if (mode == ConversionMode::Partial) { - (void)applyPartialConversion(getOperation(), target, patterns, - &converter); + DenseSet unlegalizedOps; + (void)applyPartialConversion(getOperation(), target, patterns, &converter, + &unlegalizedOps); + // Emit remarks for each legalizable operation. + for (auto *op : unlegalizedOps) + op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; return; }