diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -26,7 +26,7 @@ namespace mlir { class ShapedTypeComponents; -using ReifiedRankedShapedTypeDims = SmallVector>; +using ReifiedRankedShapedTypeDims = SmallVector>; /// Adaptor class to abstract the differences between whether value is from /// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute. diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -211,15 +211,16 @@ let methods = [ InterfaceMethod< /*desc=*/[{ - Reify the shape of the result of an operation (typically in - terms of shape of its operands) - - Insert operations using the given `OpBuilder` that computes - the result shape. The `reifiedReturnShapes` is expected to be - populated with as many vectors as the number of results of the - op. Each of these vectors is expected to be of size equal to - rank of the corresponding result. If the shape of a particular - result cannot be computed it must be empty. + Reify the shape of the result of an operation (typically in terms of the + shape of its operands). + + `reifiedReturnShapes` is populated with one vector per op result. Each + of those vectors contains an OpFoldResult for each dimension of the + shaped type. In case a dimension in the type is static, the + corresponding entry is an IntegerAttr. Otherwise, it is a Value. The + given builder may be used to insert ops that compute result shapes. + + If the shape of a particular result cannot be computed it must be empty. }], /*retTy=*/"::mlir::LogicalResult", /*methodName=*/"reifyResultShapes", diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -147,7 +147,7 @@ resultDims[shapedValue.cast().getResultNumber()]; for (const auto &dim : enumerate(tensorType.getShape())) if (ShapedType::isDynamic(dim.value())) - dynamicSizes.push_back(shape[dim.index()]); + dynamicSizes.push_back(shape[dim.index()].get()); } } } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -369,13 +369,13 @@ LogicalResult AllocTensorOp::reifyResultShapes( OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - auto shapes = llvm::to_vector<4>(llvm::map_range( - llvm::seq(0, getType().getRank()), [&](int64_t dim) -> Value { - if (isDynamicDim(dim)) - return getDynamicSize(builder, dim); - return builder.create(getLoc(), - getStaticSize(dim)); - })); + auto shapes = llvm::to_vector<4>( + llvm::map_range(llvm::seq(0, getType().getRank()), + [&](int64_t dim) -> OpFoldResult { + if (isDynamicDim(dim)) + return getDynamicSize(builder, dim); + return builder.getIndexAttr(getStaticSize(dim)); + })); reifiedReturnShapes.emplace_back(std::move(shapes)); return success(); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -642,13 +642,12 @@ int64_t pos = 0; ArrayRef shapeExprs = resultShapesFromInputShapesMap.getResults(); for (OpOperand *opOperand : getDpsInitOperands()) { - SmallVector shapes; + SmallVector shapes; for (int64_t dim : llvm::seq(0, getRank(opOperand))) { if (checkDimExpr.visit(shapeExprs[pos])) shapes.push_back(createOrFoldDimOp(b, loc, opOperand->get(), dim)); else - shapes.push_back( - getValueOrCreateConstantIndexOp(b, loc, allResultDimValues[pos])); + shapes.push_back(allResultDimValues[pos]); pos++; } reifiedReturnShapes.emplace_back(std::move(shapes)); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -488,10 +488,9 @@ return rewriter.notifyMatchFailure( padOp, "failed to reify tensor.pad op result shape"); - SmallVector newShape = - getAsOpFoldResult(reifiedShape.front()); auto emptyTensor = rewriter.create( - padOp.getLoc(), newShape, padOp.getResultType().getElementType()); + padOp.getLoc(), reifiedShape.front(), + padOp.getResultType().getElementType()); Value replacement = rewriter .create(fillOp.getLoc(), ValueRange{padValue}, diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -14,6 +14,7 @@ //===----------------------------------------------------------------------===// // #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -132,7 +133,8 @@ for (int64_t i = 0; i < tensorType.getRank(); ++i) { if (tensorType.isDynamicDim(i)) dynSizes.push_back( - reifiedShape[value.cast().getResultNumber()][i]); + reifiedShape[value.cast().getResultNumber()][i] + .get()); } return dynSizes; } @@ -298,7 +300,7 @@ SmallVector dynamicSizes; for (int64_t i = 0; i < resultType.getRank(); ++i) if (resultType.isDynamicDim(i)) - dynamicSizes.push_back(reifiedShape[0][i]); + dynamicSizes.push_back(reifiedShape[0][i].get()); // If the `padOp` has a nofold attribute and all paddings are known to be 0, // explicitly insert a `linalg.copy`. diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp @@ -75,7 +75,7 @@ // Create the tensor of same size as output of the pad op. RankedTensorType padResultType = padOp.getResultType(); - auto resultSizes = getAsOpFoldResult(resultShape[0]); + auto resultSizes = resultShape[0]; auto emptyTensor = rewriter.create( loc, resultSizes, padResultType.getElementType()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -204,7 +204,7 @@ newOperands.push_back(*paddedOperand); } - SmallVector> reifiedResultShapes; + ReifiedRankedShapedTypeDims reifiedResultShapes; if (failed(cast(opToPad.getOperation()) .reifyResultShapes(rewriter, reifiedResultShapes))) { LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n"); @@ -231,11 +231,10 @@ int64_t rank = paddedResult.getType().cast().getRank(); SmallVector offsets(rank, rewriter.getIndexAttr(0)); SmallVector sizes; - for (Value v : reifiedResultShapes[resultNumber]) - sizes.push_back(getAsOpFoldResult(v)); SmallVector strides(rank, rewriter.getIndexAttr(1)); paddedSubtensorResults.push_back(rewriter.create( - loc, paddedResult, offsets, sizes, strides)); + loc, paddedResult, offsets, reifiedResultShapes[resultNumber], + strides)); } return paddedSubtensorResults; } diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/InferTypeOpInterface.h" @@ -92,7 +93,7 @@ if (!dimIndex) return failure(); - SmallVector> reifiedResultShapes; + ReifiedRankedShapedTypeDims reifiedResultShapes; if (failed( rankedShapeTypeOp.reifyResultShapes(rewriter, reifiedResultShapes))) return failure(); @@ -106,7 +107,10 @@ static_cast(sourceType.getRank())) return failure(); - rewriter.replaceOp(dimOp, reifiedResultShapes[resultNumber][*dimIndex]); + rewriter.replaceOp(dimOp, + getValueOrCreateConstantIndexOp( + rewriter, dimOp.getLoc(), + reifiedResultShapes[resultNumber][*dimIndex])); return success(); } }; diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp @@ -38,10 +38,12 @@ /// terms of shape of the `src`, when the reshape op is a collapsing /// operation. It is the product of the shape of the collapsed dimensions of the /// `src`. -static OpFoldResult -getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc, - int64_t dimIndex, Value src, - ArrayRef reassociationMap) { +static OpFoldResult getCollapsedOutputDimFromInputShape( + OpBuilder &builder, Location loc, int64_t dimIndex, Value src, + ArrayRef dstStaticShape, ArrayRef reassociationMap) { + if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) { + return builder.getIndexAttr(dstStaticShape[dimIndex]); + } AffineMap map = reassociationMap[dimIndex]; unsigned startPos = map.getResults().front().cast().getPosition(); @@ -65,8 +67,8 @@ ArrayRef dstStaticShape, ArrayRef reassociation) { return llvm::to_vector<4>(llvm::map_range( llvm::seq(0, dstStaticShape.size()), [&](int64_t dim) { - return getCollapsedOutputDimFromInputShape(builder, loc, dim, src, - reassociation); + return getCollapsedOutputDimFromInputShape( + builder, loc, dim, src, dstStaticShape, reassociation); })); } @@ -77,7 +79,7 @@ ArrayRef dstStaticShape, ArrayRef reassociation, llvm::DenseMap &expandedDimToCollapsedDim) { if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) { - return builder.getI64IntegerAttr(dstStaticShape[dimIndex]); + return builder.getIndexAttr(dstStaticShape[dimIndex]); } unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex]; unsigned startPos = reassociation[sourceDimPos] @@ -144,11 +146,9 @@ ReifiedRankedShapedTypeDims &reifiedReturnShapes) const { auto loc = op->getLoc(); auto reshapeOp = cast(op); - auto resultShape = getReshapeOutputShapeFromInputShape( + reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape( b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(), - reshapeOp.getReassociationMaps()); - reifiedReturnShapes.push_back( - getValueOrCreateConstantIndexOp(b, loc, resultShape)); + reshapeOp.getReassociationMaps())); return success(); } }; @@ -165,8 +165,13 @@ Location loc = padOp.getLoc(); auto lowPad = padOp.getMixedLowPad(); auto highPad = padOp.getMixedHighPad(); - SmallVector shapes; + SmallVector shapes; for (auto dim : llvm::seq(0, padOp.getSourceType().getRank())) { + if (!padOp.getResultType().isDynamicDim(dim)) { + shapes.push_back(b.getIndexAttr(padOp.getResultType().getDimSize(dim))); + continue; + } + // Shape along each dimension is source dim + low pad + high pad. SmallVector mapOperands; mapOperands.push_back( diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -87,7 +87,7 @@ return failure(); if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes))) return failure(); - mixedSizes = getAsOpFoldResult(reifiedShapes[opResult.getResultNumber()]); + mixedSizes = reifiedShapes[opResult.getResultNumber()]; } else { // Static shape: Take static sizes directly. for (int64_t sz : tensorType.getShape()) @@ -523,14 +523,13 @@ LogicalResult EmptyOp::reifyResultShapes(OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - reifiedReturnShapes.resize(1, SmallVector(getType().getRank())); + reifiedReturnShapes.resize(1, SmallVector(getType().getRank())); unsigned ctr = 0; for (int64_t i = 0; i < getType().getRank(); ++i) { if (getType().isDynamicDim(i)) { reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++]; } else { - reifiedReturnShapes[0][i] = - builder.create(getLoc(), i); + reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i)); } } return success(); @@ -1004,14 +1003,14 @@ LogicalResult GenerateOp::reifyResultShapes( OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - reifiedReturnShapes.resize(1, SmallVector(getType().getRank())); + reifiedReturnShapes.resize(1, SmallVector(getType().getRank())); int idx = 0; for (auto dim : llvm::seq(0, getType().getRank())) { if (getType().isDynamicDim(dim)) { reifiedReturnShapes[0][dim] = getOperand(idx++); } else { - reifiedReturnShapes[0][dim] = builder.create( - getLoc(), getType().getDimSize(dim)); + reifiedReturnShapes[0][dim] = + builder.getIndexAttr(getType().getDimSize(dim)); } } return success(); @@ -1787,16 +1786,10 @@ reifiedReturnShapes[0].reserve(getType().getRank()); SmallVector mixedSizes = getMixedSizes(); llvm::SmallBitVector droppedDims = getDroppedDims(); - Location loc = getLoc(); for (const auto &size : enumerate(mixedSizes)) { if (droppedDims.test(size.index())) continue; - if (auto attr = size.value().dyn_cast()) { - reifiedReturnShapes[0].push_back(builder.create( - loc, attr.cast().getInt())); - continue; - } - reifiedReturnShapes[0].push_back(size.value().get()); + reifiedReturnShapes[0].push_back(size.value()); } return success(); } @@ -2210,7 +2203,7 @@ LogicalResult InsertSliceOp::reifyResultShapes( OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - reifiedReturnShapes.resize(1, SmallVector(getType().getRank())); + reifiedReturnShapes.resize(1, SmallVector(getType().getRank())); for (auto dim : llvm::seq(0, getType().getRank())) { reifiedReturnShapes[0][dim] = builder.createOrFold(getLoc(), getDest(), dim); @@ -3160,7 +3153,7 @@ static_assert(llvm::is_one_of::value, "applies to only pack or unpack operations"); int64_t destRank = op.getDestRank(); - reifiedReturnShapes.resize(1, SmallVector(destRank)); + reifiedReturnShapes.resize(1, SmallVector(destRank)); for (auto dim : llvm::seq(0, destRank)) { reifiedReturnShapes[0][dim] = builder.createOrFold(op.getLoc(), op.getDest(), dim); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -219,7 +219,7 @@ (void)packOp.reifyResultShapes(b, outputShape); resultSizes.assign(sizes.begin(), sizes.end()); for (auto dataTileDim : llvm::seq(inputRank, outputRank)) - resultSizes.push_back(getAsOpFoldResult(outputShape[0][dataTileDim])); + resultSizes.push_back(outputShape[0][dataTileDim]); return success(); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp @@ -33,9 +33,8 @@ !llvm::hasSingleElement(resultShapes)) return failure(); // TODO: Do not drop tensor type encoding. - Value emptyTensor = - rewriter.create(loc, getAsOpFoldResult(resultShapes[0]), - reshapeOp.getResultType().getElementType()); + Value emptyTensor = rewriter.create( + loc, resultShapes[0], reshapeOp.getResultType().getElementType()); if (emptyTensor.getType() != reshapeOp.getResultType()) { rewriter.replaceOpWithNewOp( reshapeOp, reshapeOp.getResultType(), emptyTensor); diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp @@ -116,8 +116,7 @@ dyn_cast(op.getOperation()); if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes))) return failure(); - SmallVector collapseShapeOutputShape = - getAsOpFoldResult(reifiedShapes[0]); + SmallVector &collapseShapeOutputShape = reifiedShapes[0]; SmallVector reassociationIndices = op.getReassociationIndices(); diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -193,7 +193,7 @@ // Create the destination tensor using the above values. Type elementType = op.getSourceType().getElementType(); - SmallVector outputShape = getAsOpFoldResult(reifiedShapes[0]); + SmallVector outputShape = reifiedShapes[0]; Value dest = rewriter.create(op->getLoc(), outputShape, elementType); diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1244,7 +1244,7 @@ auto currShape = llvm::to_vector<4>(llvm::map_range( llvm::seq( 0, operand.getType().cast().getRank()), - [&](int64_t dim) -> Value { + [&](int64_t dim) -> OpFoldResult { return builder.createOrFold(loc, operand, dim); })); shapes.emplace_back(std::move(currShape));