diff --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h --- a/mlir/include/mlir/Dialect/SCF/Passes.h +++ b/mlir/include/mlir/Dialect/SCF/Passes.h @@ -17,6 +17,9 @@ namespace mlir { +/// Creates a pass that bufferizes the SCF dialect. +std::unique_ptr createSCFBufferizePass(); + /// Creates a pass that specializes for loop for unrolling and /// vectorization. std::unique_ptr createForLoopSpecializationPass(); diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td --- a/mlir/include/mlir/Dialect/SCF/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Passes.td @@ -11,6 +11,11 @@ include "mlir/Pass/PassBase.td" +def SCFBufferize : FunctionPass<"scf-bufferize"> { + let summary = "Bufferize the scf dialect."; + let constructor = "mlir::createSCFBufferizePass()"; +} + def SCFForLoopSpecialization : FunctionPass<"for-loop-specialization"> { let summary = "Specialize `for` loops for vectorization"; diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h --- a/mlir/include/mlir/Dialect/SCF/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms.h @@ -17,7 +17,11 @@ namespace mlir { +class ConversionTarget; +class MLIRContext; +class OwningRewritePatternList; class Region; +class TypeConverter; namespace scf { @@ -42,6 +46,19 @@ /// The old loop is replaced with the new one. void tileParallelLoop(ParallelOp op, llvm::ArrayRef tileSizes); +/// Populates patterns for SCF structural type conversions and sets up the +/// provided ConversionTarget with the appropriate legality configuration for +/// the ops to get converted properly. +/// +/// A "structural" type conversion is one where the underlying ops are +/// completely agnostic to the actual types involved and simply need to update +/// their types. An example of this is scf.if -- the scf.if op and the +/// corresponding scf.yield ops need to update their types accordingly to the +/// TypeConverter, but otherwise don't care what type conversions are happening. +void populateSCFStructuralTypeConversionsAndLegality( + MLIRContext *context, TypeConverter &typeConverter, + OwningRewritePatternList &patterns, ConversionTarget &target); + } // namespace scf } // namespace mlir diff --git a/mlir/include/mlir/Transforms/Bufferize.h b/mlir/include/mlir/Transforms/Bufferize.h --- a/mlir/include/mlir/Transforms/Bufferize.h +++ b/mlir/include/mlir/Transforms/Bufferize.h @@ -143,6 +143,15 @@ SmallVector decomposeTypeConversions; }; +/// Marks ops used by bufferization for type conversion materializations as +/// "legal" in the given ConversionTarget. +/// +/// This function should be called by all bufferization passes using +/// BufferizeTypeConverter so that materializations work proprely. One exception +/// is bufferization passes doing "full" conversions, where it can be desirable +/// for even the materializations to remain illegal so that they are eliminated. +void populateBufferizeMaterializationLegality(ConversionTarget &target); + /// Helper conversion pattern that encapsulates a BufferizeTypeConverter /// instance. template diff --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp @@ -0,0 +1,41 @@ +//===- Bufferize.cpp - scf bufferize pass ---------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Bufferize.h" +#include "PassDetail.h" +#include "mlir/Dialect/SCF/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace mlir::scf; + +namespace { +struct SCFBufferizePass : public SCFBufferizeBase { + void runOnFunction() override { + auto func = getOperation(); + auto *context = &getContext(); + + BufferizeTypeConverter typeConverter; + OwningRewritePatternList patterns; + ConversionTarget target(*context); + + populateBufferizeMaterializationLegality(target); + populateSCFStructuralTypeConversionsAndLegality(context, typeConverter, + patterns, target); + if (failed(applyPartialConversion(func, target, patterns))) + return signalPassFailure(); + }; +}; +} // end anonymous namespace + +std::unique_ptr mlir::createSCFBufferizePass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -1,7 +1,9 @@ add_mlir_dialect_library(MLIRSCFTransforms + Bufferize.cpp LoopSpecialization.cpp ParallelLoopFusion.cpp ParallelLoopTiling.cpp + StructuralTypeConversions.cpp Utils.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -0,0 +1,117 @@ +//===- StructuralTypeConversions.cpp - scf structural type conversions ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/SCF/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace mlir::scf; + +namespace { +class ConvertForOpTypes : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ForOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + SmallVector newResultTypes; + for (auto type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + + // Clone and replace. + ForOp newOp = cast(rewriter.clone(*op.getOperation())); + newOp.getOperation()->setOperands(operands); + for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + auto bodyArgs = newOp.getBody()->getArguments(); + for (auto t : llvm::zip(llvm::drop_begin(bodyArgs, 1), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + rewriter.replaceOp(op, newOp.getResults()); + + return success(); + } +}; +} // namespace + +namespace { +class ConvertIfOpTypes : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(IfOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // TODO: Generalize this to any type conversion, not just 1:1. + // + // We need to implement something more sophisticated here that tracks which + // types convert to which other types and does the appropriate + // materialization logic. + // For example, it's possible that one result type converts to 0 types and + // another to 2 types, so newResultTypes would at least be the right size to + // not crash in the llvm::zip call below, but then we would set the the + // wrong type on the SSA values! These edge cases are also why we cannot + // safely use the TypeConverter::convertTypes helper here. + SmallVector newResultTypes; + for (auto type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + + // TODO: Write this with updateRootInPlace once the conversion infra + // supports source materializations on ops updated in place. + IfOp newOp = cast(rewriter.clone(*op.getOperation())); + newOp.getOperation()->setOperands(operands); + for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; +} // namespace + +namespace { +// When the result types of a ForOp/IfOp get changed, the operand types of the +// corresponding yield op need to be changed. In order to trigger the +// appropriate type conversions / materializations, we need a dummy pattern. +class ConvertYieldOpTypes : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(scf::YieldOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, operands); + return success(); + } +}; +} // namespace + +void mlir::scf::populateSCFStructuralTypeConversionsAndLegality( + MLIRContext *context, TypeConverter &typeConverter, + OwningRewritePatternList &patterns, ConversionTarget &target) { + patterns.insert( + typeConverter, context); + target.addDynamicallyLegalOp([&](Operation *op) { + return typeConverter.isLegal(op->getResultTypes()); + }); + target.addDynamicallyLegalOp([&](scf::YieldOp op) { + // We only have conversions for a subset of ops that use scf.yield + // terminators. + if (!isa(op.getParentOp())) + return true; + return typeConverter.isLegal(op.getOperandTypes()); + }); +} diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp --- a/mlir/lib/Transforms/Bufferize.cpp +++ b/mlir/lib/Transforms/Bufferize.cpp @@ -72,6 +72,10 @@ return KeepAsFunctionResult; } +void mlir::populateBufferizeMaterializationLegality(ConversionTarget &target) { + target.addLegalOp(); +}; + //===----------------------------------------------------------------------===// // BufferizeFuncOpConverter //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/bufferize.mlir b/mlir/test/Dialect/SCF/bufferize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/bufferize.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt %s -scf-bufferize | FileCheck %s + +// CHECK-LABEL: func @if( +// CHECK-SAME: %[[PRED:.*]]: i1, +// CHECK-SAME: %[[TRUE_TENSOR:.*]]: tensor, +// CHECK-SAME: %[[FALSE_TENSOR:.*]]: tensor) -> tensor { +// CHECK: %[[RESULT_MEMREF:.*]] = scf.if %[[PRED]] -> (memref) { +// CHECK: %[[TRUE_MEMREF:.*]] = tensor_to_memref %[[TRUE_TENSOR]] : memref +// CHECK: scf.yield %[[TRUE_MEMREF]] : memref +// CHECK: } else { +// CHECK: %[[FALSE_MEMREF:.*]] = tensor_to_memref %[[FALSE_TENSOR]] : memref +// CHECK: scf.yield %[[FALSE_MEMREF]] : memref +// CHECK: } +// CHECK: %[[RESULT_TENSOR:.*]] = tensor_load %[[RESULT_MEMREF:.*]] : memref +// CHECK: return %[[RESULT_TENSOR]] : tensor +// CHECK: } +func @if(%pred: i1, %true_val: tensor, %false_val: tensor) -> tensor { + %0 = scf.if %pred -> (tensor) { + scf.yield %true_val : tensor + } else { + scf.yield %false_val : tensor + } + return %0 : tensor +} + +// CHECK-LABEL: func @for( +// CHECK-SAME: %[[TENSOR:.*]]: tensor, +// CHECK-SAME: %[[LB:.*]]: index, %[[UB:.*]]: index, +// CHECK-SAME: %[[STEP:.*]]: index) -> tensor { +// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref +// CHECK: %[[RESULT_MEMREF:.*]] = scf.for %[[VAL_6:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ITER:.*]] = %[[MEMREF]]) -> (memref) { +// CHECK: scf.yield %[[ITER]] : memref +// CHECK: } +// CHECK: %[[VAL_8:.*]] = tensor_load %[[VAL_9:.*]] : memref +// CHECK: return %[[VAL_8]] : tensor +// CHECK: } +func @for(%arg0: tensor, %lb: index, %ub: index, %step: index) -> tensor { + %ret = scf.for %iv = %lb to %ub step %step iter_args(%iter = %arg0) -> tensor { + scf.yield %iter : tensor + } + return %ret : tensor +}