diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -313,6 +313,7 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { LinalgOp linalgOp = cast(op); + // Expect at least one shaped operand. // This means an op that constructs a tensor out of indices cannot be a // LinalgOp at the moment. For now this will have to be a special op until we @@ -457,47 +458,69 @@ } // Check if given shapes match to inferred shapes. - Optional> loopRanges = linalgOp.getStaticLoopRanges(); - if (!loopRanges) + Optional> startLoopRangeValues = + linalgOp.getStaticLoopRanges(), + endLoopRangeValues = + linalgOp.getStaticLoopRanges(); + if (!endLoopRangeValues) return linalgOp.emitError("unable to find loop range for operation"); // Verify only static cases since we can't get exact dimension sizes and loop // ranges for dynamic cases in this stage. - if (llvm::none_of(*loopRanges, [](int64_t &range) { + if (llvm::none_of(*endLoopRangeValues, [](int64_t &range) { return range == ShapedType::kDynamicSize; })) { - for (int64_t &range : *loopRanges) + for (int64_t &range : *startLoopRangeValues) + range = 0; + for (int64_t &range : *endLoopRangeValues) range -= 1; for (const auto &en : llvm::enumerate(linalgOp.getShapedOperandTypes())) { - auto indices = indexingMaps[en.index()].compose(*loopRanges); + auto firstIndices = + indexingMaps[en.index()].compose(*startLoopRangeValues); + auto lastIndices = indexingMaps[en.index()].compose(*endLoopRangeValues); for (auto j : llvm::seq(0, en.value().getRank())) { - // Ignore dynamic dimension or the case that the inferred last index is - // zero. The index is increasing or decreasing in Linalg, for example, - // the last index should be `0` or `size-1`. We only check the cases - // that are non-zero because most of cases are increasing and it is too - // expensive to find the shape of decreasing cases. - if (en.value().isDynamicDim(j) || indices[j] == 0) + // Ignore dynamic dimension + if (en.value().isDynamicDim(j)) continue; - // The size of shaped operands and inferred dimension size should be - // same. But, for now we check if the inferred sizes are in boundary of - // shaped operands' size or not in case that Affine Expressions are - // complicated such as d0 * 3 + d1 since it is not easy to handle the - // issues. - auto inferredSize = indices[j] + 1; + // The first index or last index should be the maximum or the minimum in + // the inferred index ranges since the range is increasing or + // decreasing. The size of dimensions of shaped operands and the maximum + // value + 1 in the inferred range should be the same. But, for now we + // check if the inferred ranges are in boundary of shaped operands' size + // or not in case that Affine Expressions are complicated such as d0 * 3 + // + d1 since it is not easy to handle the issues. + auto firstIndex = firstIndices[j]; + auto lastIndex = lastIndices[j]; auto shapedDimSize = en.value().getDimSize(j); + if (shapedDimSize > 0 && (firstIndex < 0 || lastIndex < 0)) { + std::string mapStr; + { + llvm::raw_string_ostream os(mapStr); + os << indexingMaps[en.index()]; + } + return linalgOp.emitError( + "unexpected result less than 0 at expression #") + << j << " in affineMap\n" + << mapStr; + } if (indexingMaps[en.index()].getResult(j).dyn_cast()) { - if (inferredSize != shapedDimSize) { + if (lastIndex + 1 != shapedDimSize) { return linalgOp.emitOpError("inferred shaped operand #") << en.index() << " has shape's dimension #" << j << " to be " - << inferredSize << ", but found " << shapedDimSize; + << lastIndex + 1 << ", but found " << shapedDimSize; } } else { - if (inferredSize > shapedDimSize) { + if (lastIndex >= shapedDimSize) { + return linalgOp.emitOpError("inferred shaped operand #") + << en.index() << " has shape's dimension #" << j + << " to be greater than or equal to " << lastIndex + 1 + << ", but found " << shapedDimSize; + } else if (firstIndex >= shapedDimSize) { return linalgOp.emitOpError("inferred shaped operand #") << en.index() << " has shape's dimension #" << j - << " to be greater than or equal to " << inferredSize + << " to be greater than or equal to " << firstIndex + 1 << ", but found " << shapedDimSize; } }