diff --git a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp --- a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp @@ -138,6 +138,63 @@ } }; +class ConvertTypesInSCFForOp final : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(ForOp forOp, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping(); + const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); + + // Nothing to do if there is no non-identity conversion. + if (!operandMapping.hasNonIdentityConversion() && + !resultMapping.hasNonIdentityConversion()) + return failure(); + + // If the lower-bound, upper-bound, or step were expanded, abort the + // conversion. This conversion does not know what to do in such cases. + ValueRange lbs = adaptor.getLowerBound(); + ValueRange ubs = adaptor.getUpperBound(); + ValueRange steps = adaptor.getStep(); + if (lbs.size() != 1 || ubs.size() != 1 || steps.size() != 1) + return rewriter.notifyMatchFailure( + forOp, "index operands converted to multiple values"); + + Location loc = forOp.getLoc(); + + Region *region = &forOp.getRegion(); + Block *block = ®ion->front(); + + // Construct the new for-op with an empty body. + ValueRange newInits = adaptor.getFlatOperands().drop_front(3); + auto newOp = + rewriter.create(loc, lbs[0], ubs[0], steps[0], newInits); + newOp->setAttrs(forOp->getAttrs()); + + // We do not need the empty blocks created by rewriter. + rewriter.eraseBlock(newOp.getBody()); + + // Convert the signature of the body region. + OneToNTypeMapping bodyTypeMapping(block->getArgumentTypes()); + if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), + bodyTypeMapping))) + return failure(); + + // Perform signature conversion on the body block. + rewriter.applySignatureConversion(block, bodyTypeMapping); + + // Splice the old body region into the new for-op. + Region &dstRegion = newOp.getBodyRegion(); + rewriter.inlineRegionBefore(forOp.getRegion(), dstRegion, dstRegion.end()); + + rewriter.replaceOp(forOp, newOp.getResults(), resultMapping); + + return success(); + } +}; + namespace mlir { namespace scf { @@ -146,6 +203,7 @@ patterns.add< // clang-format off ConvertTypesInSCFConditionOp, + ConvertTypesInSCFForOp, ConvertTypesInSCFIfOp, ConvertTypesInSCFWhileOp, ConvertTypesInSCFYieldOp diff --git a/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir --- a/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir +++ b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir @@ -116,3 +116,68 @@ } return %0 : tuple, i1> } + +// ----- + +// Test case: Nested 1:N type conversion is carried through scf.for and scf.yield. + +// CHECK-LABEL: func.func @for_operands_results( +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { +// CHECK-NEXT: %[[C0:.+]] = arith.constant 0 : index +// CHECK-NEXT: %[[C1:.+]] = arith.constant 1 : index +// CHECK-NEXT: %[[C10:.+]] = arith.constant 10 : index +// CHECK-NEXT: %[[OUT:.+]]:2 = scf.for %arg2 = %[[C0]] to %[[C10]] step %[[C1]] iter_args(%[[ITER0:.+]] = %[[ARG0]], %[[ITER1:.+]] = %[[ARG1]]) -> (i1, i2) { +// CHECK-NEXT: scf.yield %[[ITER0]], %[[ITER1]] : i1, i2 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[OUT]]#0, %[[OUT]]#1 : i1, i2 + +func.func @for_operands_results(%arg0: tuple, i1, tuple>) -> tuple, i1, tuple> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + %0 = scf.for %i = %c0 to %c10 step %c1 iter_args(%acc = %arg0) -> tuple, i1, tuple> { + scf.yield %acc : tuple, i1, tuple> + } + + return %0 : tuple, i1, tuple> +} + +// ----- + +// Test case: Nested 1:N type conversion is carried through scf.for and scf.yield + +// CHECK-LABEL: func.func @for_tuple_ops( +// CHECK-SAME: %[[ARG0:.+]]: i1) -> i1 { +// CHECK-NEXT: %[[C0:.+]] = arith.constant 0 : index +// CHECK-NEXT: %[[C1:.+]] = arith.constant 1 : index +// CHECK-NEXT: %[[C10:.+]] = arith.constant 10 : index +// CHECK-NEXT: %[[FOR:.+]] = scf.for %arg1 = %[[C0]] to %[[C10]] step %[[C1]] iter_args(%[[ITER:.+]] = %[[ARG0]]) -> (i1) { +// CHECK-NEXT: %[[V1:.+]] = "test.make_tuple"() : () -> tuple<> +// CHECK-NEXT: %[[V2:.+]] = "test.make_tuple"(%[[V1]], %[[ITER]]) : (tuple<>, i1) -> tuple, i1> +// CHECK-NEXT: %[[V3:.+]] = "test.op"(%[[V2]]) : (tuple, i1>) -> tuple, i1> +// CHECK-NEXT: %[[V4:.+]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple, i1>) -> tuple<> +// CHECK-NEXT: %[[V5:.+]] = "test.get_tuple_element"(%[[V3]]) <{index = 1 : i32}> : (tuple, i1>) -> i1 +// CHECK-NEXT: scf.yield %[[V5]] : i1 +// CHECK-NEXT: } +// CHECK-NEXT: %[[V6:.+]] = "test.make_tuple"() : () -> tuple<> +// CHECK-NEXT: %[[V7:.+]] = "test.make_tuple"(%[[V6]], %[[FOR]]) : (tuple<>, i1) -> tuple, i1> +// CHECK-NEXT: %[[V8:.+]] = "test.op"(%[[V7]]) : (tuple, i1>) -> tuple, i1> +// CHECK-NEXT: %[[V9:.+]] = "test.get_tuple_element"(%[[V8]]) <{index = 0 : i32}> : (tuple, i1>) -> tuple<> +// CHECK-NEXT: %[[V10:.+]] = "test.get_tuple_element"(%[[V8]]) <{index = 1 : i32}> : (tuple, i1>) -> i1 +// CHECK-NEXT: return %[[V10]] : i1 + +func.func @for_tuple_ops(%arg0: tuple, i1>) -> tuple, i1> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + %0 = scf.for %i = %c0 to %c10 step %c1 iter_args(%acc = %arg0) -> tuple, i1> { + %1 = "test.op"(%acc) : (tuple, i1>) -> tuple, i1> + scf.yield %1 : tuple, i1> + } + + %1 = "test.op"(%0) : (tuple, i1>) -> tuple, i1> + return %1 : tuple, i1> +}