diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -457,47 +457,61 @@ } // Check if given shapes match to inferred shapes. - Optional> loopRanges = linalgOp.getStaticLoopRanges(); - if (!loopRanges) + Optional> endLoopRangeValues = + linalgOp.getStaticLoopRanges(); + if (!endLoopRangeValues) return linalgOp.emitError("unable to find loop range for operation"); + SmallVector startLoopRangeValues((*endLoopRangeValues).size(), 0); // Verify only static cases since we can't get exact dimension sizes and loop // ranges for dynamic cases in this stage. - if (llvm::none_of(*loopRanges, [](int64_t &range) { + if (llvm::none_of(*endLoopRangeValues, [](int64_t &range) { return range == ShapedType::kDynamicSize; })) { - for (int64_t &range : *loopRanges) + for (int64_t &range : *endLoopRangeValues) range -= 1; for (const auto &en : llvm::enumerate(linalgOp.getShapedOperandTypes())) { - auto indices = indexingMaps[en.index()].compose(*loopRanges); + auto startIndices = + indexingMaps[en.index()].compose(startLoopRangeValues); + auto endIndices = indexingMaps[en.index()].compose(*endLoopRangeValues); for (auto j : llvm::seq(0, en.value().getRank())) { - // Ignore dynamic dimension or the case that the inferred last index is - // zero. The index is increasing or decreasing in Linalg, for example, - // the last index should be `0` or `size-1`. We only check the cases - // that are non-zero because most of cases are increasing and it is too - // expensive to find the shape of decreasing cases. - if (en.value().isDynamicDim(j) || indices[j] == 0) + // Ignore dynamic dimension or the case that the dimension size is 0 + auto shapedDimSize = en.value().getDimSize(j); + if (en.value().isDynamicDim(j) || shapedDimSize == 0) continue; - // The size of shaped operands and inferred dimension size should be - // same. But, for now we check if the inferred sizes are in boundary of - // shaped operands' size or not in case that Affine Expressions are - // complicated such as d0 * 3 + d1 since it is not easy to handle the - // issues. - auto inferredSize = indices[j] + 1; - auto shapedDimSize = en.value().getDimSize(j); + // The first index or last index should be the maximum or the minimum in + // the inferred index ranges since the range is increasing or + // decreasing. The size of dimensions of shaped operands and the maximum + // value + 1 in the inferred range should be the same. But, for now we + // check if the inferred ranges are in boundary of shaped operands' size + // or not in case that Affine Expressions are complicated such as d0 * 3 + // + d1 since it is not easy to handle the issues. + // Found the case that this solution can't check, for example, (d0, d1) + // -> (d1 - d0) + auto inferredDimSize = std::max(startIndices[j], endIndices[j]) + 1; + if (std::min(startIndices[j], endIndices[j]) < 0) { + std::string mapStr; + { + llvm::raw_string_ostream os(mapStr); + os << indexingMaps[en.index()]; + } + return linalgOp.emitError( + "unexpected result less than 0 at expression #") + << j << " in " << mapStr; + } if (indexingMaps[en.index()].getResult(j).dyn_cast()) { - if (inferredSize != shapedDimSize) { + if (inferredDimSize != shapedDimSize) { return linalgOp.emitOpError("inferred shaped operand #") << en.index() << " has shape's dimension #" << j << " to be " - << inferredSize << ", but found " << shapedDimSize; + << inferredDimSize << ", but found " << shapedDimSize; } } else { - if (inferredSize > shapedDimSize) { + if (inferredDimSize > shapedDimSize) { return linalgOp.emitOpError("inferred shaped operand #") << en.index() << " has shape's dimension #" << j - << " to be greater than or equal to " << inferredSize + << " to be greater than or equal to " << inferredDimSize << ", but found " << shapedDimSize; } } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -820,3 +820,22 @@ } : (index, index, index, memref<192xf32>) -> () return } + +// ----- + +#attrs = { + indexing_maps = [ + affine_map<(i) -> (3 - i)>, + affine_map<(i) -> (i)> + ], + iterator_types = ["parallel"] +} + +func @invalid_reverse(%A: memref<5xf32>, %B: memref<5xf32>) { + // expected-error @+1 {{unexpected result less than 0 at expression #0 in}} + linalg.generic #attrs ins(%A: memref<5xf32>) outs(%B: memref<5xf32>) { + ^bb0(%a: f32, %b: f32): + linalg.yield %a : f32 + } + return +}