Index: mlir/include/mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h @@ -0,0 +1,61 @@ +//===- ShapeMappingAnalysis.h - Preserve shape Info ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SHAPE_ANALYSIS_SHAPEMAPPINGANALYSIS_H_ +#define MLIR_DIALECT_SHAPE_ANALYSIS_SHAPEMAPPINGANALYSIS_H_ + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { + +namespace shape { + +struct ShapeMappingValue { + ShapeMappingValue() = default; + ShapeMappingValue(FlatSymbolRefAttr symbol, llvm::SmallVector &&inps) + : funcSymbol(symbol), inputs(inps) {} + + FlatSymbolRefAttr funcSymbol; + llvm::SmallVector inputs; +}; + +/// ShapeMappingAnalysis is used together with OutlineShapeComputationPass to +/// preserve Value and corresponding shape function / arguments mapping +/// infomation +struct ShapeMappingAnalysis { + + ShapeMappingAnalysis(Operation *op) : operation(op) {} + + /// Dumps the shape mapping information to the given stream. + void print(raw_ostream &os) const { + os << "// ---- Shape Mapping Infomation -----\n"; + for (auto it : shapeMapping) { + const ShapeMappingValue &mappingValue = it.second; + os << "// - Shape for: " << it.first << "\n"; + os << "// - Shape function symbol: " << mappingValue.funcSymbol << "\n"; + os << "// - Shape function arguments: \n"; + for (Value v : mappingValue.inputs) { + os << "// - " << v << "\n"; + } + os << "// - End of one shape mapping item -----\n\n"; + } + } + + llvm::DenseMap shapeMapping; + +private: + Operation *operation; +}; + +} // namespace shape +} // namespace mlir + +#endif // MLIR_DIALECT_SHAPE_ANALYSIS_SHAPEMAPPINGANALYSIS_H_ Index: mlir/include/mlir/Dialect/Shape/Transforms/Passes.h =================================================================== --- mlir/include/mlir/Dialect/Shape/Transforms/Passes.h +++ mlir/include/mlir/Dialect/Shape/Transforms/Passes.h @@ -58,10 +58,9 @@ // level. std::unique_ptr> createShapeBufferizePass(); -// Outline the shape computation part by adding shape.func and corresponding -// symbols mapping to it.ยท -std::unique_ptr> -createOutlineShapeComputationPass(const std::string &entryFunc = ""); +/// Outline the shape computation part by adding shape.func and populate +/// conrresponding mapping infomation into ShapeMappingAnalysis. +std::unique_ptr> createOutlineShapeComputationPass(); //===----------------------------------------------------------------------===// // Registration Index: mlir/include/mlir/Dialect/Shape/Transforms/Passes.td =================================================================== --- mlir/include/mlir/Dialect/Shape/Transforms/Passes.td +++ mlir/include/mlir/Dialect/Shape/Transforms/Passes.td @@ -15,9 +15,9 @@ let summary = "Using shape.func to preserve shape computation"; let description = [{ This pass outlines the shape computation part in high level IR by adding - shape.func and corresponding symbols mapping to it. The shape computation - part is usually introduced by shape reification, and each single dynamic - shape is denoted by shape.with_shape. + shape.func and and populate conrresponding mapping infomation into + ShapeMappingAnalysis. The shape computation part is usually introduced by + shape reification, and each single dynamic shape is denoted by shape.with_shape. There're two main reasons this shape-outline pass is needed: 1. Many passes don't take shape reification part into consideration. @@ -26,17 +26,6 @@ 2. Sometimes we cannot redo shape reification after converting from dialect A to dialect B. Because op-level shape reification is only implemented on A. - - // TODO: what if the `encoding` attribute before the pass? - The symbol(s) is placed in the `encoding` attribute of RankedTensorType. If - the shape of the tensor is or equals to a dynamic shape source, it is a - single FlatSymbolRefAttr. Else it is a ArrayAttr, where the first element - represents corresponding shape function name, the others represents its - arguments. - - // TODO: support all functions after lenient call is supported. - Currently this pass is only limited to entry function. And internal - func.call is also not supported. Input: ```mlir @@ -64,39 +53,49 @@ Output ```mlir func.func @main(%arg0: tensor, %arg1: tensor<2x4x?xf32>) -> - tensor { - %0 = "test.abs"(%arg0) : (tensor) -> tensor - %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor, tensor<2x4x?xf32>) -> tensor - return %1 : tensor + tensor { + %0 = "test.abs"(%arg0) : (tensor) -> tensor + %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor, + tensor<2x4x?xf32>) -> tensor + return %1 : tensor } - shape.func private @shape_cal_0(%arg0: tensor<3xindex>) -> !shape.shape { + shape.func private @shape_cal_1(%arg0: tensor) -> !shape.shape { %c2 = arith.constant 2 : index %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index - %0 = get_extent %arg0, %c2 : tensor<3xindex>, index -> index - %1 = get_extent %arg0, %c0 : tensor<3xindex>, index -> index - %2 = arith.addi %1, %c2 : index - %3 = from_extents %2, %c4, %0 : index, index, index - return %3 : !shape.shape + %0 = shape_of %arg0 : tensor -> tensor<3xindex> + %1 = get_extent %0, %c2 : tensor<3xindex>, index -> index + %2 = get_extent %0, %c0 : tensor<3xindex>, index -> index + %3 = arith.addi %2, %c2 : index + %4 = from_extents %3, %c4, %1 : index, index, index + return %4 : !shape.shape + } + shape.func private @shape_cal_0(%arg0: tensor) -> tensor<3xindex> { + %0 = shape_of %arg0 : tensor -> tensor<3xindex> + return %0 : tensor<3xindex> } ``` For the above example, the shape computation is inlined in the input IR, which is used for two values' (test.abs and test.concat) shape. And the shape - compuatation part is outlined in the output IR. The test.abs' output shape - equals to arg0's , so the encoding is a simple symbol @arg_0. The - test.concat's output shape is bound to a shape.func, the first symbol - @shape_cal_0 denotes the function name, and the second symbol @arg_0 denotes - its argument. + compuatation part is outlined in the output IR. + + And the shape mapping infomation will be: + // ---- Shape Mapping Infomation ----- + // - Shape for: %0 = "test.abs"(%arg0) : (tensor) -> tensor + // - Shape function symbol: @shape_cal_0 + // - Shape function arguments: + // - of type 'tensor' at index: 0 + // - End of one shape mapping item ----- + + // - Shape for: %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor, tensor<2x4x?xf32>) -> tensor + // - Shape function symbol: @shape_cal_1 + // - Shape function arguments: + // - of type 'tensor' at index: 0 + // - End of one shape mapping item ----- + }]; let constructor = "mlir::createOutlineShapeComputationPass()"; - let options = [ - Option<"entryFunc", "entry-func", "std::string", - /*default=*/"\"main\"", - "Used to specify the entry-function.">, - ]; let dependentDialects = ["shape::ShapeDialect"]; } Index: mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp =================================================================== --- mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp +++ mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -63,19 +64,21 @@ std::pair> createFuncFromCluster(OpBuilder &b, const SmallVector &cluster, Value shape, StringRef fnName, Location loc) { - assert(!cluster.empty() && - "There must be at least one element in the cluster!"); - ValueRange inputs = getInputsOfCluster(cluster); - auto fnType = b.getFunctionType(inputs.getTypes(), shape.getType()); - b.setInsertionPointAfter(cluster[0]->getParentOp()); + auto fnType = cluster.empty() + ? b.getFunctionType(shape.getType(), shape.getType()) + : b.getFunctionType(inputs.getTypes(), shape.getType()); shape::FuncOp fnOp = b.create(loc, fnName, fnType); Block *block = fnOp.addEntryBlock(); b.setInsertionPoint(block, block->end()); BlockAndValueMapping bvm; - for (auto inputAndArg : llvm::zip(inputs, fnOp.getArguments())) { - bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg)); - } + if (cluster.empty()) + bvm.map(shape, fnOp.getArgument(0)); + else + for (auto inputAndArg : llvm::zip(inputs, fnOp.getArguments())) { + bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg)); + } + for (Operation *op : cluster) { b.clone(*op, bvm); } @@ -118,119 +121,52 @@ return orderedClusters; } -// Increment `idx` until find the next available symbol name -std::string -getNextAvailableSymbolName(const std::string &prefix, int &idx, - std::unordered_set &usedSymbolNames) { - std::string name = prefix + std::to_string(idx++); - - while (usedSymbolNames.count(name)) { - name = prefix + std::to_string(idx++); - } - usedSymbolNames.insert(name); - return name; -} - -// return argument index if `shape` is the output of a -// shape.shape_of(func_arg), else return -1. -llvm::Optional getShapeOfFuncArgIdx(Value shape, func::FuncOp funcOp) { - shape::ShapeOfOp shapeOfOp = shape.getDefiningOp(); - if (shapeOfOp == nullptr) - return llvm::None; - Value inp = shapeOfOp.getArg(); - for (int i = 0; i < int(funcOp.getNumArguments()); ++i) { - if (funcOp.getArgument(i) == inp) - return i; - } - - return llvm::None; -} - void constructShapeFunc( const std::vector &allWithOps, MLIRContext *context, DenseMap> &clusters, - SymbolTable &symbolTable, func::FuncOp funcOp, - std::unordered_set &usedSymbolNames, - DenseMap &dynShapeSrc2Symbol, - DenseMap> &dynShape2ShapeFunc) { - std::string dynamicSourceNamePrefix = "s"; - int dynamicSourceNameIdx = 0; + SymbolTable &symbolTable, + DenseMap &dynShape2ShapeFunc, func::FuncOp funcOp, + shape::ShapeMappingAnalysis &shapeMappingAnalysis) { std::string shapeCalculationNamePrefix = "shape_cal_"; int shapeCalculationNameIdx = 0; OpBuilder builder(context); - auto getOrConstructSymbolFromShape = [&](Value shape) { - auto symbolIt = dynShapeSrc2Symbol.find(shape); - if (symbolIt == dynShapeSrc2Symbol.end()) { - std::string name; - llvm::Optional index = getShapeOfFuncArgIdx(shape, funcOp); - if (index.has_value()) - name = "arg_" + std::to_string(index.value()); - else - name = getNextAvailableSymbolName( - dynamicSourceNamePrefix, dynamicSourceNameIdx, usedSymbolNames); - - auto symbol = FlatSymbolRefAttr::get(context, name); - dynShapeSrc2Symbol[shape] = symbol; - return symbol; - } else { - return symbolIt->second; - } - }; - - // Construct a shape function or a symbol for each cluster + // Construct a shape function for (shape::WithOp withOp : allWithOps) { Value value = withOp.getOperand(); Value shape = withOp.getShape(); RankedTensorType rankedType = value.getType().dyn_cast(); if (rankedType == nullptr) continue; + const SmallVector &cluster = clusters[shape]; - // The cluster is empty when the shape is equal to a dynamic shape source - if (cluster.empty()) { - FlatSymbolRefAttr symbol = getOrConstructSymbolFromShape(shape); - value.setType(RankedTensorType::get(rankedType.getShape(), - rankedType.getElementType(), symbol)); - LLVM_DEBUG(llvm::dbgs() - << "Symbol for " << shape << ": " << symbol << "\n"); + shape::ShapeMappingValue shapeMappingValue; + auto it = dynShape2ShapeFunc.find(shape); + if (it == dynShape2ShapeFunc.end()) { + std::string name = shapeCalculationNamePrefix + + std::to_string(shapeCalculationNameIdx++); + Location loc = value.getLoc(); + builder.setInsertionPointAfter(funcOp); + auto pair = createFuncFromCluster(builder, cluster, shape, name, loc); + const SmallVector &inputs = pair.second; + shape::FuncOp shapeFuncOp = pair.first; + StringAttr insertedName = symbolTable.insert(shapeFuncOp); + auto symbol = FlatSymbolRefAttr::get(context, insertedName); + + shapeMappingValue.funcSymbol = symbol; + shapeMappingValue.inputs = inputs; } else { - auto it = dynShape2ShapeFunc.find(shape); - SmallVector symbols; - if (it == dynShape2ShapeFunc.end()) { - std::string name = getNextAvailableSymbolName( - shapeCalculationNamePrefix, shapeCalculationNameIdx, - usedSymbolNames); - Location loc = value.getLoc(); - auto pair = createFuncFromCluster(builder, cluster, shape, name, loc); - const SmallVector &inputs = pair.second; - shape::FuncOp shapeFuncOp = pair.first; - StringAttr insertedName = symbolTable.insert(shapeFuncOp); - auto symbol = FlatSymbolRefAttr::get(context, insertedName); - symbols.push_back(symbol); - for (Value inp : inputs) { - FlatSymbolRefAttr argSymbol = getOrConstructSymbolFromShape(inp); - symbols.push_back(argSymbol); - } - dynShape2ShapeFunc[shape] = symbols; - } else { - symbols = it->second; - } - auto arrayAttr = ArrayAttr::get(context, symbols); - LLVM_DEBUG(llvm::dbgs() - << "Symbol for " << shape << ": " << arrayAttr << "\n"); - value.setType(RankedTensorType::get( - rankedType.getShape(), rankedType.getElementType(), arrayAttr)); + shapeMappingValue = it->second; } + dynShape2ShapeFunc[shape] = shapeMappingValue; + shapeMappingAnalysis.shapeMapping.insert( + std::make_pair(value, shapeMappingValue)); } } struct OutlineShapeComputationPass : public impl::OutlineShapeComputationBase { - OutlineShapeComputationPass(const std::string &entryFunc) { - this->entryFunc = entryFunc; - } - void runOnOperation() override; private: @@ -259,15 +195,10 @@ void OutlineShapeComputationPass::runOnOperation() { ModuleOp moduleOp = getOperation(); SymbolTable symbolTable(moduleOp); - // `usedSymbolNames` makes sure the created symbol is unique at Moudle level - std::unordered_set usedSymbolNames; - DenseMap dynShapeSrc2Symbol; - DenseMap> dynShape2ShapeFunc; + DenseMap dynShape2ShapeFunc; + auto &shapeMappingAnalysis = getAnalysis(); moduleOp.walk([&](func::FuncOp funcOp) { - if (funcOp.getName() != entryFunc) - return; - MLIRContext *context = funcOp.getContext(); RewritePatternSet prevPatterns(context); prevPatterns.insert(context); @@ -292,8 +223,8 @@ DenseMap> clusters = constructClustersForEachShape(allWithOps, funcOp); - constructShapeFunc(allWithOps, context, clusters, symbolTable, funcOp, - usedSymbolNames, dynShapeSrc2Symbol, dynShape2ShapeFunc); + constructShapeFunc(allWithOps, context, clusters, symbolTable, + dynShape2ShapeFunc, funcOp, shapeMappingAnalysis); for (shape::WithOp withOp : allWithOps) { Value value = withOp.getOperand(); @@ -307,9 +238,7 @@ if (failed(applyPatternsAndFoldGreedily(funcOp, {}))) return signalPassFailure(); - funcOp.setType( - FunctionType::get(context, funcOp.front().getArgumentTypes(), - funcOp.front().getTerminator()->getOperandTypes())); + markAllAnalysesPreserved(); }); } @@ -325,33 +254,24 @@ return getOrderedClusters(clusters, funcOp); } -// The output of a cluster is the `shape`, and the inputs are either the result -// of shape.shape_of or function argument. +// The output of a cluster is the `shape`, and the inputs are the outputs of +// operations who are not in `onlyUsedByWithShapes_` void OutlineShapeComputationPass::getClusterFromValue( Value shape, DenseMap> &clusters) { DenseSet cluster; - Operation *defOp = shape.getDefiningOp(); - // defOp == nullptr means shape is the argument of the func op - if (nullptr == defOp) { - return; - } DenseSet visited; std::queue queue; - visited.insert(defOp); - queue.push(defOp); + // defOp == nullptr means shape is the argument of the func op + if (defOp != nullptr) { + visited.insert(defOp); + queue.push(defOp); + } while (!queue.empty()) { Operation *op = queue.front(); queue.pop(); - if (op->getNumOperands() == 0) { - cluster.insert(op); - } else if (llvm::isa(op) && - onlyUsedByWithShapes_.count(op)) { - // Stop when the op is type of shape.shape_of and only used by - // shape.with_shape ops - continue; - } else { + if (onlyUsedByWithShapes_.lookup(op)) { cluster.insert(op); for (Value inp : op->getOperands()) { Operation *inpDefOp = inp.getDefiningOp(); @@ -396,6 +316,6 @@ } // namespace std::unique_ptr> -mlir::createOutlineShapeComputationPass(const std::string &entryFunc) { - return std::make_unique(entryFunc); +mlir::createOutlineShapeComputationPass() { + return std::make_unique(); } Index: mlir/test/Dialect/Shape/outline-shape-computation.mlir =================================================================== --- mlir/test/Dialect/Shape/outline-shape-computation.mlir +++ mlir/test/Dialect/Shape/outline-shape-computation.mlir @@ -1,7 +1,16 @@ -// RUN: mlir-opt -outline-shape-computation="entry-func=main" -split-input-file -allow-unregistered-dialect %s | FileCheck %s +// RUN: mlir-opt -outline-shape-computation -test-print-shape-mapping -split-input-file -allow-unregistered-dialect %s 2>&1 | FileCheck %s -// two dynamic shapes: a trivial symbol and a shape.func +// two dynamic shapes: one of direct shape.shape_of(arg) and the other func.func @main(%arg0: tensor, %arg1: tensor<2x4x?xf32>) -> tensor { + // CHECK-DAG: Shape for: %0 = "test.abs" + // CHECK-DAG: Shape function symbol: @shape_cal_0 + // CHECK-DAG: Shape function arguments: + // CHECK-DAG{LITERAL}: of type 'tensor' at index: 0 + + // CHECK-DAG: Shape for: %1 = "test.concat" + // CHECK-DAG: Shape function symbol: @shape_cal_1 + // CHECK-DAG: Shape function arguments: + // CHECK-DAG{LITERAL}: of type 'tensor' at index: 0 %c2 = arith.constant 2 : index %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index @@ -20,18 +29,21 @@ } // CHECK-LABEL: func.func @main -// CHECK-SAME: (%arg0: tensor, %arg1: tensor<2x4x?xf32>) -> tensor -// CHECK-NEXT: %0 = "test.abs"(%arg0) : (tensor) -> tensor -// CHECK-NEXT: %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor, tensor<2x4x?xf32>) -> tensor -// CHECK-NEXT: return %1 : tensor - -// CHECK-LABEL: shape.func private @shape_cal_0 -// CHECK-SAME: (%arg0: tensor<3xindex>) -> !shape.shape -// CHECK-DAG: %[[V0:.*]] = get_extent %arg0, %c2 : tensor<3xindex>, index -> index -// CHECK-DAG: %[[V1:.*]] = get_extent %arg0, %c0 : tensor<3xindex>, index -> index -// CHECK-DAG: %[[V2:.*]] = arith.addi %[[V1]], %c2 : index -// CHECK-DAG: %[[V3:.*]] = from_extents %[[V2]], %c4, %[[V0]] : index, index, index -// CHECK-DAG: return %[[V3]] : !shape.shape +// CHECK-NEXT: %0 = "test.abs"(%arg0) : (tensor) -> tensor +// CHECK-NEXT: %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor, tensor<2x4x?xf32>) -> tensor +// CHECK-NEXT: return %1 : tensor + +// CHECK-DAG: shape.func private @shape_cal_1(%arg0: tensor) -> !shape.shape { +// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor -> tensor<3xindex> +// CHECK-DAG: %[[V1:.*]] = get_extent %[[V0]], %c2 : tensor<3xindex>, index -> index +// CHECK-DAG: %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<3xindex>, index -> index +// CHECK-DAG: %[[V3:.*]] = arith.addi %[[V2]], %c2 : index +// CHECK-DAG: %[[V4:.*]] = from_extents %[[V3]], %c4, %[[V1]] : index, index, index +// CHECK-DAG: return %[[V4]] : !shape.shape + +// CHECK-DAG: shape.func private @shape_cal_0(%arg0: tensor) -> tensor<3xindex> { +// CHECK-DAG: %0 = shape_of %arg0 : tensor -> tensor<3xindex> +// CHECK-DAG: return %0 : tensor<3xindex> // ----- @@ -54,19 +66,18 @@ return %10 : tensor } // CHECK-LABEL: func.func @main -// CHECK-SAME: (%arg0: tensor, %arg1: tensor<2x4x?xf32>) -> tensor -// CHECK-NEXT: %0 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor, tensor<2x4x?xf32>) -> tensor -// CHECK-NEXT: %1 = "test.abs"(%0) : (tensor) -> tensor -// CHECK-NEXT: return %1 : tensor - -// CHECK-LABEL: shape.func private @shape_cal_0 -// CHECK-SAME: (%arg0: tensor<3xindex>) -> !shape.shape -// CHECK-DAG: %[[V0:.*]] = get_extent %arg0, %c2 : tensor<3xindex>, index -> index -// CHECK-DAG: %[[V1:.*]] = get_extent %arg0, %c0 : tensor<3xindex>, index -> index -// CHECK-DAG: %[[V2:.*]] = arith.addi %[[V1]], %c2 : index -// CHECK-DAG: %[[V3:.*]] = from_extents %[[V2]], %c4, %[[V0]] : index, index, index -// CHECK-DAG: return %[[V3]] : !shape.shape +// CHECK-NEXT: %0 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor, tensor<2x4x?xf32>) -> tensor +// CHECK-NEXT: %1 = "test.abs"(%0) : (tensor) -> tensor +// CHECK-NEXT: return %1 : tensor +// CHECK: shape.func private @shape_cal_0(%arg0: tensor) -> !shape.shape { +// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor -> tensor<3xindex> +// CHECK-DAG: %[[V1:.*]] = get_extent %[[V0]], %c2 : tensor<3xindex>, index -> index +// CHECK-DAG: %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<3xindex>, index -> index +// CHECK-DAG: %[[V3:.*]] = arith.addi %[[V2]], %c2 : index +// CHECK-DAG: %[[V4:.*]] = from_extents %[[V3]], %c4, %[[V1]] : index, index, index +// CHECK-DAG: return %4 : !shape.shape +// CHECK-NOT: shape_cal_1 // ----- @@ -85,8 +96,119 @@ return %9 : tensor } // CHECK-LABEL: func.func @main -// CHECK-SAME: (%arg0: tensor) -> tensor -// CHECK-NEXT: %0 = "test.nonzero"(%arg0) : (tensor) -> tensor -// CHECK-NEXT: %1 = "test.abs"(%0) : (tensor) -> tensor -// CHECK-NEXT: %2 = "test.negate"(%1) : (tensor) -> tensor -// CHECK-NEXT: return %2 : tensor +// CHECK-NEXT: %0 = "test.nonzero"(%arg0) : (tensor) -> tensor +// CHECK-NEXT: %1 = "test.abs"(%0) : (tensor) -> tensor +// CHECK-NEXT: %2 = "test.negate"(%1) : (tensor) -> tensor +// CHECK-NEXT: return %2 : tensor + +// CHECK: shape.func private @shape_cal_0(%arg0: tensor) -> tensor<1xindex> { +// CHECK-NEXT: %0 = shape_of %arg0 : tensor -> tensor<1xindex> +// CHECK-NEXT: return %0 : tensor<1xindex> +// CHECK-NOT: shape_cal_1 + +// ----- + +// there's only a return op in the constructed shape.func +func.func @main(%arg0: tensor, %arg1: tensor<1xindex>) -> tensor { + %0 = "test.nonzero"(%arg0) : (tensor) -> tensor + %1 = shape.with_shape %0, %arg1 : tensor, tensor<1xindex> + %2 = shape.value_of %1 : tensor + return %2 : tensor +} +// CHECK-LABEL: func.func @main(%arg0: tensor, %arg1: tensor<1xindex>) -> tensor { +// CHECK-NEXT: %0 = "test.nonzero"(%arg0) : (tensor) -> tensor +// CHECK-NEXT: return %0 : tensor + +// CHECK: shape.func private @shape_cal_0(%arg0: tensor<1xindex>) -> tensor<1xindex> { +// CHECK-NEXT: return %arg0 : tensor<1xindex> + +// ----- + +// shape computation part interleaves with general computation +func.func @main(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, index) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %0 = shape.shape_of %arg0 : tensor -> tensor<3xindex> + %1 = shape.shape_of %arg1 : tensor -> tensor<3xindex> + %2 = shape.shape_of %arg2 : tensor -> tensor<3xindex> + %3 = "test.concat"(%arg0, %arg1, %arg2) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor + %4 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index + %5 = shape.get_extent %1, %c0 : tensor<3xindex>, index -> index + %6 = shape.get_extent %2, %c0 : tensor<3xindex>, index -> index + %7 = arith.addi %4, %5 : index + %8 = arith.addi %7, %6 : index + %9 = shape.from_extents %8, %c4, %c5 : index, index, index + %10 = shape.with_shape %3, %9 : tensor, !shape.shape + %11 = shape.value_of %10 : tensor + return %11, %7 : tensor, index +} +// CHECK-LABEL: func.func @main +// CHECK-DAG: %[[V0:.*]] = shape.shape_of %arg0 : tensor -> tensor<3xindex> +// CHECK-DAG: %[[V1:.*]] = shape.shape_of %arg1 : tensor -> tensor<3xindex> +// CHECK-DAG: %[[V2:.*]] = "test.concat"(%arg0, %arg1, %arg2) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor +// CHECK-DAG: %[[V3:.*]] = shape.get_extent %[[V0]], %c0 : tensor<3xindex>, index -> index +// CHECK-DAG: %[[V4:.*]] = shape.get_extent %[[V1]], %c0 : tensor<3xindex>, index -> index +// CHECK-DAG: %[[V5:.*]] = arith.addi %[[V3]], %[[V4]] : index +// CHECK-DAG: return %[[V2]], %[[V5]] : tensor, index + +// CHECK: shape.func private @shape_cal_0(%arg0: tensor, %arg1: index, %arg2: index) -> !shape.shape { +// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor -> tensor<3xindex> +// CHECK-DAG: %[[V1:.*]] = get_extent %[[V0]], %arg1 : tensor<3xindex>, index -> index +// CHECK-DAG: %[[V2:.*]] = arith.addi %arg2, %[[V1]] : index +// CHECK-DAG: %[[V3:.*]] = from_extents %[[V2]], %c4, %c5 : index, index, index +// CHECK-DAG: return %[[V3]] : !shape.shape + +// ----- + +// there're multiple reused shape computations +func.func @main(%arg0: tensor, %arg1: tensor) -> (tensor, tensor, tensor, tensor) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %0 = shape.shape_of %arg0 : tensor -> tensor<2xindex> + %1 = shape.shape_of %arg1 : tensor -> tensor<2xindex> + %2 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor, tensor) -> tensor + %3 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor, tensor) -> tensor + %4 = shape.get_extent %0, %c0 : tensor<2xindex>, index -> index + %5 = shape.get_extent %1, %c0 : tensor<2xindex>, index -> index + %6 = arith.addi %4, %5 : index + %7 = shape.from_extents %6, %c4 : index, index + %8 = shape.with_shape %2, %7 : tensor, !shape.shape + %9 = shape.with_shape %3, %7 : tensor, !shape.shape + %10 = shape.value_of %8 : tensor + %11 = shape.value_of %9 : tensor + %12 = "test.concat"(%arg0, %2) {axis = 0 : i64} : (tensor, tensor) -> tensor + %13 = "test.concat"(%arg0, %3) {axis = 0 : i64} : (tensor, tensor) -> tensor + %14 = arith.addi %6, %4 : index + %15 = shape.from_extents %14, %c4 : index, index + %16 = shape.with_shape %12, %15 : tensor, !shape.shape + %17 = shape.with_shape %13, %15 : tensor, !shape.shape + %18 = shape.value_of %16 : tensor + %19 = shape.value_of %17 : tensor + return %10, %11, %18, %19 : tensor, tensor, tensor, tensor +} +// CHECK-LABEL: func.func @main +// CHECK-DAG: %[[V0:.*]] = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor, tensor) -> tensor +// CHECK-DAG: %[[V1:.*]] = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor, tensor) -> tensor +// CHECK-DAG: %[[V2:.*]] = "test.concat"(%arg0, %[[V0]]) {axis = 0 : i64} : (tensor, tensor) -> tensor +// CHECK-DAG: %[[V3:.*]] = "test.concat"(%arg0, %[[V1]]) {axis = 0 : i64} : (tensor, tensor) -> tensor +// CHECK-DAG: return %[[V0]], %[[V1]], %[[V2]], %[[V3]] : tensor, tensor, tensor, tensor + +// CHECK: shape.func private @shape_cal_1(%arg0: tensor, %arg1: tensor) -> !shape.shape { +// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor -> tensor<2xindex> +// CHECK-DAG: %[[V1:.*]] = shape_of %arg1 : tensor -> tensor<2xindex> +// CHECK-DAG: %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<2xindex>, index -> index +// CHECK-DAG: %[[V3:.*]] = get_extent %[[V1]], %c0 : tensor<2xindex>, index -> index +// CHECK-DAG: %[[V4:.*]] = arith.addi %[[V2]], %[[V3]] : index +// CHECK-DAG: %[[V5:.*]] = arith.addi %[[V4]], %[[V2]] : index +// CHECK-DAG: %[[V6:.*]] = from_extents %[[V5]], %c4 : index, index +// CHECK-DAG: return %[[V6]] : !shape.shape + +// CHECK: shape.func private @shape_cal_0(%arg0: tensor, %arg1: tensor) -> !shape.shape { +// CHECK-DAG: %[[V0:.*]] = shape_of %arg0 : tensor -> tensor<2xindex> +// CHECK-DAG: %[[V1:.*]] = shape_of %arg1 : tensor -> tensor<2xindex> +// CHECK-DAG: %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<2xindex>, index -> index +// CHECK-DAG: %[[V3:.*]] = get_extent %[[V1]], %c0 : tensor<2xindex>, index -> index +// CHECK-DAG: %[[V4:.*]] = arith.addi %[[V2]], %[[V3]] : index +// CHECK-DAG: %[[V5:.*]] = from_extents %[[V4]], %c4 : index, index +// CHECK-DAG: return %[[V5]] : !shape.shape Index: mlir/test/lib/Dialect/Shape/CMakeLists.txt =================================================================== --- mlir/test/lib/Dialect/Shape/CMakeLists.txt +++ mlir/test/lib/Dialect/Shape/CMakeLists.txt @@ -1,6 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRShapeTestPasses TestShapeFunctions.cpp + TestShapeMappingAnalysis.cpp EXCLUDE_FROM_LIBMLIR @@ -11,6 +12,7 @@ LINK_LIBS PUBLIC MLIRIR MLIRPass + MLIRShapeOpsTransforms MLIRShapeDialect MLIRSupport ) Index: mlir/test/lib/Dialect/Shape/TestShapeMappingAnalysis.cpp =================================================================== --- /dev/null +++ mlir/test/lib/Dialect/Shape/TestShapeMappingAnalysis.cpp @@ -0,0 +1,43 @@ +//===- TestShapeMappingInfo.cpp -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +struct TestShapeMappingPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestShapeMappingPass) + + StringRef getArgument() const final { return "test-print-shape-mapping"; } + StringRef getDescription() const final { + return "Print the contents of a constructed shape mapping information."; + } + void runOnOperation() override { + llvm::Optional> + maybeAnalysis = getCachedAnalysis(); + if (maybeAnalysis.hasValue()) + maybeAnalysis.value().get().print(llvm::errs()); + else + llvm::errs() << "No cached ShapeMappingAnalysis existed."; + } +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestShapeMappingPass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir Index: mlir/tools/mlir-opt/mlir-opt.cpp =================================================================== --- mlir/tools/mlir-opt/mlir-opt.cpp +++ mlir/tools/mlir-opt/mlir-opt.cpp @@ -109,6 +109,7 @@ void registerTestPreparationPassWithAllowedMemrefResults(); void registerTestRecursiveTypesPass(); void registerTestSCFUtilsPass(); +void registerTestShapeMappingPass(); void registerTestSliceAnalysisPass(); void registerTestTensorTransforms(); void registerTestTilingInterface(); @@ -208,6 +209,7 @@ mlir::test::registerTestPDLLPasses(); mlir::test::registerTestRecursiveTypesPass(); mlir::test::registerTestSCFUtilsPass(); + mlir::test::registerTestShapeMappingPass(); mlir::test::registerTestSliceAnalysisPass(); mlir::test::registerTestTensorTransforms(); mlir::test::registerTestTilingInterface();