diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h @@ -16,6 +16,10 @@ #include "mlir/Pass/Pass.h" +namespace mlir { +class BufferAssignmentTypeConverter; +} // namespace mlir + namespace mlir { /// Creates an instance of the ShapeToShapeLowering pass that legalizes Shape /// dialect to be convertible to Standard. For example, `shape.num_elements` get @@ -36,6 +40,13 @@ MLIRContext *ctx); std::unique_ptr createRemoveShapeConstraintsPass(); +void populateShapeTypeConversionPatterns( + MLIRContext *ctx, BufferAssignmentTypeConverter *converter, + OwningRewritePatternList *patterns); +// Collects a set of patterns to replace tensors as inputs and outputs to shape +// operations with buffers. This only modifies the shape operations. +std::unique_ptr createShapeTensorToMemrefPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td @@ -21,4 +21,9 @@ let constructor = "mlir::createShapeToShapeLowering()"; } +// TODO(tpopp): Generalize this to allow any type conversions desired. +def ShapeTensorToMemref : FunctionPass<"shape-tensor-to-memref"> { + let summary = "Replace tensors involving shape operations with memrefs"; + let constructor = "mlir::createShapeTensorToMemrefPass()"; +} #endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRShapeOpsTransforms RemoveShapeConstraints.cpp + ShapeTypeConversion.cpp ShapeToShapeLowering.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp @@ -0,0 +1,98 @@ +//=====------- ShapeTypeConversion.cpp - Shape Type Conversions ----------*- C++ +//-*-=====// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines patterns to convert types of inputs and outputs to shape +// operations to be memrefs instead of tensors. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/BufferPlacement.h" + +using namespace mlir; +using namespace mlir::shape; + +namespace { +// Propagate tensor to memref conversions through shape.assuming ops. +class TypeConversionAssumingOpConverter + : public BufferAssignmentOpConversionPattern { +public: + using BufferAssignmentOpConversionPattern< + shape::AssumingOp>::BufferAssignmentOpConversionPattern; + + LogicalResult + matchAndRewrite(shape::AssumingOp assumingOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + SmallVector newResultTypes; + newResultTypes.reserve(assumingOp.getNumResults()); + for (auto result : assumingOp.getResults()) { + auto originalType = result.getType(); + Type convertedType = converter->convertType(originalType); + newResultTypes.push_back(convertedType); + } + + auto newAssumingOp = rewriter.create( + assumingOp.getLoc(), newResultTypes, assumingOp.witness()); + + // Handle the region transfer carefully here to avoid assertions that both + // operations are valid at replacement time. + newAssumingOp.doRegion().push_back(new Block()); + rewriter.replaceOp(assumingOp, newAssumingOp.getResults()); + newAssumingOp.doRegion().takeBody(assumingOp.doRegion()); + + return success(); + } +}; + +struct ShapeTensorToMemrefPass + : public ShapeTensorToMemrefBase { + void runOnFunction() override { + MLIRContext &ctx = getContext(); + + OwningRewritePatternList patterns; + BufferAssignmentTypeConverter converter; + populateShapeTypeConversionPatterns(&ctx, &converter, &patterns); + + ConversionTarget target(getContext()); + auto isMemRefType = [](Type type) { return type.isa(); }; + + target.addDynamicallyLegalOp([&](shape::AssumingOp op) { + return std::all_of(op.result_type_begin(), op.result_type_end(), + isMemRefType); + }); + + if (failed(mlir::applyPartialConversion(getFunction(), target, patterns))) + signalPassFailure(); + } +}; + +} // namespace + +/// Populates `patterns` with the conversion patterns of tensor->memref. +// +// TODO(tpopp): Change this to work generally with any type conversions. +void mlir::populateShapeTypeConversionPatterns( + MLIRContext *context, BufferAssignmentTypeConverter *converter, + OwningRewritePatternList *patterns) { + patterns->insert(context, converter); +} + +//===----------------------------------------------------------------------===// +// ShapeTensorToMemrefPass construction +//===----------------------------------------------------------------------===// + +std::unique_ptr mlir::createShapeTensorToMemrefPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Shape/shape-type-conversion.mlir b/mlir/test/Dialect/Shape/shape-type-conversion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Shape/shape-type-conversion.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -shape-tensor-to-memref <%s | FileCheck %s + +// ----- +// Check that shape.assuming returns a memref. +// +// CHECK-LABEL: @shape_assuming_returns_memref +func @shape_assuming_returns_memref() { + %0 = shape.const_witness true + // CHECK: shape.assuming %{{.*}} -> (memref<2xf16>) { + %1 = shape.assuming %0 -> (tensor<2xf16>) { + %2 = "test.source"() : () -> (tensor<2xf16>) + shape.assuming_yield %2 : tensor<2xf16> + } + "test.sink"(%1) : (tensor<2xf16>) -> () + return +} + +