diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1431,7 +1431,6 @@ class MemRef_ReassociativeReshapeOp traits = []> : MemRef_Op, - Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>, Results<(outs AnyStridedMemRef:$result)>{ code commonExtraClassDeclaration = [{ @@ -1456,10 +1455,6 @@ Value getViewSource() { return getSrc(); } }]; - let assemblyFormat = [{ - $src $reassociation attr-dict `:` type($src) `into` type($result) - }]; - let hasFolder = 1; let hasCanonicalizer = 1; let hasVerifier = 1; @@ -1481,8 +1476,8 @@ Example: ```mlir - %r = memref.expand_shape %0 [[0, 1], [2]] - : memref into memref + %r = memref.expand_shape %0 [[0, 1], [2]] [%sz0, %sz1, 32] + : memref into memref ``` At most one dimension of a reassociation group (e.g., [0, 1] above) may be @@ -1505,29 +1500,47 @@ there must be a dynamic result dimension in the corresponding reassociation group. Same for strides. + The representation for the output shape supports a partially-static + specification via attributes specified through the `static_output_shape` + argument. A special sentinel value `ShapedType::kDynamic` encodes that the + corresponding entry has a dynamic value. There must be exactly as many SSA + inputs in `output_shape` as there are `ShapedType::kDynamic` entries in + `static_output_shape`. + Note: This op currently assumes that the inner strides are of the source/result layout map are the faster-varying ones. }]; + let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation, + Variadic:$output_shape, + DenseI64ArrayAttr:$static_output_shape); + + let assemblyFormat = [{ + $src $reassociation + custom($output_shape, $static_output_shape) attr-dict `:` + type($src) `into` type($result) + }]; + let builders = [ // Builders using ReassociationIndices. OpBuilder<(ins "Type":$resultType, "Value":$src, "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), + "ValueRange":$output_shape, "ArrayRef":$static_output_shape), [{ - build($_builder, $_state, resultType, src, attrs); - $_state.addAttribute("reassociation", - getReassociationIndicesAttribute($_builder, reassociation)); + build($_builder, $_state, resultType, src, + getReassociationIndicesAttribute($_builder, reassociation), + output_shape, static_output_shape); }]>, // Builder using ReassociationExprs. OpBuilder<(ins "Type":$resultType, "Value":$src, "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), + "ValueRange":$output_shape, "ArrayRef":$static_output_shape), [{ auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); - build($_builder, $_state, resultType, src, reassociationMaps, attrs); + convertReassociationMapsToIndices(reassociation); + build($_builder, $_state, resultType, src, reassociationMaps, + output_shape, static_output_shape); }]>, // Builder that infers the result layout map. The result shape must be @@ -1590,6 +1603,12 @@ source/result layout map are the faster-varying ones. }]; + let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation); + + let assemblyFormat = [{ + $src $reassociation attr-dict `:` type($src) `into` type($result) + }]; + let builders = [ // Builders for a contracting reshape whose result type is computed from // `src` and `reassociation`. @@ -1601,7 +1620,7 @@ CArg<"ArrayRef", "{}">:$attrs), [{ auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); + convertReassociationMapsToIndices(reassociation); build($_builder, $_state, src, reassociationMaps, attrs); }]>, @@ -1619,7 +1638,7 @@ CArg<"ArrayRef", "{}">:$attrs), [{ auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); + convertReassociationMapsToIndices(reassociation); build($_builder, $_state, resultType, src, reassociationMaps, attrs); }]> ]; diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -982,7 +982,6 @@ Tensor_Op, Pure])>, - Arguments<(ins AnyTensor:$src, IndexListArrayAttr:$reassociation)>, Results<(outs AnyTensor:$result)> { code commonExtraClassDeclaration = [{ @@ -1006,10 +1005,6 @@ } }]; - let assemblyFormat = [{ - $src $reassociation attr-dict `:` type($src) `into` type($result) - }]; - let hasFolder = 1; let hasCanonicalizer = 1; let hasVerifier = 1; @@ -1022,11 +1017,16 @@ rank whose sizes are a reassociation of the original `src`. A reassociation is defined as a continuous grouping of dimensions and is - represented with an array of DenseI64ArrayAttr attribute. + represented with an array of DenseI64ArrayAttr attribute. The reassociation + maps applied to the result tensor with the higher rank must obtain the + operand tensor with the smaller rank. - The verification rule is that the reassociation maps are applied to the - result tensor with the higher rank to obtain the operand tensor with the - smaller rank. + The representation for the output shape supports a partially-static + specification via attributes specified through the `static_output_shape` + argument. A special sentinel value `ShapedType::kDynamic` encodes that the + corresponding entry has a dynamic value. There must be exactly as many SSA + inputs in `output_shape` as there are `ShapedType::kDynamic` entries in + `static_output_shape`. The operand tensor type of a reshape can be zero-ranked if the result tensor type is statically shaped with all dimensions being unit extent. In @@ -1036,27 +1036,39 @@ ```mlir // Dimension expansion i -> (i', j') and (k) -> (k') - %b = tensor.expand_shape %a [[0, 1], [2]] - : tensor into tensor + %b = tensor.expand_shape %a [[0, 1], [2]] [%sz0, %sz1, 32] + : tensor into tensor ``` }]; + + let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation, + Variadic:$output_shape, + DenseI64ArrayAttr:$static_output_shape); + + let assemblyFormat = [{ + $src $reassociation + custom($output_shape, $static_output_shape) attr-dict `:` + type($src) `into` type($result) + }]; + let builders = [ // Builders using ReassociationIndices. OpBuilder<(ins "Type":$resultType, "Value":$src, "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), + "ValueRange":$output_shape, "ArrayRef":$static_output_shape), [{ - build($_builder, $_state, resultType, src, attrs); - $_state.addAttribute("reassociation", - getReassociationIndicesAttribute($_builder, reassociation)); + build($_builder, $_state, resultType, src, + getReassociationIndicesAttribute($_builder, reassociation), + output_shape, static_output_shape); }]>, OpBuilder<(ins "Type":$resultType, "Value":$src, "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), + "ValueRange":$output_shape, "ArrayRef":$static_output_shape), [{ auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); - build($_builder, $_state, resultType, src, reassociationMaps, attrs); + convertReassociationMapsToIndices(reassociation); + build($_builder, $_state, resultType, src, reassociationMaps, + output_shape, static_output_shape); }]> ]; @@ -1069,6 +1081,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> { let summary = "operation to produce a tensor with a smaller rank"; + let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation); let description = [{ The `tensor.collapse_shape` op produces a new tensor with a smaller rank whose sizes are a reassociation of the original `src`. @@ -1092,6 +1105,11 @@ : tensor into tensor ``` }]; + + let assemblyFormat = [{ + $src $reassociation attr-dict `:` type($src) `into` type($result) + }]; + let builders = [ // Builders for a contracting reshape whose result type is computed from // `src` and `reassociation`. @@ -1103,7 +1121,7 @@ CArg<"ArrayRef", "{}">:$attrs), [{ auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); + convertReassociationMapsToIndices(reassociation); build($_builder, $_state, src, reassociationMaps, attrs); }]>, @@ -1121,7 +1139,7 @@ CArg<"ArrayRef", "{}">:$attrs), [{ auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); + convertReassociationMapsToIndices(reassociation); build($_builder, $_state, resultType, src, reassociationMaps, attrs); }]> ]; diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -29,6 +29,29 @@ /// Attribute name for the ArrayAttr which encodes reassociation indices. constexpr StringRef getReassociationAttrName() { return "reassociation"; } +// Infer the output shape for a {memref|tensor}.expand_shape when it is possible +// to do so. +// +// `inputShape` is the shape of the tensor or memref being expanded as a +// sequence of SSA values or constants. `expandedType` is the output shape of +// the expand_shape operation. `reassociation` is the reassociation denoting +// the output dims each input dim is mapped to. +// +// Returns the output shape in `outputShape` and `staticOutputShape`, following +// the conventions for the output_shape and static_output_shape inputs to the +// expand_shape ops. +void inferExpandShapeOutputShape(RankedTensorType expandedType, + ArrayRef reassociation, + ArrayRef inputShape, + SmallVectorImpl &outputShape, + SmallVector &staticOutputShape); + +void inferExpandShapeOutputShape(RankedTensorType expandedType, + ArrayRef reassociation, + ArrayRef inputShape, + SmallVectorImpl &outputShape, + SmallVector &staticOutputShape); + /// Compose reassociation maps that are used in pair of reshape ops where one /// is a producer and other is the consumer. Only valid to use this method when /// both the producer and consumer are collapsing dimensions or both are @@ -61,7 +84,7 @@ /// Convert Array> to Array>. SmallVector convertReassociationMapsToIndices( - OpBuilder &b, ArrayRef reassociationExprs); + ArrayRef reassociationExprs); /// Return the reassociations maps to use to reshape given the source type and /// the target type when possible. Return std::nullopt when this computation @@ -166,9 +189,11 @@ /// Returns true iff the type is a MemRefType and has a non-identity layout. bool hasNonIdentityLayout(Type type); +enum class ReshapeOpKind { kExpand, kCollapse }; + /// Pattern to collapse producer/consumer reshape ops that are both collapsing /// dimensions or are both expanding dimensions. -template +template struct ComposeReassociativeReshapeOps : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, @@ -191,8 +216,14 @@ rewriter.getContext()); if (!reassociationIndices) return failure(); - rewriter.replaceOpWithNewOp( - reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices); + + if constexpr (opKind == ReshapeOpKind::kExpand) + rewriter.replaceOpWithNewOp( + reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices, + reshapeOp.getOutputShape(), reshapeOp.getStaticOutputShape()); + else + rewriter.replaceOpWithNewOp( + reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices); return success(); } }; @@ -277,8 +308,12 @@ rewriter.replaceOpWithNewOp( collapseOp, resultType, expandOp.getSrc(), composedReassociation); } else if (srcRank < resultRank) { - rewriter.replaceOpWithNewOp( - collapseOp, resultType, expandOp.getSrc(), composedReassociation); + abort(); + // inferExpandShapeOutputShape won't work here since it will introduce + // additional uses for collapseOp. + // + // rewriter.replaceOpWithNewOp( + // collapseOp, resultType, expandOp.getSrc(), composedReassociation); } else { // Collapses/expansions that do not change the rank are not allowed. Use // a cast instead. @@ -333,7 +368,8 @@ return failure(); rewriter.replaceOpWithNewOp( - expandOp, resultType, collapseOp.getSrc(), *composedReassociation); + expandOp, resultType, collapseOp.getSrc(), *composedReassociation, + expandOp.getOutputShape(), expandOp.getStaticOutputShape()); return success(); } diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -836,8 +836,17 @@ rewriter.getAffineDimExpr(expandedDim + 1)); } + Value result = linalgOp.getResults()[0]; + SmallVector inputShape = + tensor::createDimValues(rewriter, loc, result); + SmallVector outputShape; + SmallVector staticOutputShape; + inferExpandShapeOutputShape(resultTy.cast(), + reassociationMap, inputShape, outputShape, + staticOutputShape); + rewriter.replaceOpWithNewOp( - op, resultTy, linalgOp.getResults()[0], reassociationMap); + op, resultTy, result, reassociationMap, outputShape, staticOutputShape); return success(); } @@ -1036,8 +1045,19 @@ return rewriter.notifyMatchFailure( reshape, "tosa.reshape Cannot expand into given shape"); } + + Value src = adaptor.getOperands()[0]; + SmallVector srcShape = + tensor::createDimValues(rewriter, src.getLoc(), src); + SmallVector outputShape; + SmallVector staticOutputShape; + inferExpandShapeOutputShape(resultTy.cast(), + reassociationMap, srcShape, outputShape, + staticOutputShape); + rewriter.replaceOpWithNewOp( - reshape, resultTy, adaptor.getOperands()[0], reassociationMap); + reshape, resultTy, src, reassociationMap, outputShape, + staticOutputShape); return success(); } }; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -453,9 +453,17 @@ return failure(); Location loc = oldFill.getLoc(); - auto newInit = rewriter.create( - loc, reshapeOp.getResultType(), oldFill.output(), - reshapeOp.getReassociation()); + TensorReshapeOp newInit; + if constexpr (std::is_same::value) + + newInit = rewriter.create( + loc, reshapeOp.getResultType(), oldFill.output(), + reshapeOp.getReassociation(), reshapeOp.getOutputShape(), + reshapeOp.getStaticOutputShape()); + else + newInit = rewriter.create(loc, reshapeOp.getResultType(), + oldFill.output(), + reshapeOp.getReassociation()); rewriter.replaceOpWithNewOp(reshapeOp, ValueRange{oldFill.value()}, ValueRange{newInit}); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" @@ -451,8 +452,15 @@ assert(rankReductionStrategy == RankReductionStrategy::ReassociativeReshape && "unknown rank reduction strategy"); + SmallVector inputShape = + tensor::createDimValues(rewriter, loc, result); + SmallVector outputShape; + SmallVector staticOutputShape; + inferExpandShapeOutputShape(origResultType, reassociation, inputShape, + outputShape, staticOutputShape); return rewriter.create(loc, origResultType, result, - reassociation); + reassociation, outputShape, + staticOutputShape); } // Collapse the given value. @@ -616,8 +624,17 @@ Location loc = sliceOp.getLoc(); Value newSlice = rewriter.create( loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides); + + SmallVector newSliceShape = + tensor::createDimValues(rewriter, loc, newSlice); + SmallVector outputShape; + SmallVector staticOutputShape; + inferExpandShapeOutputShape(resultType, *reassociation, newSliceShape, + outputShape, staticOutputShape); + rewriter.replaceOpWithNewOp( - sliceOp, resultType, newSlice, *reassociation); + sliceOp, resultType, newSlice, *reassociation, outputShape, + staticOutputShape); return success(); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Matchers.h" @@ -774,9 +775,18 @@ reassociation, /*isExpandingReshape=*/true))) return std::nullopt; + Location loc = genericOp.getLoc(); + + SmallVector inputShape = + tensor::createDimValues(rewriter, loc, opOperand->get()); + SmallVector outputShape; + SmallVector staticOutputShape; + + inferExpandShapeOutputShape(expandedOperandType, reassociation, + inputShape, outputShape, staticOutputShape); expandedOpOperands.push_back(rewriter.create( genericOp.getLoc(), expandedOperandType, opOperand->get(), - reassociation)); + reassociation, outputShape, staticOutputShape)); continue; } } @@ -801,9 +811,20 @@ reassociation, /*isExpandingReshape=*/true))) return std::nullopt; + + Location loc = genericOp.getLoc(); + + Value operand = opOperand->get(); + SmallVector operandShape = + tensor::createDimValues(rewriter, loc, operand); + SmallVector outputShape; + SmallVector staticOutputShape; + inferExpandShapeOutputShape(expandedOutputType, reassociation, + operandShape, outputShape, staticOutputShape); + outputs.push_back(rewriter.create( - genericOp.getLoc(), expandedOutputType, opOperand->get(), - reassociation)); + loc, expandedOutputType, operand, reassociation, outputShape, + staticOutputShape)); } else { outputs.push_back(opOperand->get()); } @@ -1499,8 +1520,19 @@ genericOp.getIndexingMapMatchingResult(originalResult.value()); SmallVector reassociation = getOperandReassociation(indexingMap, collapsingInfo); + + SmallVector collapsedOpShape = + tensor::createDimValues(rewriter, loc, collapsedOpResult); + SmallVector outputShape; + SmallVector staticOutputShape; + + inferExpandShapeOutputShape(originalResultType.cast(), + reassociation, collapsedOpShape, outputShape, + staticOutputShape); + Value result = rewriter.create( - loc, originalResultType, collapsedOpResult, reassociation); + loc, originalResultType, collapsedOpResult, reassociation, + outputShape, staticOutputShape); results.push_back(result); } else { results.push_back(collapsedOpResult); diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SmallVector.h" @@ -95,8 +96,17 @@ newConv->setAttr(attr.getName(), attr.getValue()); // Expand dimensions back out to + + Value newConvVal = newConv->getResult(0); + SmallVector newConvShape = + tensor::createDimValues(rewriter, loc, newConvVal); + SmallVector outputShape; + SmallVector staticOutputShape; + inferExpandShapeOutputShape(resultTy, collapsedInitDims, newConvShape, + outputShape, staticOutputShape); rewriter.replaceOpWithNewOp( - operation, resultTy, newConv->getResult(0), collapsedInitDims); + operation, resultTy, result, collapsedInitDims, outputShape, + staticOutputShape); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -109,8 +109,17 @@ Type newType = RankedTensorType::get( newShape, operand->get().getType().cast().getElementType()); + + SmallVector inputShape = + tensor::createDimValues(b, loc, operand->get()); + SmallVector outputShape; + SmallVector staticOutputShape; + inferExpandShapeOutputShape(newType.cast(), reassociation, + inputShape, outputShape, staticOutputShape); + Value newInput = b.create( - loc, newType, operand->get(), reassociation); + loc, newType, operand->get(), reassociation, outputShape, + staticOutputShape); newInputs.push_back(newInput); } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2226,7 +2226,8 @@ // Failure of this assertion usually indicates a problem with the source // type, e.g., could not get strides/offset. assert(succeeded(resultType) && "could not compute layout"); - build(builder, result, *resultType, src, reassociation); + build(builder, result, *resultType, src, reassociation, ValueRange{}, + resultShape); } LogicalResult ExpandShapeOp::verify() { @@ -2255,14 +2256,28 @@ return emitOpError("expected expanded type to be ") << *expectedResultType << " but found " << resultType; + if ((int64_t)getStaticOutputShape().size() != resultType.getRank()) + return emitOpError("expected number of static shape bounds to be equal to " + "the output rank (") + << resultType.getRank() << ") but found " + << getStaticOutputShape().size() << " inputs instead"; + + if ((int64_t)getOutputShape().size() != + llvm::count(getStaticOutputShape(), ShapedType::kDynamic)) + return emitOpError("mismatch in dynamic dims in output_shape and " + "static_output_shape: static_output_shape has ") + << llvm::count(getStaticOutputShape(), ShapedType::kDynamic) + << " dynamic dims while output_shape has " << getOutputShape().size() + << " values"; + return success(); } void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - ComposeExpandOfCollapseOp>( - context); + results.add< + ComposeReassociativeReshapeOps, + ComposeExpandOfCollapseOp>(context); } /// Compute the layout map after collapsing a given source MemRef type with the @@ -2460,9 +2475,10 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - ComposeCollapseOfExpandOp, - CollapseShapeOpMemRefCastFolder>(context); + results.add< + ComposeReassociativeReshapeOps, + ComposeCollapseOfExpandOp, + CollapseShapeOpMemRefCastFolder>(context); } OpFoldResult ExpandShapeOp::fold(ArrayRef operands) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -490,8 +490,15 @@ op.getResult().getType().template cast(); auto denseTp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); - auto reshape = rewriter.create(loc, denseTp, op.getSrc(), - op.getReassociation()); + ReshapeOp reshape; + if constexpr (std::is_same::value) { + reshape = rewriter.create( + loc, denseTp, op.getSrc(), op.getReassociation(), + op.getOutputShape(), op.getStaticOutputShape()); + } else { + reshape = rewriter.create(loc, denseTp, op.getSrc(), + op.getReassociation()); + } Value convert = rewriter.create(loc, rtp, reshape); rewriter.replaceOp(op, convert); return success(); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1349,6 +1349,20 @@ return emitOpError("expected rank expansion, but found source rank ") << srcType.getRank() << " >= result rank " << resultType.getRank(); + if ((int64_t)getStaticOutputShape().size() != resultType.getRank()) + return emitOpError("expected number of static shape dims to be equal to " + "the output rank (") + << resultType.getRank() << ") but found " + << getStaticOutputShape().size() << " inputs instead"; + + if ((int64_t)getOutputShape().size() != + llvm::count(getStaticOutputShape(), ShapedType::kDynamic)) + return emitOpError("mismatch in dynamic dims in output_shape and " + "static_output_shape: static_output_shape has ") + << llvm::count(getStaticOutputShape(), ShapedType::kDynamic) + << " dynamic dims while output_shape has " << getOutputShape().size() + << " values"; + return verifyTensorReshapeOp(*this, getResultType(), getSrcType()); } @@ -1520,21 +1534,22 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - ComposeExpandOfCollapseOp, - FoldReshapeWithConstant, - FoldReshapeWithFromElements, FoldDimOfExpandShape, - FoldDimOfCollapseShape>(context); + results.add< + ComposeReassociativeReshapeOps, + ComposeExpandOfCollapseOp, + FoldReshapeWithConstant, + FoldReshapeWithFromElements, FoldDimOfExpandShape, + FoldDimOfCollapseShape>(context); } void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results - .add, - ComposeCollapseOfExpandOp, - FoldReshapeWithConstant, - FoldReshapeWithFromElements, FoldCollapseOfCastOp>( - context); + results.add< + ComposeReassociativeReshapeOps, + ComposeCollapseOfExpandOp, + FoldReshapeWithConstant, + FoldReshapeWithFromElements, FoldCollapseOfCastOp>( + context); } OpFoldResult ExpandShapeOp::fold(ArrayRef operands) { diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -15,6 +15,57 @@ using namespace mlir; +void mlir::inferExpandShapeOutputShape( + RankedTensorType expandedType, ArrayRef reassociation, + ArrayRef inputShape, SmallVectorImpl &outputShape, + SmallVector &staticOutputShape) { + outputShape.resize(expandedType.getRank()); + staticOutputShape.resize(expandedType.getRank(), ShapedType::kDynamic); + + for (const auto &it : llvm::enumerate(reassociation)) { + ReassociationIndices indexGroup = it.value(); + int64_t inputIndex = it.index(); + bool foundDynamic = false; + for (int64_t index : indexGroup) { + if (ShapedType::isDynamic(expandedType.getDimSize(index))) { + assert(!foundDynamic && + "Cannot infer expanded shape with multiple dynamic dims in the " + "same reassociation group!"); + foundDynamic = true; + // Call get() under the assumption that we're not casting + // dynamism. + outputShape[index] = inputShape[inputIndex].get(); + } + } + + for (int64_t index : indexGroup) { + int64_t outputDimSize = expandedType.getDimSize(index); + if (ShapedType::isDynamic(outputDimSize)) + continue; + // This restriction can be lifted -- we just need to change the output + // shape computation above to insert an integer division operation. + assert((!foundDynamic || outputDimSize == 1) && + "Non-unit dimensions not supported in reassociation groups with " + "dynamic dims!"); + staticOutputShape[index] = outputDimSize; + } + } + + assert(llvm::count(staticOutputShape, ShapedType::kDynamic) == + (outputShape.size() - llvm::count(outputShape, Value{})) && + "Missing output shape entries!"); + llvm::erase_value(outputShape, Value{}); +} + +void mlir::inferExpandShapeOutputShape( + RankedTensorType expandedType, ArrayRef reassociation, + ArrayRef inputShape, SmallVectorImpl &outputShape, + SmallVector &staticOutputShape) { + inferExpandShapeOutputShape(expandedType, + convertReassociationMapsToIndices(reassociation), + inputShape, outputShape, staticOutputShape); +} + Optional> mlir::getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType) { @@ -166,7 +217,7 @@ } SmallVector mlir::convertReassociationMapsToIndices( - OpBuilder &b, ArrayRef reassociationExprs) { + ArrayRef reassociationExprs) { SmallVector reassociationIndices; for (const auto &exprs : reassociationExprs) { ReassociationIndices indices; @@ -228,24 +279,16 @@ ArrayRef reassociationMaps, bool isExpandingReshape) { unsigned expandedDimStart = 0; for (const auto &map : llvm::enumerate(reassociationMaps)) { - std::optional dynamicShape; + bool foundDynamicShape = false; int64_t linearizedStaticShape = 1; for (const auto &dim : llvm::enumerate( expandedShape.slice(expandedDimStart, map.value().size()))) { - if (ShapedType::isDynamic(dim.value())) { - if (isExpandingReshape && dynamicShape) { - return emitError("invalid to have a single dimension (" + - Twine(map.index()) + - ") expanded into multiple dynamic dims (" + - Twine(expandedDimStart + dynamicShape.value()) + - "," + Twine(expandedDimStart + dim.index()) + ")"); - } - dynamicShape = dim.index(); - } else { + if (ShapedType::isDynamic(dim.value())) + foundDynamicShape = true; + else linearizedStaticShape *= dim.value(); - } } - if (dynamicShape) { + if (foundDynamicShape) { if (!ShapedType::isDynamic(collapsedShape[map.index()])) { return emitError( "expected dimension " + Twine(map.index()) + diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -177,12 +177,26 @@ func.func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor) -> (tensor, tensor<1x1xf32>) { %0 = tensor.collapse_shape %arg0 [] : tensor<1x1xf32> into tensor - %1 = tensor.expand_shape %0 [] : tensor into tensor<1x1xf32> + %1 = tensor.expand_shape %0 [] [1, 1] : tensor into tensor<1x1xf32> return %0, %1 : tensor, tensor<1x1xf32> } // CHECK-LABEL: func @tensor_reshape_zero_dim // CHECK: tensor.collapse_shape %{{.*}} [] : tensor<1x1xf32> into tensor -// CHECK: tensor.expand_shape %{{.*}} [] : tensor into tensor<1x1xf32> +// CHECK: tensor.expand_shape %{{.*}} [] [1, 1] : tensor into tensor<1x1xf32> + +// ----- + +func.func @tensor_expand_shape_dynamic_dim(%arg0 : tensor, %sz0 : index, %sz1 : index, %sz2 : index) + -> (tensor<5x?x?x?xf32>) { + %1 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] [5, %sz0, %sz1, %sz2] : tensor into tensor<5x?x?x?xf32> + return %1 : tensor<5x?x?x?xf32> +} + +// CHECK-LABEL: func.func @tensor_expand_shape_dynamic_dim(%arg0: tensor, %arg1: index, %arg2: index, %arg3: index) -> tensor<5x?x?x?xf32> { +// CHECK: %expanded = tensor.expand_shape %arg0 {{\[\[}}0, 1], [2, 3{{\]\]}} [5, %arg1, %arg2, %arg3] : tensor into tensor<5x?x?x?xf32> +// CHECK: return %expanded : tensor<5x?x?x?xf32> +// CHECK: } + // -----