diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -604,6 +604,20 @@ }]; } +def Shape_ValueOfOp : Shape_Op<"value_of", [NoSideEffect]> { + let summary = "Returns value of a !shape.value_shape operand"; + + let description = [{ + The operation takes !shape.value_shape as an argument and it returns a + tensor, which represents the value of (value, shape) tuple of the argument. + }]; + + let arguments = (ins Shape_ValueShapeType:$arg); + let results = (outs AnyShaped:$result); + + let assemblyFormat = "$arg attr-dict `:` type($result)"; +} + def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [ DeclareOpInterfaceMethods, NoSideEffect ]> { @@ -686,7 +700,7 @@ }]; let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$operand, - Shape_ShapeType:$shape); + Shape_ShapeOrExtentTensorType:$shape); let results = (outs Shape_ValueShapeType:$result); let assemblyFormat = "operands attr-dict `:` type($operand) `,` type($shape)"; @@ -1060,7 +1074,20 @@ OptionalAttr:$sym_visibility); let regions = (region AnyRegion:$body); + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs, + CArg<"ArrayRef", "{}">:$argAttrs) + >]; + let extraClassDeclaration = [{ + static FuncOp create(Location location, StringRef name, FunctionType type, + ArrayRef attrs = {}); + static FuncOp create(Location location, StringRef name, FunctionType type, + Operation::dialect_attr_range attrs); + static FuncOp create(Location location, StringRef name, FunctionType type, + ArrayRef attrs, + ArrayRef argAttrs); //===------------------------------------------------------------------===// // CallableOpInterface //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h @@ -18,6 +18,7 @@ namespace mlir { class ConversionTarget; +class ModuleOp; class TypeConverter; namespace func { class FuncOp; @@ -50,6 +51,15 @@ // level. std::unique_ptr> createShapeBufferizePass(); +// Outline the shape computation part by adding shape.func and corresponding +// symbols mapping to it. +// The symbol(s) is placed in the encoding attribute of RankedTensorType. If the +// shape of the tensor is or equals to a dynamic shape source, it is a single +// FlatSymbolRefAttr. Else it is a ArrayAttr, where the first element represents +// corresponding shape function name, the others represents the arguments. +std::unique_ptr> +createOutlineShapeComputationPass(const std::string &entryFunc = ""); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td @@ -11,6 +11,17 @@ include "mlir/Pass/PassBase.td" +def OutlineShapeComputation : Pass<"outline-shape-computation", "ModuleOp"> { + let summary = "Using shape.func to preserve shape computation."; + let constructor = "mlir::createOutlineShapeComputationPass()"; + let options = [ + Option<"entryFunc", "entry-func", "std::string", + /*default=*/"\"main\"", + "Used to specify the entry-function.">, + ]; + let dependentDialects = ["shape::ShapeDialect"]; +} + def RemoveShapeConstraints : Pass<"remove-shape-constraints", "func::FuncOp"> { let summary = "Replace all cstr_ ops with a true witness"; let constructor = "mlir::createRemoveShapeConstraintsPass()"; diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1215,6 +1215,43 @@ // FuncOp //===----------------------------------------------------------------------===// +FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, + ArrayRef attrs) { + OpBuilder builder(location->getContext()); + OperationState state(location, getOperationName()); + FuncOp::build(builder, state, name, type, attrs); + return cast(Operation::create(state)); +} +FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, + Operation::dialect_attr_range attrs) { + SmallVector attrRef(attrs); + return create(location, name, type, llvm::makeArrayRef(attrRef)); +} +FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, + ArrayRef attrs, + ArrayRef argAttrs) { + FuncOp func = create(location, name, type, attrs); + func.setAllArgAttrs(argAttrs); + return func; +} + +void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, + FunctionType type, ArrayRef attrs, + ArrayRef argAttrs) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(FunctionOpInterface::getTypeAttrName(), + TypeAttr::get(type)); + state.attributes.append(attrs.begin(), attrs.end()); + state.addRegion(); + + if (argAttrs.empty()) + return; + assert(type.getNumInputs() == argAttrs.size()); + function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs, + /*resultAttrs=*/llvm::None); +} + ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { auto buildFuncType = [](Builder &builder, ArrayRef argTypes, ArrayRef results, diff --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRShapeOpsTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + OutlineShapeComputation.cpp RemoveShapeConstraints.cpp ShapeToShapeLowering.cpp diff --git a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp @@ -0,0 +1,387 @@ +//====----- OutlineShapeComputation.cpp -------------------------*- C++-*--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/Support/Debug.h" +#include +#include +#include + +#define DEBUG_TYPE "outline-shape-computation" + +using namespace mlir; + +namespace { + +// A Value is an input of the cluster if it is an operand of an operation in the +// cluster and its defining operation is not in the cluster. +SmallVector +getInputsOfCluster(const llvm::SmallVector &cluster) { + SmallVector inputs; + llvm::SmallDenseSet inputSet; + llvm::SmallDenseSet opSet; + for (Operation *op : cluster) { + bool inserted = opSet.insert(op).second; + (void)inserted; + assert(inserted && "cluster contains duplicate operations"); + } + + for (Operation *op : cluster) { + for (Value operand : op->getOperands()) { + Operation *operandOp = operand.getDefiningOp(); + if (opSet.find(operandOp) != opSet.end()) { + // skip if defining op is in the cluster + continue; + } + if (inputSet.insert(operand).second) { + inputs.push_back(operand); + } + } + } + return inputs; +} + +// Create a shape.func representing the shape computation for \p shape. +std::pair> +createFuncFromCluster(OpBuilder &b, const SmallVector &cluster, + Value shape, StringRef fnName) { + if (cluster.size() == 0) { + llvm_unreachable("There must be at least one element in the cluster."); + } + + SmallVector inputs = getInputsOfCluster(cluster); + SmallVector outputTypes{shape.getType()}; + SmallVector inputTypes = llvm::to_vector( + llvm::map_range(inputs, [](Value inp) { return inp.getType(); })); + + auto fnType = b.getFunctionType(inputTypes, outputTypes); + b.setInsertionPointAfter(cluster[0]->getParentOp()); + shape::FuncOp fnOp = + b.create(UnknownLoc::get(b.getContext()), fnName, fnType); + Block *block = fnOp.addEntryBlock(); + b.setInsertionPoint(block, block->end()); + BlockAndValueMapping bvm; + for (auto inputAndArg : llvm::zip(inputs, fnOp.getArguments())) { + bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg)); + } + for (Operation *op : cluster) { + b.clone(*op, bvm); + } + llvm::SmallVector fnReturns; + fnReturns.push_back(bvm.lookupOrDefault(shape)); + + b.create(UnknownLoc::get(b.getContext()), fnReturns); + fnOp.setPrivate(); + return std::make_pair(fnOp, inputs); +} + +// The operations in the cluster might be unsorted, which could be inconvinient +// when creating shape.func op. +DenseMap> +getOrderedClusters(const DenseMap> &clusters, + func::FuncOp funcOp) { + DenseMap> op2Shapes; + for (auto it : clusters) { + Value shape = it.first; + const DenseSet &cluster = it.second; + for (Operation *cOp : cluster) { + op2Shapes[cOp].push_back(shape); + } + } + + DenseMap> orderedClusters; + funcOp.walk([&](Operation *op) { + auto it = op2Shapes.find(op); + if (it != op2Shapes.end()) { + Operation *cOp = it->first; + for (Value shape : it->second) { + orderedClusters[shape].push_back(cOp); + } + } + }); + + return orderedClusters; +} + +// Increment \p idx until find the next available symbol name +std::string +getNextAvailableSymbolName(const std::string &prefix, int &idx, + std::unordered_set &usedSymbolNames) { + std::string name = prefix + std::to_string(idx++); + + while (usedSymbolNames.count(name)) { + name = prefix + std::to_string(idx++); + } + usedSymbolNames.insert(name); + return name; +} + +// return argument index if \p shape is the output of a +// shape.shape_of(func_arg), else return -1. +int getShapeOfFuncArgIdx(Value shape, func::FuncOp funcOp) { + shape::ShapeOfOp shapeOfOp = shape.getDefiningOp(); + if (shapeOfOp == nullptr) + return false; + Value inp = shapeOfOp.getArg(); + for (int i = 0; i < int(funcOp.getNumArguments()); ++i) { + if (funcOp.getArgument(i) == inp) { + return i; + } + } + + return -1; +} + +void constructShapeFunc(const std::vector &allWithOps, + MLIRContext *context, + DenseMap> &clusters, + SymbolTable &symbolTable, func::FuncOp funcOp) { + std::unordered_set usedSymbolNames; + DenseMap dynShapeSrc2Symbol; + std::string dynamicSourceNamePrefix = "s"; + int dynamicSourceNameIdx = 0; + std::string shapeCalculationNamePrefix = "shape_cal_"; + int shapeCalculationNameIdx = 0; + OpBuilder builder(context); + + auto getOrConstructSymbolFromShape = [&](Value shape) { + auto symbolIt = dynShapeSrc2Symbol.find(shape); + if (symbolIt == dynShapeSrc2Symbol.end()) { + std::string name; + int index = getShapeOfFuncArgIdx(shape, funcOp); + if (index >= 0) { + name = "arg_" + std::to_string(index); + } else { + name = getNextAvailableSymbolName( + dynamicSourceNamePrefix, dynamicSourceNameIdx, usedSymbolNames); + } + auto symbol = FlatSymbolRefAttr::get(context, name); + dynShapeSrc2Symbol[shape] = symbol; + return symbol; + } else { + return symbolIt->second; + } + }; + + // Construct a shape function or a symbol for each cluster + for (shape::WithOp withOp : allWithOps) { + Value value = withOp.getOperand(); + Value shape = withOp.getShape(); + RankedTensorType rankedType = value.getType().dyn_cast(); + if (rankedType == nullptr) { + continue; + } + const SmallVector &cluster = clusters[shape]; + // The cluster is empty when the shape is equal to a dynamic shape source + if (cluster.empty()) { + FlatSymbolRefAttr symbol = getOrConstructSymbolFromShape(shape); + value.setType(RankedTensorType::get(rankedType.getShape(), + rankedType.getElementType(), symbol)); + LLVM_DEBUG(llvm::dbgs() + << "Symbol for " << shape << ": " << symbol << "\n"); + } else { + SmallVector symbols; + std::string name = getNextAvailableSymbolName( + shapeCalculationNamePrefix, shapeCalculationNameIdx, usedSymbolNames); + auto pair = createFuncFromCluster(builder, cluster, shape, name); + const SmallVector &inputs = pair.second; + shape::FuncOp shapeFuncOp = pair.first; + StringAttr insertedName = symbolTable.insert(shapeFuncOp); + auto symbol = FlatSymbolRefAttr::get(context, insertedName); + symbols.push_back(symbol); + for (Value inp : inputs) { + FlatSymbolRefAttr argSymbol = getOrConstructSymbolFromShape(inp); + symbols.push_back(argSymbol); + } + auto arrayAttr = ArrayAttr::get(context, symbols); + LLVM_DEBUG(llvm::dbgs() + << "Symbol for " << shape << ": " << arrayAttr << "\n"); + value.setType(RankedTensorType::get( + rankedType.getShape(), rankedType.getElementType(), arrayAttr)); + } + } +} + +struct OutlineShapeComputationPass + : public OutlineShapeComputationBase { + + OutlineShapeComputationPass(const std::string &entryFunc) + : OutlineShapeComputationBase() { + this->entryFunc = entryFunc; + } + + void runOnOperation() override; + +private: + bool calOnlyUsedByWithShapesRecursively(Operation *op); + void getClusterFromValue(Value shape, + DenseMap> &clusters); + DenseMap> + constructClustersForEachShape(const std::vector &allWithOps, + func::FuncOp funcOp); + DenseMap onlyUsedByWithShapes_; +}; + +class TensorDimOpRewriter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::DimOp op, + PatternRewriter &rewriter) const override { + auto shapeOf = + rewriter.create(op.getLoc(), op.getSource()); + rewriter.replaceOpWithNewOp(op, op.getType(), shapeOf, + op.getIndex()); + return success(); + } +}; + +void OutlineShapeComputationPass::runOnOperation() { + ModuleOp moduleOp = getOperation(); + SymbolTable symbolTable(moduleOp); + + moduleOp.walk([&](func::FuncOp funcOp) { + if (funcOp.getName() != entryFunc) + return; + + MLIRContext *context = funcOp.getContext(); + RewritePatternSet prevPatterns(context); + prevPatterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(prevPatterns)))) { + return signalPassFailure(); + } + + // initialize class member \p onlyUsedByWithShapes_ + onlyUsedByWithShapes_.clear(); + funcOp.walk([&](Operation *op) { calOnlyUsedByWithShapesRecursively(op); }); + + // collect all the shape.with_shape ops. + std::vector allWithOps; + funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); }); + + DenseMap> clusters = + constructClustersForEachShape(allWithOps, funcOp); + constructShapeFunc(allWithOps, context, clusters, symbolTable, funcOp); + + for (shape::WithOp withOp : allWithOps) { + Value value = withOp.getOperand(); + for (Operation *user : withOp.getResult().getUsers()) { + if (Value valueOf = llvm::dyn_cast(user)) { + valueOf.replaceAllUsesExcept(value, withOp); + } + } + } + + // dce + if (failed(applyPatternsAndFoldGreedily(funcOp, {}))) { + return signalPassFailure(); + } + + funcOp.setType( + FunctionType::get(context, funcOp.front().getArgumentTypes(), + funcOp.front().getTerminator()->getOperandTypes())); + }); +} + +DenseMap> +OutlineShapeComputationPass::constructClustersForEachShape( + const std::vector &allWithOps, func::FuncOp funcOp) { + DenseMap> clusters; + for (shape::WithOp withOp : allWithOps) { + Value shape = withOp.getShape(); + if (clusters.count(shape) == 0) { + getClusterFromValue(shape, clusters); + } + } + return getOrderedClusters(clusters, funcOp); +} + +// The output of a cluster is the \p shape, and the inputs are either the result +// of shape.shape_of or function argument. +void OutlineShapeComputationPass::getClusterFromValue( + Value shape, DenseMap> &clusters) { + DenseSet cluster; + + Operation *defOp = shape.getDefiningOp(); + // defOp == nullptr means shape is the argument of the func op + if (nullptr == defOp) { + return; + } + + DenseSet visited; + std::queue queue; + visited.insert(defOp); + queue.push(defOp); + while (!queue.empty()) { + Operation *op = queue.front(); + queue.pop(); + if (op->getNumOperands() == 0) { + cluster.insert(op); + } else if (llvm::isa(op) && + !onlyUsedByWithShapes_.count( + op->getOperand(0).getDefiningOp())) { + // Stop when the op is type of shape.shape_of and its operand isn't only + // used by shape.with_shape ops + continue; + } else { + cluster.insert(op); + for (Value inp : op->getOperands()) { + Operation *inpDefOp = inp.getDefiningOp(); + if (nullptr != inpDefOp && !visited.contains(inpDefOp)) { + visited.insert(inpDefOp); + queue.push(inpDefOp); + } + } + } + } + + clusters[shape] = std::move(cluster); +} + +// check if an operation is only used by shape.with_shape directly or +// indirectly. +bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively( + Operation *op) { + auto it = onlyUsedByWithShapes_.find(op); + if (it != onlyUsedByWithShapes_.end()) + return it->second; + + if (llvm::isa(op)) { + onlyUsedByWithShapes_[op] = true; + return true; + } + + if (op->use_empty()) { + onlyUsedByWithShapes_[op] = false; + return false; + } + + bool allUsers = true; + for (Operation *op : op->getUsers()) { + allUsers |= calOnlyUsedByWithShapesRecursively(op); + } + + onlyUsedByWithShapes_[op] = allUsers; + return allUsers; +} + +} // namespace + +std::unique_ptr> +mlir::createOutlineShapeComputationPass(const std::string &entryFunc) { + return std::make_unique(entryFunc); +} \ No newline at end of file diff --git a/mlir/lib/Dialect/Shape/Transforms/PassDetail.h b/mlir/lib/Dialect/Shape/Transforms/PassDetail.h --- a/mlir/lib/Dialect/Shape/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/Shape/Transforms/PassDetail.h @@ -10,6 +10,7 @@ #define DIALECT_SHAPE_TRANSFORMS_PASSDETAIL_H_ #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" namespace mlir { @@ -22,6 +23,10 @@ class MemRefDialect; } // namespace memref +namespace shape { +class ShapeDialect; +} // namespace shape + #define GEN_PASS_CLASSES #include "mlir/Dialect/Shape/Transforms/Passes.h.inc" diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -97,6 +97,11 @@ return %0 : tensor } +func.func @test_value_of(%arg0: !shape.value_shape) -> tensor { + %0 = shape.value_of %arg0 : tensor + return %0 : tensor +} + func.func @test_constraints() { %0 = shape.const_shape [] : !shape.shape %1 = shape.const_shape [1, 2, 3] : !shape.shape @@ -250,6 +255,7 @@ %2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape return %2 : !shape.shape } + func.func @shape_with_shape(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape { %0 = shape.shape_of %a : !shape.value_shape -> !shape.shape %1 = shape.with_shape %b, %0 : !shape.value_shape, !shape.shape @@ -257,6 +263,12 @@ return %2 : !shape.shape } +func.func @shape_with_shape_extent_tensor_type(%a : tensor, %b : !shape.value_shape) -> !shape.value_shape { + %0 = shape.shape_of %a : tensor -> tensor<3xindex> + %1 = shape.with_shape %b, %0 : !shape.value_shape, tensor<3xindex> + return %1 : !shape.value_shape +} + func.func @any_on_shape(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape) -> !shape.shape { %result = shape.any %a, %b, %c diff --git a/mlir/test/Dialect/Shape/outline-shape-computation.mlir b/mlir/test/Dialect/Shape/outline-shape-computation.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Shape/outline-shape-computation.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-opt -outline-shape-computation="entry-func=main" -split-input-file %s | FileCheck %s + + +func.func @main(%arg0: tensor, %arg1: tensor<2x4x?xf32>) -> tensor { + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %0 = shape.shape_of %arg0 : tensor -> tensor<3xindex> + %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index + %2 = "tosa.abs"(%arg0) : (tensor) -> tensor + %3 = shape.with_shape %2, %0 : tensor, tensor<3xindex> + %4 = shape.value_of %3 : tensor + %5 = "tosa.concat"(%4, %arg1) {axis = 0 : i64} : (tensor, tensor<2x4x?xf32>) -> tensor + %6 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index + %7 = arith.addi %6, %c2 : index + %8 = shape.from_extents %7, %c4, %1 : index, index, index + %9 = shape.with_shape %5, %8 : tensor, !shape.shape + %10 = shape.value_of %9 : tensor + return %10 : tensor +} + +// CHECK-LABEL: func.func @main +// CHECK-SAME: (%arg0: tensor, %arg1: tensor<2x4x?xf32>) -> tensor +// CHECK-NEXT: %0 = "tosa.abs"(%arg0) : (tensor) -> tensor +// CHECK-NEXT: %1 = "tosa.concat"(%0, %arg1) {axis = 0 : i64} : (tensor, tensor<2x4x?xf32>) -> tensor +// CHECK-NEXT: return %1 : tensor + +// CHECK-LABEL: shape.func private @shape_cal_0 +// CHECK-SAME: (%arg0: tensor<3xindex>) -> !shape.shape +// CHECK-DAG: %[[V0:.*]] = get_extent %arg0, %c2 : tensor<3xindex>, index -> index +// CHECK-DAG: %[[V1:.*]] = get_extent %arg0, %c0 : tensor<3xindex>, index -> index +// CHECK-DAG: %[[V2:.*]] = arith.addi %[[V1]], %c2 : index +// CHECK-DAG: %[[V3:.*]] = from_extents %[[V2]], %c4, %[[V0]] : index, index, index +// CHECK-DAG: return %[[V3]] : !shape.shape