diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -185,9 +185,9 @@ if (!convOp.padding()) return im(imIdx); + auto *context = ScopedContext::getContext(); ValueHandle zeroIndex = std_constant_index(0); - SmallVector conds = { - std_constant_int(/*value=*/1, /*width=*/1)}; + SmallVector conds; SmallVector clampedImIdx; for (auto iter : llvm::enumerate(imIdx)) { int idx = iter.index(); @@ -201,13 +201,16 @@ using edsc::op::operator<; using edsc::op::operator>=; using edsc::op::operator||; - conds.push_back(conds.back() || (dim < zeroIndex)); - ValueHandle bound = std_dim(convOp.input(), idx); - conds.push_back(conds.back() || (dim >= bound)); + ValueHandle leftOutOfBound = dim < zeroIndex; + if (conds.empty()) + conds.push_back(leftOutOfBound); + else + conds.push_back(conds.back() || leftOutOfBound); + ValueHandle rightBound = std_dim(convOp.input(), idx); + conds.push_back(conds.back() || (dim >= rightBound)); // When padding is involed, the indices will only be shifted to negative, // so having a max op is enough. - auto *context = ScopedContext::getContext(); auto maxMap = AffineMap::get(/*dimCount=*/1, 0, {getAffineDimExpr(/*position=*/0, context), getAffineConstantExpr(0, context)}); @@ -219,7 +222,8 @@ Type type = convOp.input().getType().cast().getElementType(); ValueHandle zero = std_constant(type, b.getZeroAttr(type)); ValueHandle readInput = im(clampedImIdx); - return std_select(conds.back(), zero, readInput); + return conds.empty() ? readInput + : std_select(conds.back(), zero, readInput); } static void emitScalarImplementation(ArrayRef allIvs, ConvOp convOp) {