diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -88,6 +88,50 @@ return success; } +/// Returns the maximum and minimum index of the given Affine Expression. +/// This is used for static bound checker. +static int64_t getStaticMaximumIndex(AffineMap &indexingMap, + SmallVector loopRanges, + unsigned idx) { + AffineExpr a, b; + auto expr = indexingMap.getResult(idx); + expr.walk([&](AffineExpr e) { + if (auto d = e.dyn_cast()) + a = d; + else if (auto c = e.dyn_cast()) + b = c; + else { + if (e.getKind() == AffineExprKind::Mul && + b.dyn_cast().getValue() < 0 && + a.dyn_cast()) + loopRanges[a.dyn_cast().getPosition()] = 0; + } + }); + return indexingMap.compose(loopRanges)[idx]; +} + +static int64_t getStaticMinimumIndex(AffineMap &indexingMap, + SmallVector &loopRanges, + unsigned idx) { + SmallVector ranges(loopRanges.size(), 0); + AffineExpr a, b; + auto expr = indexingMap.getResult(idx); + expr.walk([&](AffineExpr e) { + if (auto d = e.dyn_cast()) + a = d; + else if (auto c = e.dyn_cast()) + b = c; + else { + if (e.getKind() == AffineExprKind::Mul && + b.dyn_cast().getValue() < 0 && + a.dyn_cast()) + ranges[a.dyn_cast().getPosition()] = + loopRanges[a.dyn_cast().getPosition()]; + } + }); + return indexingMap.compose(ranges)[idx]; +} + enum MatchContractionResult { Success = 0, NotLinalgOp, @@ -457,23 +501,18 @@ } // Check if given shapes match to inferred shapes. - Optional> endLoopRangeValues = - linalgOp.getStaticLoopRanges(); - if (!endLoopRangeValues) + Optional> loopRanges = linalgOp.getStaticLoopRanges(); + if (!loopRanges) return linalgOp.emitError("unable to find loop range for operation"); - SmallVector startLoopRangeValues((*endLoopRangeValues).size(), 0); // Verify only static cases since we can't get exact dimension sizes and loop // ranges for dynamic cases in this stage. - if (llvm::none_of(*endLoopRangeValues, [](int64_t &range) { + if (llvm::none_of(*loopRanges, [](int64_t &range) { return range == ShapedType::kDynamicSize; })) { - for (int64_t &range : *endLoopRangeValues) + for (int64_t &range : *loopRanges) range -= 1; for (const auto &en : llvm::enumerate(linalgOp.getShapedOperandTypes())) { - auto startIndices = - indexingMaps[en.index()].compose(startLoopRangeValues); - auto endIndices = indexingMaps[en.index()].compose(*endLoopRangeValues); for (auto j : llvm::seq(0, en.value().getRank())) { // Ignore dynamic dimension or the case that the dimension size is 0 @@ -481,17 +520,15 @@ if (en.value().isDynamicDim(j) || shapedDimSize == 0) continue; - // The first index or last index should be the maximum or the minimum in - // the inferred index ranges since the range is increasing or - // decreasing. The size of dimensions of shaped operands and the maximum - // value + 1 in the inferred range should be the same. But, for now we + // The size of dimensions of shaped operands and the maximum + // index + 1 in the inferred range should be the same. But, for now we // check if the inferred ranges are in boundary of shaped operands' size // or not in case that Affine Expressions are complicated such as d0 * 3 // + d1 since it is not easy to handle the issues. - // Found the case that this solution can't check, for example, (d0, d1) - // -> (d1 - d0) - auto inferredDimSize = std::max(startIndices[j], endIndices[j]) + 1; - if (std::min(startIndices[j], endIndices[j]) < 0) { + auto inferredDimSize = + getStaticMaximumIndex(indexingMaps[en.index()], *loopRanges, j) + 1; + if (getStaticMinimumIndex(indexingMaps[en.index()], *loopRanges, j) < + 0) { std::string mapStr; { llvm::raw_string_ostream os(mapStr); diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -839,3 +839,23 @@ } return } + +// ----- + +#attrs = { + indexing_maps = [ + affine_map<(i, j) -> (-i + j)>, + affine_map<(i, j) -> (j)>, + affine_map<(i, j) -> (i)> + ], + iterator_types = ["parallel", "parallel"] +} + +func @accessing_index_less_than_0(%A: memref<2xf32>, %B: memref<2xf32>, %C: memref<2xf32>) { + // expected-error @+1 {{unexpected result less than 0 at expression #0 in}} + linalg.generic #attrs ins(%A, %B: memref<2xf32>, memref<2xf32>) outs(%C: memref<2xf32>) { + ^bb0(%a: f32, %b: f32, %c: f32): + linalg.yield %b : f32 + } + return +}