diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -113,6 +113,10 @@ /// that can be folded. LogicalResult foldTensorCast(Operation *op); +/// Return the dimension of the given tensor value. +OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, + int64_t dim); + /// Return the dimensions of the given tensor value. SmallVector getMixedSizes(OpBuilder &builder, Location loc, Value value); diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h @@ -26,16 +26,6 @@ SmallVector createDynamicDimValues(OpBuilder &b, Location loc, Value rankedTensor); -// Returns the tensor extent along dimension `dim` if `rankedTensor` is of -// `RankedTensorType`. Returns `failure()` otherwise. -FailureOr createDimValue(OpBuilder &b, Location loc, - Value rankedTensor, int64_t dim); - -// Creates dim ops or constant ops for each dimension of the ranked tensor -// argument and returns these as values. -SmallVector createDimValues(OpBuilder &b, Location loc, - Value rankedTensor); - /// Returns the transposed `rankedTensorType` if `transposeVector` is non-empty. /// Fail if `transposeVector` is not a permutation matching the tensor rank. FailureOr diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1884,8 +1884,8 @@ llvm::SmallVector results; auto addDynamicDimension = [&](Value source, int64_t dim) { - auto dynamicDim = tensor::createDimValue(builder, loc, source, dim); - if (auto dimValue = llvm::dyn_cast_if_present(dynamicDim.value())) + auto sz = tensor::getMixedSize(builder, loc, source, dim); + if (auto dimValue = llvm::dyn_cast_if_present(sz)) results.push_back(dimValue); }; diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -367,8 +367,8 @@ SmallVector strides(rank, rewriter.getIndexAttr(1)); SmallVector offsets(rank, rewriter.getIndexAttr(0)); - SmallVector sizes = tensor::createDimValues( - rewriter, op.getLoc(), adaptor.getOperands()[0]); + SmallVector sizes = + tensor::getMixedSizes(rewriter, op.getLoc(), adaptor.getOperands()[0]); // Pre-compute the offsets along the axis dimension. // The axisOffsets will be of size rank + 1, where the last value @@ -403,7 +403,7 @@ loc, resultType.getShape(), resultType.getElementType(), dynDims); for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) { - auto sizes = tensor::createDimValues(rewriter, op.getLoc(), arg); + auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg); offsets[axis] = offset; result = rewriter.createOrFold( loc, arg, result, offsets, sizes, strides); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -310,7 +310,7 @@ rewriter.setInsertionPointAfterValue(op->get()); auto elemType = cast(op->get().getType()).getElementType(); auto empty = rewriter.create( - loc, tensor::createDimValues(rewriter, loc, op->get()), elemType); + loc, tensor::getMixedSizes(rewriter, loc, op->get()), elemType); auto [start, end] = genericOp.getDpsInitsPositionRange(); newOutputOperands[op->getOperandNumber() - start] = empty.getResult(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1779,16 +1779,10 @@ if (definingOp) continue; modifiedOutput = true; - SmallVector dynamicDims; - for (const auto &dim : llvm::enumerate(operandType.getShape())) { - if (dim.value() != ShapedType::kDynamic) - continue; - dynamicDims.push_back(rewriter.createOrFold( - loc, operandVal, dim.index())); - } + SmallVector mixedSizes = + tensor::getMixedSizes(rewriter, loc, operandVal); Value emptyTensor = rewriter.create( - loc, operandType.getShape(), operandType.getElementType(), - dynamicDims); + loc, mixedSizes, operandType.getElementType()); op->setOperand(opOperand->getOperandNumber(), emptyTensor); } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -291,7 +291,7 @@ int64_t dim = oldShape[oldIdx]; newOutputShape.push_back(dim); if (ShapedType::isDynamic(dim)) - dynamicDims.push_back(b.createOrFold( + dynamicDims.push_back(b.create( loc, linalgOp.getDpsInitOperand(0)->get(), oldIdx)); } Value emptyTensor = b.create( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -512,12 +512,10 @@ llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) { int outerPos = packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]]; - OpFoldResult origSize = rewriter.createOrFold( - loc, packOp.getSource(), - rewriter.create(loc, pos)); - OpFoldResult outerSize = rewriter.createOrFold( - loc, packOp.getDest(), - rewriter.create(loc, outerPos)); + OpFoldResult origSize = + tensor::getMixedSize(rewriter, loc, packOp.getSource(), pos); + OpFoldResult outerSize = + tensor::getMixedSize(rewriter, loc, packOp.getDest(), outerPos); AffineExpr s0, d0, d1; bindDims(rewriter.getContext(), d0, d1); bindSymbols(rewriter.getContext(), s0); @@ -1132,8 +1130,8 @@ SmallVector staticSizes; for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { if (resultType.isDynamicDim(dim)) { - auto srcSize = rewriter.createOrFold( - padOp.getLoc(), padOp.getSource(), dim); + auto srcSize = getIdxValue(tensor::getMixedSize(rewriter, padOp.getLoc(), + padOp.getSource(), dim)); // Add low and high padding value. auto plusLow = rewriter.createOrFold( padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); @@ -1157,15 +1155,8 @@ // for copying the PadOp source. auto sourceType = padOp.getSourceType(); // Compute size of source of tensor::PadOp. - SmallVector srcSizes; - for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) { - if (sourceType.isDynamicDim(dim)) { - srcSizes.push_back(rewriter.createOrFold( - padOp.getLoc(), padOp.getSource(), dim)); - } else { - srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim))); - } - } + SmallVector srcSizes = + tensor::getMixedSizes(rewriter, padOp.getLoc(), padOp.getSource()); // Strides of InsertSliceOp are all 1. SmallVector strides(sourceType.getRank(), rewriter.getIndexAttr(1)); @@ -1459,8 +1450,8 @@ ArrayRef destShape = unpackOp.getDestType().getShape(); for (auto i : llvm::seq(0, destRank)) { if (dimAndTileMapping.count(i) || destShape[i] != 1) - tileSizes.push_back(getAsOpFoldResult( - rewriter.createOrFold(loc, unpackOp.getDest(), i))); + tileSizes.push_back( + tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i)); } auto partialTile = rewriter.create( diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -455,7 +455,7 @@ SmallVector resultSizesList; for (size_t i = 0; i < offsets.size(); i++) resultSizesList.push_back( - b.createOrFold(loc, parallelOp->getResult(0), i)); + tensor::getMixedSize(b, loc, parallelOp->getResult(0), i)); SmallVector outOffsets(offsets.size(), b.getIndexAttr(0)); SmallVector replacements = yieldTiledValues( b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets, diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -46,18 +46,22 @@ return nullptr; } +OpFoldResult tensor::getMixedSize(OpBuilder &builder, Location loc, Value value, + int64_t dim) { + auto tensorType = llvm::cast(value.getType()); + SmallVector result; + if (tensorType.isDynamicDim(dim)) + return builder.createOrFold(loc, value, dim); + + return builder.getIndexAttr(tensorType.getDimSize(dim)); +} + SmallVector tensor::getMixedSizes(OpBuilder &builder, Location loc, Value value) { auto tensorType = llvm::cast(value.getType()); SmallVector result; - for (int64_t i = 0; i < tensorType.getRank(); ++i) { - if (tensorType.isDynamicDim(i)) { - Value size = builder.create(loc, value, i); - result.push_back(size); - } else { - result.push_back(builder.getIndexAttr(tensorType.getDimSize(i))); - } - } + for (int64_t i = 0; i < tensorType.getRank(); ++i) + result.push_back(getMixedSize(builder, loc, value, i)); return result; } @@ -2283,15 +2287,7 @@ LogicalResult InsertSliceOp::reifyResultShapes( OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { reifiedReturnShapes.resize(1, SmallVector(getType().getRank())); - for (auto dim : llvm::seq(0, getType().getRank())) { - if (getType().isDynamicDim(dim)) { - reifiedReturnShapes[0][dim] = - builder.createOrFold(getLoc(), getDest(), dim); - } else { - reifiedReturnShapes[0][dim] = - builder.getIndexAttr(getType().getDimSize(dim)); - } - } + reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest()); return success(); } @@ -3254,16 +3250,8 @@ "applies to only pack or unpack operations"); int64_t destRank = op.getDestRank(); reifiedReturnShapes.resize(1, SmallVector(destRank)); - ShapedType resultType = llvm::cast(op.getResult().getType()); - for (auto dim : llvm::seq(0, destRank)) { - if (resultType.isDynamicDim(dim)) { - reifiedReturnShapes[0][dim] = - builder.createOrFold(op.getLoc(), op.getDest(), dim); - } else { - reifiedReturnShapes[0][dim] = - builder.getIndexAttr(resultType.getDimSize(dim)); - } - } + reifiedReturnShapes[0] = + tensor::getMixedSizes(builder, op.getLoc(), op.getDest()); return success(); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -134,7 +134,7 @@ DenseMap dimAndTileMapping = packOp.getDimAndTileMapping(); SmallVector srcDimValues = - tensor::createDimValues(b, loc, packOp.getSource()); + tensor::getMixedSizes(b, loc, packOp.getSource()); SmallVector inputIndices, inputSizes; for (auto dim : llvm::seq(0, inputRank)) { using AV = affine::AffineValueExpr; @@ -502,8 +502,7 @@ bool hasHighPad = !isConstantIntValue(high, 0); auto offset = offsets[dim]; auto length = sizes[dim]; - auto srcSize = - tensor::createDimValue(b, loc, padOp.getSource(), dim).value(); + auto srcSize = tensor::getMixedSize(b, loc, padOp.getSource(), dim); // The new amount of low padding is `low - offset`. Except for the case // where none of the low padding is read. In that case, the new amount of diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp @@ -26,27 +26,6 @@ using namespace mlir::affine; using namespace mlir::tensor; -/// Get the dimension size of a value of RankedTensor type at the -static OpFoldResult getShapeDimSize(OpBuilder &b, Location loc, - Value rankedTensor, int64_t dimIdx) { - RankedTensorType tensorType = cast(rankedTensor.getType()); - if (!tensorType.isDynamicDim(dimIdx)) { - return b.getIndexAttr(tensorType.getDimSize(dimIdx)); - } - Value idxValue = b.create(loc, dimIdx); - return b.createOrFold(loc, rankedTensor, idxValue); -} - -/// Get all the dimension sizes of a value of RankedTensor type. -static SmallVector getShapeDimSizes(OpBuilder &b, Location loc, - Value rankedTensor) { - SmallVector dimSizes; - RankedTensorType tensorType = cast(rankedTensor.getType()); - for (unsigned i = 0; i < tensorType.getRank(); i++) - dimSizes.push_back(getShapeDimSize(b, loc, rankedTensor, i)); - return dimSizes; -} - /// A tuple that represents (dimension number, dimension value). using DimAndIndex = std::tuple; @@ -123,7 +102,8 @@ llvm::SmallBitVector slicedDimensions = getSlicedDimensions(collapseShapeOutputShape, sliceParams); - auto collapseShapeInputShape = getShapeDimSizes(b, op.getLoc(), op.getSrc()); + auto collapseShapeInputShape = + tensor::getMixedSizes(b, op.getLoc(), op.getSrc()); SmallVector tileSizes; for (unsigned i = 0; i < sliceParams.size(); i++) { @@ -193,7 +173,7 @@ auto one = rewriter.getIndexAttr(1); SmallVector offsets(sourceType.getRank(), zero); SmallVector sizes = - getShapeDimSizes(rewriter, op.getLoc(), op.getSrc()); + tensor::getMixedSizes(rewriter, op.getLoc(), op.getSrc()); SmallVector strides(sourceType.getRank(), one); auto sliceOp = rewriter.create( op.getLoc(), info->sliceResultType, op.getSrc(), offsets, sizes, strides); diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -34,9 +34,9 @@ // Compute the padding width. AffineExpr d0; bindDims(b.getContext(), d0); - auto dimOp = b.createOrFold(loc, source, en.index()); + OpFoldResult sz = tensor::getMixedSize(b, loc, source, en.index()); high[en.index()] = - affine::makeComposedAffineApply(b, loc, en.value() - d0, {dimOp}) + affine::makeComposedAffineApply(b, loc, en.value() - d0, {sz}) .getResult(); } return b.create(loc, type, source, low, high, pad, nofold); @@ -55,35 +55,6 @@ return dynamicDims; } -FailureOr mlir::tensor::createDimValue(OpBuilder &b, Location loc, - Value rankedTensor, - int64_t dim) { - auto tensorTy = dyn_cast(rankedTensor.getType()); - if (!tensorTy) - return failure(); - auto shape = tensorTy.getShape(); - if (dim >= static_cast(shape.size())) - return failure(); - if (ShapedType::isDynamic(shape[dim])) - return OpFoldResult(b.createOrFold(loc, rankedTensor, dim)); - return OpFoldResult(b.getIndexAttr(shape[dim])); -} - -SmallVector -mlir::tensor::createDimValues(OpBuilder &b, Location loc, Value rankedTensor) { - auto tensorTy = cast(rankedTensor.getType()); - SmallVector dims; - for (const auto &en : llvm::enumerate(tensorTy.getShape())) { - if (ShapedType::isDynamic(en.value())) { - dims.push_back( - b.createOrFold(loc, rankedTensor, en.index())); - } else { - dims.push_back(b.getIndexAttr(en.value())); - } - } - return dims; -} - FailureOr mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType, ArrayRef transposeVector) {