diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -635,6 +635,7 @@ let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; + let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }]; let extraClassDeclaration = [{ // Inline the region into the region containing the AssumingOp and delete diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h @@ -17,7 +17,8 @@ #include "mlir/Pass/Pass.h" namespace mlir { -class BufferizeTypeConverter; +class ConversionTarget; +class TypeConverter; } // namespace mlir namespace mlir { @@ -40,9 +41,21 @@ MLIRContext *ctx); std::unique_ptr createRemoveShapeConstraintsPass(); -void populateShapeTypeConversionPatterns(MLIRContext *ctx, - BufferizeTypeConverter &converter, - OwningRewritePatternList &patterns); +/// Populates patterns for shape dialect structural type conversions and sets up +/// the provided ConversionTarget with the appropriate legality configuration +/// for the ops to get converted properly. +/// +/// A "structural" type conversion is one where the underlying ops are +/// completely agnostic to the actual types involved and simply need to update +/// their types consistently. An example of this is shape.assuming -- the +/// shape.assuming op and the corresponding shape.assuming_yield op need to have +/// consistent types, but the exact types don't matter. So all that we need to +/// do for a structural type conversion is to update both of their types +/// consistently to the new types prescribed by the TypeConverter. +void populateShapeStructuralTypeConversionsAndLegality( + MLIRContext *context, TypeConverter &typeConverter, + OwningRewritePatternList &patterns, ConversionTarget &target); + // Bufferizes shape dialect ops. // // Note that most shape dialect ops must be converted to std before diff --git a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp @@ -8,82 +8,30 @@ #include "mlir/Transforms/Bufferize.h" #include "PassDetail.h" -#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/Transforms/Passes.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/StandardTypes.h" #include "mlir/Pass/Pass.h" using namespace mlir; -using namespace mlir::shape; namespace { -// Propagate tensor to memref conversions through shape.assuming ops. -class TypeConversionAssumingOpConverter - : public BufferizeOpConversionPattern { -public: - using BufferizeOpConversionPattern< - shape::AssumingOp>::BufferizeOpConversionPattern; - - LogicalResult - matchAndRewrite(shape::AssumingOp assumingOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - SmallVector newResultTypes; - newResultTypes.reserve(assumingOp.getNumResults()); - for (auto result : assumingOp.getResults()) { - auto originalType = result.getType(); - Type convertedType = converter.convertType(originalType); - newResultTypes.push_back(convertedType); - } - - auto newAssumingOp = rewriter.create( - assumingOp.getLoc(), newResultTypes, assumingOp.witness()); - - rewriter.replaceOp(assumingOp, newAssumingOp.getResults()); - rewriter.inlineRegionBefore(assumingOp.doRegion(), newAssumingOp.doRegion(), - newAssumingOp.doRegion().end()); - - return success(); - } -}; - struct ShapeBufferizePass : public ShapeBufferizeBase { void runOnFunction() override { MLIRContext &ctx = getContext(); OwningRewritePatternList patterns; - BufferizeTypeConverter converter; - populateShapeTypeConversionPatterns(&ctx, converter, patterns); - + BufferizeTypeConverter typeConverter; ConversionTarget target(getContext()); - auto isMemRefType = [](Type type) { return type.isa(); }; - target.addDynamicallyLegalOp([&](shape::AssumingOp op) { - return std::all_of(op.result_type_begin(), op.result_type_end(), - isMemRefType); - }); + populateBufferizeMaterializationLegality(target); + populateShapeStructuralTypeConversionsAndLegality(&ctx, typeConverter, + patterns, target); - if (failed(mlir::applyPartialConversion(getFunction(), target, patterns))) + if (failed(applyPartialConversion(getFunction(), target, patterns))) signalPassFailure(); } }; - } // namespace -/// Populates `patterns` with the conversion patterns of tensor->memref. -// -// TODO: Change this to work generally with any type conversions. -void mlir::populateShapeTypeConversionPatterns( - MLIRContext *context, BufferizeTypeConverter &converter, - OwningRewritePatternList &patterns) { - patterns.insert(context, converter); -} - -//===----------------------------------------------------------------------===// -// ShapeBufferizePass construction -//===----------------------------------------------------------------------===// - std::unique_ptr mlir::createShapeBufferizePass() { return std::make_unique(); } diff --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ Bufferize.cpp RemoveShapeConstraints.cpp ShapeToShapeLowering.cpp + StructuralTypeConversions.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ShapeOps/Transforms diff --git a/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp @@ -0,0 +1,71 @@ +//===- StructuralTypeConversions.cpp - Shape structural type conversions --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace mlir::shape; + +namespace { +class ConvertAssumingOpTypes : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AssumingOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + SmallVector newResultTypes; + newResultTypes.reserve(op.getNumResults()); + for (auto result : op.getResults()) { + auto originalType = result.getType(); + Type convertedType = getTypeConverter()->convertType(originalType); + newResultTypes.push_back(convertedType); + } + + auto newAssumingOp = + rewriter.create(op.getLoc(), newResultTypes, op.witness()); + + rewriter.replaceOp(op, newAssumingOp.getResults()); + rewriter.inlineRegionBefore(op.doRegion(), newAssumingOp.doRegion(), + newAssumingOp.doRegion().end()); + + return success(); + } +}; +} // namespace + +namespace { +class ConvertAssumingYieldOpTypes + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AssumingYieldOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + rewriter.replaceOpWithNewOp(op, operands); + return success(); + } +}; +} // namespace + +void mlir::populateShapeStructuralTypeConversionsAndLegality( + MLIRContext *context, TypeConverter &typeConverter, + OwningRewritePatternList &patterns, ConversionTarget &target) { + patterns.insert( + typeConverter, context); + target.addDynamicallyLegalOp([&](AssumingOp op) { + return typeConverter.isLegal(op.getResultTypes()); + }); + target.addDynamicallyLegalOp([&](AssumingYieldOp op) { + return typeConverter.isLegal(op.getOperandTypes()); + }); +} diff --git a/mlir/test/Dialect/Shape/bufferize.mlir b/mlir/test/Dialect/Shape/bufferize.mlir --- a/mlir/test/Dialect/Shape/bufferize.mlir +++ b/mlir/test/Dialect/Shape/bufferize.mlir @@ -1,12 +1,20 @@ // RUN: mlir-opt -split-input-file -shape-bufferize <%s | FileCheck %s // ----- -// Check that shape.assuming returns a memref. -// -// CHECK-LABEL: @shape_assuming_returns_memref -func @shape_assuming_returns_memref() { + +// CHECK-LABEL: func @shape_assuming() { +// CHECK: %[[WTRUE:.*]] = shape.const_witness true +// CHECK: %[[MEMREF:.*]] = shape.assuming %[[WTRUE]] -> (memref<2xf16>) { +// CHECK: %[[TENSOR_VAL:.*]] = "test.source"() : () -> tensor<2xf16> +// CHECK: %[[YIELDED_MEMREF:.*]] = tensor_to_memref %[[TENSOR_VAL]] : memref<2xf16> +// CHECK: shape.assuming_yield %[[YIELDED_MEMREF]] : memref<2xf16> +// CHECK: } +// CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF:.*]] : memref<2xf16> +// CHECK: "test.sink"(%[[TENSOR]]) : (tensor<2xf16>) -> () +// CHECK: return +// CHECK: } +func @shape_assuming() { %0 = shape.const_witness true - // CHECK: shape.assuming %{{.*}} -> (memref<2xf16>) { %1 = shape.assuming %0 -> (tensor<2xf16>) { %2 = "test.source"() : () -> (tensor<2xf16>) shape.assuming_yield %2 : tensor<2xf16>