diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -699,6 +699,20 @@ return getBlock()->getArgument(opOperand->getOperandNumber()); }] >, + InterfaceMethod< + /*desc=*/[{ + Return the operand for a `blockArgument`. + }], + /*retTy=*/"OpOperand *", + /*methodName=*/"getTiedOpOperand", + /*args=*/(ins "BlockArgument":$blockArgument), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(blockArgument.getOwner() == getBlock()); + return &this->getOperation()->getOpOperand( + blockArgument.getArgNumber()); + }] + >, InterfaceMethod< /*desc=*/[{ Return the input or output indexing map for `opOperand`. diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -45,6 +45,10 @@ void populatePadTensorTilingPatterns(RewritePatternSet &patterns, const LinalgTilingOptions &options); +/// Populate patterns for splitting a `LinalgOp` with multiple statements within +/// its payload into multiple `GenericOp` that have a single statement. +void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns); + /// Populate patterns for vectorizing low-D convolution ops. This is a step in /// progressive lowering for convolution ops, it assume high-D convolution ops /// were decomposed previously. diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ Bufferize.cpp CodegenStrategy.cpp ConstantFold.cpp + DecomposeLinalgOps.cpp Detensorize.cpp DropUnitDims.cpp ElementwiseOpFusion.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp @@ -0,0 +1,252 @@ +//===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" + +using namespace mlir; +using namespace mlir::linalg; + +namespace { + +/// Pattern to decompose a GenericOp that has more than two statements +/// into one GenericOp with the first statement (i.e. peeled operation), and +/// a second GenericOp with the remaining statements (i.e. residual operations). + +/// - The result of the first GenericOp has the same shape as the iteration +/// space of the GenericOp. The number of results is same as the number of +/// results of the peeled operation. The element type of the results of the +/// first operation is same as the type of the results of the peeled +/// operation. +/// - The residual operation gets new `ins` operands that are the result of the +/// peeled operation. If any operand of the residual operation becomes dead it +/// is expected to be dropped by further canonicalization. +/// - If the result of the peeled operation was yielded by the original +/// GenericOp the uses of the corresponding results will be replaced with the +/// result of the first GenericOp created. +struct DecomposeLinalgOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override; + +private: + /// Helper method to create a generic op for the peeled scalar operation. The + /// created op has an empty region. + GenericOp createPeeledGenericOp(GenericOp genericOp, + PatternRewriter &rewriter) const; + + /// Helper method to create a generic op for the residual scalar operation. + /// The created op has the same region as the original op. + GenericOp createResidualGenericOp(GenericOp genericOp, + GenericOp peeledGenericOp, + PatternRewriter &rewriter) const; +}; +} // namespace + +/// Helper method to compute the range of a generic op. +static SmallVector getGenericOpLoopRange(OpBuilder &b, + GenericOp op) { + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); + Location loc = op.getLoc(); + auto allShapesSizes = + cast(op.getOperation()).createFlatListOfOperandDims(b, loc); + AffineMap map = op.getShapesToLoopsMap(); + return getAsOpFoldResult(applyMapToValues(b, loc, map, allShapesSizes)); +} + +GenericOp +DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp, + PatternRewriter &rewriter) const { + Block *body = genericOp.getBody(); + Operation *peeledScalarOperation = &(*body->begin()); + SmallVector peeledGenericOpIndexingMaps; + for (auto operand : genericOp.getInputOperands()) { + peeledGenericOpIndexingMaps.push_back( + genericOp.getTiedIndexingMap(operand)); + } + + /// Compute the loop ranges for operation. This is the shape of the result of + /// the generic op for the peeled operation. + Location loc = genericOp.getLoc(); + SmallVector resultShape = + getGenericOpLoopRange(rewriter, genericOp); + SmallVector initValues; + SmallVector resultTypes; + AffineMap identityMap = rewriter.getMultiDimIdentityMap(resultShape.size()); + for (auto scalarResult : peeledScalarOperation->getResults()) { + Value initTensor = rewriter.create( + loc, resultShape, scalarResult.getType()); + initValues.push_back(initTensor); + resultTypes.push_back(initTensor.getType()); + peeledGenericOpIndexingMaps.push_back(identityMap); + } + + /// Create the peeled generic op with an empty body. + auto indexingMapAttr = + rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps); + return rewriter.create( + loc, resultTypes, genericOp.inputs(), initValues, indexingMapAttr, + genericOp.iterator_types(), /*doc=*/nullptr, /*libraryCall=*/nullptr, + [](OpBuilder, Location, ValueRange) {}); +} + +GenericOp +DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp, + GenericOp peeledGenericOp, + PatternRewriter &rewriter) const { + /// Append all results from the peeledGenericOps as `ins` operand for the + /// residual generic op. + SmallVector residualGenericOpOperands = llvm::to_vector( + llvm::map_range(genericOp.getInputOperands(), + [](OpOperand *operand) { return operand->get(); })); + auto peeledGenericOpResults = llvm::to_vector(llvm::map_range( + peeledGenericOp.getResults(), [](OpResult r) -> Value { return r; })); + residualGenericOpOperands.append(peeledGenericOpResults); + + /// Add identity maps for the newly added operands. + AffineMap identityMap = + rewriter.getMultiDimIdentityMap(genericOp.getNumLoops()); + auto indexingMaps = llvm::to_vector( + llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *operand) { + return genericOp.getTiedIndexingMap(operand); + })); + indexingMaps.resize(indexingMaps.size() + peeledGenericOpResults.size(), + identityMap); + for (OpOperand *outOperand : genericOp.getOutputOperands()) + indexingMaps.push_back(genericOp.getTiedIndexingMap(outOperand)); + + auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps); + return rewriter.create( + genericOp->getLoc(), genericOp->getResultTypes(), + residualGenericOpOperands, genericOp.outputs(), indexingMapAttr, + genericOp.iterator_types(), /*doc=*/nullptr, /*libraryCall=*/nullptr, + [](OpBuilder, Location, ValueRange) {}); +} + +LogicalResult +DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const { + /// For now only match on operations where the iterator types are all parallel + if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) { + return rewriter.notifyMatchFailure(genericOp, + "unhandled decomposition of operation " + "with non-parallel iterator types"); + } + // TODO: this could be generalized to handle `linalg.generic` with buffer + // operands too but requires allocation for intermediates. Punt on this for + // now. + if (!genericOp.hasTensorSemantics()) { + return rewriter.notifyMatchFailure( + genericOp, "only operations with tensor semantics are handled"); + } + + if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) { + return genericOp.payloadUsesValueFromOperand(outOperand); + })) { + return rewriter.notifyMatchFailure( + genericOp, "unhandled decomposition of generic op with use of out " + "operand value in payload"); + } + + /// If the op has only a single statement (apart from the yield), do nothing. + Block *body = genericOp.getBody(); + if (body->getOperations().size() <= 2) { + return rewriter.notifyMatchFailure(genericOp, + "body has less than 3 statements"); + } + + GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter); + GenericOp residualGenericOp = + createResidualGenericOp(genericOp, peeledGenericOp, rewriter); + + /// Move the first statement of the original operation into the body of the + /// generic op for the peeled operation. + Block *peeledGenericOpBody = peeledGenericOp.getBody(); + Block *residualGenericOpBody = residualGenericOp.getBody(); + assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() && + "expected split generic ops to have empty region"); + peeledGenericOpBody->getOperations().splice( + peeledGenericOpBody->begin(), body->getOperations(), body->begin()); + residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(), + body->getOperations()); + + Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin()); + { + // Yield all the result of the peeled scalar operation. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointToEnd(peeledGenericOpBody); + rewriter.create(genericOp.getLoc(), + peeledScalarOperation->getResults()); + } + + /// In the split operations, replace block arguments uses that refer to + /// original operation to the block arguments of the newly created operation. + unsigned origNumInputs = genericOp.getNumInputs(); + for (auto inputBlockArg : + llvm::enumerate(genericOp.getBody()->getArguments())) { + Value residualOpReplacementArg = + residualGenericOpBody->getArgument(inputBlockArg.index()); + inputBlockArg.value().replaceUsesWithIf( + residualOpReplacementArg, [&](OpOperand &use) { + return use.getOwner()->getBlock() == residualGenericOpBody; + }); + + Value peeledOpReplacementArg = + peeledGenericOpBody->getArgument(inputBlockArg.index()); + inputBlockArg.value().replaceUsesWithIf( + peeledOpReplacementArg, [&](OpOperand &use) { + return use.getOwner()->getBlock() == peeledGenericOpBody; + }); + } + + /// Before fixing up the residual operation, track what values are yielded. If + /// any of those are from the peeled scalar operation, the uses of the + /// corresponding result have to be remapped to result of the generic op for + /// the peeled operation. + auto yieldOp = residualGenericOpBody->getTerminator(); + SmallVector replacements; + for (auto yieldValue : llvm::enumerate(yieldOp->getOperands())) { + OpResult opr = yieldValue.value().dyn_cast(); + if (!opr || opr.getOwner() != peeledScalarOperation) { + replacements.push_back(residualGenericOp.getResult(yieldValue.index())); + } else { + replacements.push_back(peeledGenericOp->getResult(opr.getResultNumber())); + } + } + + /// Update all uses of the peeled scalar operation results in the residual op + /// with the newly added arguments. + { + SmallVector scalarReplacements; + unsigned peeledGenericOpNumResults = peeledGenericOp->getNumResults(); + scalarReplacements.reserve(peeledGenericOpNumResults); + for (auto num : llvm::seq(0, peeledGenericOpNumResults)) + scalarReplacements.push_back( + residualGenericOpBody->getArgument(num + origNumInputs)); + bool allUsesReplaced = false; + rewriter.replaceOpWithinBlock(peeledScalarOperation, scalarReplacements, + residualGenericOpBody, &allUsesReplaced); + assert(!allUsesReplaced && + "peeled scalar operation is erased when it wasnt expected to be"); + + // Replace + } + + // Replace the original operation + rewriter.replaceOp(genericOp, replacements); + return success(); +} + +void mlir::linalg::populateDecomposeLinalgOpsPattern( + RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Linalg/decompose-ops.mlir b/mlir/test/Dialect/Linalg/decompose-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/decompose-ops.mlir @@ -0,0 +1,22 @@ +func.func @simple_op(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) + -> (tensor, tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %init1 = linalg.init_tensor [%d1, %d0] : tensor + %init2 = linalg.init_tensor [%d0, %d1] : tensor + %result:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1, %arg2 : tensor, tensor, tensor) + outs(%init1, %init2 : tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32, %b4 : f32) : + %0 = arith.addf %b0, %b1 : f32 + %1 = arith.mulf %0, %b2 : f32 + linalg.yield %0, %1 : f32, f32 + } -> (tensor, tensor) + return %result#0, %result#1 : tensor, tensor +} diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -1,6 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRLinalgTestPasses TestLinalgCodegenStrategy.cpp + TestLinalgDecomposeOps.cpp TestLinalgElementwiseFusion.cpp TestLinalgFusionTransforms.cpp TestLinalgHoisting.cpp diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp @@ -0,0 +1,56 @@ +//===- TestLinalgDecomposeOps.cpp - Test Linalg decomposition ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass for testing decomposition of Linalg ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +struct TestLinalgDecomposeOps + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgDecomposeOps) + + TestLinalgDecomposeOps() = default; + TestLinalgDecomposeOps(const TestLinalgDecomposeOps &pass) + : PassWrapper(pass) {} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + StringRef getArgument() const final { + return "test-linalg-decompose-ops-pattern"; + } + StringRef getDescription() const final { + return "Test Linalg decomposition patterns"; + } + + void runOnOperation() override { + MLIRContext *context = &this->getContext(); + RewritePatternSet decompositionPatterns(context); + linalg::populateDecomposeLinalgOpsPattern(decompositionPatterns); + if (failed(applyPatternsAndFoldGreedily( + getOperation(), std::move(decompositionPatterns)))) { + return signalPassFailure(); + } + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerTestLinalgDecomposeOps() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -87,6 +87,7 @@ void registerTestInterfaces(); void registerTestLastModifiedPass(); void registerTestLinalgCodegenStrategy(); +void registerTestLinalgDecomposeOps(); void registerTestLinalgElementwiseFusion(); void registerTestLinalgFusionTransforms(); void registerTestLinalgTensorFusionTransforms(); @@ -186,6 +187,7 @@ mlir::test::registerTestInterfaces(); mlir::test::registerTestLastModifiedPass(); mlir::test::registerTestLinalgCodegenStrategy(); + mlir::test::registerTestLinalgDecomposeOps(); mlir::test::registerTestLinalgElementwiseFusion(); mlir::test::registerTestLinalgFusionTransforms(); mlir::test::registerTestLinalgTensorFusionTransforms();