Index: mlir/include/mlir/Dialect/Shape/Transforms/Passes.h =================================================================== --- mlir/include/mlir/Dialect/Shape/Transforms/Passes.h +++ mlir/include/mlir/Dialect/Shape/Transforms/Passes.h @@ -18,6 +18,7 @@ namespace mlir { class ConversionTarget; +class ModuleOp; class TypeConverter; namespace func { class FuncOp; @@ -50,6 +51,11 @@ // level. std::unique_ptr> createShapeBufferizePass(); +// Outline the shape computation part by adding shape.func and corresponding +// symbols mapping to it.ยท +std::unique_ptr> +createOutlineShapeComputationPass(const std::string &entryFunc = ""); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// Index: mlir/include/mlir/Dialect/Shape/Transforms/Passes.td =================================================================== --- mlir/include/mlir/Dialect/Shape/Transforms/Passes.td +++ mlir/include/mlir/Dialect/Shape/Transforms/Passes.td @@ -11,6 +11,95 @@ include "mlir/Pass/PassBase.td" +def OutlineShapeComputation : Pass<"outline-shape-computation", "ModuleOp"> { + let summary = "Using shape.func to preserve shape computation"; + let description = [{ + This pass outlines the shape computation part in high level IR by adding + shape.func and corresponding symbols mapping to it. The shape computation + part is usually introduced by shape reification, and each single dynamic + shape is denoted by shape.with_shape. + + There're two main reasons this shape-outline pass is needed: + 1. Many passes don't take shape reification part into consideration. + Therefore we need to "remove" the shape reification part temporarily for + these passes. + 2. Sometimes we cannot redo shape reification after converting from dialect + A to dialect B. Because op-level shape reification is only implemented + on A. + + // TODO: what if the `encoding` attribute before the pass? + The symbol(s) is placed in the `encoding` attribute of RankedTensorType. If + the shape of the tensor is or equals to a dynamic shape source, it is a + single FlatSymbolRefAttr. Else it is a ArrayAttr, where the first element + represents corresponding shape function name, the others represents its + arguments. + + // TODO: support all functions after lenient call is supported. + Currently this pass is only limited to entry function. And internal + func.call is also not supported. + + Input: + ```mlir + func.func @main(%arg0: tensor, %arg1: tensor<2x4x?xf32>) -> + tensor { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %0 = shape.shape_of %arg0 : tensor -> tensor<3xindex> + %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index + %2 = "test.abs"(%arg0) : (tensor) -> tensor + %3 = shape.with_shape %2, %0 : tensor, tensor<3xindex> + %4 = shape.value_of %3 : tensor + %5 = "test.concat"(%4, %arg1) {axis = 0 : i64} : (tensor, + tensor<2x4x?xf32>) -> tensor + %6 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index + %7 = arith.addi %6, %c2 : index + %8 = shape.from_extents %7, %c4, %1 : index, index, index + %9 = shape.with_shape %5, %8 : tensor, !shape.shape + %10 = shape.value_of %9 : tensor + return %10 : tensor + } + ``` + + Output + ```mlir + func.func @main(%arg0: tensor, %arg1: tensor<2x4x?xf32>) -> + tensor { + %0 = "test.abs"(%arg0) : (tensor) -> tensor + %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor, tensor<2x4x?xf32>) -> tensor + return %1 : tensor + } + shape.func private @shape_cal_0(%arg0: tensor<3xindex>) -> !shape.shape { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %0 = get_extent %arg0, %c2 : tensor<3xindex>, index -> index + %1 = get_extent %arg0, %c0 : tensor<3xindex>, index -> index + %2 = arith.addi %1, %c2 : index + %3 = from_extents %2, %c4, %0 : index, index, index + return %3 : !shape.shape + } + ``` + + For the above example, the shape computation is inlined in the input IR, + which is used for two values' (test.abs and test.concat) shape. And the shape + compuatation part is outlined in the output IR. The test.abs' output shape + equals to arg0's , so the encoding is a simple symbol @arg_0. The + test.concat's output shape is bound to a shape.func, the first symbol + @shape_cal_0 denotes the function name, and the second symbol @arg_0 denotes + its argument. + }]; + let constructor = "mlir::createOutlineShapeComputationPass()"; + let options = [ + Option<"entryFunc", "entry-func", "std::string", + /*default=*/"\"main\"", + "Used to specify the entry-function.">, + ]; + let dependentDialects = ["shape::ShapeDialect"]; +} + def RemoveShapeConstraints : Pass<"remove-shape-constraints", "func::FuncOp"> { let summary = "Replace all cstr_ ops with a true witness"; let constructor = "mlir::createRemoveShapeConstraintsPass()"; Index: mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt +++ mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRShapeOpsTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + OutlineShapeComputation.cpp RemoveShapeConstraints.cpp ShapeToShapeLowering.cpp Index: mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp @@ -0,0 +1,398 @@ +//====----- OutlineShapeComputation.cpp -------------------------*- C++-*--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/Support/Debug.h" +#include +#include +#include + +#define DEBUG_TYPE "outline-shape-computation" + +using namespace mlir; + +namespace { + +// A Value is an input of the cluster if it is an operand of an operation in the +// cluster and its defining operation is not in the cluster. +ValueRange +getInputsOfCluster(const llvm::SmallVector &cluster) { + SmallVector inputs; + llvm::SmallDenseSet inputSet; + llvm::SmallDenseSet opSet; + for (Operation *op : cluster) { + bool inserted = opSet.insert(op).second; + (void)inserted; + assert(inserted && "cluster contains duplicate operations"); + } + + for (Operation *op : cluster) { + for (Value operand : op->getOperands()) { + Operation *operandOp = operand.getDefiningOp(); + if (opSet.find(operandOp) != opSet.end()) { + // skip if defining op is in the cluster + continue; + } + if (inputSet.insert(operand).second) + inputs.push_back(operand); + } + } + return inputs; +} + +// Create a shape.func representing the shape computation for `shape`. +std::pair> +createFuncFromCluster(OpBuilder &b, const SmallVector &cluster, + Value shape, StringRef fnName, Location loc) { + assert(!cluster.empty() && + "There must be at least one element in the cluster!"); + + ValueRange inputs = getInputsOfCluster(cluster); + auto fnType = b.getFunctionType(inputs.getTypes(), shape.getType()); + b.setInsertionPointAfter(cluster[0]->getParentOp()); + shape::FuncOp fnOp = b.create(loc, fnName, fnType); + Block *block = fnOp.addEntryBlock(); + b.setInsertionPoint(block, block->end()); + BlockAndValueMapping bvm; + for (auto inputAndArg : llvm::zip(inputs, fnOp.getArguments())) { + bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg)); + } + for (Operation *op : cluster) { + b.clone(*op, bvm); + } + llvm::SmallVector fnReturns; + fnReturns.push_back(bvm.lookupOrDefault(shape)); + + b.create(UnknownLoc::get(b.getContext()), fnReturns); + fnOp.setPrivate(); + return std::make_pair(fnOp, inputs); +} + +// The operations in the cluster might be unsorted, which could be inconvenient +// when creating shape.func op. +DenseMap> +getOrderedClusters(const DenseMap> &clusters, + func::FuncOp funcOp) { + // Compute all clusters that each operation is in + DenseMap> op2Shapes; + for (auto it : clusters) { + Value shape = it.first; + const DenseSet &cluster = it.second; + for (Operation *cOp : cluster) { + op2Shapes[cOp].push_back(shape); + } + } + + // Iterate through all operations in order. Get all the clusters `cOp` belongs + // to and construct the new ordered cluster as it traverses. + DenseMap> orderedClusters; + funcOp.walk([&](Operation *op) { + auto it = op2Shapes.find(op); + if (it != op2Shapes.end()) { + Operation *cOp = it->first; + for (Value shape : it->second) { + orderedClusters[shape].push_back(cOp); + } + } + }); + + return orderedClusters; +} + +// Increment `idx` until find the next available symbol name +std::string +getNextAvailableSymbolName(const std::string &prefix, int &idx, + std::unordered_set &usedSymbolNames) { + std::string name = prefix + std::to_string(idx++); + + while (usedSymbolNames.count(name)) { + name = prefix + std::to_string(idx++); + } + usedSymbolNames.insert(name); + return name; +} + +// return argument index if `shape` is the output of a +// shape.shape_of(func_arg), else return -1. +llvm::Optional getShapeOfFuncArgIdx(Value shape, func::FuncOp funcOp) { + shape::ShapeOfOp shapeOfOp = shape.getDefiningOp(); + if (shapeOfOp == nullptr) + return llvm::None; + Value inp = shapeOfOp.getArg(); + for (int i = 0; i < int(funcOp.getNumArguments()); ++i) { + if (funcOp.getArgument(i) == inp) + return i; + } + + return llvm::None; +} + +void constructShapeFunc( + const std::vector &allWithOps, MLIRContext *context, + DenseMap> &clusters, + SymbolTable &symbolTable, func::FuncOp funcOp, + std::unordered_set &usedSymbolNames, + DenseMap &dynShapeSrc2Symbol, + DenseMap> &dynShape2ShapeFunc) { + std::string dynamicSourceNamePrefix = "s"; + int dynamicSourceNameIdx = 0; + std::string shapeCalculationNamePrefix = "shape_cal_"; + int shapeCalculationNameIdx = 0; + OpBuilder builder(context); + + auto getOrConstructSymbolFromShape = [&](Value shape) { + auto symbolIt = dynShapeSrc2Symbol.find(shape); + if (symbolIt == dynShapeSrc2Symbol.end()) { + std::string name; + llvm::Optional index = getShapeOfFuncArgIdx(shape, funcOp); + if (index.has_value()) + name = "arg_" + std::to_string(index.value()); + else + name = getNextAvailableSymbolName( + dynamicSourceNamePrefix, dynamicSourceNameIdx, usedSymbolNames); + + auto symbol = FlatSymbolRefAttr::get(context, name); + dynShapeSrc2Symbol[shape] = symbol; + return symbol; + } else { + return symbolIt->second; + } + }; + + // Construct a shape function or a symbol for each cluster + for (shape::WithOp withOp : allWithOps) { + Value value = withOp.getOperand(); + Value shape = withOp.getShape(); + RankedTensorType rankedType = value.getType().dyn_cast(); + if (rankedType == nullptr) + continue; + const SmallVector &cluster = clusters[shape]; + // The cluster is empty when the shape is equal to a dynamic shape source + if (cluster.empty()) { + FlatSymbolRefAttr symbol = getOrConstructSymbolFromShape(shape); + value.setType(RankedTensorType::get(rankedType.getShape(), + rankedType.getElementType(), symbol)); + LLVM_DEBUG(llvm::dbgs() + << "Symbol for " << shape << ": " << symbol << "\n"); + } else { + auto it = dynShape2ShapeFunc.find(shape); + SmallVector symbols; + if (it == dynShape2ShapeFunc.end()) { + std::string name = getNextAvailableSymbolName( + shapeCalculationNamePrefix, shapeCalculationNameIdx, + usedSymbolNames); + Location loc = value.getLoc(); + auto pair = createFuncFromCluster(builder, cluster, shape, name, loc); + const SmallVector &inputs = pair.second; + shape::FuncOp shapeFuncOp = pair.first; + StringAttr insertedName = symbolTable.insert(shapeFuncOp); + auto symbol = FlatSymbolRefAttr::get(context, insertedName); + symbols.push_back(symbol); + for (Value inp : inputs) { + FlatSymbolRefAttr argSymbol = getOrConstructSymbolFromShape(inp); + symbols.push_back(argSymbol); + } + dynShape2ShapeFunc[shape] = symbols; + } else { + symbols = it->second; + } + auto arrayAttr = ArrayAttr::get(context, symbols); + LLVM_DEBUG(llvm::dbgs() + << "Symbol for " << shape << ": " << arrayAttr << "\n"); + value.setType(RankedTensorType::get( + rankedType.getShape(), rankedType.getElementType(), arrayAttr)); + } + } +} + +struct OutlineShapeComputationPass + : public OutlineShapeComputationBase { + + OutlineShapeComputationPass(const std::string &entryFunc) + : OutlineShapeComputationBase() { + this->entryFunc = entryFunc; + } + + void runOnOperation() override; + +private: + bool calOnlyUsedByWithShapesRecursively(Operation *op, Value prevOutput); + void getClusterFromValue(Value shape, + DenseMap> &clusters); + DenseMap> + constructClustersForEachShape(const std::vector &allWithOps, + func::FuncOp funcOp); + DenseMap onlyUsedByWithShapes_; +}; + +class TensorDimOpRewriter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::DimOp op, + PatternRewriter &rewriter) const override { + auto shapeOf = + rewriter.create(op.getLoc(), op.getSource()); + rewriter.replaceOpWithNewOp(op, op.getType(), shapeOf, + op.getIndex()); + return success(); + } +}; + +void OutlineShapeComputationPass::runOnOperation() { + ModuleOp moduleOp = getOperation(); + SymbolTable symbolTable(moduleOp); + // `usedSymbolNames` makes sure the created symbol is unique at Moudle level + std::unordered_set usedSymbolNames; + DenseMap dynShapeSrc2Symbol; + DenseMap> dynShape2ShapeFunc; + + moduleOp.walk([&](func::FuncOp funcOp) { + if (funcOp.getName() != entryFunc) + return; + + MLIRContext *context = funcOp.getContext(); + RewritePatternSet prevPatterns(context); + prevPatterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(prevPatterns)))) + return signalPassFailure(); + + // initialize class member `onlyUsedByWithShapes_` + onlyUsedByWithShapes_.clear(); + funcOp.walk([&](Operation *op) { + calOnlyUsedByWithShapesRecursively(op, /*prevOutput=*/nullptr); + }); + // clang-format off + LLVM_DEBUG(llvm::dbgs() << "onlyUsedByWithShapes_ table: \n"; + for (auto it : onlyUsedByWithShapes_) + if (it.second) + llvm::dbgs() << *(it.first) << ": " << it.second << "\n";); + // clang-format on + + // collect all the shape.with_shape ops. + std::vector allWithOps; + funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); }); + + DenseMap> clusters = + constructClustersForEachShape(allWithOps, funcOp); + constructShapeFunc(allWithOps, context, clusters, symbolTable, funcOp, + usedSymbolNames, dynShapeSrc2Symbol, dynShape2ShapeFunc); + + for (shape::WithOp withOp : allWithOps) { + Value value = withOp.getOperand(); + for (Operation *user : withOp.getResult().getUsers()) { + if (Value valueOf = llvm::dyn_cast(user)) + valueOf.replaceAllUsesExcept(value, withOp); + } + } + + // dce + if (failed(applyPatternsAndFoldGreedily(funcOp, {}))) + return signalPassFailure(); + + funcOp.setType( + FunctionType::get(context, funcOp.front().getArgumentTypes(), + funcOp.front().getTerminator()->getOperandTypes())); + }); +} + +DenseMap> +OutlineShapeComputationPass::constructClustersForEachShape( + const std::vector &allWithOps, func::FuncOp funcOp) { + DenseMap> clusters; + for (shape::WithOp withOp : allWithOps) { + Value shape = withOp.getShape(); + if (clusters.count(shape) == 0) + getClusterFromValue(shape, clusters); + } + return getOrderedClusters(clusters, funcOp); +} + +// The output of a cluster is the `shape`, and the inputs are either the result +// of shape.shape_of or function argument. +void OutlineShapeComputationPass::getClusterFromValue( + Value shape, DenseMap> &clusters) { + DenseSet cluster; + + Operation *defOp = shape.getDefiningOp(); + // defOp == nullptr means shape is the argument of the func op + if (nullptr == defOp) { + return; + } + + DenseSet visited; + std::queue queue; + visited.insert(defOp); + queue.push(defOp); + while (!queue.empty()) { + Operation *op = queue.front(); + queue.pop(); + if (op->getNumOperands() == 0) { + cluster.insert(op); + } else if (llvm::isa(op) && + onlyUsedByWithShapes_.count(op)) { + // Stop when the op is type of shape.shape_of and only used by + // shape.with_shape ops + continue; + } else { + cluster.insert(op); + for (Value inp : op->getOperands()) { + Operation *inpDefOp = inp.getDefiningOp(); + if (nullptr != inpDefOp && !visited.contains(inpDefOp)) { + visited.insert(inpDefOp); + queue.push(inpDefOp); + } + } + } + } + + clusters[shape] = std::move(cluster); +} + +// return true if `op` is a shape.with_shape, or all the users' of `op` +// eventually point to the shape operand of shape.with_shape ops +bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively( + Operation *op, Value prevOutput) { + auto it = onlyUsedByWithShapes_.find(op); + if (it != onlyUsedByWithShapes_.end()) + return it->second; + + if (auto withOp = llvm::dyn_cast(op)) { + bool retVal = withOp.getShape() == prevOutput; + return retVal; + } + + if (op->use_empty()) { + onlyUsedByWithShapes_[op] = false; + return false; + } + + bool allUsers = true; + for (Value oup : op->getResults()) + for (Operation *user : oup.getUsers()) + allUsers &= calOnlyUsedByWithShapesRecursively(user, oup); + + onlyUsedByWithShapes_[op] = allUsers; + return allUsers; +} + +} // namespace + +std::unique_ptr> +mlir::createOutlineShapeComputationPass(const std::string &entryFunc) { + return std::make_unique(entryFunc); +} Index: mlir/lib/Dialect/Shape/Transforms/PassDetail.h =================================================================== --- mlir/lib/Dialect/Shape/Transforms/PassDetail.h +++ mlir/lib/Dialect/Shape/Transforms/PassDetail.h @@ -10,6 +10,7 @@ #define DIALECT_SHAPE_TRANSFORMS_PASSDETAIL_H_ #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" namespace mlir { @@ -22,6 +23,10 @@ class MemRefDialect; } // namespace memref +namespace shape { +class ShapeDialect; +} // namespace shape + #define GEN_PASS_CLASSES #include "mlir/Dialect/Shape/Transforms/Passes.h.inc" Index: mlir/test/Dialect/Shape/outline-shape-computation.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Shape/outline-shape-computation.mlir @@ -0,0 +1,92 @@ +// RUN: mlir-opt -outline-shape-computation="entry-func=main" -split-input-file -allow-unregistered-dialect %s | FileCheck %s + +// two dynamic shapes: a trivial symbol and a shape.func +func.func @main(%arg0: tensor, %arg1: tensor<2x4x?xf32>) -> tensor { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %0 = shape.shape_of %arg0 : tensor -> tensor<3xindex> + %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index + %2 = "test.abs"(%arg0) : (tensor) -> tensor + %3 = shape.with_shape %2, %0 : tensor, tensor<3xindex> + %4 = shape.value_of %3 : tensor + %5 = "test.concat"(%4, %arg1) {axis = 0 : i64} : (tensor, tensor<2x4x?xf32>) -> tensor + %6 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index + %7 = arith.addi %6, %c2 : index + %8 = shape.from_extents %7, %c4, %1 : index, index, index + %9 = shape.with_shape %5, %8 : tensor, !shape.shape + %10 = shape.value_of %9 : tensor + return %10 : tensor +} + +// CHECK-LABEL: func.func @main +// CHECK-SAME: (%arg0: tensor, %arg1: tensor<2x4x?xf32>) -> tensor +// CHECK-NEXT: %0 = "test.abs"(%arg0) : (tensor) -> tensor +// CHECK-NEXT: %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor, tensor<2x4x?xf32>) -> tensor +// CHECK-NEXT: return %1 : tensor + +// CHECK-LABEL: shape.func private @shape_cal_0 +// CHECK-SAME: (%arg0: tensor<3xindex>) -> !shape.shape +// CHECK-DAG: %[[V0:.*]] = get_extent %arg0, %c2 : tensor<3xindex>, index -> index +// CHECK-DAG: %[[V1:.*]] = get_extent %arg0, %c0 : tensor<3xindex>, index -> index +// CHECK-DAG: %[[V2:.*]] = arith.addi %[[V1]], %c2 : index +// CHECK-DAG: %[[V3:.*]] = from_extents %[[V2]], %c4, %[[V0]] : index, index, index +// CHECK-DAG: return %[[V3]] : !shape.shape + +// ----- + +// two dynamic shapes and they share the same shape.func +func.func @main(%arg0: tensor, %arg1: tensor<2x4x?xf32>) -> tensor { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %0 = shape.shape_of %arg0 : tensor -> tensor<3xindex> + %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index + %2 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor, tensor<2x4x?xf32>) -> tensor + %3 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index + %4 = arith.addi %3, %c2 : index + %5 = shape.from_extents %4, %c4, %1 : index, index, index + %6 = shape.with_shape %2, %5 : tensor, !shape.shape + %7 = shape.value_of %6 : tensor + %8 = "test.abs"(%7) : (tensor) -> tensor + %9 = shape.with_shape %8, %5 : tensor, !shape.shape + %10 = shape.value_of %9 : tensor + return %10 : tensor +} +// CHECK-LABEL: func.func @main +// CHECK-SAME: (%arg0: tensor, %arg1: tensor<2x4x?xf32>) -> tensor +// CHECK-NEXT: %0 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor, tensor<2x4x?xf32>) -> tensor +// CHECK-NEXT: %1 = "test.abs"(%0) : (tensor) -> tensor +// CHECK-NEXT: return %1 : tensor + +// CHECK-LABEL: shape.func private @shape_cal_0 +// CHECK-SAME: (%arg0: tensor<3xindex>) -> !shape.shape +// CHECK-DAG: %[[V0:.*]] = get_extent %arg0, %c2 : tensor<3xindex>, index -> index +// CHECK-DAG: %[[V1:.*]] = get_extent %arg0, %c0 : tensor<3xindex>, index -> index +// CHECK-DAG: %[[V2:.*]] = arith.addi %[[V1]], %c2 : index +// CHECK-DAG: %[[V3:.*]] = from_extents %[[V2]], %c4, %[[V0]] : index, index, index +// CHECK-DAG: return %[[V3]] : !shape.shape + + +// ----- + +// there's an internal dynamic shape source, and two other dynamic shapes shares it +func.func @main(%arg0: tensor) -> tensor { + %0 = "test.nonzero"(%arg0) : (tensor) -> tensor + %1 = shape.shape_of %0 : tensor -> tensor<1xindex> + %2 = shape.with_shape %0, %1 : tensor, tensor<1xindex> + %3 = shape.value_of %2 : tensor + %4 = "test.abs"(%3) : (tensor) -> tensor + %5 = shape.with_shape %4, %1 : tensor, tensor<1xindex> + %6 = shape.value_of %5 : tensor + %7 = "test.negate"(%6) : (tensor) -> tensor + %8 = shape.with_shape %7, %1 : tensor, tensor<1xindex> + %9 = shape.value_of %8 : tensor + return %9 : tensor +} +// CHECK-LABEL: func.func @main +// CHECK-SAME: (%arg0: tensor) -> tensor +// CHECK-NEXT: %0 = "test.nonzero"(%arg0) : (tensor) -> tensor +// CHECK-NEXT: %1 = "test.abs"(%0) : (tensor) -> tensor +// CHECK-NEXT: %2 = "test.negate"(%1) : (tensor) -> tensor +// CHECK-NEXT: return %2 : tensor