diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -53,6 +53,26 @@ let constructor = "createTosaMakeBroadcastablePass()"; } +def TosaPartition : Pass<"tosa-partition", "ModuleOp"> { + let summary = "Outline TOSA Conv2D ops and adjacent element-wise ops"; + let description = [{ + Create outlined kernels from tosa::Conv2D ops (or similar) and surrounding + elementwise ops. + }]; + + let options = [ + ListOption<"anchorOps", "anchor-ops", "std::string", + "One or more operations to be used as focus of partitioned " + "kernels", + "llvm::cl::ZeroOrMore">, + Option<"partitionTagOpt", "partition-tag", "std::string", + /*default=*/"\"kernel\"", "Attribute for outlined functions">, + Option<"trailingOnly", "trailing-only", "bool", /*default=*/"false", + "Don't gather ops ahead of anchor op"> + ]; + let dependentDialects = ["tosa::TosaDialect"]; +} + def TosaOptionalDecompositions : Pass<"tosa-optional-decompositions", "func::FuncOp"> { let summary = "Applies Tosa operations optional decompositions"; diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -7,6 +7,7 @@ TosaLayerwiseConstantFoldPass.cpp TosaMakeBroadcastable.cpp TosaOptionalDecompositions.cpp + TosaPartition.cpp TosaValidation.cpp ADDITIONAL_HEADER_DIRS @@ -16,6 +17,7 @@ MLIRTosaPassIncGen LINK_LIBS PUBLIC + MLIRTransforms MLIRFuncDialect MLIRPass MLIRTosaDialect diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaPartition.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaPartition.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaPartition.cpp @@ -0,0 +1,610 @@ +//===- TosaPartition.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Replace conv2d followed by elementwise op with call to function containing +// them. Generalised, outline any anchor op, all its trailing elementwise ops, +// and all its leading elementwise ops. (Where "elementwise" itself is +// generalised to include transpose and reshape ops and certain constant ops.) +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include +#include +#include + +using llvm::SmallVector; + +// TODO(kdrewnia): Make it so list options can have defaults, and then get rid +// of needing to set defaults here +namespace mlir { +namespace tosa { +#define GEN_PASS_DEF_TOSAPARTITION +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" +} // namespace tosa +} // namespace mlir + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +class TosaPartitionPass + : public tosa::impl::TosaPartitionBase { +public: + using tosa::impl::TosaPartitionBase::TosaPartitionBase; + + bool isAnchorOp(Operation *op); + bool isTransposeOp(Operation *op) const; + bool isLeadingOp(Operation *op) const; + bool isTrailingOp(Operation *op) const; + StringRef partitionTag() const; + void traceInputs(Operation *op, SetVector &predecessors, + SetVector &inputNodes); + void runOnOperation() override; +}; + +// Tosa ops can broadcast values along axes, which allows for +// element-wise operations without fully-matching dimensions. The +// Elementwise trait is strict about matching dimensions, but +// broadcastable ops are also element-wise, and we know that an +// additional set of ops are also element-wise. +bool isElementwiseOp(Operation *op) { + return op->hasTrait() || + op->hasTrait() || + // clang-format off + isa(op); + // clang-format on +} + +bool isFuseableOp(Operation *op) { return isElementwiseOp(op); } + +bool isZeroAttribute(Attribute value) { + if (auto intValue = value.dyn_cast()) + return intValue.getValue().isNullValue(); + if (auto fpValue = value.dyn_cast()) + return fpValue.getValue().isZero(); + if (auto splatValue = value.dyn_cast()) + return isZeroAttribute(splatValue.getSplatValue()); + if (auto elementsValue = value.dyn_cast()) + return llvm::all_of(elementsValue.getValues(), isZeroAttribute); + if (auto arrayValue = value.dyn_cast()) + return llvm::all_of(arrayValue.getValue(), isZeroAttribute); + return false; +} + +bool isConstantZero(Operation *op) { + if (auto cst = dyn_cast(op)) + return isZeroAttribute(cst.getValue()); + if (auto cst = dyn_cast(op)) + return isZeroAttribute(cst->getAttr("value")); + return false; +} + +bool isSmallishConstant(Operation *op) { + if (mlir::detail::isConstantLike(op)) + // In TOSA, should always be tensor and thus shaped, but just in case. + if (auto cstType = op->getResult(0).getType().dyn_cast()) + if (cstType.hasStaticShape() && cstType.getNumElements() <= 8) + return true; + return false; +} + +//////////////////////////////////////////////////////////////////////////////// + +// Inspired by / adapted from outlineIfOp() in SCF/Transforms/Utils.cpp +// and mergeIdenticalBlocks() in Utils/RegionUtils.cpp. + +// The OutliningCandidate structure is designed to identify identical kernels +// as they're being gathered and before creating any. + +struct OutliningCandidate { + OutliningCandidate(Operation *anchorOp_, ArrayRef &trailingOps_, + ArrayRef &leadingOps_, + ArrayRef ¶ms_, ArrayRef &returnVals_, + StringRef partFnName_); + + unsigned addOp(Operation *op, unsigned orderIt); + + Operation *anchorOp; + SmallVector trailingOps; + SmallVector leadingOps; + SmallVector params; + SmallVector returnVals; + std::string partFnName; + llvm::hash_code hash; + func::FuncOp function; + + /// Return the order index for the given value that is within the block of + /// this data. + unsigned getOrderOf(Value value) const; + + /// A map of result producing operations to their relative orders within this + /// block. The order of an operation is the number of defined values that are + /// produced within the block before this operation. + DenseMap opOrderIndex; +}; + +unsigned OutliningCandidate::addOp(Operation *op, unsigned orderIt) { + if (unsigned numResults = op->getNumResults()) { + opOrderIndex.try_emplace(op, orderIt); + orderIt += numResults; + } + + auto opHash = OperationEquivalence::computeHash( + op, OperationEquivalence::ignoreHashValue, + OperationEquivalence::ignoreHashValue, + OperationEquivalence::IgnoreLocations); + hash = llvm::hash_combine(hash, opHash); + + return orderIt; +} + +OutliningCandidate::OutliningCandidate(Operation *anchorOp_, + ArrayRef &trailingOps_, + ArrayRef &leadingOps_, + ArrayRef ¶ms_, + ArrayRef &returnVals_, + StringRef partFnName_) + : anchorOp(anchorOp_), partFnName(partFnName_), hash(0), function(nullptr) { + // We'll need to grab the cloned ops to avoid use-after-free. + for (auto *op : trailingOps_) { + trailingOps.push_back(op); + } + for (auto *op : leadingOps_) { + leadingOps.push_back(op); + } + for (auto val : params_) { + params.push_back(val); + } + for (auto val : returnVals_) { + returnVals.push_back(val); + } + + unsigned orderIt = params.size(); + for (auto *op : leadingOps) { + orderIt = addOp(op, orderIt); + } + orderIt = addOp(anchorOp, orderIt); + for (auto *op : trailingOps) { + orderIt = addOp(op, orderIt); + } +} + +unsigned OutliningCandidate::getOrderOf(Value value) const { + // Arguments use the argument number as the order index. + if (BlockArgument arg = value.dyn_cast()) + return arg.getArgNumber(); + for (unsigned i = 0; i < params.size(); i++) { + if (params[i] == value) + return i; + } + + // Otherwise, the result order is offset from the parent op's order. + auto *definingOp = value.getDefiningOp(); + if (definingOp) { + auto opOrderIt = opOrderIndex.find(definingOp); + // Candidate arguments will have a definingOp that won't be in opOrderIndex. + assert(opOrderIt != opOrderIndex.end() && "expected op to have an order"); + return opOrderIt->second + value.cast().getResultNumber(); + } + + return 0; +} + +bool opsMatch(Operation *lhs, Operation *rhs, OutliningCandidate &one, + OutliningCandidate &two) { + // Check that the operations are equivalent. + if (!OperationEquivalence::isEquivalentTo( + lhs, rhs, OperationEquivalence::ignoreValueEquivalence, + OperationEquivalence::ignoreValueEquivalence, + OperationEquivalence::Flags::IgnoreLocations)) + return false; + + // Compare the operands of the two operations. If the operand is within + // the block, it must refer to the same operation. + if (lhs->getNumOperands() != rhs->getNumOperands()) + return false; + auto lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands(); + + for (int operand : llvm::seq(0, lhs->getNumOperands())) { + Value lhsOperand = lhsOperands[operand]; + Value rhsOperand = rhsOperands[operand]; + if (lhsOperand == rhsOperand) + continue; + // Check that the types of the operands match. + if (lhsOperand.getType() != rhsOperand.getType()) + return false; + + // Otherwise, these operands must have the same logical order within the + // parent block. + if (one.getOrderOf(lhsOperand) != two.getOrderOf(rhsOperand)) + return false; + } + + return true; +} + +bool outliningCandidatesEquivalent(OutliningCandidate &one, + OutliningCandidate &two) { + if (one.hash != two.hash) + return false; + + if (one.params.size() != two.params.size()) + return false; + + for (unsigned i = 0; i < one.params.size(); i++) + if (one.params[i].getType() != two.params[i].getType()) + return false; + + for (auto ops : llvm::zip(one.leadingOps, two.leadingOps)) + if (!opsMatch(std::get<0>(ops), std::get<1>(ops), one, two)) + return false; + + if (!opsMatch(one.anchorOp, two.anchorOp, one, two)) + return false; + + for (auto ops : llvm::zip(one.trailingOps, two.trailingOps)) + if (!opsMatch(std::get<0>(ops), std::get<1>(ops), one, two)) + return false; + + return true; +} + +OutliningCandidate * +findOutliningCandidate(OutliningCandidate &newCandidate, + std::vector &candidates) { + for (auto &candidate : candidates) + if (outliningCandidatesEquivalent(candidate, newCandidate)) + return &candidate; + + return nullptr; +} + +// Given a convolution op and its fuse-able trailing and leading ops, +// remove them into a separate function. +void outlinePartitionOps(Operation *anchorOp, ArrayRef trailingOps, + ArrayRef leadingOps, + ArrayRef params, ArrayRef returnVals, + StringRef partFnName, StringRef attrName, + std::vector &candidates) { + ValueRange values(params); + OpBuilder b(anchorOp); + Location loc = anchorOp->getLoc(); + func::FuncOp outlinedFunc; + + // ------------------------------------------------------------ + // Merging part. + + OutliningCandidate newCandidate(anchorOp, trailingOps, leadingOps, params, + returnVals, partFnName); + + if (OutliningCandidate *found = + findOutliningCandidate(newCandidate, candidates)) { + // Matches one we already have. + outlinedFunc = found->function; + } else { + // ------------------------------------------------------------ + // Construction part. + + // Insert outlined function before current function. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(anchorOp->getParentOfType()); + + // Make FuncOp from anchorOp's operand types and trailingOp's result type. + MLIRContext *ctx = anchorOp->getContext(); + ValueRange results(returnVals); + FunctionType type = + FunctionType::get(ctx, values.getTypes(), results.getTypes()); + SmallVector kernelAttrs{ + b.getNamedAttr(attrName, b.getUnitAttr()), + }; + outlinedFunc = b.create( + loc, partFnName, type, ArrayRef(kernelAttrs)); + outlinedFunc->setAttr("sym_visibility", StringAttr::get(ctx, "private")); + newCandidate.function = outlinedFunc; + + // Clone leadingOps, anchorOp, and trailingOps into the body of the new + // function, while also updating the comparison details for future + // candidates. + b.setInsertionPointToStart(outlinedFunc.addEntryBlock()); + BlockAndValueMapping bvm; + for (auto it : llvm::zip(values, outlinedFunc.getArguments())) + bvm.map(std::get<0>(it), std::get<1>(it)); + + newCandidate.leadingOps.clear(); + for (auto *op : llvm::reverse(leadingOps)) { + newCandidate.leadingOps.push_back(b.clone(*op, bvm)); + newCandidate.opOrderIndex[newCandidate.leadingOps.back()] = + newCandidate.opOrderIndex[op]; + } + std::reverse(newCandidate.leadingOps.begin(), + newCandidate.leadingOps.end()); + + newCandidate.anchorOp = b.clone(*anchorOp, bvm); + newCandidate.opOrderIndex[newCandidate.anchorOp] = + newCandidate.opOrderIndex[anchorOp]; + + newCandidate.trailingOps.clear(); + for (auto *op : trailingOps) { + // All operands should already be in bvm. + assert(llvm::all_of(op->getOperands(), + [&](Value v) { return bvm.lookupOrNull(v); })); + newCandidate.trailingOps.push_back(b.clone(*op, bvm)); + newCandidate.opOrderIndex[newCandidate.trailingOps.back()] = + newCandidate.opOrderIndex[op]; + } + + // Make ReturnOp from trailingOps' results. + SmallVector returnOperands; + for (auto op : returnVals) { + returnOperands.push_back(bvm.lookup(op)); + } + // Can't also supply return types, because it'll see a mismatch + // in numbers where there isn't one. + b.create(loc, returnOperands); + + candidates.push_back(newCandidate); + } + + // ------------------------------------------------------------ + // Replacement part. + + // Replace anchorOp, trailingOps, and leadingOps with CallOp to new function. + Operation *lastOp = anchorOp; + if (!trailingOps.empty()) + lastOp = trailingOps[trailingOps.size() - 1]; + b.setInsertionPointAfter(lastOp); + func::CallOp callOp = b.create(loc, outlinedFunc, values); + + for (auto it : llvm::zip(returnVals, callOp->getResults())) { + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + } + + // Erase the ops we outlined, which should be safe now. + for (auto &op : llvm::make_early_inc_range(llvm::reverse(trailingOps))) + if (op->use_empty()) + op->erase(); + assert(anchorOp->use_empty() && "expected 'op' to have no uses"); + anchorOp->erase(); + for (auto &op : llvm::make_early_inc_range(leadingOps)) + if (op->use_empty()) + op->erase(); +} + +} // namespace + +bool TosaPartitionPass::isAnchorOp(Operation *op) { + if (anchorOps.empty()) // ListOption doesn't have a default value. + anchorOps = {"tosa.conv2d", "tosa.matmul", "tosa.depthwise_conv2d", + "tosa.fully_connected"}; + return llvm::is_contained(anchorOps, op->getName().getIdentifier().str()); +} + +bool TosaPartitionPass::isTransposeOp(Operation *op) const { + return isa(op); +} + +bool TosaPartitionPass::isLeadingOp(Operation *op) const { + return isConstantZero(op) || isSmallishConstant(op) || isTransposeOp(op) || + (!trailingOnly && isFuseableOp(op)); +} + +bool TosaPartitionPass::isTrailingOp(Operation *op) const { + return isa(op) || isFuseableOp(op); +} + +StringRef TosaPartitionPass::partitionTag() const { return partitionTagOpt; } + +void TosaPartitionPass::traceInputs(Operation *op, + SetVector &predecessors, + SetVector &inputNodes) { + for (const auto &opnd : op->getOperands()) { + Operation *usedOp = opnd.getDefiningOp(); + if (usedOp && (isTransposeOp(op) ? isSmallishConstant(usedOp) + : isLeadingOp(usedOp))) { + if (predecessors.contains( + usedOp)) // If already present, move it for new use. + predecessors.remove(usedOp); + predecessors.insert(usedOp); + if (!mlir::detail::isConstantLike(usedOp)) { + // depth first + traceInputs(usedOp, predecessors, inputNodes); + } + } else if (!predecessors.contains( + usedOp)) { // special-case consts aren't inputs + inputNodes.insert(opnd); + } + } +} + +// Inspired by / adapted from TestSCFIfUtilsPass in +// test/lib/Transforms/TestSCFUtils.cpp. +void TosaPartitionPass::runOnOperation() { + ModuleOp module = getOperation(); + auto funcOps = module.getOps(); + for (auto func : llvm::make_early_inc_range(funcOps)) { + // Don't partition a kernel; it may be already partitioned. + if (func->hasAttr(partitionTag())) + continue; + + int count = 0; + // (Problems with node mismatches and unexpected uses if we have the + // candidates list at module level.) + std::vector candidates; + auto callback = [&](Operation *op) { + if (!isAnchorOp(op)) + return WalkResult::advance(); + Operation *anchorOp = op; + auto strCount = std::string("__part_") + std::to_string(count++); + + // Given a Conv2DOp (or other anchor op), gather all the + // element-wise ops that are reachable from its results, + // contiguously. + // + // The ops after the anchor are "trailing" ops. + // + // inputNodes gathers what will become the parameters of the + // outlined function; initially it's the anchor's arguments, + // and it accumulates arguments to other ops that don't come + // from inside the outlined function. + // + // resultNodes will become the results of the outlined function. + // It starts with the anchor's result(s) and gains the results + // of each new trailingOp. When all a resultNode's users can be + // determined to lie within the outlined function, it's removed + // from the set. + // + // These are SetVectors because we test with contains() a lot, + // but still want to preserve order. + SetVector trailingOps; + SetVector inputNodes; + SetVector resultNodes(anchorOp->getResults().begin(), + anchorOp->getResults().end()); + + // Grab a useful set of leading ops, like we do for trailing. + SetVector leadingOps; + traceInputs(anchorOp, leadingOps, inputNodes); + + DominanceInfo domInfo(func); + std::deque worklist; // cuz I want to pull from the front. + + worklist.push_back(anchorOp); + while (!worklist.empty()) { + Operation *op = worklist.front(); + worklist.pop_front(); + for (auto *userOp : op->getUsers()) { + if (isTrailingOp(userOp)) { + bool skip = false; + // First criterion is that the op is element-wise. Second + // criterion is that the op dominates all the users of the + // accumulated results of the outlined function. In other words, + // we can't take an op that comes "after" a user of the result + // from the eventual call, because the call needs to dominate all + // its users. + for (const Value &val : resultNodes) { + for (auto *user : val.getDefiningOp()->getUsers()) { + if (user != userOp && + !domInfo.properlyDominates(userOp, user)) { + skip = true; + } + } + } + + // Third criterion: TransposeOps can be acceptable trailing ops + // while also being acceptable leading ops. Let's prefer the + // latter by insisting that a trailing TransposeOp directly uses + // an anchor op. + if (isa(userOp)) { + auto *firstOp = userOp->getOpOperand(0).get().getDefiningOp(); + if (!firstOp || !isAnchorOp(firstOp)) + skip = true; + } + + // userOp is acceptable. Keep it as a trailingOp, put it on the + // worklist. Add its operands to inputNodes unless they come + // from other trailingOps (indicated by being in resultNodes). + // If all the users of any resultNode are in trailingOps, there's + // no need to return it so remove from resultNodes. Finally, + // add all userOp's results to resultNodes. + if (!skip) { + // Also accept small constant inputs to userOp. + for (Value opnd : userOp->getOperands()) { + auto *op = opnd.getDefiningOp(); + if (op && (isConstantZero(op) || isSmallishConstant(op))) + trailingOps.insert(op); + } + // General case. + trailingOps.insert(userOp); + worklist.push_back(userOp); + for (Value opnd : userOp->getOperands()) + if (!resultNodes.contains(opnd) && + !trailingOps.contains(opnd.getDefiningOp())) + inputNodes.insert(opnd); + for (const Value &val : resultNodes) + if (llvm::all_of(val.getUsers(), [&](Operation *u) { + return trailingOps.contains(u); + })) + resultNodes.remove(val); + for (auto res : userOp->getResults()) + resultNodes.insert(res); + } + } + } + } + + // Make the outlined function from the ops we've gathered. + outlinePartitionOps(anchorOp, trailingOps.getArrayRef(), + leadingOps.getArrayRef(), inputNodes.getArrayRef(), + resultNodes.getArrayRef(), + std::string(func.getSymName()) + strCount, + partitionTag(), candidates); + // Outlining will erase nodes and thus perturb the walk, so + // signal interrupted to exit it and restart. + return WalkResult::interrupt(); + }; + + // Walk until we've outlined all the anchor ops we can. + while (func.walk(callback).wasInterrupted()) { + } + } +} diff --git a/mlir/test/Dialect/Tosa/tosa-partition-options.mlir b/mlir/test/Dialect/Tosa/tosa-partition-options.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-partition-options.mlir @@ -0,0 +1,77 @@ +// RUN: mlir-opt --tosa-partition %s | FileCheck %s +// RUN: mlir-opt --tosa-partition=partition-tag=one %s | FileCheck %s --check-prefix=ONE +// RUN: mlir-opt --tosa-partition='anchor-ops=tosa.depthwise_conv2d partition-tag=two' %s | FileCheck %s --check-prefix=TWO +// RUN: mlir-opt --tosa-partition='anchor-ops=tosa.depthwise_conv2d trailing-only partition-tag=three' %s | FileCheck %s --check-prefix=THREE +// RUN: mlir-opt --tosa-partition='anchor-ops=tosa.conv2d partition-tag=four' %s | FileCheck %s --check-prefix=FOUR + +// RUN: mlir-opt --test-tosa-partition-options=default %s | FileCheck %s --check-prefix=CHECK +// RUN: mlir-opt --test-tosa-partition-options=depthwise-only %s | FileCheck %s --check-prefix=TWO +// RUN: mlir-opt --test-tosa-partition-options=conv-only %s | FileCheck %s --check-prefix=FOUR +// RUN: mlir-opt --test-tosa-partition-options=attr-one %s | FileCheck %s --check-prefix=ONE +// RUN: mlir-opt --test-tosa-partition-options=nofront-arg %s | FileCheck %s --check-prefix=THREE + +// CHECK-LABEL: func private @test_fusion8__part_0 +// CHECK-SAME: attributes {{{.*}}kernel} +// CHECK-NEXT: arith.constant +// CHECK-NEXT: tosa.transpose +// CHECK-NEXT: tosa.depthwise_conv2d +// CHECK-NEXT: tosa.abs +// CHECK-NEXT: tosa.add +// CHECK-NEXT: return +// CHECK: func private @test_fusion8__part_1 +// CHECK-NEXT: tosa.conv2d +// CHECK-NEXT: return +// CHECK: func @test_fusion8 +// CHECK: call @test_fusion8__part_1 +// CHECK: call @test_fusion8__part_0 + +// ONE-LABEL: func private @test_fusion8__part_0 +// ONE-SAME: attributes {{{.*}}one} +// ONE-NEXT: arith.constant +// ONE-NEXT: tosa.transpose +// ONE-NEXT: tosa.depthwise_conv2d +// ONE-NEXT: tosa.abs +// ONE-NEXT: tosa.add +// ONE-NEXT: return +// ONE: func @test_fusion8 +// ONE: call @test_fusion8__part_0 + +// TWO-LABEL: func private @test_fusion8__part_0 +// TWO-NEXT: arith.constant +// TWO-NEXT: tosa.transpose +// TWO-NEXT: tosa.depthwise_conv2d +// TWO-NEXT: tosa.abs +// TWO-NEXT: tosa.add +// TWO-NEXT: return +// TWO: func @test_fusion8 +// TWO: tosa.conv2d +// TWO: call @test_fusion8__part_0 + +// THREE-LABEL: func private @test_fusion8__part_0 +// THREE-NEXT: arith.constant +// THREE-NEXT: tosa.transpose +// THREE-NEXT: tosa.depthwise_conv2d +// THREE-NEXT: tosa.abs +// THREE-NEXT: tosa.add +// THREE-NEXT: return +// THREE: func @test_fusion8 +// THREE: tosa.conv2d +// THREE: call @test_fusion8__part_0 + +// FOUR-LABEL: func private @test_fusion8__part_0 +// FOUR-SAME: attributes {{{.*}}four} +// FOUR-NEXT: tosa.conv2d +// FOUR-NEXT: tosa.add +// FOUR-NEXT: return +// FOUR: func @test_fusion8 +// FOUR: call @test_fusion8__part_0 + +func.func @test_fusion8(%arg0: tensor<128x32x32x8xf32>, %arg1: tensor<128x8x3x3xf32>, %arg2: tensor<8xf32>, %arg3: tensor<128x8x32x32xf32>, %arg4: tensor<128x8x3x3xf32>, %arg5: tensor<8xf32>) -> tensor<128x128x30x30xf32> { + %cst = arith.constant dense<[0, 2, 3, 1]> : tensor<4xi64> + %0 = "tosa.transpose"(%arg0, %cst) {changing_layout_root = false} : (tensor<128x32x32x8xf32>, tensor<4xi64>) -> tensor<128x8x32x32xf32> + %1 = "tosa.depthwise_conv2d"(%0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<128x8x32x32xf32>, tensor<128x8x3x3xf32>, tensor<8xf32>) -> tensor<128x128x30x30xf32> + %2 = "tosa.conv2d"(%arg3, %arg4, %arg5) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<128x8x32x32xf32>, tensor<128x8x3x3xf32>, tensor<8xf32>) -> tensor<128x128x30x30xf32> + %3 = "tosa.abs"(%1) {} : (tensor<128x128x30x30xf32>) -> tensor<128x128x30x30xf32> + %4 = "tosa.add"(%3, %2) {} : (tensor<128x128x30x30xf32>, tensor<128x128x30x30xf32>) -> tensor<128x128x30x30xf32> + return %4 : tensor<128x128x30x30xf32> +} diff --git a/mlir/test/Dialect/Tosa/tosa-partition-run.mlir b/mlir/test/Dialect/Tosa/tosa-partition-run.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-partition-run.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),tosa-to-tensor,tosa-to-arith,one-shot-bufferize{allow-return-allocs bufferize-function-boundaries},func.func(buffer-deallocation),func.func(convert-linalg-to-loops),lower-affine,convert-linalg-to-llvm,convert-scf-to-cf,convert-math-to-llvm,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)" \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext \ +// RUN: > %t1 +// +// RUN cat %t1 | FileCheck %s +// CHECK: Unranked Memref +// CHECK-SAME: sizes = [5, 10, 8, 18] strides = [1440, 144, 18, 1] +// CHECK-NEXT: [-0.399353 +// CHECK-NEXT: [-0.399353 +// CHECK-NEXT: [-0.399353 +// +// RUN: mlir-opt %s -pass-pipeline="builtin.module(tosa-partition,func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),tosa-to-tensor,tosa-to-arith,one-shot-bufferize{allow-return-allocs bufferize-function-boundaries},func.func(buffer-deallocation),func.func(convert-linalg-to-loops),lower-affine,convert-linalg-to-llvm,convert-scf-to-cf,convert-math-to-llvm,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)" \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext \ +// RUN: > %t2 +// +// RUN: diff --ignore-matching-lines='Unranked Memref' %t1 %t2 + +module attributes {torch.debug_module_name = "Conv2dNoPaddingModule"} { + func.func private @printMemrefF32(tensor<*xf32>) + func.func private @printNewline() + func.func @main() { + %0 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> + %2 = "tosa.const"() {value = dense<0.000000e+00> : tensor<10xf32>} : () -> tensor<10xf32> + %3 = "tosa.const"() {value = dense<1> : tensor<5x2x10x20xi32>} : () -> tensor<5x2x10x20xi32> + %4 = "tosa.const"() {value = dense<"0x2E4CE7BABE79013E42A646BE27A031BEC9EBB9BDC771813DE50699BB015F3F3E7D5AABBC74777F3D3EE291BD7AC53DBD029566BE0BD91FBED4FCC6BD880D0F3CE4D5BE3D2BD2103E94A023BEB234D2BDDF54AF3DF56B483EDFAF46BDA39C343EBF9C1BBD4750CC3C818B5A3ED6E65FBED9F117BE3B6A74BDDE29BCBDD388503EAE711CBE0936DEBD8D9F28BE2B0C62BE1EE40CBEC9784F3ED265D73DD5F5E93D62184B3C4F7BF7BD4A56233D115B61BEE0652EBE11DBF8BD5A48183E94830D3E4915D6BD4C570BBC225D1A3E6EF16F3E0F95BF3D6F6C023DF7D3213EE11C0EBE40E7333D86203BBE5C4827BE435DF9BDC06ADA3D5221C23D3DF80EBE25D5913DD97F043EAEB5F3BC8A72133C39B25F3DCABB153E69C0673E3DFF39BED1E6B0BD8AB6BD3D5BFA473E5008523E00F7543ECA22403DB7E151BE36A0B13CDDFE16BE85EF60BE1D72563E0F85373E31C370BEE6B3343D2FA322BD93DF1EBDB4F7DCBDDDA1B93D36F50EBED8F5B03D511DF43D3AC82C3E657EB43DB2E16EBEA4911CBE4907F13D31114A3DC7473CBEB2FA0ABEF40E633ED0A1223E2F7AD2BD55FC72BD11EB65BEB7D28ABBF4C135BE892C3ABEDAC754BC9CF4103D0FB0C5BD93370F3E60E012BE86005B3E4367253EB6884BBE314870BD8A402E3CB5DB0C3D08FA643DAA6EBD3DB851673C3888EBBDBE6AE43D028667BE1F0E0FBE75AD71BD7523EBBDF6DEA8BDBBD245BEEB5C4DBD475E4E3DB93C1DBEAD2E46BCF2C62C3E857EC6BC27A7D63B0493A6BCE862433D3577193E63A0643ECB46193E4D26653ED8A48BBC6ED158BE81D8E4BDBA57243E0345C8BAEFEEEFBDD2F438BE5DE061BE86B54BBE74D343BDE05C043E187D023E41C668BE348E163EA4DD3CBE6B1A4CBDA2BAC3BD33F539BD718E3DBD669558BE0B6650BE2D1217BD60C3473B6A49DBBDEED6B53DBF3C59BE2D4F82BC8341543E9BE5C4BDB2F2593E77D1AE3D32D159BE10B5183EE3CFDEBC1D7DD7BDED00413E890A43BE"> : tensor<10x2x3x3xf32>} : () -> tensor<10x2x3x3xf32> + %5 = "tosa.cast"(%3) : (tensor<5x2x10x20xi32>) -> tensor<5x2x10x20xf32> + %6 = "tosa.transpose"(%5, %1) : (tensor<5x2x10x20xf32>, tensor<4xi32>) -> tensor<5x10x20x2xf32> + %7 = "tosa.transpose"(%4, %1) : (tensor<10x2x3x3xf32>, tensor<4xi32>) -> tensor<10x3x3x2xf32> + %8 = "tosa.conv2d"(%6, %7, %2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<5x10x20x2xf32>, tensor<10x3x3x2xf32>, tensor<10xf32>) -> tensor<5x8x18x10xf32> + %9 = "tosa.transpose"(%8, %0) : (tensor<5x8x18x10xf32>, tensor<4xi32>) -> tensor<5x10x8x18xf32> + %10 = bufferization.alloc_tensor() copy(%9) : tensor<5x10x8x18xf32> + %11 = tensor.cast %10 : tensor<5x10x8x18xf32> to tensor<*xf32> + call @printMemrefF32(%11) : (tensor<*xf32>) -> () + call @printNewline() : () -> () + return + } +} diff --git a/mlir/test/Dialect/Tosa/tosa-partition-transpose.mlir b/mlir/test/Dialect/Tosa/tosa-partition-transpose.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-partition-transpose.mlir @@ -0,0 +1,51 @@ +// RUN: mlir-opt --tosa-partition %s | FileCheck %s +// CHECK-LABEL: func private @forward__part_0 +// CHECK-NEXT: tosa.const +// CHECK-NEXT: tosa.const +// CHECK-NEXT: tosa.transpose +// CHECK-NEXT: tosa.transpose +// CHECK-NEXT: tosa.conv2d +// CHECK-NEXT: tosa.const +// CHECK-NEXT: tosa.transpose +// CHECK: return +// CHECK-LABEL: func private @forward__part_1 +// CHECK-NEXT: tosa.const +// CHECK-NEXT: tosa.const +// CHECK-NEXT: tosa.transpose +// CHECK-NEXT: tosa.transpose +// CHECK-NEXT: tosa.conv2d +// CHECK-NEXT: tosa.const +// CHECK-NEXT: tosa.transpose +// CHECK-NEXT: return + +module attributes {torch.debug_module_name = "ResNet"} { + func.func @forward(%arg0: tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32> { + %21 = "tosa.const"() {value = dense<"0x538ED43C"> : tensor<64x64x3x3xf32>} : () -> tensor<64x64x3x3xf32> + %22 = "tosa.const"() {value = dense<"0x14E76B3D"> : tensor<64x64x3x3xf32>} : () -> tensor<64x64x3x3xf32> + %24 = "tosa.const"() {value = dense<0.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32> + %25 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> + %26 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> + %35 = "tosa.const"() {value = dense<[[[4.350550e-01]], [[0.204353958]], [[0.234371066]], [[0.555869102]], [[0.9625808]], [[0.34837237]], [[0.0870912894]], [[0.685126185]], [[0.471447408]], [[1.26418459]], [[0.15185073]], [[0.672973871]], [[0.242970988]], [[0.557663679]], [[0.870111048]], [[0.241860405]], [[0.20524618]], [[0.814857721]], [[0.304025054]], [[0.261729032]], [[0.805968225]], [[0.800655066]], [[1.55810392]], [[0.240396887]], [[0.444482625]], [[0.676496982]], [[0.55617404]], [[0.937783837]], [[0.258368522]], [[0.317292869]], [[0.0961778983]], [[0.411815882]], [[0.519685268]], [[0.976713955]], [[1.27033126]], [[0.890840947]], [[0.360923588]], [[0.222713485]], [[1.15876949]], [[1.59651864]], [[0.406010926]], [[0.255918652]], [[0.176323771]], [[0.279654771]], [[0.375710368]], [[0.128194466]], [[1.82796717]], [[0.314452082]], [[0.741928279]], [[0.21285513]], [[0.812224507]], [[0.46600005]], [[0.406509399]], [[0.491434872]], [[4.813950e-01]], [[0.169687942]], [[0.400027037]], [[3.866580e-01]], [[0.149923742]], [[0.413703024]], [[0.0671055093]], [[0.83032006]], [[0.243440315]], [[0.344939411]]]> : tensor<64x1x1xf32>} : () -> tensor<64x1x1xf32> + %36 = "tosa.const"() {value = dense<[[[[-0.433190078]], [[-0.175694913]], [[0.0307474807]], [[-0.705806791]], [[-1.63644612]], [[-0.798920393]], [[-0.0678167045]], [[-0.195550278]], [[-1.12603867]], [[-9.577850e-01]], [[0.00303830206]], [[-1.82651103]], [[-0.0393386371]], [[-0.867985606]], [[-1.10618353]], [[-0.635920107]], [[-0.98719418]], [[-0.577786744]], [[-1.33491731]], [[-0.340780795]], [[-1.19819415]], [[-1.60575163]], [[-2.17018104]], [[-0.881360471]], [[-0.817468405]], [[-0.695106149]], [[0.654207646]], [[-1.64223909]], [[0.281131387]], [[0.316296548]], [[-0.412329644]], [[-1.40228522]], [[-1.5043844]], [[-2.5030787]], [[-2.15804982]], [[-1.36450112]], [[-0.857909083]], [[-0.220595926]], [[-2.55477524]], [[-2.26948547]], [[-0.160925135]], [[-0.855172097]], [[0.528860927]], [[1.34920466]], [[-0.938154459]], [[-0.335565716]], [[-2.91677713]], [[-1.59672499]], [[-1.88746727]], [[-1.61662126]], [[-1.94432163]], [[-2.01948833]], [[-0.967084765]], [[-1.38807428]], [[-1.88363433]], [[0.186851829]], [[-1.3487308]], [[-0.459332734]], [[-0.454171181]], [[-0.903211176]], [[-0.0767973289]], [[-1.77187669]], [[1.24841809]], [[-0.913857877]]]]> : tensor<1x64x1x1xf32>} : () -> tensor<1x64x1x1xf32> + %37 = "tosa.const"() {value = dense<[[[[0.309043467]], [[0.214711159]], [[0.236576676]], [[0.425931394]], [[0.513686061]], [[0.21810919]], [[0.220440537]], [[0.229953438]], [[0.264005542]], [[0.269451439]], [[0.213777632]], [[0.460190892]], [[0.266088396]], [[0.231897846]], [[0.390032232]], [[0.238854751]], [[0.266045094]], [[0.363448173]], [[0.347412556]], [[0.247700587]], [[3.284530e-01]], [[0.53493458]], [[0.644020081]], [[0.227493554]], [[0.448180854]], [[0.30781728]], [[0.260412276]], [[0.465068072]], [[0.217946246]], [[0.285815388]], [[0.342561722]], [[0.441996783]], [[0.444976747]], [[0.450018674]], [[0.551550686]], [[0.509206712]], [[0.256445497]], [[0.263418853]], [[0.566351056]], [[0.640956103]], [[0.222756818]], [[0.198601693]], [[0.245954156]], [[0.224201113]], [[0.214311257]], [[0.198216185]], [[0.636765599]], [[0.310554832]], [[0.504927635]], [[0.240274549]], [[0.306493819]], [[0.375986129]], [[0.379380524]], [[0.428064138]], [[0.299081922]], [[0.332632065]], [[0.259646714]], [[0.334532797]], [[2.005660e-01]], [[0.435055643]], [[0.168260247]], [[0.514937878]], [[0.262882739]], [[0.325370371]]]]> : tensor<1x64x1x1xf32>} : () -> tensor<1x64x1x1xf32> + %38 = "tosa.const"() {value = dense<[[[[0.165660679]], [[0.241959468]], [[0.177964702]], [[-0.0431376398]], [[-0.205331057]], [[0.159750536]], [[0.292935789]], [[0.0911777243]], [[0.1116466]], [[0.0883733555]], [[0.110413767]], [[-0.203523606]], [[0.153928623]], [[8.571950e-02]], [[-0.109398432]], [[0.0654498711]], [[0.0765945539]], [[-0.206687212]], [[-0.02118567]], [[0.139637291]], [[0.0401249304]], [[-0.282749802]], [[-0.325705916]], [[-0.00346554257]], [[-0.437301368]], [[-0.124844968]], [[0.128242895]], [[-0.0873934776]], [[0.119922049]], [[-0.0829298421]], [[-0.531535208]], [[-7.803100e-02]], [[-0.387627274]], [[-0.0546562374]], [[-0.181554914]], [[-0.188849777]], [[0.132011667]], [[0.00313512608]], [[-0.269714326]], [[-0.298389643]], [[0.139352217]], [[0.259656608]], [[0.137245387]], [[0.00526427291]], [[0.0131502384]], [[0.329455644]], [[-0.271506935]], [[-0.0186743364]], [[-0.246697694]], [[0.157942355]], [[0.0164602138]], [[-0.0890493542]], [[-0.190278217]], [[-0.078674376]], [[0.170012966]], [[-0.483182877]], [[0.0619214252]], [[-0.0677097216]], [[0.312465042]], [[-0.506395757]], [[0.31381321]], [[-0.261734873]], [[-0.154493123]], [[0.00628100662]]]]> : tensor<1x64x1x1xf32>} : () -> tensor<1x64x1x1xf32> + %103 = "tosa.const"() {value = dense<9.99999974E-6> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32> + %122 = "tosa.transpose"(%arg0, %25) : (tensor<1x64x56x56xf32>, tensor<4xi32>) -> tensor<1x56x56x64xf32> + %123 = "tosa.transpose"(%22, %25) : (tensor<64x64x3x3xf32>, tensor<4xi32>) -> tensor<64x3x3x64xf32> + %124 = "tosa.conv2d"(%122, %123, %24) {dilation = [1, 1], pad = [1, 1, 1, 1], stride = [1, 1]} : (tensor<1x56x56x64xf32>, tensor<64x3x3x64xf32>, tensor<64xf32>) -> tensor<1x56x56x64xf32> + %125 = "tosa.transpose"(%124, %26) : (tensor<1x56x56x64xf32>, tensor<4xi32>) -> tensor<1x64x56x56xf32> + %126 = "tosa.sub"(%125, %36) : (tensor<1x64x56x56xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x56x56xf32> + %127 = "tosa.add"(%35, %103) : (tensor<64x1x1xf32>, tensor<1x1x1xf32>) -> tensor<64x1x1xf32> + %128 = "tosa.rsqrt"(%127) : (tensor<64x1x1xf32>) -> tensor<64x1x1xf32> + %129 = "tosa.reshape"(%128) {new_shape = [1, 64, 1, 1]} : (tensor<64x1x1xf32>) -> tensor<1x64x1x1xf32> + %130 = "tosa.mul"(%126, %129) {shift = 0 : i32} : (tensor<1x64x56x56xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x56x56xf32> + %131 = "tosa.mul"(%130, %37) {shift = 0 : i32} : (tensor<1x64x56x56xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x56x56xf32> + %132 = "tosa.add"(%131, %38) : (tensor<1x64x56x56xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x56x56xf32> + %133 = "tosa.clamp"(%132) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32> + %134 = "tosa.transpose"(%133, %25) : (tensor<1x64x56x56xf32>, tensor<4xi32>) -> tensor<1x56x56x64xf32> + %135 = "tosa.transpose"(%21, %25) : (tensor<64x64x3x3xf32>, tensor<4xi32>) -> tensor<64x3x3x64xf32> + %136 = "tosa.conv2d"(%134, %135, %24) {dilation = [1, 1], pad = [1, 1, 1, 1], stride = [1, 1]} : (tensor<1x56x56x64xf32>, tensor<64x3x3x64xf32>, tensor<64xf32>) -> tensor<1x56x56x64xf32> + %137 = "tosa.transpose"(%136, %26) : (tensor<1x56x56x64xf32>, tensor<4xi32>) -> tensor<1x64x56x56xf32> + return %137 : tensor<1x64x56x56xf32> + } +} diff --git a/mlir/test/Dialect/Tosa/tosa-partition.mlir b/mlir/test/Dialect/Tosa/tosa-partition.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-partition.mlir @@ -0,0 +1,101 @@ +// RUN: mlir-opt --split-input-file --tosa-partition %s -verify-each=0 -o - | FileCheck %s + +// CHECK-LABEL: func private @test_fusion__part_0 +// CHECK: tosa.conv2d +// CHECK: tosa.abs +// CHECK: return +// CHECK: func @test_fusion +// CHECK: call @test_fusion__part_0 +func.func @test_fusion(%arg0: tensor<128x8x32x32xf32>, %arg1: tensor<128x8x3x3xf32>, %arg2: tensor<8xf32>) -> tensor<128x128x30x30xf32> { + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<128x8x32x32xf32>, tensor<128x8x3x3xf32>, tensor<8xf32>) -> tensor<128x128x30x30xf32> + %1 = "tosa.abs"(%0) {} : (tensor<128x128x30x30xf32>) -> tensor<128x128x30x30xf32> + return %1 : tensor<128x128x30x30xf32> +} + + +// CHECK-LABEL: func private @test_fusion2__part_0 +// CHECK: tosa.conv2d +// CHECK: tosa.negate +// CHECK: return +// CHECK: func @test_fusion2 +// CHECK: call @test_fusion2__part_0 +func.func @test_fusion2(%arg0: tensor<128x8x32x32xf32>, %arg1: tensor<128x8x3x3xf32>, %arg2: tensor<8xf32>) -> tensor<128x128x30x30xf32> { + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<128x8x32x32xf32>, tensor<128x8x3x3xf32>, tensor<8xf32>) -> tensor<128x128x30x30xf32> + %1 = "tosa.negate"(%0) {} : (tensor<128x128x30x30xf32>) -> tensor<128x128x30x30xf32> + return %1 : tensor<128x128x30x30xf32> +} + + +// CHECK-LABEL: func private @test_fusion3__part_0 +// CHECK: tosa.conv2d +// CHECK: tosa.abs +// CHECK: tosa.negate +// CHECK: return +// CHECK: func @test_fusion3 +// CHECK: call @test_fusion3__part_0 +func.func @test_fusion3(%arg0: tensor<128x8x32x32xf32>, %arg1: tensor<128x8x3x3xf32>, %arg2: tensor<8xf32>) -> tensor<128x128x30x30xf32> { + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<128x8x32x32xf32>, tensor<128x8x3x3xf32>, tensor<8xf32>) -> tensor<128x128x30x30xf32> + %1 = "tosa.abs"(%0) {} : (tensor<128x128x30x30xf32>) -> tensor<128x128x30x30xf32> + %2 = "tosa.negate"(%1) {} : (tensor<128x128x30x30xf32>) -> tensor<128x128x30x30xf32> + return %2 : tensor<128x128x30x30xf32> +} + + +// CHECK-LABEL: func private @test_fusion4__part_0 +// CHECK: tosa.conv2d +// CHECK: tosa.abs +// +++pf: This test used to absorb the tosa.add, too, but doesn't now. +// CHECK: return +// CHECK: func @test_fusion4 +// CHECK: call @test_fusion4__part_0 +func.func @test_fusion4(%arg0: tensor<128x8x32x32xf32>, %arg1: tensor<128x8x3x3xf32>, %arg2: tensor<8xf32>) -> tensor<128x128x30x30xf32> { + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<128x8x32x32xf32>, tensor<128x8x3x3xf32>, tensor<8xf32>) -> tensor<128x128x30x30xf32> + %1 = "tosa.abs"(%0) {} : (tensor<128x128x30x30xf32>) -> tensor<128x128x30x30xf32> + %2 = "tosa.add"(%0, %1) {} : (tensor<128x128x30x30xf32>, tensor<128x128x30x30xf32>) -> tensor<128x128x30x30xf32> + return %2 : tensor<128x128x30x30xf32> +} + + +// CHECK-LABEL: func private @test_fusion5__part_0 +// CHECK-NEXT: tosa.conv2d +// CHECK-NEXT: tosa.abs +// CHECK-NEXT: tosa.add +// CHECK-NEXT: return +// CHECK: func private @test_fusion5__part_1 +// CHECK-NEXT: tosa.conv2d +// CHECK-NEXT: return +// CHECK: func @test_fusion5 +// CHECK-NEXT: call @test_fusion5__part_1 +// CHECK-NEXT: call @test_fusion5__part_0 +func.func @test_fusion5(%arg0: tensor<128x8x32x32xf32>, %arg1: tensor<128x8x3x3xf32>, %arg2: tensor<8xf32>, %arg3: tensor<128x8x32x32xf32>, %arg4: tensor<128x8x3x3xf32>, %arg5: tensor<8xf32>) -> tensor<128x128x30x30xf32> { + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<128x8x32x32xf32>, tensor<128x8x3x3xf32>, tensor<8xf32>) -> tensor<128x128x30x30xf32> + %1 = "tosa.conv2d"(%arg3, %arg4, %arg5) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<128x8x32x32xf32>, tensor<128x8x3x3xf32>, tensor<8xf32>) -> tensor<128x128x30x30xf32> + %2 = "tosa.abs"(%0) {} : (tensor<128x128x30x30xf32>) -> tensor<128x128x30x30xf32> + %3 = "tosa.add"(%2, %1) {} : (tensor<128x128x30x30xf32>, tensor<128x128x30x30xf32>) -> tensor<128x128x30x30xf32> + return %3 : tensor<128x128x30x30xf32> +} + + +// CHECK-LABEL: func private @test_fusion6__part_0 +// CHECK-NEXT: tosa.conv2d +// CHECK-NEXT: return +// CHECK: func @test_fusion6 +// CHECK-NEXT: call @test_fusion6__part_0 +// CHECK-NEXT: return +func.func @test_fusion6(%arg0: tensor<128x8x32x32xf32>, %arg1: tensor<128x8x3x3xf32>, %arg2: tensor<8xf32>) -> tensor<128x128x30x30xf32> { + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<128x8x32x32xf32>, tensor<128x8x3x3xf32>, tensor<8xf32>) -> tensor<128x128x30x30xf32> + return %0 : tensor<128x128x30x30xf32> +} + +// CHECK-LABEL: func private @test_fusion7__part_0 +// CHECK-NEXT: tosa.abs +// CHECK-NEXT: tosa.conv2d +// CHECK-NEXT: return +// CHECK: func @test_fusion7 +// CHECK-NEXT: call @test_fusion7__part_0 +// CHECK-NEXT: return +func.func @test_fusion7(%arg0: tensor<128x8x32x32xf32>, %arg1: tensor<128x8x3x3xf32>, %arg2: tensor<8xf32>) -> tensor<128x128x30x30xf32> { + %0 = "tosa.abs"(%arg0) {} : (tensor<128x8x32x32xf32>) -> tensor<128x8x32x32xf32> + %1 = "tosa.conv2d"(%0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<128x8x32x32xf32>, tensor<128x8x3x3xf32>, tensor<8xf32>) -> tensor<128x128x30x30xf32> + return %1 : tensor<128x128x30x30xf32> +} diff --git a/mlir/test/lib/Dialect/Tosa/CMakeLists.txt b/mlir/test/lib/Dialect/Tosa/CMakeLists.txt --- a/mlir/test/lib/Dialect/Tosa/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Tosa/CMakeLists.txt @@ -14,4 +14,5 @@ MLIRPass MLIRTosaDialect MLIRTransformUtils + MLIRTosaTransforms ) diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp --- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp +++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp @@ -18,7 +18,9 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" #define PASS_NAME "tosa-test-quant-utils" @@ -205,3 +207,78 @@ PassRegistration(); } } // namespace mlir + +namespace { + +class TestTosaPartitionOptionsPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTosaPartitionOptionsPass) + + StringRef getArgument() const final { return "test-tosa-partition-options"; } + StringRef getDescription() const final { + return "Tests the programmatic interface to --tosa-partition options."; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + TestTosaPartitionOptionsPass() = default; + TestTosaPartitionOptionsPass(const TestTosaPartitionOptionsPass &) {} + + void runOnOperation() override { + ModuleOp module = getOperation(); + PassManager pm(module.getContext(), mlir::PassManager::Nesting::Implicit); + if (defaultCase) { + pm.addPass(createTosaPartition()); + } else if (depthwiseOnly) { + SmallVector anchors = {"tosa.depthwise_conv2d"}; + TosaPartitionOptions options; + options.anchorOps = anchors; + pm.addPass(createTosaPartition(options)); + } else if (convOnly) { + SmallVector anchors = {"tosa.conv2d"}; + TosaPartitionOptions options; + options.anchorOps = anchors; + options.partitionTagOpt = "four"; + pm.addPass(createTosaPartition(options)); + } else if (attrOne) { + // TODO: Once list options can have defaults, use that + SmallVector anchors = {"tosa.conv2d", "tosa.matmul", + "tosa.depthwise_conv2d"}; + TosaPartitionOptions options; + options.anchorOps = anchors; + options.partitionTagOpt = "one"; + pm.addPass(createTosaPartition(options)); + } else if (nofrontArg) { + SmallVector anchors = {"tosa.depthwise_conv2d"}; + TosaPartitionOptions options; + options.anchorOps = anchors; + options.trailingOnly = true; + options.partitionTagOpt = "three"; + pm.addPass(createTosaPartition(options)); + } + + if (failed(pm.run(module))) + signalPassFailure(); + } + + Option defaultCase{*this, "default", llvm::cl::desc("Default.")}; + Option depthwiseOnly{*this, "depthwise-only", + llvm::cl::desc("Depthwise only.")}; + Option convOnly{*this, "conv-only", llvm::cl::desc("Only conv2d.")}; + Option attrOne{*this, "attr-one", + llvm::cl::desc("Attribute-name 'one'.")}; + Option nofrontArg{*this, "nofront-arg", + llvm::cl::desc("Nofront as arg.")}; +}; +} // namespace + +namespace mlir { +namespace test { +void registerTestTosaPartitionOptionsPass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -37,6 +37,7 @@ MLIRTestRewrite MLIRTestTransformDialect MLIRTestTransforms + MLIRTosaTestPasses MLIRTilingInterfaceTestPasses MLIRVectorTestPasses MLIRLLVMTestPasses diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -118,6 +118,7 @@ void registerTestTensorTransforms(); void registerTestTilingInterface(); void registerTestTopologicalSortAnalysisPass(); +void registerTestTosaPartitionOptionsPass(); void registerTestTransformDialectEraseSchedulePass(); void registerTestTransformDialectInterpreterPass(); void registerTestVectorLowerings(); @@ -157,8 +158,8 @@ registerTestSpirvEntryPointABIPass(); registerTestSpirvModuleCombinerPass(); registerTestTraitsPass(); - registerVectorizerTestPass(); registerTosaTestQuantUtilAPIPass(); + registerVectorizerTestPass(); mlir::test::registerCommutativityUtils(); mlir::test::registerConvertCallOpPass(); @@ -223,6 +224,7 @@ mlir::test::registerTestTensorTransforms(); mlir::test::registerTestTilingInterface(); mlir::test::registerTestTopologicalSortAnalysisPass(); + mlir::test::registerTestTosaPartitionOptionsPass(); mlir::test::registerTestTransformDialectEraseSchedulePass(); mlir::test::registerTestTransformDialectInterpreterPass(); mlir::test::registerTestVectorLowerings();