diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1493,7 +1493,6 @@ class MemRef_ReassociativeReshapeOp traits = []> : MemRef_Op, - Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>, Results<(outs AnyStridedMemRef:$result)>{ code commonExtraClassDeclaration = [{ @@ -1518,10 +1517,6 @@ Value getViewSource() { return getSrc(); } }]; - let assemblyFormat = [{ - $src $reassociation attr-dict `:` type($src) `into` type($result) - }]; - let hasFolder = 1; let hasCanonicalizer = 1; let hasVerifier = 1; @@ -1543,14 +1538,10 @@ Example: ```mlir - %r = memref.expand_shape %0 [[0, 1], [2]] - : memref into memref + %r = memref.expand_shape %0 [[0, 1], [2]] [%sz0, %sz1, 32] + : memref into memref ``` - At most one dimension of a reassociation group (e.g., [0, 1] above) may be - dynamic in the result type. Otherwise, the op would be ambiguous, as it - would not be clear how the source dimension is extended. - If an op can be statically proven to be invalid (e.g, an expansion from `memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If it cannot statically be proven invalid (e.g., the full example above; it is @@ -1567,29 +1558,49 @@ there must be a dynamic result dimension in the corresponding reassociation group. Same for strides. + The representation for the output shape supports a partially-static + specification via attributes specified through the `static_output_shape` + argument. A special sentinel value `ShapedType::kDynamic` encodes that the + corresponding entry has a dynamic value. There must be exactly as many SSA + inputs in `output_shape` as there are `ShapedType::kDynamic` entries in + `static_output_shape`. + Note: This op currently assumes that the inner strides are of the source/result layout map are the faster-varying ones. }]; + let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation, + Variadic:$output_shape, + DenseI64ArrayAttr:$static_output_shape); + + let assemblyFormat = [{ + $src $reassociation `output_shape` + custom($output_shape, $static_output_shape) attr-dict `:` + type($src) `into` type($result) + }]; + let builders = [ // Builders using ReassociationIndices. OpBuilder<(ins "Type":$resultType, "Value":$src, "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), + "ArrayRef":$outputShape), [{ - build($_builder, $_state, resultType, src, attrs); - $_state.addAttribute("reassociation", - getReassociationIndicesAttribute($_builder, reassociation)); + auto [staticOutputShape, dynamicOutputShape] = + decomposeMixedValues(SmallVector(outputShape)); + build($_builder, $_state, resultType, src, + getReassociationIndicesAttribute($_builder, reassociation), + dynamicOutputShape, staticOutputShape); }]>, // Builder using ReassociationExprs. OpBuilder<(ins "Type":$resultType, "Value":$src, "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), + "ArrayRef":$outputShape), [{ auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); - build($_builder, $_state, resultType, src, reassociationMaps, attrs); + convertReassociationMapsToIndices(reassociation); + build($_builder, $_state, resultType, src, reassociationMaps, + outputShape); }]>, // Builder that infers the result layout map. The result shape must be @@ -1602,6 +1613,14 @@ static FailureOr computeExpandedType( MemRefType srcType, ArrayRef resultShape, ArrayRef reassociation); + + // Infer the output shape for a memref.expand_shape when it is possible + // to do so. + static LogicalResult inferOutputShape( + OpBuilder &b, Location loc, MemRefType expandedType, + ArrayRef reassociation, + ArrayRef inputShape, + SmallVectorImpl &outputShape); }]; let hasVerifier = 1; @@ -1652,6 +1671,12 @@ source/result layout map are the faster-varying ones. }]; + let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation); + + let assemblyFormat = [{ + $src $reassociation attr-dict `:` type($src) `into` type($result) + }]; + let builders = [ // Builders for a contracting reshape whose result type is computed from // `src` and `reassociation`. @@ -1663,7 +1688,7 @@ CArg<"ArrayRef", "{}">:$attrs), [{ auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); + convertReassociationMapsToIndices(reassociation); build($_builder, $_state, src, reassociationMaps, attrs); }]>, @@ -1681,7 +1706,7 @@ CArg<"ArrayRef", "{}">:$attrs), [{ auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); + convertReassociationMapsToIndices(reassociation); build($_builder, $_state, resultType, src, reassociationMaps, attrs); }]> ]; diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -970,7 +970,6 @@ Tensor_Op, Pure])>, - Arguments<(ins AnyTensor:$src, IndexListArrayAttr:$reassociation)>, Results<(outs AnyTensor:$result)> { code commonExtraClassDeclaration = [{ @@ -994,10 +993,6 @@ } }]; - let assemblyFormat = [{ - $src $reassociation attr-dict `:` type($src) `into` type($result) - }]; - let hasFolder = 1; let hasCanonicalizer = 1; let hasVerifier = 1; @@ -1010,11 +1005,16 @@ rank whose sizes are a reassociation of the original `src`. A reassociation is defined as a continuous grouping of dimensions and is - represented with an array of DenseI64ArrayAttr attribute. + represented with an array of DenseI64ArrayAttr attribute. The reassociation + maps applied to the result tensor with the higher rank must result in the + operand tensor with the smaller rank. - The verification rule is that the reassociation maps are applied to the - result tensor with the higher rank to obtain the operand tensor with the - smaller rank. + The representation for the output shape supports a partially-static + specification via attributes specified through the `static_output_shape` + argument. A special sentinel value `ShapedType::kDynamic` encodes that the + corresponding entry has a dynamic value. There must be exactly as many SSA + inputs in `output_shape` as there are `ShapedType::kDynamic` entries in + `static_output_shape`. The operand tensor type of a reshape can be zero-ranked if the result tensor type is statically shaped with all dimensions being unit extent. In @@ -1024,32 +1024,54 @@ ```mlir // Dimension expansion i -> (i', j') and (k) -> (k') - %b = tensor.expand_shape %a [[0, 1], [2]] - : tensor into tensor + %b = tensor.expand_shape %a [[0, 1], [2]] [%sz0, %sz1, 32] + : tensor into tensor ``` }]; + + let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation, + Variadic:$output_shape, + DenseI64ArrayAttr:$static_output_shape); + + let assemblyFormat = [{ + $src $reassociation `output_shape` + custom($output_shape, $static_output_shape) attr-dict `:` + type($src) `into` type($result) + }]; + let builders = [ // Builders using ReassociationIndices. OpBuilder<(ins "Type":$resultType, "Value":$src, "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), + "ArrayRef":$outputShape), [{ - build($_builder, $_state, resultType, src, attrs); - $_state.addAttribute("reassociation", - getReassociationIndicesAttribute($_builder, reassociation)); + auto [staticOutputShape, dynamicOutputShape] = + decomposeMixedValues(SmallVector(outputShape)); + build($_builder, $_state, resultType, src, + getReassociationIndicesAttribute($_builder, reassociation), + dynamicOutputShape, staticOutputShape); }]>, OpBuilder<(ins "Type":$resultType, "Value":$src, "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), + "ArrayRef":$outputShape), [{ auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); - build($_builder, $_state, resultType, src, reassociationMaps, attrs); + convertReassociationMapsToIndices(reassociation); + build($_builder, $_state, resultType, src, reassociationMaps, + outputShape); }]> ]; let extraClassDeclaration = commonExtraClassDeclaration # [{ int64_t getCorrespondingSourceDim(int64_t resultDim); + + // Infer the output shape for a tensor.expand_shape when it is possible + // to do so. + static LogicalResult inferOutputShape( + OpBuilder &b, Location loc, RankedTensorType expandedType, + ArrayRef reassociation, + ArrayRef inputShape, + SmallVectorImpl &outputShape); }]; let hasVerifier = 1; @@ -1057,6 +1079,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> { let summary = "operation to produce a tensor with a smaller rank"; + let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation); let description = [{ The `tensor.collapse_shape` op produces a new tensor with a smaller rank whose sizes are a reassociation of the original `src`. @@ -1080,6 +1103,11 @@ : tensor into tensor ``` }]; + + let assemblyFormat = [{ + $src $reassociation attr-dict `:` type($src) `into` type($result) + }]; + let builders = [ // Builders for a contracting reshape whose result type is computed from // `src` and `reassociation`. @@ -1091,7 +1119,7 @@ CArg<"ArrayRef", "{}">:$attrs), [{ auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); + convertReassociationMapsToIndices(reassociation); build($_builder, $_state, src, reassociationMaps, attrs); }]>, @@ -1109,7 +1137,7 @@ CArg<"ArrayRef", "{}">:$attrs), [{ auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); + convertReassociationMapsToIndices(reassociation); build($_builder, $_state, resultType, src, reassociationMaps, attrs); }]> ]; diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -30,6 +30,29 @@ /// Attribute name for the ArrayAttr which encodes reassociation indices. constexpr StringRef getReassociationAttrName() { return "reassociation"; } +// Infer the output shape for a {memref|tensor}.expand_shape when it is possible +// to do so. +// +// Note: This should *only* be used to implement +// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces. +// If you need to infer the output shape you should use the static method of +// `ExpandShapeOp` instead of calling this. +// +// `inputShape` is the shape of the tensor or memref being expanded as a +// sequence of SSA values or constants. `expandedType` is the output shape of +// the expand_shape operation. `reassociation` is the reassociation denoting +// the output dims each input dim is mapped to. +// +// Returns the output shape in `outputShape` and `staticOutputShape`, following +// the conventions for the output_shape and static_output_shape inputs to the +// expand_shape ops. +LogicalResult +inferExpandShapeOutputShape(OpBuilder &b, Location loc, + RankedTensorType expandedType, + ArrayRef reassociation, + ArrayRef inputShape, + SmallVectorImpl &outputShape); + /// Compose reassociation maps that are used in pair of reshape ops where one /// is a producer and other is the consumer. Only valid to use this method when /// both the producer and consumer are collapsing dimensions or both are @@ -62,7 +85,7 @@ /// Convert Array> to Array>. SmallVector convertReassociationMapsToIndices( - OpBuilder &b, ArrayRef reassociationExprs); + ArrayRef reassociationExprs); /// Return the reassociations maps to use to reshape given the source type and /// the target type when possible. Return std::nullopt when this computation @@ -167,9 +190,11 @@ /// Returns true iff the type is a MemRefType and has a non-identity layout. bool hasNonIdentityLayout(Type type); +enum class ReshapeOpKind { kExpand, kCollapse }; + /// Pattern to collapse producer/consumer reshape ops that are both collapsing /// dimensions or are both expanding dimensions. -template +template struct ComposeReassociativeReshapeOps : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, @@ -192,8 +217,18 @@ rewriter.getContext()); if (!reassociationIndices) return failure(); - rewriter.replaceOpWithNewOp( - reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices); + + if constexpr (opKind == ReshapeOpKind::kExpand) { + SmallVector outputShape( + getMixedValues(reshapeOp.getStaticOutputShape(), + reshapeOp.getOutputShape(), rewriter)); + rewriter.replaceOpWithNewOp( + reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices, + outputShape); + } else { + rewriter.replaceOpWithNewOp( + reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices); + } return success(); } }; @@ -226,7 +261,8 @@ // /// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1` /// `reassociation_2` and produce `expand_shape`. -template +template struct ComposeCollapseOfExpandOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CollapseOpTy collapseOp, @@ -278,8 +314,31 @@ rewriter.replaceOpWithNewOp( collapseOp, resultType, expandOp.getSrc(), composedReassociation); } else if (srcRank < resultRank) { + auto tensorType = expandOp.getSrc().getType().template cast(); + SmallVector inputShape; + for (int64_t i = 0; i < tensorType.getRank(); ++i) { + if (tensorType.isDynamicDim(i)) { + Value size = + rewriter.create(expandOp.getLoc(), expandOp.getSrc(), i); + inputShape.push_back(size); + } else { + inputShape.push_back(rewriter.getIndexAttr(tensorType.getDimSize(i))); + } + } + + SmallVector outputShape; + if (failed(ExpandOpTy::inferOutputShape( + rewriter, collapseOp.getLoc(), + collapseOp.getType().template cast(), + composedReassociation, inputShape, outputShape))) { + return rewriter.notifyMatchFailure( + collapseOp, + "unable to infer output shape argument for tensor.expand_shape"); + } + rewriter.replaceOpWithNewOp( - collapseOp, resultType, expandOp.getSrc(), composedReassociation); + collapseOp, resultType, expandOp.getSrc(), composedReassociation, + outputShape); } else { // Collapses/expansions that do not change the rank are not allowed. Use // a cast instead. @@ -333,8 +392,11 @@ if (!composedReassociation) return failure(); + SmallVector outputShape(getMixedValues( + expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter)); rewriter.replaceOpWithNewOp( - expandOp, resultType, collapseOp.getSrc(), *composedReassociation); + expandOp, resultType, collapseOp.getSrc(), *composedReassociation, + outputShape); return success(); } diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -109,9 +109,8 @@ /// Decompose a vector of mixed static or dynamic values into the /// corresponding pair of arrays. This is the inverse function of /// `getMixedValues`. -std::pair> -decomposeMixedValues(Builder &b, - const SmallVectorImpl &mixedValues); +std::pair, SmallVector> +decomposeMixedValues(const SmallVectorImpl &mixedValues); } // namespace mlir diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -16,7 +16,6 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" @@ -1159,8 +1158,19 @@ b.create(loc, value); }); + Value result = generic.getResults()[0]; + SmallVector inputShape = + tensor::getMixedSizes(rewriter, loc, result); + SmallVector outputShape; + if (failed(tensor::ExpandShapeOp::inferOutputShape( + rewriter, loc, resultTy.cast(), + convertReassociationMapsToIndices(reassociationMap), inputShape, + outputShape))) { + return rewriter.notifyMatchFailure( + op, "unable to infer output shape argument for tensor.expand_shape"); + } rewriter.replaceOpWithNewOp( - op, resultTy, generic.getResults()[0], reassociationMap); + op, resultTy, result, reassociationMap, outputShape); return success(); } }; diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -197,8 +197,21 @@ return rewriter.notifyMatchFailure( reshape, "tosa.reshape Cannot expand into given shape"); } + + Value input = adaptor.getInput1(); + SmallVector inputShape = + tensor::getMixedSizes(rewriter, reshape.getLoc(), input); + SmallVector outputShape; + if (failed(tensor::ExpandShapeOp::inferOutputShape( + rewriter, reshape.getLoc(), resultTy.cast(), + convertReassociationMapsToIndices(reassociationMap), inputShape, + outputShape))) { + return rewriter.notifyMatchFailure( + reshape, + "unable to infer output shape argument for tensor.expand_shape"); + } rewriter.replaceOpWithNewOp( - reshape, resultTy, adaptor.getOperands()[0], reassociationMap); + reshape, resultTy, input, reassociationMap, outputShape); return success(); } }; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -454,9 +454,17 @@ return failure(); Location loc = oldFill.getLoc(); - auto newInit = rewriter.create( - loc, reshapeOp.getResultType(), oldFill.output(), - reshapeOp.getReassociation()); + TensorReshapeOp newInit; + if constexpr (std::is_same::value) + + newInit = rewriter.create( + loc, reshapeOp.getResultType(), oldFill.output(), + reshapeOp.getReassociation(), reshapeOp.getOutputShape(), + reshapeOp.getStaticOutputShape()); + else + newInit = rewriter.create(loc, reshapeOp.getResultType(), + oldFill.output(), + reshapeOp.getReassociation()); rewriter.replaceOpWithNewOp(reshapeOp, ValueRange{oldFill.value()}, ValueRange{newInit}); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -881,10 +881,21 @@ DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL();); // 5. Expand from the padded result to the stripMinedShape. + SmallVector padOpResultShape = + tensor::getMixedSizes(rewriter, loc, padOp.getResult()); + auto expandShapeResultType = + RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); + SmallVector outputShape; + if (failed(tensor::ExpandShapeOp::inferOutputShape( + rewriter, loc, expandShapeResultType, packingMetadata.reassociations, + padOpResultShape, outputShape))) { + return rewriter.notifyMatchFailure( + packOp, + "unable to infer output shape argument for tensor.expand_shape"); + } auto reshapeOp = rewriter.create( - loc, - RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), - padOp.getResult(), packingMetadata.reassociations); + loc, expandShapeResultType, padOp.getResult(), + packingMetadata.reassociations, outputShape); // 6. Transpose stripMinedShape to packedShape. SmallVector insertPositionsToLastDimsPerm = computePermutationVector( diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp @@ -181,8 +181,18 @@ result = genericOp.getResults().front(); } + SmallVector inputShape = + tensor::getMixedSizes(rewriter, loc, result); + SmallVector expandShapeOutputShape; + if (failed(tensor::ExpandShapeOp::inferOutputShape( + rewriter, loc, outputType.cast(), + outputReassocIndices, inputShape, expandShapeOutputShape))) { + return rewriter.notifyMatchFailure( + convOp, + "unable to infer output shape argument for tensor.expand_shape"); + } auto reshapedResult = rewriter.create( - loc, outputType, result, outputReassocIndices); + loc, outputType, result, outputReassocIndices, expandShapeOutputShape); rewriter.replaceOp(convOp, ArrayRef{reshapedResult}); @@ -329,9 +339,22 @@ SmallVector batchMatVecReassociationIndice = {{0, 1}, {2, 3}}; - Value batchMatVecResultReshaped = rewriter.create( - loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0), - batchMatVecReassociationIndice); + Value result = batchMatVecResult.getResult(0); + SmallVector inputShape = + tensor::getMixedSizes(rewriter, loc, result); + SmallVector expandShapeOutputShape; + if (failed(tensor::ExpandShapeOp::inferOutputShape( + rewriter, loc, + transposedOutputTensor.getType().cast(), + batchMatVecReassociationIndice, inputShape, + expandShapeOutputShape))) { + return rewriter.notifyMatchFailure( + convOp, + "unable to infer output shape argument for tensor.expand_shape"); + } + auto batchMatVecResultReshaped = rewriter.create( + loc, transposedOutputTensor.getType(), result, + batchMatVecReassociationIndice, expandShapeOutputShape); Value transposedResult = transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1}); @@ -480,8 +503,18 @@ result = genericOp.getResults().front(); } + SmallVector inputShape = + tensor::getMixedSizes(rewriter, loc, result); + SmallVector expandShapeOutputShape; + if (failed(tensor::ExpandShapeOp::inferOutputShape( + rewriter, loc, outputType.cast(), + outputReassocIndices, inputShape, expandShapeOutputShape))) { + return rewriter.notifyMatchFailure( + convOp, + "unable to infer output shape argument for tensor.expand_shape"); + } auto reshapedResult = rewriter.create( - loc, outputType, result, outputReassocIndices); + loc, outputType, result, outputReassocIndices, expandShapeOutputShape); rewriter.replaceOp(convOp, ArrayRef{reshapedResult}); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" @@ -433,9 +434,9 @@ rankReductionStrategy(rankReductionStrategy) {} // Expand the given value. - Value expandValue(Value result, Value origOutput, - ArrayRef reassociation, Location loc, - PatternRewriter &rewriter) const { + FailureOr expandValue(Value result, Value origOutput, + ArrayRef reassociation, + Location loc, PatternRewriter &rewriter) const { // There are no results for memref outputs. auto origResultType = origOutput.getType().cast(); if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { @@ -451,8 +452,18 @@ assert(rankReductionStrategy == RankReductionStrategy::ReassociativeReshape && "unknown rank reduction strategy"); - return rewriter.create(loc, origResultType, result, - reassociation); + SmallVector inputShape = + tensor::getMixedSizes(rewriter, loc, result); + SmallVector outputShape; + if (failed(tensor::ExpandShapeOp::inferOutputShape( + rewriter, loc, origResultType, reassociation, inputShape, + outputShape))) { + return failure(); + } + return rewriter + .create(loc, origResultType, result, + reassociation, outputShape) + .getResult(); } // Collapse the given value. @@ -578,8 +589,14 @@ resultReplacements.push_back(result.value()); continue; } - resultReplacements.push_back(expandValue( - result.value(), origOutput, reassociations[index], loc, rewriter)); + + FailureOr expandedValue = expandValue( + result.value(), origOutput, reassociations[index], loc, rewriter); + if (failed(expandedValue)) { + return rewriter.notifyMatchFailure(genericOp, + "unable to expand result"); + } + resultReplacements.push_back(*expandedValue); } rewriter.replaceOp(genericOp, resultReplacements); @@ -616,8 +633,20 @@ Location loc = sliceOp.getLoc(); Value newSlice = rewriter.create( loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides); + + SmallVector newSliceShape = + tensor::getMixedSizes(rewriter, loc, newSlice); + SmallVector outputShape; + if (failed(tensor::ExpandShapeOp::inferOutputShape( + rewriter, loc, resultType, *reassociation, newSliceShape, + outputShape))) { + return rewriter.notifyMatchFailure( + sliceOp, + "unable to infer output shape argument for tensor.expand_shape"); + } + rewriter.replaceOpWithNewOp( - sliceOp, resultType, newSlice, *reassociation); + sliceOp, resultType, newSlice, *reassociation, outputShape); return success(); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -792,9 +792,23 @@ reassociation, /*isExpandingReshape=*/true))) return std::nullopt; + Location loc = genericOp.getLoc(); + + SmallVector inputShape = + tensor::getMixedSizes(rewriter, loc, opOperand->get()); + SmallVector outputShape; + + if (failed(tensor::ExpandShapeOp::inferOutputShape( + rewriter, loc, expandedOperandType, reassociation, inputShape, + outputShape))) { + (void)rewriter.notifyMatchFailure( + genericOp, + "unable to infer output shape argument for tensor.expand_shape"); + return std::nullopt; + } expandedOpOperands.push_back(rewriter.create( genericOp.getLoc(), expandedOperandType, opOperand->get(), - reassociation)); + reassociation, outputShape)); continue; } } @@ -819,9 +833,24 @@ reassociation, /*isExpandingReshape=*/true))) return std::nullopt; + + Location loc = genericOp.getLoc(); + + Value operand = opOperand->get(); + SmallVector operandShape = + tensor::getMixedSizes(rewriter, loc, operand); + SmallVector outputShape; + if (failed(tensor::ExpandShapeOp::inferOutputShape( + rewriter, loc, expandedOutputType, reassociation, operandShape, + outputShape))) { + (void)rewriter.notifyMatchFailure( + genericOp, + "unable to infer output shape argument for tensor.expand_shape"); + return std::nullopt; + } + outputs.push_back(rewriter.create( - genericOp.getLoc(), expandedOutputType, opOperand->get(), - reassociation)); + loc, expandedOutputType, operand, reassociation, outputShape)); } else { outputs.push_back(opOperand->get()); } @@ -1519,8 +1548,22 @@ genericOp.getIndexingMapMatchingResult(originalResult.value()); SmallVector reassociation = getOperandReassociation(indexingMap, collapsingInfo); + + SmallVector collapsedOpShape = + tensor::getMixedSizes(rewriter, loc, collapsedOpResult); + SmallVector outputShape; + + if (failed(tensor::ExpandShapeOp::inferOutputShape( + rewriter, loc, originalResultType.cast(), + reassociation, collapsedOpShape, outputShape))) { + return rewriter.notifyMatchFailure( + genericOp, + "unable to infer output shape argument for tensor.expand_shape"); + } + Value result = rewriter.create( - loc, originalResultType, collapsedOpResult, reassociation); + loc, originalResultType, collapsedOpResult, reassociation, + outputShape); results.push_back(result); } else { results.push_back(collapsedOpResult); diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp @@ -95,8 +95,20 @@ newConv->setAttr(attr.getName(), attr.getValue()); // Expand dimensions back out to + + Value newConvVal = newConv->getResult(0); + SmallVector newConvShape = + tensor::getMixedSizes(rewriter, loc, newConvVal); + SmallVector outputShape; + if (failed(tensor::ExpandShapeOp::inferOutputShape( + rewriter, loc, resultTy, collapsedInitDims, newConvShape, + outputShape))) { + return rewriter.notifyMatchFailure( + operation, + "unable to infer output shape argument for tensor.expand_shape"); + } rewriter.replaceOpWithNewOp( - operation, resultTy, newConv->getResult(0), collapsedInitDims); + operation, resultTy, result, collapsedInitDims, outputShape); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -114,8 +114,19 @@ Type newType = RankedTensorType::get( newShape, operand->get().getType().cast().getElementType()); + + SmallVector inputShape = + tensor::getMixedSizes(b, loc, operand->get()); + SmallVector outputShape; + if (failed(tensor::ExpandShapeOp::inferOutputShape( + b, loc, newType.cast(), reassociation, inputShape, + outputShape))) { + return b.notifyMatchFailure( + op, "unable to infer output shape argument for tensor.expand_shape"); + } + Value newInput = b.create( - loc, newType, operand->get(), reassociation); + loc, newType, operand->get(), reassociation, outputShape); newInputs.push_back(newInput); } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2250,6 +2250,18 @@ srcType.getMemorySpace()); } +LogicalResult +ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc, + MemRefType expandedType, + ArrayRef reassociation, + ArrayRef inputShape, + SmallVectorImpl &outputShape) { + auto expandedTensorType = + getTensorTypeFromMemRefType(expandedType).cast(); + return inferExpandShapeOutputShape(b, loc, expandedTensorType, reassociation, + inputShape, outputShape); +} + void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, ArrayRef resultShape, Value src, ArrayRef reassociation) { @@ -2260,7 +2272,9 @@ // Failure of this assertion usually indicates a problem with the source // type, e.g., could not get strides/offset. assert(succeeded(resultType) && "could not compute layout"); - build(builder, result, *resultType, src, reassociation); + SmallVector outputShape( + getMixedValues(resultShape, ValueRange{}, builder)); + build(builder, result, *resultType, src, reassociation, outputShape); } LogicalResult ExpandShapeOp::verify() { @@ -2289,14 +2303,28 @@ return emitOpError("expected expanded type to be ") << *expectedResultType << " but found " << resultType; + if ((int64_t)getStaticOutputShape().size() != resultType.getRank()) + return emitOpError("expected number of static shape bounds to be equal to " + "the output rank (") + << resultType.getRank() << ") but found " + << getStaticOutputShape().size() << " inputs instead"; + + if ((int64_t)getOutputShape().size() != + llvm::count(getStaticOutputShape(), ShapedType::kDynamic)) + return emitOpError("mismatch in dynamic dims in output_shape and " + "static_output_shape: static_output_shape has ") + << llvm::count(getStaticOutputShape(), ShapedType::kDynamic) + << " dynamic dims while output_shape has " << getOutputShape().size() + << " values"; + return success(); } void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - ComposeExpandOfCollapseOp>( - context); + results.add< + ComposeReassociativeReshapeOps, + ComposeExpandOfCollapseOp>(context); } /// Compute the layout map after collapsing a given source MemRef type with the @@ -2494,9 +2522,11 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - ComposeCollapseOfExpandOp, - CollapseShapeOpMemRefCastFolder>(context); + results.add< + ComposeReassociativeReshapeOps, + ComposeCollapseOfExpandOp, + CollapseShapeOpMemRefCastFolder>(context); } OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -1276,11 +1276,20 @@ flatBuf = reallocOrSubView(rewriter, loc, len, flatBuf); } + SmallVector flatBufShape = + memref::getMixedSizes(rewriter, loc, flatBuf); + MemRefType expandShapeResultType = + MemRefType::get(coordinatesTp.getShape(), coordinatesTp.getElementType()); + auto reassociation = ArrayRef{ReassociationIndices{0, 1}}; + SmallVector outputShape; + if (failed(memref::ExpandShapeOp::inferOutputShape( + rewriter, loc, expandShapeResultType, reassociation, flatBufShape, + outputShape))) { + return rewriter.notifyMatchFailure( + op, "unable to infer output shape argument for tensor.expand_shape"); + } Value coordinatesBuf = rewriter.create( - loc, - MemRefType::get(coordinatesTp.getShape(), - coordinatesTp.getElementType()), - flatBuf, ArrayRef{ReassociationIndices{0, 1}}); + loc, expandShapeResultType, flatBuf, reassociation, outputShape); // Converts MemRefs back to Tensors. Value values = rewriter.create(loc, valuesBuf); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -453,8 +453,15 @@ auto rtp = getRankedTensorType(op.getResult()); auto denseTp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); - auto reshape = rewriter.create(loc, denseTp, op.getSrc(), - op.getReassociation()); + ReshapeOp reshape; + if constexpr (std::is_same::value) { + reshape = rewriter.create( + loc, denseTp, op.getSrc(), op.getReassociation(), + op.getOutputShape(), op.getStaticOutputShape()); + } else { + reshape = rewriter.create(loc, denseTp, op.getSrc(), + op.getReassociation()); + } Value convert = rewriter.create(loc, rtp, reshape); rewriter.replaceOp(op, convert); return success(); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1242,6 +1242,16 @@ llvm_unreachable("could not find reassociation group"); } +LogicalResult +ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc, + RankedTensorType expandedType, + ArrayRef reassociation, + ArrayRef inputShape, + SmallVectorImpl &outputShape) { + return inferExpandShapeOutputShape(b, loc, expandedType, reassociation, + inputShape, outputShape); +} + SmallVector CollapseShapeOp::getReassociationMaps() { return getSymbolLessAffineMaps(getReassociationExprs()); } @@ -1343,6 +1353,20 @@ return emitOpError("expected rank expansion, but found source rank ") << srcType.getRank() << " >= result rank " << resultType.getRank(); + if ((int64_t)getStaticOutputShape().size() != resultType.getRank()) + return emitOpError("expected number of static shape dims to be equal to " + "the output rank (") + << resultType.getRank() << ") but found " + << getStaticOutputShape().size() << " inputs instead"; + + if ((int64_t)getOutputShape().size() != + llvm::count(getStaticOutputShape(), ShapedType::kDynamic)) + return emitOpError("mismatch in dynamic dims in output_shape and " + "static_output_shape: static_output_shape has ") + << llvm::count(getStaticOutputShape(), ShapedType::kDynamic) + << " dynamic dims while output_shape has " << getOutputShape().size() + << " values"; + return verifyTensorReshapeOp(*this, getResultType(), getSrcType()); } @@ -1532,23 +1556,25 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - ComposeExpandOfCollapseOp, - FoldReshapeWithConstant, - FoldReshapeWithSplat, - FoldReshapeWithFromElements, FoldDimOfExpandShape, - FoldDimOfCollapseShape>(context); + results.add< + ComposeReassociativeReshapeOps, + ComposeExpandOfCollapseOp, + FoldReshapeWithConstant, + FoldReshapeWithSplat, + FoldReshapeWithFromElements, FoldDimOfExpandShape, + FoldDimOfCollapseShape>(context); } void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results - .add, - ComposeCollapseOfExpandOp, - FoldReshapeWithConstant, - FoldReshapeWithSplat, - FoldReshapeWithFromElements, FoldCollapseOfCastOp>( - context); + results.add< + ComposeReassociativeReshapeOps, + ComposeCollapseOfExpandOp, + FoldReshapeWithConstant, + FoldReshapeWithSplat, + FoldReshapeWithFromElements, FoldCollapseOfCastOp>( + context); } OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) { @@ -3114,12 +3140,25 @@ struct SimplifyPackToExandShape : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - Value insertExpand(RewriterBase &rewriter, Location loc, Value operand, - Type newOperandType, ArrayAttr reassociation) const { + FailureOr + insertExpand(RewriterBase &rewriter, Location loc, Value operand, + Type newOperandType, + ArrayRef reassociation) const { if (operand.getType() == newOperandType) return operand; - return rewriter.create(loc, newOperandType, operand, - reassociation); + + SmallVector inputShape = + tensor::getMixedSizes(rewriter, loc, operand); + SmallVector outputShape; + if (failed(tensor::ExpandShapeOp::inferOutputShape( + rewriter, loc, newOperandType.cast(), + reassociation, inputShape, outputShape))) { + return failure(); + } + return rewriter + .create(loc, newOperandType, operand, + reassociation, outputShape) + .getResult(); } LogicalResult matchAndRewrite(PackOp packOp, @@ -3132,10 +3171,14 @@ getReassociationIndicesForReshape(sourceType, destType); if (!reassociation) return failure(); - Value expanded = insertExpand( - rewriter, packOp.getLoc(), packOp.getSource(), destType, - getReassociationIndicesAttribute(rewriter, *reassociation)); - rewriter.replaceOp(packOp, expanded); + FailureOr expanded = + insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType, + *reassociation); + if (failed(expanded)) { + return rewriter.notifyMatchFailure( + packOp, "unable to expand source of tensor.pack"); + } + rewriter.replaceOp(packOp, *expanded); return success(); } }; diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" @@ -16,6 +17,76 @@ using namespace mlir; +LogicalResult +mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc, + RankedTensorType expandedType, + ArrayRef reassociation, + ArrayRef inputShape, + SmallVectorImpl &outputShape) { + outputShape.clear(); + SmallVector outputShapeValues; + SmallVector outputShapeInts; + // For zero-rank inputs, all dims in result shape are unit extent. + if (inputShape.empty()) { + outputShapeInts.resize(expandedType.getRank(), 1); + outputShape.assign(getMixedValues(outputShapeInts, outputShapeValues, b)); + return success(); + } + + outputShapeValues.resize(expandedType.getRank()); + outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic); + + for (const auto &it : llvm::enumerate(reassociation)) { + ReassociationIndices indexGroup = it.value(); + + int64_t indexGroupStaticSizesProductInt = 1; + bool foundDynamic = false; + for (int64_t index : indexGroup) { + int64_t outputDimSize = expandedType.getDimSize(index); + // Cannot infer expanded shape with multiple dynamic dims in the + // same reassociation group! + if (ShapedType::isDynamic(outputDimSize)) { + if (foundDynamic) + return failure(); + foundDynamic = true; + } else { + indexGroupStaticSizesProductInt *= outputDimSize; + } + } + Value indexGroupStaticSizesProduct = + b.create(loc, indexGroupStaticSizesProductInt); + + int64_t inputIndex = it.index(); + for (int64_t index : indexGroup) { + if (ShapedType::isDynamic(expandedType.getDimSize(index))) { + // Call get() under the assumption that we're not casting + // dynamism. + Value indexGroupSize = inputShape[inputIndex].get(); + Value dynamicDimSize = b.createOrFold( + loc, indexGroupSize, indexGroupStaticSizesProduct); + outputShapeValues[index] = dynamicDimSize; + } + } + + for (int64_t index : indexGroup) { + int64_t outputDimSize = expandedType.getDimSize(index); + if (ShapedType::isDynamic(outputDimSize)) + continue; + outputShapeInts[index] = outputDimSize; + } + } + + assert(static_cast( + llvm::count(outputShapeInts, ShapedType::kDynamic)) == + (outputShapeValues.size() - + llvm::count(outputShapeValues, Value{})) && + "Missing output shape entries!"); + llvm::erase_value(outputShapeValues, Value{}); + + outputShape.assign(getMixedValues(outputShapeInts, outputShapeValues, b)); + return success(); +} + std::optional> mlir::getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType) { @@ -168,7 +239,7 @@ } SmallVector mlir::convertReassociationMapsToIndices( - OpBuilder &b, ArrayRef reassociationExprs) { + ArrayRef reassociationExprs) { SmallVector reassociationIndices; for (const auto &exprs : reassociationExprs) { ReassociationIndices indices; @@ -230,24 +301,17 @@ ArrayRef reassociationMaps, bool isExpandingReshape) { unsigned expandedDimStart = 0; for (const auto &map : llvm::enumerate(reassociationMaps)) { - std::optional dynamicShape; + bool foundDynamicShape = false; int64_t linearizedStaticShape = 1; + for (const auto &dim : llvm::enumerate( expandedShape.slice(expandedDimStart, map.value().size()))) { - if (ShapedType::isDynamic(dim.value())) { - if (isExpandingReshape && dynamicShape) { - return emitError("invalid to have a single dimension (" + - Twine(map.index()) + - ") expanded into multiple dynamic dims (" + - Twine(expandedDimStart + dynamicShape.value()) + - "," + Twine(expandedDimStart + dim.index()) + ")"); - } - dynamicShape = dim.index(); - } else { + if (ShapedType::isDynamic(dim.value())) + foundDynamicShape = true; + else linearizedStaticShape *= dim.value(); - } } - if (dynamicShape) { + if (foundDynamicShape) { if (!ShapedType::isDynamic(collapsedShape[map.index()])) { return emitError( "expected dimension " + Twine(map.index()) + diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -166,9 +166,8 @@ /// Decompose a vector of mixed static or dynamic values into the corresponding /// pair of arrays. This is the inverse function of `getMixedValues`. -std::pair> -decomposeMixedValues(Builder &b, - const SmallVectorImpl &mixedValues) { +std::pair, SmallVector> +decomposeMixedValues(const SmallVectorImpl &mixedValues) { SmallVector staticValues; SmallVector dynamicValues; for (const auto &it : mixedValues) { @@ -179,7 +178,7 @@ dynamicValues.push_back(it.get()); } } - return {b.getI64ArrayAttr(staticValues), dynamicValues}; + return {staticValues, dynamicValues}; } } // namespace mlir diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -177,12 +177,26 @@ func.func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor) -> (tensor, tensor<1x1xf32>) { %0 = tensor.collapse_shape %arg0 [] : tensor<1x1xf32> into tensor - %1 = tensor.expand_shape %0 [] : tensor into tensor<1x1xf32> + %1 = tensor.expand_shape %0 [] output_shape [1, 1] : tensor into tensor<1x1xf32> return %0, %1 : tensor, tensor<1x1xf32> } // CHECK-LABEL: func @tensor_reshape_zero_dim // CHECK: tensor.collapse_shape %{{.*}} [] : tensor<1x1xf32> into tensor -// CHECK: tensor.expand_shape %{{.*}} [] : tensor into tensor<1x1xf32> +// CHECK: tensor.expand_shape %{{.*}} [] output_shape [1, 1] : tensor into tensor<1x1xf32> + +// ----- + +func.func @tensor_expand_shape_dynamic_dim(%arg0 : tensor, %sz0 : index, %sz1 : index, %sz2 : index) + -> (tensor<5x?x?x?xf32>) { + %1 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [5, %sz0, %sz1, %sz2] : tensor into tensor<5x?x?x?xf32> + return %1 : tensor<5x?x?x?xf32> +} + +// CHECK-LABEL: func.func @tensor_expand_shape_dynamic_dim(%arg0: tensor, %arg1: index, %arg2: index, %arg3: index) -> tensor<5x?x?x?xf32> { +// CHECK: %expanded = tensor.expand_shape %arg0 {{\[\[}}0, 1], [2, 3{{\]\]}} output_shape [5, %arg1, %arg2, %arg3] : tensor into tensor<5x?x?x?xf32> +// CHECK: return %expanded : tensor<5x?x?x?xf32> +// CHECK: } + // -----