diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -999,18 +999,17 @@ /*desc=*/[{ Returns the statically-known loop ranges. Composes `getShapesToLoopsMap()` with the result of `getStaticShape`. - Returns None if `getShapesToLoopsMap()` fails. Returns - ShapeType::kDynamicSize for non-statically-known loop ranges. + Returns ShapeType::kDynamicSize for non-statically-known loop ranges. + This is expected to be called by a valid Linalg op }], - /*retTy=*/"Optional>", + /*retTy=*/"SmallVector", /*methodName=*/"getStaticLoopRanges", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ SmallVector viewSizes = getStaticShape(); AffineMap invertedMap = getShapesToLoopsMap(); - if (!invertedMap) - return {}; + assert(invertedMap && "expected a valid Linalg op to call the method"); return invertedMap.compose(viewSizes); }] >, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -732,23 +732,20 @@ } // Check if given shapes match to inferred shapes. - Optional> endLoopRangeValues = - linalgOp.getStaticLoopRanges(); - if (!endLoopRangeValues) - return op->emitOpError("unable to find loop range for operation"); - SmallVector startLoopRangeValues((*endLoopRangeValues).size(), 0); + SmallVector endLoopRangeValues = linalgOp.getStaticLoopRanges(); + SmallVector startLoopRangeValues(endLoopRangeValues.size(), 0); // Verify only static cases since we can't get exact dimension sizes and loop // ranges for dynamic cases in this stage. - if (llvm::none_of(*endLoopRangeValues, ShapedType::isDynamic)) { - for (int64_t &range : *endLoopRangeValues) + if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) { + for (int64_t &range : endLoopRangeValues) range -= 1; for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand); SmallVector startIndices = indexingMap.compose(startLoopRangeValues); SmallVector endIndices = - indexingMap.compose(*endLoopRangeValues); + indexingMap.compose(endLoopRangeValues); ArrayRef shape = linalgOp.getShape(opOperand); for (auto dim : llvm::seq(0, shape.size())) { // Ignore dynamic dimension or the case that the dimension size is 0 diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -518,12 +518,8 @@ return failure(); AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand); - Optional> originalLoopRange = - linalgOp.getStaticLoopRanges(); - if (!originalLoopRange) - return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range"); - originalLoopExtent.assign(originalLoopRange->begin(), - originalLoopRange->end()); + SmallVector originalLoopRange = linalgOp.getStaticLoopRanges(); + originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end()); reassociation.clear(); expandedShapeMap.clear(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -73,12 +73,10 @@ op.getReductionDims(dims); assert(dims.size() == 1); unsigned reductionDim = dims[0]; - Optional> loopRanges = op.getStaticLoopRanges(); - if (!loopRanges) - return b.notifyMatchFailure(op, "Cannot analyze loops"); - int64_t reductionDimSize = (*loopRanges)[reductionDim]; + SmallVector loopRanges = op.getStaticLoopRanges(); + int64_t reductionDimSize = loopRanges[reductionDim]; if (reductionDimSize == ShapedType::kDynamicSize || - reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges->size()) + reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges.size()) return b.notifyMatchFailure( op, "Reduction dimension not divisible by split ratio"); SmallVector combinerOps;