diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h @@ -26,6 +26,11 @@ SmallVector createDynamicDimValues(OpBuilder &b, Location loc, Value rankedTensor); +// Returns the tensor extent along dimension `dim` if `rankedTensor` is of +// `RankedTensorType`. Returns `failure()` otherwise. +FailureOr createDimValue(OpBuilder &b, Location loc, + Value rankedTensor, int64_t dim); + // Creates dim ops or constant ops for each dimension of the ranked tensor // argument and returns these as values. SmallVector createDimValues(OpBuilder &b, Location loc, diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -462,25 +462,25 @@ bindDims(b.getContext(), dim0, dim1); // Add two integers. auto addMap = AffineMap::get(2, 0, {dim0 + dim1}); - auto add = [&](Value v1, Value v2) { - return b.createOrFold(loc, addMap, ValueRange{v1, v2}); + auto add = [&](OpFoldResult v1, OpFoldResult v2) { + return makeComposedFoldedAffineApply(b, loc, addMap, {v1, v2}); }; // Subtract two integers. auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); - auto sub = [&](Value v1, Value v2) { - return b.createOrFold(loc, subMap, ValueRange{v1, v2}); + auto sub = [&](OpFoldResult v1, OpFoldResult v2) { + return makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2}); }; // Take the minimum of two integers. auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext()); - auto min = [&](Value v1, Value v2) { - return b.createOrFold(loc, idMap, ValueRange{v1, v2}); + auto min = [&](OpFoldResult v1, OpFoldResult v2) { + return makeComposedFoldedAffineMin(b, loc, idMap, {v1, v2}); }; // Take the maximum of two integers. - auto max = [&](Value v1, Value v2) { - return b.createOrFold(loc, idMap, ValueRange{v1, v2}); + auto max = [&](OpFoldResult v1, OpFoldResult v2) { + return makeComposedFoldedAffineMax(b, loc, idMap, {v1, v2}); }; // Zero index-typed integer. - auto zero = b.create(loc, 0); + OpFoldResult zero = b.getIndexAttr(0); // Helper function for filling static/dynamic low/high padding indices // vectors of PadOp. @@ -496,8 +496,7 @@ // Compute new offsets, lengths, low padding, high padding. SmallVector newOffsets, newLengths, newStrides; - SmallVector newLows, newHighs; - SmallVector staticNewLows, staticNewHighs; + SmallVector newLows, newHighs; // Set to true if the original data source is not read at all. bool hasZeroLen = false; // Same as hasZeroLen, but for dynamic dimension sizes. This condition @@ -506,23 +505,22 @@ int64_t rank = padOp.getSourceType().getRank(); for (unsigned dim = 0; dim < rank; ++dim) { - auto low = - getValueOrCreateConstantIndexOp(b, loc, padOp.getMixedLowPad()[dim]); - bool hasLowPad = getConstantIntValue(low) != static_cast(0); - auto high = - getValueOrCreateConstantIndexOp(b, loc, padOp.getMixedHighPad()[dim]); - bool hasHighPad = getConstantIntValue(high) != static_cast(0); - auto offset = getValueOrCreateConstantIndexOp(b, loc, offsets[dim]); - auto length = getValueOrCreateConstantIndexOp(b, loc, sizes[dim]); - auto srcSize = b.createOrFold(loc, padOp.getSource(), dim); + auto low = padOp.getMixedLowPad()[dim]; + bool hasLowPad = isConstantIntValue(low, 0); + auto high = padOp.getMixedHighPad()[dim]; + bool hasHighPad = isConstantIntValue(high, 0); + auto offset = offsets[dim]; + auto length = sizes[dim]; + auto srcSize = + tensor::createDimValue(b, loc, padOp.getSource(), dim).value(); // The new amount of low padding is `low - offset`. Except for the case // where none of the low padding is read. In that case, the new amount of // low padding is zero. // // Optimization: If low = 0, then newLow = 0. - Value newLow = hasLowPad ? max(zero, sub(low, offset)) : zero; - appendIndex(newLow, newLows, staticNewLows); + OpFoldResult newLow = hasLowPad ? max(zero, sub(low, offset)) : zero; + newLows.push_back(newLow); // Start reading the data from position `offset - low`. Since the original // read may have started in the low padding zone, this value could be @@ -536,9 +534,10 @@ // no data from the source.) // // Optimization: If low = 0, then the formula can be simplified. - Value newOffset = hasLowPad ? min(max(sub(offset, low), zero), srcSize) - : min(offset, srcSize); - newOffsets.push_back(getAsOpFoldResult(newOffset)); + OpFoldResult newOffset = hasLowPad + ? min(max(sub(offset, low), zero), srcSize) + : min(offset, srcSize); + newOffsets.push_back(newOffset); // The original ExtractSliceOp was reading until position `offset + // length`. Therefore, the corresponding position within the source tensor @@ -559,19 +558,21 @@ // The new ExtractSliceOp length is `endLoc - newOffset`. // // Optimization: If low = 0, then the formula can be simplified. - Value endLoc = hasLowPad - ? min(max(add(sub(offset, low), length), zero), srcSize) - : min(add(offset, length), srcSize); - Value newLength = sub(endLoc, newOffset); - newLengths.push_back(getAsOpFoldResult(newLength)); + OpFoldResult endLoc = + hasLowPad ? min(max(add(sub(offset, low), length), zero), srcSize) + : min(add(offset, length), srcSize); + OpFoldResult newLength = sub(endLoc, newOffset); + newLengths.push_back(newLength); // Check if newLength is zero. In that case, no SubTensorOp should be // executed. - if (auto newLengthInt = getConstantIntValue(newLength)) { - hasZeroLen |= *newLengthInt == 0; - } else { - Value check = b.create(loc, arith::CmpIPredicate::eq, - newLength, zero); + if (isConstantIntValue(newLength, 0)) { + hasZeroLen = true; + } else if (!hasZeroLen) { + Value check = b.create( + loc, arith::CmpIPredicate::eq, + getValueOrCreateConstantIndexOp(b, loc, newLength), + getValueOrCreateConstantIndexOp(b, loc, zero)); dynHasZeroLenCond = dynHasZeroLenCond ? b.create(loc, check, dynHasZeroLenCond) @@ -582,8 +583,9 @@ // so that the result has the same length as the original ExtractSliceOp. // As an optimization, if the original high padding is zero, then the new // high padding must also be zero. - Value newHigh = hasHighPad ? sub(sub(length, newLength), newLow) : zero; - appendIndex(newHigh, newHighs, staticNewHighs); + OpFoldResult newHigh = + hasHighPad ? sub(sub(length, newLength), newLow) : zero; + newHighs.push_back(newHigh); // Only unit stride supported. newStrides.push_back(b.getIndexAttr(1)); @@ -597,7 +599,10 @@ RankedTensorType::get(shape, padOp.getResultType().getElementType()); // Insert cast to ensure that types match. (May be folded away.) - auto castResult = [&](Value val) -> Operation * { + auto castResult = [&](Operation *op) -> Operation * { + Value val = op->getResult(0); + if (resultType == val.getType()) + return op; return b.create(loc, resultType, val); }; @@ -618,10 +623,9 @@ // the result shape of the new SliceOp has a zero dimension. auto createPadOfExtractSlice = [&]() { // Create pad(extract_slice(x)). - auto newSliceOp = b.create( + Value newSliceOp = b.create( loc, padOp.getSource(), newOffsets, newLengths, newStrides); - auto newPadOp = b.create(loc, newSliceOp, staticNewLows, - staticNewHighs, newLows, newHighs); + auto newPadOp = b.create(loc, Type(), newSliceOp, newLows, newHighs); // Copy region to new PadOp. IRMapping bvm; diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -51,6 +51,20 @@ return dynamicDims; } +FailureOr mlir::tensor::createDimValue(OpBuilder &b, Location loc, + Value rankedTensor, + int64_t dim) { + auto tensorTy = rankedTensor.getType().dyn_cast(); + if (!tensorTy) + return failure(); + auto shape = tensorTy.getShape(); + if (dim >= shape.size()) + return failure(); + if (ShapedType::isDynamic(shape[dim])) + return OpFoldResult(b.createOrFold(loc, rankedTensor, dim)); + return OpFoldResult(b.getIndexAttr(shape[dim])); +} + SmallVector mlir::tensor::createDimValues(OpBuilder &b, Location loc, Value rankedTensor) { auto tensorTy = rankedTensor.getType().cast();