diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h @@ -16,6 +16,10 @@ #include "mlir/Pass/Pass.h" +namespace mlir { +class BufferAssignmentTypeConverter; +} + namespace mlir { /// Creates an instance of the ShapeToShapeLowering pass that legalizes Shape /// dialect to be convertible to Standard. For example, `shape.num_elements` get @@ -36,6 +40,13 @@ MLIRContext *ctx); std::unique_ptr createRemoveShapeConstraintsPass(); +// Collects a set of patterns to replace tensors as inputs and outputs to shape +// operations with buffers. This only modifies the shape operations. +void populateShapeBufferizePatterns(MLIRContext *ctx, + BufferAssignmentTypeConverter *converter, + OwningRewritePatternList *patterns); +std::unique_ptr createShapeBufferizePass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td @@ -21,4 +21,8 @@ let constructor = "mlir::createShapeToShapeLowering()"; } +def ShapeBufferize : FunctionPass<"shape-bufferize"> { + let summary = "Replace tensors involving shape operations with memrefs"; + let constructor = "mlir::createShapeBufferizePass()"; +} #endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRShapeOpsTransforms RemoveShapeConstraints.cpp + ShapeBufferize.cpp ShapeToShapeLowering.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Shape/Transforms/ShapeBufferize.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeBufferize.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Shape/Transforms/ShapeBufferize.cpp @@ -0,0 +1,95 @@ +//=====- ShapeBufferize.cpp - Shape Buffer Assignment ---------*- C++ -*-=====// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines patterns to convert types of inputs and outputs to shape +// operations to be memrefs instead of tensors. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/BufferPlacement.h" + +using namespace mlir; +using namespace mlir::shape; + +namespace { +// Propagate tensor to memref conversions through shape.assuming ops. +class BufferizeAssumingOpConverter + : public BufferAssignmentOpConversionPattern { +public: + using BufferAssignmentOpConversionPattern< + shape::AssumingOp>::BufferAssignmentOpConversionPattern; + + LogicalResult + matchAndRewrite(shape::AssumingOp assumingOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + SmallVector newResultTypes; + newResultTypes.reserve(assumingOp.getNumResults()); + for (auto result : assumingOp.getResults()) { + auto originalType = result.getType(); + Type convertedType = converter->convertType(originalType); + newResultTypes.push_back(convertedType); + } + + auto newAssumingOp = rewriter.create( + assumingOp.getLoc(), newResultTypes, assumingOp.witness()); + + // Handle the region transfer carefully here to avoid assertions that both + // operations are valid at replacement time. + newAssumingOp.doRegion().push_back(new Block()); + rewriter.replaceOp(assumingOp, newAssumingOp.getResults()); + newAssumingOp.doRegion().takeBody(assumingOp.doRegion()); + + return success(); + } +}; + +struct ShapeBufferizePass : public ShapeBufferizeBase { + void runOnFunction() override; +}; +} // namespace + +void ShapeBufferizePass::runOnFunction() { + MLIRContext &ctx = getContext(); + + OwningRewritePatternList patterns; + BufferAssignmentTypeConverter converter; + populateShapeBufferizePatterns(&ctx, &converter, &patterns); + + ConversionTarget target(getContext()); + auto isMemRefType = [](Type type) { return type.isa(); }; + + target.addDynamicallyLegalOp([&](shape::AssumingOp op) { + return std::all_of(op.result_type_begin(), op.result_type_end(), + isMemRefType); + }); + + if (failed(mlir::applyPartialConversion(getFunction(), target, patterns))) + signalPassFailure(); +} + +/// Populates `patterns` with the conversion patterns of tensor->memref. +void mlir::populateShapeBufferizePatterns( + MLIRContext *context, BufferAssignmentTypeConverter *converter, + OwningRewritePatternList *patterns) { + patterns->insert(context, converter); +} + +//===----------------------------------------------------------------------===// +// ShapeBufferizePass construction +//===----------------------------------------------------------------------===// + +std::unique_ptr mlir::createShapeBufferizePass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Shape/shape-bufferize.mlir b/mlir/test/Dialect/Shape/shape-bufferize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Shape/shape-bufferize.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -shape-bufferize <%s | FileCheck %s --dump-input=fail + +// ----- +// Check that shape.assumign returns a memref. +// +// CHECK-LABEL: @f +func @f() { + %0 = shape.const_witness true + // CHECK: shape.assuming %{{.*}} -> (memref<2xf16>) { + %1 = shape.assuming %0 -> (tensor<2xf16>) { + %2 = "test.source"() : () -> (tensor<2xf16>) + shape.assuming_yield %2 : tensor<2xf16> + } + "test.sink"(%1) : (tensor<2xf16>) -> () + return +} + +