diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -699,6 +699,20 @@ return getBlock()->getArgument(opOperand->getOperandNumber()); }] >, + InterfaceMethod< + /*desc=*/[{ + Return the operand for a `blockArgument`. + }], + /*retTy=*/"OpOperand *", + /*methodName=*/"getTiedOpOperand", + /*args=*/(ins "BlockArgument":$blockArgument), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(blockArgument.getOwner() == getBlock()); + return &this->getOperation()->getOpOperand( + blockArgument.getArgNumber()); + }] + >, InterfaceMethod< /*desc=*/[{ Return the input or output indexing map for `opOperand`. diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -45,6 +45,10 @@ void populatePadTensorTilingPatterns(RewritePatternSet &patterns, const LinalgTilingOptions &options); +/// Populate patterns for splitting a `LinalgOp` with multiple statements within +/// its payload into multiple `GenericOp` that have a single statement. +void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns); + /// Populate patterns for vectorizing low-D convolution ops. This is a step in /// progressive lowering for convolution ops, it assume high-D convolution ops /// were decomposed previously. diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -94,11 +94,16 @@ bool operator!() const { return impl == nullptr; } - template bool isa() const; - template bool isa() const; - template U dyn_cast() const; - template U dyn_cast_or_null() const; - template U cast() const; + template + bool isa() const; + template + bool isa() const; + template + U dyn_cast() const; + template + U dyn_cast_or_null() const; + template + U cast() const; // Support type casting Type to itself. static bool classof(Type) { return true; } @@ -124,6 +129,7 @@ bool isF128() const; /// Return true if this is an integer type with the specified width. + bool isInteger() const; bool isInteger(unsigned width) const; /// Return true if this is a signless integer type (with the specified width). bool isSignlessInteger() const; @@ -243,7 +249,8 @@ return DenseMapInfo::getHashValue(arg.impl); } -template bool Type::isa() const { +template +bool Type::isa() const { assert(impl && "isa<> used on a null type."); return U::classof(*this); } @@ -253,13 +260,16 @@ return isa() || isa(); } -template U Type::dyn_cast() const { +template +U Type::dyn_cast() const { return isa() ? U(impl) : U(nullptr); } -template U Type::dyn_cast_or_null() const { +template +U Type::dyn_cast_or_null() const { return (impl && isa()) ? U(impl) : U(nullptr); } -template U Type::cast() const { +template +U Type::cast() const { assert(isa()); return U(impl); } @@ -269,7 +279,8 @@ namespace llvm { // Type hash just like pointers. -template <> struct DenseMapInfo { +template <> +struct DenseMapInfo { static mlir::Type getEmptyKey() { auto *pointer = llvm::DenseMapInfo::getEmptyKey(); return mlir::Type(static_cast(pointer)); @@ -296,7 +307,8 @@ }; /// We align TypeStorage by 8, so allow LLVM to steal the low bits. -template <> struct PointerLikeTypeTraits { +template <> +struct PointerLikeTypeTraits { public: static inline void *getAsVoidPointer(mlir::Type I) { return const_cast(I.getAsOpaquePointer()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ Bufferize.cpp CodegenStrategy.cpp ConstantFold.cpp + DecomposeLinalgOps.cpp Detensorize.cpp DropUnitDims.cpp ElementwiseOpFusion.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp @@ -0,0 +1,295 @@ +//===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" + +using namespace mlir; +using namespace mlir::linalg; + +namespace { + +/// Pattern to decompose a GenericOp that has more than two statements +/// into one GenericOp with the first statement (i.e. peeled operation), and +/// a second GenericOp with the remaining statements (i.e. residual operations). + +/// - The result of the first GenericOp has the same shape as the iteration +/// space of the GenericOp. The number of results is same as the number of +/// results of the peeled operation. The element type of the results of the +/// first operation is same as the type of the results of the peeled +/// operation. +/// - The residual operation gets new `ins` operands that are the result of the +/// peeled operation. If any operand of the residual operation becomes dead it +/// is expected to be dropped by further canonicalization. +/// - If the result of the peeled operation was yielded by the original +/// GenericOp the uses of the corresponding results will be replaced with the +/// result of the first GenericOp created. +struct DecomposeLinalgOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override; + +private: + /// Helper method to create a generic op for the peeled scalar operation. The + /// created op has an empty region. + GenericOp createPeeledGenericOp(GenericOp genericOp, + PatternRewriter &rewriter) const; + + /// Helper method to create a generic op for the residual scalar operation. + /// The created op has the same region as the original op. + GenericOp createResidualGenericOp(GenericOp genericOp, + GenericOp peeledGenericOp, + PatternRewriter &rewriter) const; +}; +} // namespace + +/// Helper method to compute the range of a generic op. +static SmallVector getGenericOpLoopRange(OpBuilder &b, + GenericOp op) { + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); + Location loc = op.getLoc(); + auto allShapesSizes = + cast(op.getOperation()).createFlatListOfOperandDims(b, loc); + AffineMap map = op.getShapesToLoopsMap(); + return getAsOpFoldResult(applyMapToValues(b, loc, map, allShapesSizes)); +} + +/// Permute the values in the `vector` based on the permutation specified in the +/// `map` +static Value getZero(OpBuilder &b, Location loc, Type elementType) { + assert(elementType.isIntOrIndexOrFloat() && + "expected scalar type while computing zero value"); + if (elementType.isInteger()) { + return b.create(loc, 0, elementType); + } + if (elementType.isIndex()) { + return b.create(loc, 0); + } + // Assume float. + auto floatType = elementType.cast(); + return b.create( + loc, APFloat::getZero(floatType.getFloatSemantics()), floatType); +} + +GenericOp +DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp, + PatternRewriter &rewriter) const { + Block *body = genericOp.getBody(); + Operation *peeledScalarOperation = &(*body->begin()); + SmallVector peeledGenericOpIndexingMaps = + genericOp.getIndexingMaps(); + + /// Compute the loop ranges for operation. This is the shape of the result of + /// the generic op for the peeled operation. + Location loc = genericOp.getLoc(); + SmallVector domain = getGenericOpLoopRange(rewriter, genericOp); + SmallVector newInitValues; + SmallVector newResultTypes; + AffineMap identityMap = rewriter.getMultiDimIdentityMap(domain.size()); + for (auto scalarResult : peeledScalarOperation->getResults()) { + Value initTensor = rewriter.create( + loc, domain, scalarResult.getType()); + newInitValues.push_back(initTensor); + newResultTypes.push_back(initTensor.getType()); + peeledGenericOpIndexingMaps.push_back(identityMap); + } + + /// Create the peeled generic op with an empty body. + SmallVector outsOperands = genericOp.getOutputOperands(); + outsOperands.append(newInitValues.begin(), newInitValues.end()); + SmallVector resultTypes = llvm::to_vector(genericOp.getResultTypes()); + resultTypes.append(newResultTypes.begin(), newResultTypes.end()); + auto indexingMapAttr = + rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps); + return rewriter.create( + loc, resultTypes, genericOp.inputs(), outsOperands, indexingMapAttr, + genericOp.iterator_types(), /*doc=*/nullptr, /*libraryCall=*/nullptr, + [](OpBuilder, Location, ValueRange) {}); +} + +GenericOp +DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp, + GenericOp peeledGenericOp, + PatternRewriter &rewriter) const { + /// Append all results from the peeledGenericOps as `ins` operand for the + /// residual generic op. + SmallVector residualGenericOpOperands = llvm::to_vector( + llvm::map_range(genericOp.getInputOperands(), + [](OpOperand *operand) { return operand->get(); })); + unsigned origNumResults = genericOp.getNumResults(); + SmallVector extraIns; + for (auto resultNum : + llvm::seq(origNumResults, peeledGenericOp.getNumResults())) + extraIns.push_back(peeledGenericOp->getResult(resultNum)); + residualGenericOpOperands.append(extraIns); + + /// Add identity maps for the newly added operands. + AffineMap identityMap = + rewriter.getMultiDimIdentityMap(genericOp.getNumLoops()); + auto indexingMaps = llvm::to_vector( + llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *operand) { + return genericOp.getTiedIndexingMap(operand); + })); + indexingMaps.resize(indexingMaps.size() + extraIns.size(), identityMap); + for (OpOperand *outOperand : genericOp.getOutputOperands()) + indexingMaps.push_back(genericOp.getTiedIndexingMap(outOperand)); + + auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps); + return rewriter.create( + genericOp->getLoc(), genericOp->getResultTypes(), + residualGenericOpOperands, genericOp.outputs(), indexingMapAttr, + genericOp.iterator_types(), /*doc=*/nullptr, /*libraryCall=*/nullptr, + [](OpBuilder, Location, ValueRange) {}); +} + +LogicalResult +DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const { + /// For now only match on operations where the iterator types are all parallel + if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) { + return rewriter.notifyMatchFailure(genericOp, + "unhandled decomposition of operation " + "with non-parallel iterator types"); + } + // TODO: this could be generalized to handle `linalg.generic` with buffer + // operands too but requires allocation for intermediates. Punt on this for + // now. + if (!genericOp.hasTensorSemantics()) { + return rewriter.notifyMatchFailure( + genericOp, "only operations with tensor semantics are handled"); + } + + if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) { + return genericOp.payloadUsesValueFromOperand(outOperand); + })) { + return rewriter.notifyMatchFailure( + genericOp, "unhandled decomposition of generic op with use of out " + "operand value in payload"); + } + + if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) { + return !genericOp.getTiedIndexingMap(outOperand).isPermutation(); + })) { + return rewriter.notifyMatchFailure( + genericOp, "unhandled decomposition of generic op with out operand not " + "accessed using a permutation"); + } + + /// If the op has only a single statement (apart from the yield), do nothing. + Block *body = genericOp.getBody(); + if (body->getOperations().size() <= 2) { + return rewriter.notifyMatchFailure(genericOp, + "body has less than 3 statements"); + } + + /// Check that the peeled statement has a scalar element type. + if (llvm::any_of(body->getOperations().begin()->getResultTypes(), + [](Type t) { return !t.isIntOrIndexOrFloat(); })) { + return rewriter.notifyMatchFailure( + &(*body->getOperations().begin()), + "expected return type to be only int, index or float"); + } + + GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter); + GenericOp residualGenericOp = + createResidualGenericOp(genericOp, peeledGenericOp, rewriter); + + /// Move the first statement of the original operation into the body of the + /// generic op for the peeled operation. + Block *peeledGenericOpBody = peeledGenericOp.getBody(); + Block *residualGenericOpBody = residualGenericOp.getBody(); + assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() && + "expected split generic ops to have empty region"); + peeledGenericOpBody->getOperations().splice( + peeledGenericOpBody->begin(), body->getOperations(), body->begin()); + residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(), + body->getOperations()); + + Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin()); + auto yieldOp = residualGenericOpBody->getTerminator(); + { + // Yield all the result of the peeled scalar operation. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointToEnd(peeledGenericOpBody); + SmallVector yieldedVals; + for (auto origYield : yieldOp->getOperands()) { + if (origYield.getDefiningOp() == peeledScalarOperation) { + yieldedVals.push_back(origYield); + } else { + yieldedVals.push_back( + getZero(rewriter, genericOp.getLoc(), origYield.getType())); + } + } + yieldedVals.append(llvm::to_vector( + llvm::map_range(peeledScalarOperation->getResults(), + [](OpResult opr) -> Value { return opr; }))); + rewriter.create(genericOp.getLoc(), yieldedVals); + } + + /// In the split operations, replace block arguments uses that refer to + /// original operation to the block arguments of the newly created operation. + unsigned origNumInputs = genericOp.getNumInputs(); + for (auto inputBlockArg : + llvm::enumerate(genericOp.getBody()->getArguments())) { + Value residualOpReplacementArg = + residualGenericOpBody->getArgument(inputBlockArg.index()); + inputBlockArg.value().replaceUsesWithIf( + residualOpReplacementArg, [&](OpOperand &use) { + return use.getOwner()->getBlock() == residualGenericOpBody; + }); + + Value peeledOpReplacementArg = + peeledGenericOpBody->getArgument(inputBlockArg.index()); + inputBlockArg.value().replaceUsesWithIf( + peeledOpReplacementArg, [&](OpOperand &use) { + return use.getOwner()->getBlock() == peeledGenericOpBody; + }); + } + + /// Before fixing up the residual operation, track what values are yielded. If + /// any of those are from the peeled scalar operation, the uses of the + /// corresponding result have to be remapped to result of the generic op for + /// the peeled operation. + SmallVector replacements; + for (auto yieldValue : llvm::enumerate(yieldOp->getOperands())) { + OpResult opr = yieldValue.value().dyn_cast(); + if (!opr || opr.getOwner() != peeledScalarOperation) + replacements.push_back(residualGenericOp.getResult(yieldValue.index())); + else + replacements.push_back(peeledGenericOp->getResult(yieldValue.index())); + } + + /// Update all uses of the peeled scalar operation results in the residual op + /// with the newly added arguments. + { + SmallVector scalarReplacements; + unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults(); + scalarReplacements.reserve(peeledScalarOpNumResults); + for (auto num : llvm::seq(0, peeledScalarOpNumResults)) + scalarReplacements.push_back( + residualGenericOpBody->getArgument(num + origNumInputs)); + bool allUsesReplaced = false; + rewriter.replaceOpWithinBlock(peeledScalarOperation, scalarReplacements, + residualGenericOpBody, &allUsesReplaced); + assert(!allUsesReplaced && + "peeled scalar operation is erased when it wasnt expected to be"); + } + + // Replace the original operation + rewriter.replaceOp(genericOp, replacements); + return success(); +} + +void mlir::linalg::populateDecomposeLinalgOpsPattern( + RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -27,6 +27,7 @@ bool Type::isIndex() const { return isa(); } +bool Type::isInteger() const { return isa(); } /// Return true if this is an integer type with the specified width. bool Type::isInteger(unsigned width) const { if (auto intTy = dyn_cast()) diff --git a/mlir/test/Dialect/Linalg/decompose-ops.mlir b/mlir/test/Dialect/Linalg/decompose-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/decompose-ops.mlir @@ -0,0 +1,90 @@ +// RUN: mlir-opt -test-linalg-decompose-ops -cse -split-input-file %s | FileCheck %s +// RUN: mlir-opt -test-linalg-decompose-ops -canonicalize -cse -split-input-file %s | FileCheck %s --check_prefix=WITHCANONICALIZE + +func.func @simple_op(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) + -> (tensor, tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %init1 = linalg.init_tensor [%d1, %d0] : tensor + %init2 = linalg.init_tensor [%d0, %d1] : tensor + %result:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1, %arg2 : tensor, tensor, tensor) + outs(%init1, %init2 : tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32, %b4 : f32) : + %0 = arith.addf %b0, %b1 : f32 + %1 = arith.mulf %0, %b2 : f32 + linalg.yield %0, %1 : f32, f32 + } -> (tensor, tensor) + return %result#0, %result#1 : tensor, tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d1)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK: func @simple_op( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [%[[D1]], %[[D0]]] +// CHECK-DAG: %[[INIT2:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] +// CHECK-DAG: %[[GENERIC1:.+]] = linalg.generic +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]] +// CHECK-SAME: ["parallel", "paralllel"] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : +// CHECK-SAME: outs(%[[INIT2]] : +// CHECK-NEXT: ^bb0( +// CHECK-SAME: %[[B0:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B1:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B2:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B3:[a-zA-Z0-9]+]]: f32): +// CHECK-NEXT: %[[S0:.+]] = arith.addf %[[B0]], %[[B1]] +// CHECK-NEXT: linalg.yield %[[S0]] +// CHECK: %[[GENERIC2:.+]]:2 = linalg.generic +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]], #[[MAP3]], #[[MAP0]]] +// CHECK-SAME: ["parallel", "paralllel"] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[GENERIC1]] : +// CHECK-SAME: outs(%[[INIT1]], %[[INIT2]] : +// CHECK-NEXT: ^bb0( +// CHECK-SAME: %[[B4:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B5:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B6:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B7:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B8:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B9:[a-zA-Z0-9]+]]: f32): +// CHECK-NEXT: %[[S1:.+]] = arith.mulf %[[B7]], %[[B6]] +// CHECK-NEXT: linalg.yield %[[S1]] +// CHECK: return %[[GENERIC1]], %[[GENERIC2]]#1 + +// ----- + +func.func @simple_op(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) + -> (tensor, tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %init2 = linalg.init_tensor [%d0, %d1] : tensor + %result:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1, %arg2 : tensor, tensor, tensor) + outs(%init2, %init2 : tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32, %b4 : f32) : + %0 = arith.addf %b0, %b1 : f32 + %1 = arith.mulf %0, %b2 : f32 + linalg.yield %0, %1 : f32, f32 + } -> (tensor, tensor) + return %result#0, %result#1 : tensor, tensor +} diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -1,6 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRLinalgTestPasses TestLinalgCodegenStrategy.cpp + TestLinalgDecomposeOps.cpp TestLinalgElementwiseFusion.cpp TestLinalgFusionTransforms.cpp TestLinalgHoisting.cpp diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp @@ -0,0 +1,54 @@ +//===- TestLinalgDecomposeOps.cpp - Test Linalg decomposition ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass for testing decomposition of Linalg ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +struct TestLinalgDecomposeOps + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgDecomposeOps) + + TestLinalgDecomposeOps() = default; + TestLinalgDecomposeOps(const TestLinalgDecomposeOps &pass) + : PassWrapper(pass) {} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + StringRef getArgument() const final { return "test-linalg-decompose-ops"; } + StringRef getDescription() const final { + return "Test Linalg decomposition patterns"; + } + + void runOnOperation() override { + MLIRContext *context = &this->getContext(); + RewritePatternSet decompositionPatterns(context); + linalg::populateDecomposeLinalgOpsPattern(decompositionPatterns); + if (failed(applyPatternsAndFoldGreedily( + getOperation(), std::move(decompositionPatterns)))) { + return signalPassFailure(); + } + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerTestLinalgDecomposeOps() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -87,6 +87,7 @@ void registerTestInterfaces(); void registerTestLastModifiedPass(); void registerTestLinalgCodegenStrategy(); +void registerTestLinalgDecomposeOps(); void registerTestLinalgElementwiseFusion(); void registerTestLinalgFusionTransforms(); void registerTestLinalgTensorFusionTransforms(); @@ -186,6 +187,7 @@ mlir::test::registerTestInterfaces(); mlir::test::registerTestLastModifiedPass(); mlir::test::registerTestLinalgCodegenStrategy(); + mlir::test::registerTestLinalgDecomposeOps(); mlir::test::registerTestLinalgElementwiseFusion(); mlir::test::registerTestLinalgFusionTransforms(); mlir::test::registerTestLinalgTensorFusionTransforms();