diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h --- a/mlir/include/mlir/Dialect/SCF/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms.h @@ -17,7 +17,10 @@ namespace mlir { +class MLIRContext; +class OwningRewritePatternList; class Region; +class TypeConverter; namespace scf { @@ -42,6 +45,10 @@ /// The old loop is replaced with the new one. void tileParallelLoop(ParallelOp op, llvm::ArrayRef tileSizes); +void populateSCFStructuralTypeConversions(MLIRContext *context, + TypeConverter &typeConverter, + OwningRewritePatternList &patterns); + } // namespace scf } // namespace mlir diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ LoopSpecialization.cpp ParallelLoopFusion.cpp ParallelLoopTiling.cpp + StructuralTypeConversions.cpp Utils.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -0,0 +1,73 @@ +//===- StructuralTypeConversions.cpp - scf structural type conversions ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/SCF/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace mlir::scf; + +namespace { +class ConvertForOpTypes : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ForOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + SmallVector newResultTypes; + for (auto type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + rewriter.updateRootInPlace(op, [&] { + for (auto t : llvm::zip(op.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + auto bodyArgs = op.getBody()->getArguments(); + for (auto t : llvm::zip(llvm::drop_begin(bodyArgs, 1), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + }); + return success(); + } +}; +} // namespace + +namespace { +class ConvertIfOpTypes : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(IfOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + SmallVector newResultTypes; + for (auto type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + rewriter.updateRootInPlace(op, [&] { + for (auto t : llvm::zip(op.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + }); + return success(); + } +}; +} // namespace + +void mlir::scf::populateSCFStructuralTypeConversions( + MLIRContext *context, TypeConverter &typeConverter, + OwningRewritePatternList &patterns) { + patterns.insert(typeConverter, context); + patterns.insert(typeConverter, context); +} diff --git a/mlir/test/Transforms/bufferize-structural-ops.mlir b/mlir/test/Transforms/bufferize-structural-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/bufferize-structural-ops.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt -test-bufferize-structural-ops -split-input-file %s | FileCheck %s + +// Basic cases. + +// CHECK-LABEL: func @identity(%arg0: memref) -> memref { +// CHECK-NEXT: return %arg0 : memref +func @identity(%arg0: tensor) -> tensor { + return %arg0 : tensor +} + +// CHECK-LABEL: func @bb_arg(%arg0: memref) -> memref { +// CHECK-NEXT: br ^bb1(%arg0 : memref) +// CHECK-NEXT: ^bb1(%[[BBARG:.*]]: memref): +// CHECK-NEXT: return %[[BBARG]] : memref +func @bb_arg(%arg0: tensor) -> tensor { + br ^bb1(%arg0: tensor) +^bb1(%bbarg: tensor): + return %bbarg : tensor +} + +// CHECK-LABEL: func @if(%arg0: i1, %arg1: memref, %arg2: memref) -> memref { +// CHECK-NEXT: %[[RET:.*]] = scf.if %arg0 -> (memref) { +// CHECK-NEXT: scf.yield %arg1 : memref +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %arg2 : memref +// CHECK-NEXT: } +// CHECK-NEXT: return %[[RET]] : memref +func @if(%pred: i1, %true_val: tensor, %false_val: tensor) -> tensor { + %0 = scf.if %pred -> (tensor) { + scf.yield %true_val : tensor + } else { + scf.yield %false_val : tensor + } + return %0 : tensor +} + +// CHECK-LABEL: func @for(%arg0: memref, %arg1: index, %arg2: index, %arg3: index) -> memref { +// CHECK-NEXT: %[[RET:.*]] = scf.for %arg4 = %arg1 to %arg2 step %arg3 iter_args(%arg5 = %arg0) -> (memref) { +// CHECK-NEXT: scf.yield %arg5 : memref +// CHECK-NEXT: } +// CHECK-NEXT: return %[[RET]] : memref +// CHECK-NEXT: } +func @for(%arg0: tensor, %lb: index, %ub: index, %step: index) -> tensor { + %ret = scf.for %iv = %lb to %ub step %step iter_args(%iter = %arg0) -> tensor { + scf.yield %iter : tensor + } + return %ret : tensor +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ TestAllReduceLowering.cpp TestAffineLoopParametricTiling.cpp TestBufferPlacement.cpp + TestBufferizeStructuralOps.cpp TestExpandTanh.cpp TestCallGraph.cpp TestConstantFold.cpp diff --git a/mlir/test/lib/Transforms/TestBufferizeStructuralOps.cpp b/mlir/test/lib/Transforms/TestBufferizeStructuralOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestBufferizeStructuralOps.cpp @@ -0,0 +1,89 @@ +//===- TestBufferizeStructuralOps.cpp ---------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// A "structural" type conversion is one where the underlying ops are completely +// agnostic to the actual types involved, and simply need to update their types +// accordingly. +// +// An example of this is scf.if -- the values yielded by the regions merely +// need to match the result types. The scf.if op itself just needs to update +// its types accordingly with how the body is modified. +// +// Due to a limitation of the dialect conversion infrastructure, all structural +// conversions need to be done as a single mega-conversion-pass that does a full +// conversion of types. The reason is that conversion patterns apply to the +// results of the region-bearing op itself -- but the operands of yield-like +// terminators inside the regions are not updated, causing verifier errors. +// There doesn't seem to be a way to trigger the dialect conversion framework to +// correctly update them. +// +// TODO: Should every structural type conversion just open-code updating its +// yield-like terminators by calling +// `getTypeConverter()->materializeTargetConversion(...)`? +// +// That is why we have a single standalone test pass instead of composable +// individual (non-test) passes like we do for non-structural bufferization +// patterns (which are composable because source/target materializations are +// automatically inserted by the dialect conversion framework). +// +// Pragmatically, we use bufferization type conversion since it has a +// well-established type conversion infrastructure in-tree, and this pass is +// useful documentation for how downstream users can utilize this functionality, +// given the limitations described above. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Bufferize.h" + +using namespace mlir; + +namespace { +struct TestBufferizeStructuralOpsPass + : mlir::PassWrapper> { + + void runOnOperation() override { + auto func = getOperation(); + auto *context = &getContext(); + + BufferizeTypeConverter typeConverter; + + OwningRewritePatternList patterns; + + ConversionTarget target(*context); + + // All ops whose results are not tensor types are legal. + target.markUnknownOpDynamicallyLegal([](Operation *op) { + return llvm::all_of(op->getResultTypes(), + [](Type type) { return !type.isa(); }); + }); + + populateFuncOpTypeConversionPattern(patterns, context, typeConverter); + target.addDynamicallyLegalOp([&](mlir::FuncOp op) { + return typeConverter.isSignatureLegal(op.getType()) && + typeConverter.isLegal(&op.getBody()); + }); + + scf::populateSCFStructuralTypeConversions(context, typeConverter, patterns); + + if (failed(applyFullConversion(func, target, patterns))) + return signalPassFailure(); + }; +}; +} // end anonymous namespace + +namespace mlir { +void registerTestBufferizeStructuralOpsPass() { + PassRegistration( + "test-bufferize-structural-ops", + "Tests the bufferization of 'structural ops' that cannot be composed " + "into individual passes"); +} +} // end namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -45,6 +45,7 @@ void registerTestAffineLoopUnswitchingPass(); void registerTestAllReduceLoweringPass(); void registerTestBufferPlacementPreparationPass(); +void registerTestBufferizeStructuralOpsPass(); void registerTestCallGraphPass(); void registerTestConstantFold(); void registerTestConvVectorization(); @@ -111,6 +112,7 @@ #endif registerTestAffineLoopParametricTilingPass(); registerTestBufferPlacementPreparationPass(); + registerTestBufferizeStructuralOpsPass(); registerTestDominancePass(); registerTestDynamicPipelinePass(); registerTestFunc();