diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -313,7 +313,6 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { LinalgOp linalgOp = cast(op); - // Expect at least one shaped operand. // This means an op that constructs a tensor out of indices cannot be a // LinalgOp at the moment. For now this will have to be a special op until we @@ -458,30 +457,28 @@ } // Check if given shapes match to inferred shapes. - Optional> startLoopRangeValues = - linalgOp.getStaticLoopRanges(), - endLoopRangeValues = - linalgOp.getStaticLoopRanges(); + Optional> endLoopRangeValues = + linalgOp.getStaticLoopRanges(); if (!endLoopRangeValues) return linalgOp.emitError("unable to find loop range for operation"); + SmallVector startLoopRangeValues((*endLoopRangeValues).size(), 0); // Verify only static cases since we can't get exact dimension sizes and loop // ranges for dynamic cases in this stage. if (llvm::none_of(*endLoopRangeValues, [](int64_t &range) { return range == ShapedType::kDynamicSize; })) { - for (int64_t &range : *startLoopRangeValues) - range = 0; for (int64_t &range : *endLoopRangeValues) range -= 1; for (const auto &en : llvm::enumerate(linalgOp.getShapedOperandTypes())) { - auto firstIndices = - indexingMaps[en.index()].compose(*startLoopRangeValues); - auto lastIndices = indexingMaps[en.index()].compose(*endLoopRangeValues); + auto startIndices = + indexingMaps[en.index()].compose(startLoopRangeValues); + auto endIndices = indexingMaps[en.index()].compose(*endLoopRangeValues); for (auto j : llvm::seq(0, en.value().getRank())) { - // Ignore dynamic dimension - if (en.value().isDynamicDim(j)) + // Ignore dynamic dimension or the case that the dimension size is 0 + auto shapedDimSize = en.value().getDimSize(j); + if (en.value().isDynamicDim(j) || shapedDimSize == 0) continue; // The first index or last index should be the maximum or the minimum in @@ -491,10 +488,10 @@ // check if the inferred ranges are in boundary of shaped operands' size // or not in case that Affine Expressions are complicated such as d0 * 3 // + d1 since it is not easy to handle the issues. - auto firstIndex = firstIndices[j]; - auto lastIndex = lastIndices[j]; - auto shapedDimSize = en.value().getDimSize(j); - if (shapedDimSize > 0 && (firstIndex < 0 || lastIndex < 0)) { + // Found the case that this solution can't check, for example, (d0, d1) + // -> (d1 - d0) + auto inferredDimSize = std::max(startIndices[j], endIndices[j]) + 1; + if (std::min(startIndices[j], endIndices[j]) < 0) { std::string mapStr; { llvm::raw_string_ostream os(mapStr); @@ -502,25 +499,19 @@ } return linalgOp.emitError( "unexpected result less than 0 at expression #") - << j << " in affineMap\n" - << mapStr; + << j << " in " << mapStr; } if (indexingMaps[en.index()].getResult(j).dyn_cast()) { - if (lastIndex + 1 != shapedDimSize) { + if (inferredDimSize != shapedDimSize) { return linalgOp.emitOpError("inferred shaped operand #") << en.index() << " has shape's dimension #" << j << " to be " - << lastIndex + 1 << ", but found " << shapedDimSize; + << inferredDimSize << ", but found " << shapedDimSize; } } else { - if (lastIndex >= shapedDimSize) { - return linalgOp.emitOpError("inferred shaped operand #") - << en.index() << " has shape's dimension #" << j - << " to be greater than or equal to " << lastIndex + 1 - << ", but found " << shapedDimSize; - } else if (firstIndex >= shapedDimSize) { + if (inferredDimSize > shapedDimSize) { return linalgOp.emitOpError("inferred shaped operand #") << en.index() << " has shape's dimension #" << j - << " to be greater than or equal to " << firstIndex + 1 + << " to be greater than or equal to " << inferredDimSize << ", but found " << shapedDimSize; } } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -820,3 +820,23 @@ } : (index, index, index, memref<192xf32>) -> () return } + +// ----- + +#attrs = { + indexing_maps = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (3-i)> + ], + iterator_types = ["parallel"] +} + +func @invalid_reverse(%A: memref<5xf32>, %B: memref<5xf32>) { + // expected-error @+1 {{unexpected result less than 0 at expression #0 in (d0) -> (-d0 + 3)}} + linalg.generic #attrs ins(%A: memref<5xf32>) outs(%B: memref<5xf32>) { + ^bb0(%a: f32, %b: f32): + linalg.yield %b : f32 + } + + return +}