diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -88,6 +88,50 @@ return success; } +/// Returns the maximum and minimum index of the given Affine Expression. +/// This is used for static bound checker. +static int64_t getStaticMaximumIndex(AffineMap &indexingMap, + SmallVector loopRanges, + unsigned idx) { + AffineExpr a, b; + auto expr = indexingMap.getResult(idx); + expr.walk([&](AffineExpr e) { + if (auto d = e.dyn_cast()) + a = d; + else if (auto c = e.dyn_cast()) + b = c; + else { + if (e.getKind() == AffineExprKind::Mul && + b.dyn_cast().getValue() < 0 && + a.dyn_cast()) + loopRanges[a.dyn_cast().getPosition()] = 0; + } + }); + return indexingMap.compose(loopRanges)[idx]; +} + +static int64_t getStaticMinimumIndex(AffineMap &indexingMap, + SmallVector &loopRanges, + unsigned idx) { + SmallVector ranges(loopRanges.size(), 0); + AffineExpr a, b; + auto expr = indexingMap.getResult(idx); + expr.walk([&](AffineExpr e) { + if (auto d = e.dyn_cast()) + a = d; + else if (auto c = e.dyn_cast()) + b = c; + else { + if (e.getKind() == AffineExprKind::Mul && + b.dyn_cast().getValue() < 0 && + a.dyn_cast()) + ranges[a.dyn_cast().getPosition()] = + loopRanges[a.dyn_cast().getPosition()]; + } + }); + return indexingMap.compose(ranges)[idx]; +} + enum MatchContractionResult { Success = 0, NotLinalgOp, @@ -469,35 +513,42 @@ for (int64_t &range : *loopRanges) range -= 1; for (const auto &en : llvm::enumerate(linalgOp.getShapedOperandTypes())) { - auto indices = indexingMaps[en.index()].compose(*loopRanges); for (auto j : llvm::seq(0, en.value().getRank())) { - // Ignore dynamic dimension or the case that the inferred last index is - // zero. The index is increasing or decreasing in Linalg, for example, - // the last index should be `0` or `size-1`. We only check the cases - // that are non-zero because most of cases are increasing and it is too - // expensive to find the shape of decreasing cases. - if (en.value().isDynamicDim(j) || indices[j] == 0) + // Ignore dynamic dimension or the case that the dimension size is 0 + auto shapedDimSize = en.value().getDimSize(j); + if (en.value().isDynamicDim(j) || shapedDimSize == 0) continue; - // The size of shaped operands and inferred dimension size should be - // same. But, for now we check if the inferred sizes are in boundary of - // shaped operands' size or not in case that Affine Expressions are - // complicated such as d0 * 3 + d1 since it is not easy to handle the - // issues. - auto inferredSize = indices[j] + 1; - auto shapedDimSize = en.value().getDimSize(j); + // The size of dimensions of shaped operands and the maximum + // index + 1 in the inferred range should be the same. But, for now we + // check if the inferred ranges are in boundary of shaped operands' size + // or not in case that Affine Expressions are complicated such as d0 * 3 + // + d1 since it is not easy to handle the issues. + auto inferredDimSize = + getStaticMaximumIndex(indexingMaps[en.index()], *loopRanges, j) + 1; + if (getStaticMinimumIndex(indexingMaps[en.index()], *loopRanges, j) < + 0) { + std::string mapStr; + { + llvm::raw_string_ostream os(mapStr); + os << indexingMaps[en.index()]; + } + return linalgOp.emitError( + "unexpected result less than 0 at expression #") + << j << " in " << mapStr; + } if (indexingMaps[en.index()].getResult(j).dyn_cast()) { - if (inferredSize != shapedDimSize) { + if (inferredDimSize != shapedDimSize) { return linalgOp.emitOpError("inferred shaped operand #") << en.index() << " has shape's dimension #" << j << " to be " - << inferredSize << ", but found " << shapedDimSize; + << inferredDimSize << ", but found " << shapedDimSize; } } else { - if (inferredSize > shapedDimSize) { + if (inferredDimSize > shapedDimSize) { return linalgOp.emitOpError("inferred shaped operand #") << en.index() << " has shape's dimension #" << j - << " to be greater than or equal to " << inferredSize + << " to be greater than or equal to " << inferredDimSize << ", but found " << shapedDimSize; } } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -820,3 +820,42 @@ } : (index, index, index, memref<192xf32>) -> () return } + +// ----- + +#attrs = { + indexing_maps = [ + affine_map<(i) -> (3 - i)>, + affine_map<(i) -> (i)> + ], + iterator_types = ["parallel"] +} + +func @invalid_reverse(%A: memref<5xf32>, %B: memref<5xf32>) { + // expected-error @+1 {{unexpected result less than 0 at expression #0 in}} + linalg.generic #attrs ins(%A: memref<5xf32>) outs(%B: memref<5xf32>) { + ^bb0(%a: f32, %b: f32): + linalg.yield %a : f32 + } + return +} + +// ----- + +#attrs = { + indexing_maps = [ + affine_map<(i, j) -> (-i + j)>, + affine_map<(i, j) -> (j)>, + affine_map<(i, j) -> (i)> + ], + iterator_types = ["parallel", "parallel"] +} + +func @accessing_index_less_than_0(%A: memref<2xf32>, %B: memref<2xf32>, %C: memref<2xf32>) { + // expected-error @+1 {{unexpected result less than 0 at expression #0 in}} + linalg.generic #attrs ins(%A, %B: memref<2xf32>, memref<2xf32>) outs(%C: memref<2xf32>) { + ^bb0(%a: f32, %b: f32, %c: f32): + linalg.yield %b : f32 + } + return +}