diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SHAPE_BUFFERIZABLEOPINTERFACEIMPL_H +#define MLIR_DIALECT_SHAPE_BUFFERIZABLEOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace shape { +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace shape +} // namespace mlir + +#endif // MLIR_DIALECT_SHAPE_BUFFERIZABLEOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h @@ -40,21 +40,6 @@ void populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns); std::unique_ptr> createRemoveShapeConstraintsPass(); -/// Populates patterns for shape dialect structural type conversions and sets up -/// the provided ConversionTarget with the appropriate legality configuration -/// for the ops to get converted properly. -/// -/// A "structural" type conversion is one where the underlying ops are -/// completely agnostic to the actual types involved and simply need to update -/// their types consistently. An example of this is shape.assuming -- the -/// shape.assuming op and the corresponding shape.assuming_yield op need to have -/// consistent types, but the exact types don't matter. So all that we need to -/// do for a structural type conversion is to update both of their types -/// consistently to the new types prescribed by the TypeConverter. -void populateShapeStructuralTypeConversionsAndLegality( - TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target); - // Bufferizes shape dialect ops. // // Note that most shape dialect ops must be converted to std before diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -0,0 +1,169 @@ +//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" + +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::bufferization; +using namespace mlir::shape; + +namespace mlir { +namespace shape { +namespace { + +/// Bufferization of shape.assuming. +struct AssumingOpInterface + : public BufferizableOpInterface::ExternalModel { + SmallVector + getAliasingOpOperand(Operation *op, OpResult opResult, + const BufferizationState &state) const { + // AssumingOps do not have tensor OpOperands. The yielded value can be any + // SSA value that is in scope. To allow for use-def chain traversal through + // AssumingOps in the analysis, the corresponding yield value is considered + // to be aliasing with the result. + auto assumingOp = cast(op); + size_t resultNum = std::distance(op->getOpResults().begin(), + llvm::find(op->getOpResults(), opResult)); + // TODO: Support multiple blocks. + assert(assumingOp.getDoRegion().getBlocks().size() == 1 && + "expected exactly 1 block"); + auto yieldOp = dyn_cast( + assumingOp.getDoRegion().front().getTerminator()); + assert(yieldOp && "expected shape.assuming_yield terminator"); + return {&yieldOp->getOpOperand(resultNum)}; + } + + // TODO: For better bufferization results, this could return `true` only if + // there is a memory write in the region. + bool isMemoryWrite(Operation *op, OpResult opResult, + const BufferizationState &state) const { + // Similar to scf.if, results of this op are always considered memory writes + // in the analysis. This is a useful pattern for all ops that have tensor + // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is + // implemented in terms of `bufferizesToMemoryWrite`, which does not work on + // ops without OpOperands. + return true; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationState &state) const { + auto assumingOp = cast(op); + + // Compute new result types. + SmallVector newResultTypes; + for (Type type : assumingOp->getResultTypes()) { + if (auto tensorType = type.dyn_cast()) { + newResultTypes.push_back(getMemRefType(tensorType, state.getOptions())); + } else { + newResultTypes.push_back(type); + } + } + + // Create new op and move over region. + auto newOp = rewriter.create( + op->getLoc(), newResultTypes, assumingOp.getWitness()); + newOp.getDoRegion().takeBody(assumingOp.getRegion()); + + // Update terminator. + assert(newOp.getDoRegion().getBlocks().size() == 1 && + "only 1 block supported"); + Block *newBlock = &newOp.getDoRegion().front(); + auto yieldOp = cast(newBlock->getTerminator()); + rewriter.setInsertionPoint(yieldOp); + SmallVector newYieldValues; + for (const auto &it : llvm::enumerate(yieldOp.operands())) { + Value val = it.value(); + if (val.getType().isa()) { + newYieldValues.push_back(rewriter.create( + yieldOp.getLoc(), newResultTypes[it.index()], val)); + } else { + newYieldValues.push_back(val); + } + } + rewriter.replaceOpWithNewOp(yieldOp, + newYieldValues); + + // Update all uses of the old op. + rewriter.setInsertionPointAfter(newOp); + SmallVector newResults; + for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) { + if (it.value().isa()) { + newResults.push_back(rewriter.create( + assumingOp.getLoc(), newOp->getResult(it.index()))); + } else { + newResults.push_back(newOp->getResult(it.index())); + } + } + + // Replace old op. + rewriter.replaceOp(assumingOp, newResults); + + return success(); + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const BufferizationState &state) const { + return BufferRelation::Equivalent; + } +}; + +/// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing +/// ops, so this is for analysis only. +struct AssumingYieldOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return false; + } + + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + assert(isa(op->getParentOp()) && + "expected that parent is an AssumingOp"); + return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; + } + + bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + // Yield operands always bufferize inplace. Otherwise, an alloc + copy + // may be generated inside the block. We should not return/yield allocations + // when possible. + return true; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationState &state) const { + // Op is bufferized as part of AssumingOp. + return failure(); + } +}; + +} // namespace +} // namespace shape +} // namespace mlir + +void mlir::shape::registerBufferizableOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addOpInterface(); + registry.addOpInterface(); +} diff --git a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp @@ -8,30 +8,32 @@ #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "PassDetail.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Shape/Transforms/Passes.h" #include "mlir/Pass/Pass.h" using namespace mlir; +using namespace bufferization; namespace { struct ShapeBufferizePass : public ShapeBufferizeBase { void runOnOperation() override { - MLIRContext &ctx = getContext(); + BufferizationOptions options = getPartialBufferizationOptions(); + options.allowDialectInFilter(); - RewritePatternSet patterns(&ctx); - bufferization::BufferizeTypeConverter typeConverter; - ConversionTarget target(ctx); - - bufferization::populateBufferizeMaterializationLegality(target); - populateShapeStructuralTypeConversionsAndLegality(typeConverter, patterns, - target); - - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) + if (failed(bufferizeOp(getOperation(), options))) signalPassFailure(); } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + shape::registerBufferizableOpInterfaceExternalModels(registry); + } }; } // namespace diff --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRShapeOpsTransforms + BufferizableOpInterfaceImpl.cpp Bufferize.cpp RemoveShapeConstraints.cpp ShapeToShapeLowering.cpp @@ -14,6 +15,7 @@ target_link_libraries(MLIRShapeOpsTransforms PUBLIC MLIRArithmetic + MLIRBufferization MLIRBufferizationTransforms MLIRIR MLIRMemRef diff --git a/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp +++ /dev/null @@ -1,70 +0,0 @@ -//===- StructuralTypeConversions.cpp - Shape structural type conversions --===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "PassDetail.h" -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Dialect/Shape/Transforms/Passes.h" -#include "mlir/Transforms/DialectConversion.h" - -using namespace mlir; -using namespace mlir::shape; - -namespace { -class ConvertAssumingOpTypes : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(AssumingOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - SmallVector newResultTypes; - newResultTypes.reserve(op.getNumResults()); - for (auto result : op.getResults()) { - auto originalType = result.getType(); - Type convertedType = getTypeConverter()->convertType(originalType); - newResultTypes.push_back(convertedType); - } - - auto newAssumingOp = rewriter.create( - op.getLoc(), newResultTypes, op.getWitness()); - rewriter.inlineRegionBefore(op.getDoRegion(), newAssumingOp.getDoRegion(), - newAssumingOp.getDoRegion().end()); - rewriter.replaceOp(op, newAssumingOp.getResults()); - - return success(); - } -}; -} // namespace - -namespace { -class ConvertAssumingYieldOpTypes - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(AssumingYieldOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); - return success(); - } -}; -} // namespace - -void mlir::populateShapeStructuralTypeConversionsAndLegality( - TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target) { - patterns.add( - typeConverter, patterns.getContext()); - target.addDynamicallyLegalOp([&](AssumingOp op) { - return typeConverter.isLegal(op.getResultTypes()); - }); - target.addDynamicallyLegalOp([&](AssumingYieldOp op) { - return typeConverter.isLegal(op.getOperandTypes()); - }); -} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2713,7 +2713,10 @@ "lib/Dialect/Shape/Transforms/*.cpp", "lib/Dialect/Shape/Transforms/*.h", ]), - hdrs = ["include/mlir/Dialect/Shape/Transforms/Passes.h"], + hdrs = [ + "include/mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h", + "include/mlir/Dialect/Shape/Transforms/Passes.h", + ], includes = ["include"], deps = [ ":ArithmeticDialect",