diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h @@ -120,6 +120,12 @@ TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target); +/// Populates the provided pattern set with patterns that do 1:N type +/// conversions on (some) SCF ops. This is intended to be used with +/// applyPartialOneToNConversion. +void populateSCFStructuralOneToNTypeConversions(TypeConverter &typeConverter, + RewritePatternSet &patterns); + /// Options to dictate how loops should be pipelined. struct PipeliningOption { /// Lambda returning all the operation in the forOp, with their stage, in the diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -6,6 +6,7 @@ LoopPipelining.cpp LoopRangeFolding.cpp LoopSpecialization.cpp + OneToNTypeConversion.cpp ParallelLoopCollapsing.cpp ParallelLoopFusion.cpp ParallelLoopTiling.cpp diff --git a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp @@ -0,0 +1,161 @@ +//===-- OneToNTypeConversion.cpp - SCF 1:N type conversion ------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The patterns in this file are heavily inspired (and copied from) +// lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp but work for 1:N +// type conversions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/Transforms/Transforms.h" + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Transforms/OneToNTypeConversion.h" + +using namespace mlir; +using namespace mlir::scf; + +class ConvertTypesInSCFIfOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(IfOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping & /*operandMapping*/, + const OneToNTypeMapping &resultMapping, + const ValueRange /*convertedOperands*/) const override { + Location loc = op->getLoc(); + + // Nothing to do if there is no non-identity conversion. + if (!resultMapping.hasNonIdentityConversion()) + return failure(); + + // Create new IfOp. + TypeRange convertedResultTypes = resultMapping.getConvertedTypes(); + auto newOp = rewriter.create(loc, convertedResultTypes, + op.getCondition(), true); + newOp->setAttrs(op->getAttrs()); + + // We do not need the empty blocks created by rewriter. + rewriter.eraseBlock(newOp.elseBlock()); + rewriter.eraseBlock(newOp.thenBlock()); + + // Inlines block from the original operation. + rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), + newOp.getThenRegion().end()); + rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), + newOp.getElseRegion().end()); + + rewriter.replaceOp(op, SmallVector(newOp->getResults()), + resultMapping); + return success(); + } +}; + +class ConvertTypesInSCFWhileOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(WhileOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping &resultMapping, + const ValueRange convertedOperands) const override { + Location loc = op->getLoc(); + + // Nothing to do if the op doesn't have any non-identity conversions for its + // operands or results. + if (!operandMapping.hasNonIdentityConversion() && + !resultMapping.hasNonIdentityConversion()) + return failure(); + + // Create new WhileOp. + TypeRange convertedResultTypes = resultMapping.getConvertedTypes(); + + auto newOp = + rewriter.create(loc, convertedResultTypes, convertedOperands); + newOp->setAttrs(op->getAttrs()); + + // Update block signatures. + std::array blockMappings = {operandMapping, + resultMapping}; + for (unsigned int i : {0u, 1u}) { + Region *region = &op.getRegion(i); + Block *block = ®ion->front(); + + rewriter.applySignatureConversion(block, blockMappings[i]); + + // Move updated region to new WhileOp. + Region &dstRegion = newOp.getRegion(i); + rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); + } + + rewriter.replaceOp(op, SmallVector(newOp->getResults()), + resultMapping); + return success(); + } +}; + +class ConvertTypesInSCFYieldOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(YieldOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping & /*resultMapping*/, + const ValueRange convertedOperands) const override { + // Nothing to do if there is no non-identity conversion. + if (!operandMapping.hasNonIdentityConversion()) + return failure(); + + // Convert operands. + rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); }); + + return success(); + } +}; + +class ConvertTypesInSCFConditionOp + : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(ConditionOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping & /*resultMapping*/, + const ValueRange convertedOperands) const override { + // Nothing to do if there is no non-identity conversion. + if (!operandMapping.hasNonIdentityConversion()) + return failure(); + + // Convert operands. + rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); }); + + return success(); + } +}; + +namespace mlir { +namespace scf { + +void populateSCFStructuralOneToNTypeConversions(TypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add< + // clang-format off + ConvertTypesInSCFConditionOp, + ConvertTypesInSCFIfOp, + ConvertTypesInSCFWhileOp, + ConvertTypesInSCFYieldOp + // clang-format on + >(typeConverter, patterns.getContext()); +} + +} // namespace scf +} // namespace mlir diff --git a/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir @@ -0,0 +1,118 @@ +// RUN: mlir-opt %s -split-input-file \ +// RUN: -test-one-to-n-type-conversion="convert-func-ops convert-scf-ops" \ +// RUN: | FileCheck %s + +// Test case: Nested 1:N type conversion is carried through scf.if and +// scf.yield. + +// CHECK-LABEL: func.func @if_result( +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i2, +// CHECK-SAME: %[[ARG2:.*]]: i1) -> (i1, i2) { +// CHECK-NEXT: %[[V0:.*]]:2 = scf.if %[[ARG2]] -> (i1, i2) { +// CHECK-NEXT: scf.yield %[[ARG0]], %[[ARG1]] : i1, i2 +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %[[ARG0]], %[[ARG1]] : i1, i2 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V0]]#0, %[[V0]]#1 : i1, i2 +func.func @if_result(%arg0: tuple, i1, tuple>, %arg1: i1) -> tuple, i1, tuple> { + %0 = scf.if %arg1 -> (tuple, i1, tuple>) { + scf.yield %arg0 : tuple, i1, tuple> + } else { + scf.yield %arg0 : tuple, i1, tuple> + } + return %0 : tuple, i1, tuple> +} + +// ----- + +// Test case: Nested 1:N type conversion is carried through scf.if and +// scf.yield and unconverted ops inside have proper materializations. + +// CHECK-LABEL: func.func @if_tuple_ops( +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i1) -> i1 { +// CHECK-NEXT: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<> +// CHECK-NEXT: %[[V1:.*]] = "test.make_tuple"(%[[V0]], %[[ARG0]]) : (tuple<>, i1) -> tuple, i1> +// CHECK-NEXT: %[[V2:.*]] = scf.if %[[ARG1]] -> (i1) { +// CHECK-NEXT: %[[V3:.*]] = "test.op"(%[[V1]]) : (tuple, i1>) -> tuple, i1> +// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple, i1>) -> tuple<> +// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 1 : i32} : (tuple, i1>) -> i1 +// CHECK-NEXT: scf.yield %[[V5]] : i1 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[V6:.*]] = "test.source"() : () -> tuple, i1> +// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple, i1>) -> tuple<> +// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 1 : i32} : (tuple, i1>) -> i1 +// CHECK-NEXT: scf.yield %[[V8]] : i1 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V2]] : i1 +func.func @if_tuple_ops(%arg0: tuple, i1>, %arg1: i1) -> tuple, i1> { + %0 = scf.if %arg1 -> (tuple, i1>) { + %1 = "test.op"(%arg0) : (tuple, i1>) -> tuple, i1> + scf.yield %1 : tuple, i1> + } else { + %1 = "test.source"() : () -> tuple, i1> + scf.yield %1 : tuple, i1> + } + return %0 : tuple, i1> +} +// ----- + +// Test case: Nested 1:N type conversion is carried through scf.while, +// scf.condition, and scf.yield. + +// CHECK-LABEL: func.func @while_operands_results( +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i2, +// CHECK-SAME: %[[ARG2:.*]]: i1) -> (i1, i2) { +// %[[V0:.*]]:2 = scf.while (%[[ARG3:.*]] = %[[ARG0]], %[[ARG4:.*]] = %[[ARG1]]) : (i1, i2) -> (i1, i2) { +// scf.condition(%arg2) %[[ARG3]], %[[ARG4]] : i1, i2 +// } do { +// ^bb0(%[[ARG5:.*]]: i1, %[[ARG6:.*]]: i2): +// scf.yield %[[ARG5]], %[[ARG4]] : i1, i2 +// } +// return %[[V0]]#0, %[[V0]]#1 : i1, i2 +func.func @while_operands_results(%arg0: tuple, i1, tuple>, %arg1: i1) -> tuple, i1, tuple> { + %0 = scf.while (%arg2 = %arg0) : (tuple, i1, tuple>) -> tuple, i1, tuple> { + scf.condition(%arg1) %arg2 : tuple, i1, tuple> + } do { + ^bb0(%arg2: tuple, i1, tuple>): + scf.yield %arg2 : tuple, i1, tuple> + } + return %0 : tuple, i1, tuple> +} + +// ----- + +// Test case: Nested 1:N type conversion is carried through scf.while, +// scf.condition, and unconverted ops inside have proper materializations. + +// CHECK-LABEL: func.func @while_tuple_ops( +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i1) -> i1 { +// CHECK-NEXT: %[[V0:.*]] = scf.while (%[[ARG2:.*]] = %[[ARG0]]) : (i1) -> i1 { +// CHECK-NEXT: %[[V1:.*]] = "test.make_tuple"() : () -> tuple<> +// CHECK-NEXT: %[[V2:.*]] = "test.make_tuple"(%[[V1]], %[[ARG2]]) : (tuple<>, i1) -> tuple, i1> +// CHECK-NEXT: %[[V3:.*]] = "test.op"(%[[V2]]) : (tuple, i1>) -> tuple, i1> +// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple, i1>) -> tuple<> +// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 1 : i32} : (tuple, i1>) -> i1 +// CHECK-NEXT: scf.condition(%[[ARG1]]) %[[V5]] : i1 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[ARG3:.*]]: i1): +// CHECK-NEXT: %[[V6:.*]] = "test.source"() : () -> tuple, i1> +// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple, i1>) -> tuple<> +// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 1 : i32} : (tuple, i1>) -> i1 +// CHECK-NEXT: scf.yield %[[V8]] : i1 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V0]] : i1 +func.func @while_tuple_ops(%arg0: tuple, i1>, %arg1: i1) -> tuple, i1> { + %0 = scf.while (%arg2 = %arg0) : (tuple, i1>) -> tuple, i1> { + %1 = "test.op"(%arg2) : (tuple, i1>) -> tuple, i1> + scf.condition(%arg1) %1 : tuple, i1> + } do { + ^bb0(%arg2: tuple, i1>): + %1 = "test.source"() : () -> tuple, i1> + scf.yield %1 : tuple, i1> + } + return %0 : tuple, i1> +} diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt --- a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt @@ -7,6 +7,9 @@ MLIRFuncDialect MLIRFuncTransforms MLIRIR + MLIRPass + MLIRSCFDialect + MLIRSCFTransforms MLIRTestDialect MLIRTransformUtils ) diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp --- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp @@ -8,6 +8,7 @@ #include "TestDialect.h" #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/OneToNTypeConversion.h" @@ -43,6 +44,10 @@ llvm::cl::desc("Enable conversion on func ops"), llvm::cl::init(false)}; + Option convertSCFOps{*this, "convert-scf-ops", + llvm::cl::desc("Enable conversion on scf ops"), + llvm::cl::init(false)}; + Option convertTupleOps{*this, "convert-tuple-ops", llvm::cl::desc("Enable conversion on tuple ops"), llvm::cl::init(false)}; @@ -237,6 +242,8 @@ populateDecomposeTuplesTestPatterns(typeConverter, patterns); if (convertFuncOps) populateFuncTypeConversionPatterns(typeConverter, patterns); + if (convertSCFOps) + scf::populateSCFStructuralOneToNTypeConversions(typeConverter, patterns); // Run conversion. if (failed(applyPartialOneToNConversion(module, typeConverter,