diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -699,6 +699,20 @@ return getBlock()->getArgument(opOperand->getOperandNumber()); }] >, + InterfaceMethod< + /*desc=*/[{ + Return the operand for a `blockArgument`. + }], + /*retTy=*/"OpOperand *", + /*methodName=*/"getTiedOpOperand", + /*args=*/(ins "BlockArgument":$blockArgument), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(blockArgument.getOwner() == getBlock()); + return &this->getOperation()->getOpOperand( + blockArgument.getArgNumber()); + }] + >, InterfaceMethod< /*desc=*/[{ Return the input or output indexing map for `opOperand`. diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -45,6 +45,10 @@ void populatePadTensorTilingPatterns(RewritePatternSet &patterns, const LinalgTilingOptions &options); +/// Populate patterns for splitting a `LinalgOp` with multiple statements within +/// its payload into multiple `GenericOp` that have a single statement. +void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns); + /// Populate patterns for vectorizing low-D convolution ops. This is a step in /// progressive lowering for convolution ops, it assume high-D convolution ops /// were decomposed previously. diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -94,11 +94,16 @@ bool operator!() const { return impl == nullptr; } - template bool isa() const; - template bool isa() const; - template U dyn_cast() const; - template U dyn_cast_or_null() const; - template U cast() const; + template + bool isa() const; + template + bool isa() const; + template + U dyn_cast() const; + template + U dyn_cast_or_null() const; + template + U cast() const; // Support type casting Type to itself. static bool classof(Type) { return true; } @@ -124,6 +129,7 @@ bool isF128() const; /// Return true if this is an integer type with the specified width. + bool isInteger() const; bool isInteger(unsigned width) const; /// Return true if this is a signless integer type (with the specified width). bool isSignlessInteger() const; @@ -243,7 +249,8 @@ return DenseMapInfo::getHashValue(arg.impl); } -template bool Type::isa() const { +template +bool Type::isa() const { assert(impl && "isa<> used on a null type."); return U::classof(*this); } @@ -253,13 +260,16 @@ return isa() || isa(); } -template U Type::dyn_cast() const { +template +U Type::dyn_cast() const { return isa() ? U(impl) : U(nullptr); } -template U Type::dyn_cast_or_null() const { +template +U Type::dyn_cast_or_null() const { return (impl && isa()) ? U(impl) : U(nullptr); } -template U Type::cast() const { +template +U Type::cast() const { assert(isa()); return U(impl); } @@ -269,7 +279,8 @@ namespace llvm { // Type hash just like pointers. -template <> struct DenseMapInfo { +template <> +struct DenseMapInfo { static mlir::Type getEmptyKey() { auto *pointer = llvm::DenseMapInfo::getEmptyKey(); return mlir::Type(static_cast(pointer)); @@ -296,7 +307,8 @@ }; /// We align TypeStorage by 8, so allow LLVM to steal the low bits. -template <> struct PointerLikeTypeTraits { +template <> +struct PointerLikeTypeTraits { public: static inline void *getAsVoidPointer(mlir::Type I) { return const_cast(I.getAsOpaquePointer()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ Bufferize.cpp CodegenStrategy.cpp ConstantFold.cpp + DecomposeLinalgOps.cpp Detensorize.cpp DropUnitDims.cpp ElementwiseOpFusion.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp @@ -0,0 +1,391 @@ +//===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" + +using namespace mlir; +using namespace mlir::linalg; + +namespace { + +/// Pattern to decompose a GenericOp that has more than two statements +/// into one GenericOp with the first statement (i.e. peeled operation), and +/// a second GenericOp with the remaining statements (i.e. residual operations). + +/// - The result of the first GenericOp has the same shape as the iteration +/// space of the GenericOp. The body of the op yields as many values as the +/// original op plus all the results of the peeled operation. +/// - The second GenericOp has as many operands as the original operation plus +/// all the results of the first Generic Op. It has the same number of yields as +/// the original op. +/// - If the result of the peeled operation was yielded by the original +/// GenericOp the uses of the corresponding results will be replaced with the +/// result of the first GenericOp created. +/// +/// Example +/// +/// ```mlir +/// %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...) +/// outs(%init0, %init1 : ...) { +/// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...): +/// %0 = %b0, %b1 : ... +/// %1 = %0, %b2 : ... +/// linalg.yield %0, %1 : ... +/// } -> (..., ...) +/// return %result#0, %result#1 +/// ``` +/// +/// gets split into +/// +/// ```mlir +/// %init = linalg.init_tensor ... +/// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...) +/// outs(%init0, %init1, %init : ...) +/// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...): +/// %0 = %b0, %b1 : ... +/// linalg.yield %0, %..., %0 : ... +/// } -> (..., ..., ...) +/// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...) +/// outs(%init0, %init1 : ...) { +/// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...): +/// %1 = %b3, %b2 : ... +/// linalg.yield %..., %1 : ... +/// } -> (..., ...) +/// return %op0#0, %op1#1 +/// ``` +/// +/// After canonicalization this is expected to be +/// +/// ```mlir +/// %init = linalg.init_tensor ... +/// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...) +/// outs(%init : ...) +/// ^bb0(%b0: ... , %b1: ... , %b2: ...): +/// %0 = %b0, %b1 : ... +/// linalg.yield %0 : ... +/// } -> ... +/// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...) +/// outs(%init1 : ...) { +/// ^bb0(%b0: ... , %b1: ... , %b2: ...): +/// %1 = %b1, %b0 : ... +/// linalg.yield %..., %1 : ... +/// } -> ... +/// return %op0, %op1 +/// ``` +struct DecomposeLinalgOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override; + +private: + /// Helper method to create a generic op for the peeled scalar operation. The + /// created op has an empty region. + GenericOp createPeeledGenericOp(GenericOp genericOp, + PatternRewriter &rewriter) const; + + /// Helper method to create a generic op for the residual scalar operation. + /// The created op has the same region as the original op. + GenericOp createResidualGenericOp(GenericOp genericOp, + GenericOp peeledGenericOp, + PatternRewriter &rewriter) const; +}; +} // namespace + +/// Helper method to compute the range of a generic op. +static SmallVector getGenericOpLoopRange(OpBuilder &b, + GenericOp op) { + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); + Location loc = op.getLoc(); + auto allShapesSizes = + cast(op.getOperation()).createFlatListOfOperandDims(b, loc); + AffineMap map = op.getShapesToLoopsMap(); + return getAsOpFoldResult(applyMapToValues(b, loc, map, allShapesSizes)); +} + +/// Helper method to permute the list of `values` based on the `map`. +SmallVector permuteValues(ArrayRef values, + AffineMap map) { + assert(map.isPermutation()); + SmallVector permutedValues(values.size()); + for (auto position : + llvm::enumerate(llvm::map_range(map.getResults(), [](AffineExpr expr) { + return expr.cast().getPosition(); + }))) { + permutedValues[position.value()] = values[position.index()]; + } + return permutedValues; +} + +/// Get zero value for an element type. +static Value getZero(OpBuilder &b, Location loc, Type elementType) { + assert(elementType.isIntOrIndexOrFloat() && + "expected scalar type while computing zero value"); + if (elementType.isInteger()) { + return b.create(loc, 0, elementType); + } + if (elementType.isIndex()) { + return b.create(loc, 0); + } + // Assume float. + auto floatType = elementType.cast(); + return b.create( + loc, APFloat::getZero(floatType.getFloatSemantics()), floatType); +} + +GenericOp +DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp, + PatternRewriter &rewriter) const { + Block *body = genericOp.getBody(); + Operation *peeledScalarOperation = &(*body->begin()); + SmallVector peeledGenericOpIndexingMaps = + genericOp.getIndexingMaps(); + + /// Compute the loop ranges for operation. This is the shape of the result of + /// the generic op for the peeled operation. + Location loc = genericOp.getLoc(); + SmallVector domain = getGenericOpLoopRange(rewriter, genericOp); + SmallVector newInitValues; + SmallVector newResultTypes; + + /// The indexing map to use for the new results is obtained by + /// - Check if the result is yielded. If so use the same indexing map as the + /// corresponding output + /// - Identity indexing map if the result is not yielded. + Operation *yieldOp = body->getTerminator(); + auto getResultIndexingMap = [&](OpResult scalarOpResult) -> AffineMap { + OpOperand *firstUseInYield = nullptr, *identityUseInYield = nullptr; + for (OpOperand &use : scalarOpResult.getUses()) { + if (use.getOwner() != yieldOp) + continue; + if (!firstUseInYield) + firstUseInYield = &use; + OpResult genericOpResult = + genericOp.getResult(use.getOperandNumber()).cast(); + AffineMap indexingMap = + genericOp.getTiedIndexingMapForResult(genericOpResult); + if (indexingMap.isIdentity()) + identityUseInYield = &use; + } + if (identityUseInYield || !firstUseInYield) + return rewriter.getMultiDimIdentityMap(domain.size()); + OpResult genericOpResult = + genericOp.getResult(firstUseInYield->getOperandNumber()) + .cast(); + return genericOp.getTiedIndexingMapForResult(genericOpResult); + }; + + for (auto scalarResult : peeledScalarOperation->getResults()) { + AffineMap resultIndexingMap = getResultIndexingMap(scalarResult); + SmallVector initSize = + permuteValues(domain, resultIndexingMap); + Value initTensor = rewriter.create( + loc, initSize, scalarResult.getType()); + newInitValues.push_back(initTensor); + newResultTypes.push_back(initTensor.getType()); + peeledGenericOpIndexingMaps.push_back(resultIndexingMap); + } + + /// Create the peeled generic op with an empty body. + SmallVector outsOperands = genericOp.getOutputOperands(); + outsOperands.append(newInitValues.begin(), newInitValues.end()); + SmallVector resultTypes = llvm::to_vector(genericOp.getResultTypes()); + resultTypes.append(newResultTypes.begin(), newResultTypes.end()); + auto indexingMapAttr = + rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps); + return rewriter.create( + loc, resultTypes, genericOp.inputs(), outsOperands, indexingMapAttr, + genericOp.iterator_types(), /*doc=*/nullptr, /*libraryCall=*/nullptr, + [](OpBuilder, Location, ValueRange) {}); +} + +GenericOp +DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp, + GenericOp peeledGenericOp, + PatternRewriter &rewriter) const { + /// Append all results from the peeledGenericOps as `ins` operand for the + /// residual generic op. + SmallVector residualGenericOpOperands = llvm::to_vector( + llvm::map_range(genericOp.getInputOperands(), + [](OpOperand *operand) { return operand->get(); })); + unsigned origNumResults = genericOp.getNumResults(); + unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults(); + SmallVector extraIns; + for (auto resultNum : + llvm::seq(origNumResults, peeledGenericOpNumResults)) + extraIns.push_back(peeledGenericOp->getResult(resultNum)); + residualGenericOpOperands.append(extraIns); + + /// Add indexing maps for the newly added operands. Use the same map + /// as those used for the new results of the peeledGenericOp. + auto indexingMaps = llvm::to_vector( + llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *operand) { + return genericOp.getTiedIndexingMap(operand); + })); + for (auto resultNum : + llvm::seq(origNumResults, peeledGenericOpNumResults)) { + OpResult result = peeledGenericOp.getResult(resultNum).cast(); + indexingMaps.push_back(peeledGenericOp.getTiedIndexingMapForResult(result)); + } + for (OpOperand *outOperand : genericOp.getOutputOperands()) + indexingMaps.push_back(genericOp.getTiedIndexingMap(outOperand)); + + auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps); + return rewriter.create( + genericOp->getLoc(), genericOp->getResultTypes(), + residualGenericOpOperands, genericOp.outputs(), indexingMapAttr, + genericOp.iterator_types(), /*doc=*/nullptr, /*libraryCall=*/nullptr, + [](OpBuilder, Location, ValueRange) {}); +} + +LogicalResult +DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const { + /// For now only match on operations where the iterator types are all parallel + if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) { + return rewriter.notifyMatchFailure(genericOp, + "unhandled decomposition of operation " + "with non-parallel iterator types"); + } + // TODO: this could be generalized to handle `linalg.generic` with buffer + // operands too but requires allocation for intermediates. Punt on this for + // now. + if (!genericOp.hasTensorSemantics()) { + return rewriter.notifyMatchFailure( + genericOp, "only operations with tensor semantics are handled"); + } + + if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) { + return genericOp.payloadUsesValueFromOperand(outOperand); + })) { + return rewriter.notifyMatchFailure( + genericOp, "unhandled decomposition of generic op with use of out " + "operand value in payload"); + } + + if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) { + return !genericOp.getTiedIndexingMap(outOperand).isPermutation(); + })) { + return rewriter.notifyMatchFailure( + genericOp, "unhandled decomposition of generic op with out operand not " + "accessed using a permutation"); + } + + /// If the op has only a single statement (apart from the yield), do nothing. + Block *body = genericOp.getBody(); + if (body->getOperations().size() <= 2) { + return rewriter.notifyMatchFailure(genericOp, + "body has less than 3 statements"); + } + + /// Check that the peeled statement has a scalar element type. + if (llvm::any_of(body->getOperations().begin()->getResultTypes(), + [](Type t) { return !t.isIntOrIndexOrFloat(); })) { + return rewriter.notifyMatchFailure( + &(*body->getOperations().begin()), + "expected return type to be only int, index or float"); + } + + GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter); + GenericOp residualGenericOp = + createResidualGenericOp(genericOp, peeledGenericOp, rewriter); + + /// Move the first statement of the original operation into the body of the + /// generic op for the peeled operation. + Block *peeledGenericOpBody = peeledGenericOp.getBody(); + Block *residualGenericOpBody = residualGenericOp.getBody(); + assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() && + "expected split generic ops to have empty region"); + peeledGenericOpBody->getOperations().splice( + peeledGenericOpBody->begin(), body->getOperations(), body->begin()); + residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(), + body->getOperations()); + + Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin()); + auto yieldOp = residualGenericOpBody->getTerminator(); + { + // Yield all the result of the peeled scalar operation. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointToEnd(peeledGenericOpBody); + SmallVector yieldedVals; + for (auto origYield : yieldOp->getOperands()) { + if (origYield.getDefiningOp() == peeledScalarOperation) { + yieldedVals.push_back(origYield); + } else { + yieldedVals.push_back( + getZero(rewriter, genericOp.getLoc(), origYield.getType())); + } + } + yieldedVals.append(llvm::to_vector( + llvm::map_range(peeledScalarOperation->getResults(), + [](OpResult opr) -> Value { return opr; }))); + rewriter.create(genericOp.getLoc(), yieldedVals); + } + + /// In the split operations, replace block arguments uses that refer to + /// original operation to the block arguments of the newly created operation. + unsigned origNumInputs = genericOp.getNumInputs(); + for (auto inputBlockArg : + llvm::enumerate(genericOp.getBody()->getArguments())) { + Value residualOpReplacementArg = + residualGenericOpBody->getArgument(inputBlockArg.index()); + inputBlockArg.value().replaceUsesWithIf( + residualOpReplacementArg, [&](OpOperand &use) { + return use.getOwner()->getBlock() == residualGenericOpBody; + }); + + Value peeledOpReplacementArg = + peeledGenericOpBody->getArgument(inputBlockArg.index()); + inputBlockArg.value().replaceUsesWithIf( + peeledOpReplacementArg, [&](OpOperand &use) { + return use.getOwner()->getBlock() == peeledGenericOpBody; + }); + } + + /// Before fixing up the residual operation, track what values are yielded. If + /// any of those are from the peeled scalar operation, the uses of the + /// corresponding result have to be remapped to result of the generic op for + /// the peeled operation. + SmallVector replacements; + for (auto yieldValue : llvm::enumerate(yieldOp->getOperands())) { + OpResult opr = yieldValue.value().dyn_cast(); + if (!opr || opr.getOwner() != peeledScalarOperation) + replacements.push_back(residualGenericOp.getResult(yieldValue.index())); + else + replacements.push_back(peeledGenericOp->getResult(yieldValue.index())); + } + + /// Update all uses of the peeled scalar operation results in the residual op + /// to the newly added arguments. + { + SmallVector scalarReplacements; + unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults(); + scalarReplacements.reserve(peeledScalarOpNumResults); + for (auto num : llvm::seq(0, peeledScalarOpNumResults)) + scalarReplacements.push_back( + residualGenericOpBody->getArgument(num + origNumInputs)); + bool allUsesReplaced = false; + rewriter.replaceOpWithinBlock(peeledScalarOperation, scalarReplacements, + residualGenericOpBody, &allUsesReplaced); + assert(!allUsesReplaced && + "peeled scalar operation is erased when it wasnt expected to be"); + } + + // Replace the original operation + rewriter.replaceOp(genericOp, replacements); + return success(); +} + +void mlir::linalg::populateDecomposeLinalgOpsPattern( + RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -27,6 +27,7 @@ bool Type::isIndex() const { return isa(); } +bool Type::isInteger() const { return isa(); } /// Return true if this is an integer type with the specified width. bool Type::isInteger(unsigned width) const { if (auto intTy = dyn_cast()) diff --git a/mlir/test/Dialect/Linalg/decompose-ops.mlir b/mlir/test/Dialect/Linalg/decompose-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/decompose-ops.mlir @@ -0,0 +1,326 @@ +// RUN: mlir-opt -test-linalg-decompose-ops -cse -split-input-file %s | FileCheck %s +// RUN: mlir-opt -test-linalg-decompose-ops -cse -canonicalize -split-input-file %s | FileCheck %s --check-prefix=CANONICALIZECHECK + +func.func @simple_op(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) + -> (tensor, tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %init1 = linalg.init_tensor [%d1, %d0] : tensor + %init2 = linalg.init_tensor [%d0, %d1] : tensor + %result:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1, %arg2 : tensor, tensor, tensor) + outs(%init1, %init2 : tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32, %b4 : f32) : + %0 = arith.addf %b0, %b1 : f32 + %1 = arith.mulf %0, %b2 : f32 + linalg.yield %0, %1 : f32, f32 + } -> (tensor, tensor) + return %result#0, %result#1 : tensor, tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d1)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK: func @simple_op( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [%[[D1]], %[[D0]]] +// CHECK-DAG: %[[INIT2:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] +// CHECK-DAG: %[[GENERIC1:.+]]:3 = linalg.generic +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP0]], #[[MAP3]]] +// CHECK-SAME: ["parallel", "parallel"] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : +// CHECK-SAME: outs(%[[INIT1]], %[[INIT2]], %[[INIT1]] : +// CHECK-NEXT: ^bb0( +// CHECK-SAME: %[[B0:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B1:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B2:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B3:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B4:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B5:[a-zA-Z0-9]+]]: f32): +// CHECK-NEXT: %[[S0:.+]] = arith.addf %[[B0]], %[[B1]] +// CHECK-NEXT: linalg.yield %[[S0]], %{{[a-zA-Z0-9]+}}, %[[S0]] +// CHECK: %[[GENERIC2:.+]]:2 = linalg.generic +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP3]], #[[MAP0]]] +// CHECK-SAME: ["parallel", "parallel"] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[GENERIC1]]#2 : +// CHECK-SAME: outs(%[[INIT1]], %[[INIT2]] : +// CHECK-NEXT: ^bb0( +// CHECK-SAME: %[[B6:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B7:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B8:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B9:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B10:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B11:[a-zA-Z0-9]+]]: f32): +// CHECK-NEXT: %[[S1:.+]] = arith.mulf %[[B9]], %[[B8]] +// CHECK-NEXT: linalg.yield %[[B9]], %[[S1]] +// CHECK: return %[[GENERIC1]]#0, %[[GENERIC2]]#1 + +// With cse + canonicalization + +// CANONICALIZECHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CANONICALIZECHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)> +// CANONICALIZECHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d1, d0)> +// CANONICALIZECHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d1)> +// CANONICALIZECHECK: func @simple_op( +// CANONICALIZECHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CANONICALIZECHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CANONICALIZECHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CANONICALIZECHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CANONICALIZECHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CANONICALIZECHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CANONICALIZECHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CANONICALIZECHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [%[[D1]], %[[D0]]] +// CANONICALIZECHECK-DAG: %[[INIT2:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] +// CANONICALIZECHECK-DAG: %[[GENERIC1:.+]] = linalg.generic +// CANONICALIZECHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CANONICALIZECHECK-SAME: ["parallel", "parallel"] +// CANONICALIZECHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : +// CANONICALIZECHECK-SAME: outs(%[[INIT1]] : +// CANONICALIZECHECK-NEXT: ^bb0( +// CANONICALIZECHECK-SAME: %[[B0:[a-zA-Z0-9]+]]: f32 +// CANONICALIZECHECK-SAME: %[[B1:[a-zA-Z0-9]+]]: f32 +// CANONICALIZECHECK-SAME: %[[B2:[a-zA-Z0-9]+]]: f32): +// CANONICALIZECHECK-NEXT: %[[S0:.+]] = arith.addf %[[B0]], %[[B1]] +// CANONICALIZECHECK-NEXT: linalg.yield %[[S0]] +// CANONICALIZECHECK: %[[GENERIC2:.+]] = linalg.generic +// CANONICALIZECHECK-SAME: [#[[MAP3]], #[[MAP2]], #[[MAP0]]] +// CANONICALIZECHECK-SAME: ["parallel", "parallel"] +// CANONICALIZECHECK-SAME: ins(%[[ARG2]], %[[GENERIC1]] : +// CANONICALIZECHECK-SAME: outs(%[[INIT2]] : +// CANONICALIZECHECK-NEXT: ^bb0( +// CANONICALIZECHECK-SAME: %[[B3:[a-zA-Z0-9]+]]: f32 +// CANONICALIZECHECK-SAME: %[[B4:[a-zA-Z0-9]+]]: f32 +// CANONICALIZECHECK-SAME: %[[B5:[a-zA-Z0-9]+]]: f32): +// CANONICALIZECHECK-NEXT: %[[S1:.+]] = arith.mulf %[[B4]], %[[B3]] +// CANONICALIZECHECK-NEXT: linalg.yield %[[S1]] +// CANONICALIZECHECK: return %[[GENERIC1]]#0, %[[GENERIC2]] + + +// ----- + +func.func @simple_op_permuted_outputs(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) + -> (tensor, tensor, tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %init1 = linalg.init_tensor [%d1, %d0] : tensor + %init2 = linalg.init_tensor [%d0, %d1] : tensor + %result:3 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1, %arg2 : tensor, tensor, tensor) + outs(%init1, %init2, %init2 : tensor, tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32, %b4 : f32, %b5 : f32) : + %0 = arith.addf %b0, %b1 : f32 + %1 = arith.mulf %0, %b2 : f32 + linalg.yield %0, %1, %0 : f32, f32, f32 + } -> (tensor, tensor, tensor) + return %result#0, %result#1, %result#2 : tensor, tensor, tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d1)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK: func @simple_op_permuted_outputs( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [%[[D1]], %[[D0]]] +// CHECK-DAG: %[[INIT2:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] +// CHECK-DAG: %[[GENERIC1:.+]]:4 = linalg.generic +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP0]], #[[MAP0]], #[[MAP0]]] +// CHECK-SAME: ["parallel", "parallel"] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : +// CHECK-SAME: outs(%[[INIT1]], %[[INIT2]], %[[INIT2]], %[[INIT2]] : +// CHECK-NEXT: ^bb0( +// CHECK-SAME: %[[B0:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B1:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B2:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B3:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B4:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B5:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B6:[a-zA-Z0-9]+]]: f32): +// CHECK-NEXT: %[[S0:.+]] = arith.addf %[[B0]], %[[B1]] +// CHECK-NEXT: linalg.yield %[[S0]], %{{[a-zA-Z0-9]+}}, %[[S0]] +// CHECK: %[[GENERIC2:.+]]:3 = linalg.generic +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]], #[[MAP3]], #[[MAP0]], #[[MAP0]]] +// CHECK-SAME: ["parallel", "parallel"] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[GENERIC1]]#3 : +// CHECK-SAME: outs(%[[INIT1]], %[[INIT2]], %[[INIT2]] : +// CHECK-NEXT: ^bb0( +// CHECK-SAME: %[[B7:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B8:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B9:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B10:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B11:[a-zA-Z0-9]+]]: f32 +// CHECK-SAME: %[[B12:[a-zA-Z0-9]+]]: f32): +// CHECK-NEXT: %[[S1:.+]] = arith.mulf %[[B10]], %[[B9]] +// CHECK-NEXT: linalg.yield %[[B10]], %[[S1]], %[[B10]] +// CHECK: return %[[GENERIC1]]#0, %[[GENERIC2]]#1, %[[GENERIC1]]#2 + +// CANONICALIZECHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CANONICALIZECHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)> +// CANONICALIZECHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d1, d0)> +// CANONICALIZECHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d1)> +// CANONICALIZECHECK: func @simple_op_permuted_outputs( +// CANONICALIZECHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CANONICALIZECHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CANONICALIZECHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CANONICALIZECHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CANONICALIZECHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CANONICALIZECHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CANONICALIZECHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CANONICALIZECHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [%[[D1]], %[[D0]]] +// CANONICALIZECHECK-DAG: %[[INIT2:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] +// CANONICALIZECHECK-DAG: %[[GENERIC1:.+]]:2 = linalg.generic +// CANONICALIZECHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]] +// CANONICALIZECHECK-SAME: ["parallel", "parallel"] +// CANONICALIZECHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : +// CANONICALIZECHECK-SAME: outs(%[[INIT1]], %[[INIT2]] : +// CANONICALIZECHECK-NEXT: ^bb0( +// CANONICALIZECHECK-SAME: %[[B0:[a-zA-Z0-9]+]]: f32 +// CANONICALIZECHECK-SAME: %[[B1:[a-zA-Z0-9]+]]: f32 +// CANONICALIZECHECK-SAME: %[[B2:[a-zA-Z0-9]+]]: f32): +// CANONICALIZECHECK-NEXT: %[[S0:.+]] = arith.addf %[[B0]], %[[B1]] +// CANONICALIZECHECK-NEXT: linalg.yield %[[S0]], %[[S0]] +// CANONICALIZECHECK: %[[GENERIC2:.+]] = linalg.generic +// CANONICALIZECHECK-SAME: [#[[MAP3]], #[[MAP0]], #[[MAP0]]] +// CANONICALIZECHECK-SAME: ["parallel", "parallel"] +// CANONICALIZECHECK-SAME: ins(%[[ARG2]], %[[GENERIC1]]#1 : +// CANONICALIZECHECK-SAME: outs(%[[INIT]] : +// CANONICALIZECHECK-NEXT: ^bb0( +// CANONICALIZECHECK-SAME: %[[B4:[a-zA-Z0-9]+]]: f32 +// CANONICALIZECHECK-SAME: %[[B5:[a-zA-Z0-9]+]]: f32 +// CANONICALIZECHECK-SAME: %[[B6:[a-zA-Z0-9]+]]: f32): +// CANONICALIZECHECK-NEXT: %[[S1:.+]] = arith.mulf %[[B5]], %[[B4]] +// CANONICALIZECHECK-NEXT: linalg.yield %[[S1]] +// CANONICALIZECHECK: return %[[GENERIC1]]#0, %[[GENERIC2]], %[[GENERIC1]]#1 + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> +#map2 = affine_map<(d0, d1) -> (d1, d0)> +func.func @multi_statement(%arg0 : tensor<10x20xf32>, %arg1 : tensor<10xi32>) -> tensor<20x10xf64> { + %init = linalg.init_tensor [20, 10] : tensor<20x10xf64> + %0 = linalg.generic { + indexing_maps = [#map0, #map1, #map2], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1 : tensor<10x20xf32>, tensor<10xi32>) + outs(%init : tensor<20x10xf64>) { + ^bb0(%b0 : f32, %b1 : i32, %b2 : f64): + %1 = arith.sitofp %b1 : i32 to f64 + %2 = arith.extf %b0 : f32 to f64 + %3 = arith.addf %1, %2 : f64 + linalg.yield %3 : f64 + } -> tensor<20x10xf64> + return %0 : tensor<20x10xf64> +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK: func @multi_statement( +// CHECK-SAME: %[[ARG0:.+]]: tensor<10x20xf32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<10xi32>) +// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [20, 10] : tensor<20x10xf64> +// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [10, 20] : tensor<10x20xf64> +// CHECK: %[[GENERIC0:.+]]:2 = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : +// CHECK-SAME: outs(%[[INIT0]], %[[INIT1]] : +// CHECK-NEXT: ^bb0( +// CHECK-SAME: %[[B0:.+]]: f32 +// CHECK-SAME: %[[B1:.+]]: i32 +// CHECK-SAME: %[[B2:[a-zA-Z0-9]+]]: f64 +// CHECK-SAME: %[[B3:.+]]: f64 +// CHECK-NEXT: %[[S0:.+]] = arith.sitofp %[[B1]] : i32 to f64 +// CHECK-NEXT: linalg.yield %{{.+}}, %[[S0]] +// CHECK: %[[GENERIC1:.+]]:2 = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP0]], #[[MAP2]], #[[MAP0]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[GENERIC0]]#1 : +// CHECK-SAME: outs(%[[INIT0]], %[[INIT1]] : +// CHECK-NEXT: ^bb0( +// CHECK-SAME: %[[B4:.+]]: f32 +// CHECK-SAME: %[[B5:.+]]: i32 +// CHECK-SAME: %[[B6:[a-zA-Z0-9]+]]: f64 +// CHECK-SAME: %[[B7:[a-zA-Z0-9]+]]: f64 +// CHECK-SAME: %[[B8:.+]]: f64 +// CHECK-NEXT: %[[S1:.+]] = arith.extf %[[B4]] : f32 to f64 +// CHECK-NEXT: linalg.yield %{{.+}}, %[[S1]] +// CHECK: %[[GENERIC2:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP0]], #[[MAP0]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[GENERIC0]]#1, %[[GENERIC1]]#1 : +// CHECK-SAME: outs(%[[INIT0]] : +// CHECK-NEXT: ^bb0( +// CHECK-SAME: %[[B9:.+]]: f32 +// CHECK-SAME: %[[B10:.+]]: i32 +// CHECK-SAME: %[[B11:[a-zA-Z0-9]+]]: f64 +// CHECK-SAME: %[[B12:[a-zA-Z0-9]+]]: f64 +// CHECK-SAME: %[[B13:.+]]: f64 +// CHECK-NEXT: %[[S2:.+]] = arith.addf %[[B11]], %[[B12]] : f64 +// CHECK-NEXT: linalg.yield %[[S2]] +// CHECK: return %[[GENERIC2]] + +// CANONICALIZECHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0)> +// CANONICALIZECHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CANONICALIZECHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d1, d0)> +// CANONICALIZECHECK: func @multi_statement( +// CANONICALIZECHECK-SAME: %[[ARG0:.+]]: tensor<10x20xf32> +// CANONICALIZECHECK-SAME: %[[ARG1:.+]]: tensor<10xi32>) +// CANONICALIZECHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [20, 10] : tensor<20x10xf64> +// CANONICALIZECHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [10, 20] : tensor<10x20xf64> +// CANONICALIZECHECK: %[[GENERIC0:.+]] = linalg.generic +// CANONICALIZECHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CANONICALIZECHECK-SAME: iterator_types = ["parallel", "parallel"] +// CANONICALIZECHECK-SAME: ins(%[[ARG1]] : +// CANONICALIZECHECK-SAME: outs(%[[INIT1]] : +// CANONICALIZECHECK-NEXT: ^bb0( +// CANONICALIZECHECK-SAME: %[[B0:.+]]: i32 +// CANONICALIZECHECK-SAME: %[[B1:.+]]: f64 +// CANONICALIZECHECK-NEXT: %[[S0:.+]] = arith.sitofp %[[B0]] : i32 to f64 +// CANONICALIZECHECK-NEXT: linalg.yield %[[S0]] +// CANONICALIZECHECK: %[[GENERIC1:.+]] = linalg.generic +// CANONICALIZECHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP1]]] +// CANONICALIZECHECK-SAME: iterator_types = ["parallel", "parallel"] +// CANONICALIZECHECK-SAME: ins(%[[ARG0]] : +// CANONICALIZECHECK-SAME: outs(%[[INIT1]] : +// CANONICALIZECHECK-NEXT: ^bb0( +// CANONICALIZECHECK-SAME: %[[B2:.+]]: f32 +// CANONICALIZECHECK-SAME: %[[B3:.+]]: f64 +// CANONICALIZECHECK-NEXT: %[[S1:.+]] = arith.extf %[[B2]] : f32 to f64 +// CANONICALIZECHECK-NEXT: linalg.yield %[[S1]] +// CANONICALIZECHECK: %[[GENERIC2:.+]] = linalg.generic +// CANONICALIZECHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP1]], #[[MAP2]]] +// CANONICALIZECHECK-SAME: iterator_types = ["parallel", "parallel"] +// CANONICALIZECHECK-SAME: ins(%[[GENERIC0]], %[[GENERIC1]] : +// CANONICALIZECHECK-SAME: outs(%[[INIT0]] : +// CANONICALIZECHECK-NEXT: ^bb0( +// CANONICALIZECHECK-SAME: %[[B4:[a-zA-Z0-9]+]]: f64 +// CANONICALIZECHECK-SAME: %[[B5:[a-zA-Z0-9]+]]: f64 +// CANONICALIZECHECK-SAME: %[[B6:.+]]: f64 +// CANONICALIZECHECK-NEXT: %[[S2:.+]] = arith.addf %[[B4]], %[[B5]] : f64 +// CANONICALIZECHECK-NEXT: linalg.yield %[[S2]] +// CANONICALIZECHECK: return %[[GENERIC2]] diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -1,6 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRLinalgTestPasses TestLinalgCodegenStrategy.cpp + TestLinalgDecomposeOps.cpp TestLinalgElementwiseFusion.cpp TestLinalgFusionTransforms.cpp TestLinalgHoisting.cpp diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp @@ -0,0 +1,54 @@ +//===- TestLinalgDecomposeOps.cpp - Test Linalg decomposition ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass for testing decomposition of Linalg ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +struct TestLinalgDecomposeOps + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgDecomposeOps) + + TestLinalgDecomposeOps() = default; + TestLinalgDecomposeOps(const TestLinalgDecomposeOps &pass) + : PassWrapper(pass) {} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + StringRef getArgument() const final { return "test-linalg-decompose-ops"; } + StringRef getDescription() const final { + return "Test Linalg decomposition patterns"; + } + + void runOnOperation() override { + MLIRContext *context = &this->getContext(); + RewritePatternSet decompositionPatterns(context); + linalg::populateDecomposeLinalgOpsPattern(decompositionPatterns); + if (failed(applyPatternsAndFoldGreedily( + getOperation(), std::move(decompositionPatterns)))) { + return signalPassFailure(); + } + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerTestLinalgDecomposeOps() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -87,6 +87,7 @@ void registerTestInterfaces(); void registerTestLastModifiedPass(); void registerTestLinalgCodegenStrategy(); +void registerTestLinalgDecomposeOps(); void registerTestLinalgElementwiseFusion(); void registerTestLinalgFusionTransforms(); void registerTestLinalgTensorFusionTransforms(); @@ -186,6 +187,7 @@ mlir::test::registerTestInterfaces(); mlir::test::registerTestLastModifiedPass(); mlir::test::registerTestLinalgCodegenStrategy(); + mlir::test::registerTestLinalgDecomposeOps(); mlir::test::registerTestLinalgElementwiseFusion(); mlir::test::registerTestLinalgFusionTransforms(); mlir::test::registerTestLinalgTensorFusionTransforms();