Index: mlir/include/mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h @@ -0,0 +1,64 @@ +//===- ShapeMappingAnalysis.h - Preserve shape Info ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SHAPE_ANALYSIS_SHAPEMAPPINGANALYSIS_H_ +#define MLIR_DIALECT_SHAPE_ANALYSIS_SHAPEMAPPINGANALYSIS_H_ + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { + +namespace shape { + +/// ShapeMappingValue works as the value of ShapeMappingAnalysis table, where +/// `funcSymbol` is the symbol of mapping function, and `inputs` are the actual +/// parameters for the function. +struct ShapeMappingValue { + ShapeMappingValue() = default; + ShapeMappingValue(FlatSymbolRefAttr symbol, llvm::SmallVector &&inps) + : funcSymbol(symbol), inputs(inps) {} + + FlatSymbolRefAttr funcSymbol; + llvm::SmallVector inputs; +}; + +/// ShapeMappingAnalysis is used together with OutlineShapeComputationPass to +/// preserve Value and corresponding shape function / arguments mapping +/// infomation +struct ShapeMappingAnalysis { + + ShapeMappingAnalysis(Operation *op) : operation(op) {} + + /// Dumps the shape mapping information to the given stream. + void print(raw_ostream &os) const { + os << "// ---- Shape Mapping Infomation -----\n"; + for (auto it : shapeMapping) { + const ShapeMappingValue &mappingValue = it.second; + os << "// - Shape for: " << it.first << "\n"; + os << "// - Shape function symbol: " << mappingValue.funcSymbol << "\n"; + os << "// - Shape function arguments: \n"; + for (Value v : mappingValue.inputs) { + os << "// - " << v << "\n"; + } + os << "// - End of one shape mapping item -----\n\n"; + } + } + + llvm::DenseMap shapeMapping; + +private: + Operation *operation; +}; + +} // namespace shape +} // namespace mlir + +#endif // MLIR_DIALECT_SHAPE_ANALYSIS_SHAPEMAPPINGANALYSIS_H_ Index: mlir/include/mlir/Dialect/Shape/Transforms/Passes.h =================================================================== --- mlir/include/mlir/Dialect/Shape/Transforms/Passes.h +++ mlir/include/mlir/Dialect/Shape/Transforms/Passes.h @@ -18,6 +18,7 @@ namespace mlir { class ConversionTarget; +class ModuleOp; class TypeConverter; namespace func { class FuncOp; @@ -53,6 +54,10 @@ // level. std::unique_ptr> createShapeBufferizePass(); +/// Outline the shape computation part by adding shape.func and populate +/// conrresponding mapping infomation into ShapeMappingAnalysis. +std::unique_ptr> createOutlineShapeComputationPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// Index: mlir/include/mlir/Dialect/Shape/Transforms/Passes.td =================================================================== --- mlir/include/mlir/Dialect/Shape/Transforms/Passes.td +++ mlir/include/mlir/Dialect/Shape/Transforms/Passes.td @@ -11,6 +11,94 @@ include "mlir/Pass/PassBase.td" +def OutlineShapeComputation : Pass<"outline-shape-computation", "ModuleOp"> { + let summary = "Using shape.func to preserve shape computation"; + let description = [{ + This pass outlines the shape computation part in high level IR by adding + shape.func and and populate conrresponding mapping infomation into + ShapeMappingAnalysis. The shape computation part is usually introduced by + shape reification, and each single dynamic shape is denoted by shape.with_shape. + + There're two main reasons this shape-outline pass is needed: + 1. Many passes don't take shape reification part into consideration. + Therefore we need to "remove" the shape reification part temporarily for + these passes. + 2. Sometimes we cannot redo shape reification after converting from dialect + A to dialect B. Because op-level shape reification is only implemented + on A. + + Input: + ```mlir + func.func @main(%arg0: tensor, %arg1: tensor<2x4x?xf32>) -> + tensor { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %0 = shape.shape_of %arg0 : tensor -> tensor<3xindex> + %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index + %2 = "test.abs"(%arg0) : (tensor) -> tensor + %3 = shape.with_shape %2, %0 : tensor, tensor<3xindex> + %4 = shape.value_of %3 : tensor + %5 = "test.concat"(%4, %arg1) {axis = 0 : i64} : (tensor, + tensor<2x4x?xf32>) -> tensor + %6 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index + %7 = arith.addi %6, %c2 : index + %8 = shape.from_extents %7, %c4, %1 : index, index, index + %9 = shape.with_shape %5, %8 : tensor, !shape.shape + %10 = shape.value_of %9 : tensor + return %10 : tensor + } + ``` + + Output + ```mlir + func.func @main(%arg0: tensor, %arg1: tensor<2x4x?xf32>) -> + tensor { + %0 = "test.abs"(%arg0) : (tensor) -> tensor + %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor, + tensor<2x4x?xf32>) -> tensor + return %1 : tensor + } + shape.func private @shape_cal_1(%arg0: tensor) -> !shape.shape { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %0 = shape_of %arg0 : tensor -> tensor<3xindex> + %1 = get_extent %0, %c2 : tensor<3xindex>, index -> index + %2 = get_extent %0, %c0 : tensor<3xindex>, index -> index + %3 = arith.addi %2, %c2 : index + %4 = from_extents %3, %c4, %1 : index, index, index + return %4 : !shape.shape + } + shape.func private @shape_cal_0(%arg0: tensor) -> tensor<3xindex> { + %0 = shape_of %arg0 : tensor -> tensor<3xindex> + return %0 : tensor<3xindex> + } + ``` + + For the above example, the shape computation is inlined in the input IR, + which is used for two values' (test.abs and test.concat) shape. And the shape + compuatation part is outlined in the output IR. + + And the shape mapping infomation will be: + // ---- Shape Mapping Infomation ----- + // - Shape for: %0 = "test.abs"(%arg0) : (tensor) -> tensor + // - Shape function symbol: @shape_cal_0 + // - Shape function arguments: + // - of type 'tensor' at index: 0 + // - End of one shape mapping item ----- + + // - Shape for: %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor, tensor<2x4x?xf32>) -> tensor + // - Shape function symbol: @shape_cal_1 + // - Shape function arguments: + // - of type 'tensor' at index: 0 + // - End of one shape mapping item ----- + + }]; + let constructor = "mlir::createOutlineShapeComputationPass()"; + let dependentDialects = ["shape::ShapeDialect"]; +} + def RemoveShapeConstraints : Pass<"remove-shape-constraints", "func::FuncOp"> { let summary = "Replace all cstr_ ops with a true witness"; let constructor = "mlir::createRemoveShapeConstraintsPass()"; Index: mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt +++ mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRShapeOpsTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + OutlineShapeComputation.cpp RemoveShapeConstraints.cpp ShapeToShapeLowering.cpp Index: mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp @@ -0,0 +1,319 @@ +//====----- OutlineShapeComputation.cpp -------------------------*- C++-*--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/Support/Debug.h" +#include +#include +#include + +namespace mlir { +#define GEN_PASS_DEF_OUTLINESHAPECOMPUTATION +#include "mlir/Dialect/Shape/Transforms/Passes.h.inc" +} // namespace mlir + +#define DEBUG_TYPE "outline-shape-computation" + +using namespace mlir; + +namespace { + +// A Value is an input of the cluster if it is an operand of an operation in the +// cluster and its defining operation is not in the cluster. +ValueRange +getInputsOfCluster(const llvm::SmallVector &cluster) { + SmallVector inputs; + llvm::SmallDenseSet inputSet; + llvm::SmallDenseSet opSet; + for (Operation *op : cluster) { + bool inserted = opSet.insert(op).second; + (void)inserted; + assert(inserted && "cluster contains duplicate operations"); + } + + for (Operation *op : cluster) { + for (Value operand : op->getOperands()) { + Operation *operandOp = operand.getDefiningOp(); + if (opSet.find(operandOp) != opSet.end()) { + // skip if defining op is in the cluster + continue; + } + if (inputSet.insert(operand).second) + inputs.push_back(operand); + } + } + return inputs; +} + +// Create a shape.func representing the shape computation for `shape`. +std::pair> +createFuncFromCluster(OpBuilder &b, const SmallVector &cluster, + Value shape, StringRef fnName, Location loc) { + ValueRange inputs = getInputsOfCluster(cluster); + auto fnType = cluster.empty() + ? b.getFunctionType(shape.getType(), shape.getType()) + : b.getFunctionType(inputs.getTypes(), shape.getType()); + shape::FuncOp fnOp = b.create(loc, fnName, fnType); + Block *block = fnOp.addEntryBlock(); + b.setInsertionPoint(block, block->end()); + BlockAndValueMapping bvm; + if (cluster.empty()) + bvm.map(shape, fnOp.getArgument(0)); + else + for (auto inputAndArg : llvm::zip(inputs, fnOp.getArguments())) { + bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg)); + } + + for (Operation *op : cluster) { + b.clone(*op, bvm); + } + llvm::SmallVector fnReturns; + fnReturns.push_back(bvm.lookupOrDefault(shape)); + + b.create(UnknownLoc::get(b.getContext()), fnReturns); + fnOp.setPrivate(); + return std::make_pair(fnOp, inputs); +} + +// The operations in the cluster might be unsorted, which could be inconvenient +// when creating shape.func op. +DenseMap> +getOrderedClusters(const DenseMap> &clusters, + func::FuncOp funcOp) { + // Compute all clusters that each operation is in + DenseMap> op2Shapes; + for (auto it : clusters) { + Value shape = it.first; + const DenseSet &cluster = it.second; + for (Operation *cOp : cluster) { + op2Shapes[cOp].push_back(shape); + } + } + + // Iterate through all operations in order. Get all the clusters `cOp` belongs + // to and construct the new ordered cluster as it traverses. + DenseMap> orderedClusters; + funcOp.walk([&](Operation *op) { + auto it = op2Shapes.find(op); + if (it != op2Shapes.end()) { + Operation *cOp = it->first; + for (Value shape : it->second) { + orderedClusters[shape].push_back(cOp); + } + } + }); + + return orderedClusters; +} + +void constructShapeFunc( + const std::vector &allWithOps, MLIRContext *context, + DenseMap> &clusters, + SymbolTable &symbolTable, + DenseMap &dynShape2ShapeFunc, func::FuncOp funcOp, + shape::ShapeMappingAnalysis &shapeMappingAnalysis) { + std::string shapeCalculationNamePrefix = "shape_cal_"; + int shapeCalculationNameIdx = 0; + OpBuilder builder(context); + + // Construct a shape function + for (shape::WithOp withOp : allWithOps) { + Value value = withOp.getOperand(); + Value shape = withOp.getShape(); + RankedTensorType rankedType = value.getType().dyn_cast(); + if (rankedType == nullptr) + continue; + + const SmallVector &cluster = clusters[shape]; + shape::ShapeMappingValue shapeMappingValue; + auto it = dynShape2ShapeFunc.find(shape); + if (it == dynShape2ShapeFunc.end()) { + std::string name = shapeCalculationNamePrefix + + std::to_string(shapeCalculationNameIdx++); + Location loc = value.getLoc(); + builder.setInsertionPointAfter(funcOp); + auto pair = createFuncFromCluster(builder, cluster, shape, name, loc); + const SmallVector &inputs = pair.second; + shape::FuncOp shapeFuncOp = pair.first; + StringAttr insertedName = symbolTable.insert(shapeFuncOp); + auto symbol = FlatSymbolRefAttr::get(context, insertedName); + + shapeMappingValue.funcSymbol = symbol; + shapeMappingValue.inputs = inputs; + } else { + shapeMappingValue = it->second; + } + dynShape2ShapeFunc[shape] = shapeMappingValue; + shapeMappingAnalysis.shapeMapping.insert( + std::make_pair(value, shapeMappingValue)); + } +} + +struct OutlineShapeComputationPass + : public impl::OutlineShapeComputationBase { + + void runOnOperation() override; + +private: + bool calOnlyUsedByWithShapesRecursively(Operation *op, Value prevOutput); + void getClusterFromValue(Value shape, + DenseMap> &clusters); + DenseMap> + constructClustersForEachShape(const std::vector &allWithOps, + func::FuncOp funcOp); + DenseMap onlyUsedByWithShapes; +}; + +class TensorDimOpRewriter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::DimOp op, + PatternRewriter &rewriter) const override { + auto shapeOf = + rewriter.create(op.getLoc(), op.getSource()); + rewriter.replaceOpWithNewOp(op, op.getType(), shapeOf, + op.getIndex()); + return success(); + } +}; + +void OutlineShapeComputationPass::runOnOperation() { + ModuleOp moduleOp = getOperation(); + SymbolTable symbolTable(moduleOp); + DenseMap dynShape2ShapeFunc; + auto &shapeMappingAnalysis = getAnalysis(); + + moduleOp.walk([&](func::FuncOp funcOp) { + MLIRContext *context = funcOp.getContext(); + RewritePatternSet prevPatterns(context); + prevPatterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(prevPatterns)))) + return signalPassFailure(); + + // initialize class member `onlyUsedByWithShapes` + onlyUsedByWithShapes.clear(); + funcOp.walk([&](Operation *op) { + calOnlyUsedByWithShapesRecursively(op, /*prevOutput=*/nullptr); + }); + // clang-format off + LLVM_DEBUG(llvm::dbgs() << "onlyUsedByWithShapes table: \n"; + for (auto it : onlyUsedByWithShapes) + if (it.second) + llvm::dbgs() << *(it.first) << ": " << it.second << "\n";); + // clang-format on + + // collect all the shape.with_shape ops. + std::vector allWithOps; + funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); }); + + DenseMap> clusters = + constructClustersForEachShape(allWithOps, funcOp); + constructShapeFunc(allWithOps, context, clusters, symbolTable, + dynShape2ShapeFunc, funcOp, shapeMappingAnalysis); + + for (shape::WithOp withOp : allWithOps) { + Value value = withOp.getOperand(); + for (Operation *user : withOp.getResult().getUsers()) { + if (Value valueOf = llvm::dyn_cast(user)) + valueOf.replaceAllUsesExcept(value, withOp); + } + } + + // dce + if (failed(applyPatternsAndFoldGreedily(funcOp, {}))) + return signalPassFailure(); + }); +} + +DenseMap> +OutlineShapeComputationPass::constructClustersForEachShape( + const std::vector &allWithOps, func::FuncOp funcOp) { + DenseMap> clusters; + for (shape::WithOp withOp : allWithOps) { + Value shape = withOp.getShape(); + if (clusters.count(shape) == 0) + getClusterFromValue(shape, clusters); + } + return getOrderedClusters(clusters, funcOp); +} + +// The output of a cluster is the `shape`, and the inputs are the outputs of +// operations who are not in `onlyUsedByWithShapes` +void OutlineShapeComputationPass::getClusterFromValue( + Value shape, DenseMap> &clusters) { + DenseSet cluster; + Operation *defOp = shape.getDefiningOp(); + + DenseSet visited; + std::queue queue; + // defOp == nullptr means shape is the argument of the func op + if (defOp != nullptr) { + visited.insert(defOp); + queue.push(defOp); + } + while (!queue.empty()) { + Operation *op = queue.front(); + queue.pop(); + if (onlyUsedByWithShapes.lookup(op)) { + cluster.insert(op); + for (Value inp : op->getOperands()) { + Operation *inpDefOp = inp.getDefiningOp(); + if (nullptr != inpDefOp && !visited.contains(inpDefOp)) { + visited.insert(inpDefOp); + queue.push(inpDefOp); + } + } + } + } + + clusters[shape] = std::move(cluster); +} + +// return true if `op` is a shape.with_shape, or all the users' of `op` +// eventually point to the shape operand of shape.with_shape ops +bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively( + Operation *op, Value prevOutput) { + auto it = onlyUsedByWithShapes.find(op); + if (it != onlyUsedByWithShapes.end()) + return it->second; + + if (auto withOp = llvm::dyn_cast(op)) { + bool retVal = withOp.getShape() == prevOutput; + return retVal; + } + + if (op->use_empty()) { + onlyUsedByWithShapes[op] = false; + return false; + } + + bool allUsers = true; + for (Value oup : op->getResults()) + for (Operation *user : oup.getUsers()) + allUsers &= calOnlyUsedByWithShapesRecursively(user, oup); + + onlyUsedByWithShapes[op] = allUsers; + return allUsers; +} + +} // namespace + +std::unique_ptr> +mlir::createOutlineShapeComputationPass() { + return std::make_unique(); +} Index: mlir/test/Dialect/Shape/outline-shape-computation.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Shape/outline-shape-computation.mlir @@ -0,0 +1,205 @@ +// RUN: mlir-opt -outline-shape-computation -split-input-file -allow-unregistered-dialect %s 2>&1 | FileCheck %s + +// two dynamic shapes: one of direct shape.shape_of(arg) and the other +func.func @main(%arg0: tensor, %arg1: tensor<2x4x?xf32>) -> tensor { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %0 = shape.shape_of %arg0 : tensor -> tensor<3xindex> + %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index + %2 = "test.abs"(%arg0) : (tensor) -> tensor + %3 = shape.with_shape %2, %0 : tensor, tensor<3xindex> + %4 = shape.value_of %3 : tensor + %5 = "test.concat"(%4, %arg1) {axis = 0 : i64} : (tensor, tensor<2x4x?xf32>) -> tensor + %6 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index + %7 = arith.addi %6, %c2 : index + %8 = shape.from_extents %7, %c4, %1 : index, index, index + %9 = shape.with_shape %5, %8 : tensor, !shape.shape + %10 = shape.value_of %9 : tensor + return %10 : tensor +} + +// CHECK-LABEL: func.func @main +// CHECK-NEXT: %0 = "test.abs"(%arg0) : (tensor) -> tensor +// CHECK-NEXT: %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor, tensor<2x4x?xf32>) -> tensor +// CHECK-NEXT: return %1 : tensor + +// CHECK: shape.func private @shape_cal_1(%arg0: tensor) -> !shape.shape { +// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor -> tensor<3xindex> +// CHECK-DAG: %[[V1:.*]] = get_extent %[[V0]], %c2 : tensor<3xindex>, index -> index +// CHECK-DAG: %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<3xindex>, index -> index +// CHECK-DAG: %[[V3:.*]] = arith.addi %[[V2]], %c2 : index +// CHECK-DAG: %[[V4:.*]] = from_extents %[[V3]], %c4, %[[V1]] : index, index, index +// CHECK-DAG: return %[[V4]] : !shape.shape + +// CHECK: shape.func private @shape_cal_0(%arg0: tensor) -> tensor<3xindex> { +// CHECK-DAG: %0 = shape_of %arg0 : tensor -> tensor<3xindex> +// CHECK-DAG: return %0 : tensor<3xindex> + +// ----- + +// two dynamic shapes and they share the same shape.func +func.func @main(%arg0: tensor, %arg1: tensor<2x4x?xf32>) -> tensor { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %0 = shape.shape_of %arg0 : tensor -> tensor<3xindex> + %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index + %2 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor, tensor<2x4x?xf32>) -> tensor + %3 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index + %4 = arith.addi %3, %c2 : index + %5 = shape.from_extents %4, %c4, %1 : index, index, index + %6 = shape.with_shape %2, %5 : tensor, !shape.shape + %7 = shape.value_of %6 : tensor + %8 = "test.abs"(%7) : (tensor) -> tensor + %9 = shape.with_shape %8, %5 : tensor, !shape.shape + %10 = shape.value_of %9 : tensor + return %10 : tensor +} +// CHECK-LABEL: func.func @main +// CHECK-NEXT: %0 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor, tensor<2x4x?xf32>) -> tensor +// CHECK-NEXT: %1 = "test.abs"(%0) : (tensor) -> tensor +// CHECK-NEXT: return %1 : tensor + +// CHECK: shape.func private @shape_cal_0(%arg0: tensor) -> !shape.shape { +// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor -> tensor<3xindex> +// CHECK-DAG: %[[V1:.*]] = get_extent %[[V0]], %c2 : tensor<3xindex>, index -> index +// CHECK-DAG: %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<3xindex>, index -> index +// CHECK-DAG: %[[V3:.*]] = arith.addi %[[V2]], %c2 : index +// CHECK-DAG: %[[V4:.*]] = from_extents %[[V3]], %c4, %[[V1]] : index, index, index +// CHECK-DAG: return %4 : !shape.shape +// CHECK-NOT: shape_cal_1 + +// ----- + +// there's an internal dynamic shape source, and two other dynamic shapes shares it +func.func @main(%arg0: tensor) -> tensor { + %0 = "test.nonzero"(%arg0) : (tensor) -> tensor + %1 = shape.shape_of %0 : tensor -> tensor<1xindex> + %2 = shape.with_shape %0, %1 : tensor, tensor<1xindex> + %3 = shape.value_of %2 : tensor + %4 = "test.abs"(%3) : (tensor) -> tensor + %5 = shape.with_shape %4, %1 : tensor, tensor<1xindex> + %6 = shape.value_of %5 : tensor + %7 = "test.negate"(%6) : (tensor) -> tensor + %8 = shape.with_shape %7, %1 : tensor, tensor<1xindex> + %9 = shape.value_of %8 : tensor + return %9 : tensor +} +// CHECK-LABEL: func.func @main +// CHECK-NEXT: %0 = "test.nonzero"(%arg0) : (tensor) -> tensor +// CHECK-NEXT: %1 = "test.abs"(%0) : (tensor) -> tensor +// CHECK-NEXT: %2 = "test.negate"(%1) : (tensor) -> tensor +// CHECK-NEXT: return %2 : tensor + +// CHECK: shape.func private @shape_cal_0(%arg0: tensor) -> tensor<1xindex> { +// CHECK-NEXT: %0 = shape_of %arg0 : tensor -> tensor<1xindex> +// CHECK-NEXT: return %0 : tensor<1xindex> +// CHECK-NOT: shape_cal_1 + +// ----- + +// there's only a return op in the constructed shape.func +func.func @main(%arg0: tensor, %arg1: tensor<1xindex>) -> tensor { + %0 = "test.nonzero"(%arg0) : (tensor) -> tensor + %1 = shape.with_shape %0, %arg1 : tensor, tensor<1xindex> + %2 = shape.value_of %1 : tensor + return %2 : tensor +} +// CHECK-LABEL: func.func @main(%arg0: tensor, %arg1: tensor<1xindex>) -> tensor { +// CHECK-NEXT: %0 = "test.nonzero"(%arg0) : (tensor) -> tensor +// CHECK-NEXT: return %0 : tensor + +// CHECK: shape.func private @shape_cal_0(%arg0: tensor<1xindex>) -> tensor<1xindex> { +// CHECK-NEXT: return %arg0 : tensor<1xindex> + +// ----- + +// shape computation part interleaves with general computation +func.func @main(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, index) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %0 = shape.shape_of %arg0 : tensor -> tensor<3xindex> + %1 = shape.shape_of %arg1 : tensor -> tensor<3xindex> + %2 = shape.shape_of %arg2 : tensor -> tensor<3xindex> + %3 = "test.concat"(%arg0, %arg1, %arg2) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor + %4 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index + %5 = shape.get_extent %1, %c0 : tensor<3xindex>, index -> index + %6 = shape.get_extent %2, %c0 : tensor<3xindex>, index -> index + %7 = arith.addi %4, %5 : index + %8 = arith.addi %7, %6 : index + %9 = shape.from_extents %8, %c4, %c5 : index, index, index + %10 = shape.with_shape %3, %9 : tensor, !shape.shape + %11 = shape.value_of %10 : tensor + return %11, %7 : tensor, index +} +// CHECK-LABEL: func.func @main +// CHECK-DAG: %[[V0:.*]] = shape.shape_of %arg0 : tensor -> tensor<3xindex> +// CHECK-DAG: %[[V1:.*]] = shape.shape_of %arg1 : tensor -> tensor<3xindex> +// CHECK-DAG: %[[V2:.*]] = "test.concat"(%arg0, %arg1, %arg2) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor +// CHECK-DAG: %[[V3:.*]] = shape.get_extent %[[V0]], %c0 : tensor<3xindex>, index -> index +// CHECK-DAG: %[[V4:.*]] = shape.get_extent %[[V1]], %c0 : tensor<3xindex>, index -> index +// CHECK-DAG: %[[V5:.*]] = arith.addi %[[V3]], %[[V4]] : index +// CHECK-DAG: return %[[V2]], %[[V5]] : tensor, index + +// CHECK: shape.func private @shape_cal_0(%arg0: tensor, %arg1: index, %arg2: index) -> !shape.shape { +// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor -> tensor<3xindex> +// CHECK-DAG: %[[V1:.*]] = get_extent %[[V0]], %arg1 : tensor<3xindex>, index -> index +// CHECK-DAG: %[[V2:.*]] = arith.addi %arg2, %[[V1]] : index +// CHECK-DAG: %[[V3:.*]] = from_extents %[[V2]], %c4, %c5 : index, index, index +// CHECK-DAG: return %[[V3]] : !shape.shape + +// ----- + +// there're multiple reused shape computations +func.func @main(%arg0: tensor, %arg1: tensor) -> (tensor, tensor, tensor, tensor) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %0 = shape.shape_of %arg0 : tensor -> tensor<2xindex> + %1 = shape.shape_of %arg1 : tensor -> tensor<2xindex> + %2 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor, tensor) -> tensor + %3 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor, tensor) -> tensor + %4 = shape.get_extent %0, %c0 : tensor<2xindex>, index -> index + %5 = shape.get_extent %1, %c0 : tensor<2xindex>, index -> index + %6 = arith.addi %4, %5 : index + %7 = shape.from_extents %6, %c4 : index, index + %8 = shape.with_shape %2, %7 : tensor, !shape.shape + %9 = shape.with_shape %3, %7 : tensor, !shape.shape + %10 = shape.value_of %8 : tensor + %11 = shape.value_of %9 : tensor + %12 = "test.concat"(%arg0, %2) {axis = 0 : i64} : (tensor, tensor) -> tensor + %13 = "test.concat"(%arg0, %3) {axis = 0 : i64} : (tensor, tensor) -> tensor + %14 = arith.addi %6, %4 : index + %15 = shape.from_extents %14, %c4 : index, index + %16 = shape.with_shape %12, %15 : tensor, !shape.shape + %17 = shape.with_shape %13, %15 : tensor, !shape.shape + %18 = shape.value_of %16 : tensor + %19 = shape.value_of %17 : tensor + return %10, %11, %18, %19 : tensor, tensor, tensor, tensor +} +// CHECK-LABEL: func.func @main +// CHECK-DAG: %[[V0:.*]] = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor, tensor) -> tensor +// CHECK-DAG: %[[V1:.*]] = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor, tensor) -> tensor +// CHECK-DAG: %[[V2:.*]] = "test.concat"(%arg0, %[[V0]]) {axis = 0 : i64} : (tensor, tensor) -> tensor +// CHECK-DAG: %[[V3:.*]] = "test.concat"(%arg0, %[[V1]]) {axis = 0 : i64} : (tensor, tensor) -> tensor +// CHECK-DAG: return %[[V0]], %[[V1]], %[[V2]], %[[V3]] : tensor, tensor, tensor, tensor + +// CHECK: shape.func private @shape_cal_1(%arg0: tensor, %arg1: tensor) -> !shape.shape { +// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor -> tensor<2xindex> +// CHECK-DAG: %[[V1:.*]] = shape_of %arg1 : tensor -> tensor<2xindex> +// CHECK-DAG: %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<2xindex>, index -> index +// CHECK-DAG: %[[V3:.*]] = get_extent %[[V1]], %c0 : tensor<2xindex>, index -> index +// CHECK-DAG: %[[V4:.*]] = arith.addi %[[V2]], %[[V3]] : index +// CHECK-DAG: %[[V5:.*]] = arith.addi %[[V4]], %[[V2]] : index +// CHECK-DAG: %[[V6:.*]] = from_extents %[[V5]], %c4 : index, index +// CHECK-DAG: return %[[V6]] : !shape.shape + +// CHECK: shape.func private @shape_cal_0(%arg0: tensor, %arg1: tensor) -> !shape.shape { +// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor -> tensor<2xindex> +// CHECK-DAG: %[[V1:.*]] = shape_of %arg1 : tensor -> tensor<2xindex> +// CHECK-DAG: %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<2xindex>, index -> index +// CHECK-DAG: %[[V3:.*]] = get_extent %[[V1]], %c0 : tensor<2xindex>, index -> index +// CHECK-DAG: %[[V4:.*]] = arith.addi %[[V2]], %[[V3]] : index +// CHECK-DAG: %[[V5:.*]] = from_extents %[[V4]], %c4 : index, index +// CHECK-DAG: return %[[V5]] : !shape.shape Index: mlir/test/lib/Dialect/Shape/CMakeLists.txt =================================================================== --- mlir/test/lib/Dialect/Shape/CMakeLists.txt +++ mlir/test/lib/Dialect/Shape/CMakeLists.txt @@ -1,6 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRShapeTestPasses TestShapeFunctions.cpp + TestShapeMappingAnalysis.cpp EXCLUDE_FROM_LIBMLIR @@ -11,6 +12,7 @@ LINK_LIBS PUBLIC MLIRIR MLIRPass + MLIRShapeOpsTransforms MLIRShapeDialect MLIRSupport ) Index: mlir/test/lib/Dialect/Shape/TestShapeMappingAnalysis.cpp =================================================================== --- /dev/null +++ mlir/test/lib/Dialect/Shape/TestShapeMappingAnalysis.cpp @@ -0,0 +1,43 @@ +//===- TestShapeMappingInfo.cpp -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +struct TestShapeMappingPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestShapeMappingPass) + + StringRef getArgument() const final { return "test-print-shape-mapping"; } + StringRef getDescription() const final { + return "Print the contents of a constructed shape mapping information."; + } + void runOnOperation() override { + llvm::Optional> + maybeAnalysis = getCachedAnalysis(); + if (maybeAnalysis.has_value()) + maybeAnalysis.value().get().print(llvm::errs()); + else + llvm::errs() << "No cached ShapeMappingAnalysis existed."; + } +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestShapeMappingPass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir Index: mlir/tools/mlir-opt/mlir-opt.cpp =================================================================== --- mlir/tools/mlir-opt/mlir-opt.cpp +++ mlir/tools/mlir-opt/mlir-opt.cpp @@ -109,6 +109,7 @@ void registerTestPreparationPassWithAllowedMemrefResults(); void registerTestRecursiveTypesPass(); void registerTestSCFUtilsPass(); +void registerTestShapeMappingPass(); void registerTestSliceAnalysisPass(); void registerTestTensorTransforms(); void registerTestTilingInterface(); @@ -208,6 +209,7 @@ mlir::test::registerTestPDLLPasses(); mlir::test::registerTestRecursiveTypesPass(); mlir::test::registerTestSCFUtilsPass(); + mlir::test::registerTestShapeMappingPass(); mlir::test::registerTestSliceAnalysisPass(); mlir::test::registerTestTensorTransforms(); mlir::test::registerTestTilingInterface();