diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -561,31 +561,43 @@ OpBuilder opBuilder(op); int64_t origLoopStep = op.getStep(); - // Calculate upperBound for normalized loop. - SmallVector ubOperands; AffineBound lb = op.getLowerBound(); + AffineMap canonicalizedLbMap = lb.getMap(); + SmallVector origLbOperands; + llvm::append_range(origLbOperands, lb.getOperands()); + AffineBound ub = op.getUpperBound(); + AffineMap canonicalizedUbMap = ub.getMap(); + SmallVector origUbOperands; + llvm::append_range(origUbOperands, ub.getOperands()); + + // Calculate upperBound for normalized loop. + SmallVector ubOperands; ubOperands.reserve(ub.getNumOperands() + lb.getNumOperands()); - AffineMap origLbMap = lb.getMap(); - AffineMap origUbMap = ub.getMap(); + + // Make sure to canonicalize the original affine map beforehand. + canonicalizeMapAndOperands(&canonicalizedLbMap, &origLbOperands); + canonicalizeMapAndOperands(&canonicalizedUbMap, &origUbOperands); // Add dimension operands from upper/lower bound. - for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j) + for (unsigned j = 0, e = canonicalizedUbMap.getNumDims(); j < e; ++j) ubOperands.push_back(ub.getOperand(j)); - for (unsigned j = 0, e = origLbMap.getNumDims(); j < e; ++j) + for (unsigned j = 0, e = canonicalizedLbMap.getNumDims(); j < e; ++j) ubOperands.push_back(lb.getOperand(j)); // Add symbol operands from upper/lower bound. - for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j) - ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j)); - for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j) - ubOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j)); + for (unsigned j = 0, e = canonicalizedUbMap.getNumSymbols(); j < e; ++j) + ubOperands.push_back(ub.getOperand(canonicalizedUbMap.getNumDims() + j)); + for (unsigned j = 0, e = canonicalizedLbMap.getNumSymbols(); j < e; ++j) + ubOperands.push_back(lb.getOperand(canonicalizedLbMap.getNumDims() + j)); // Add original result expressions from lower/upper bound map. - SmallVector origLbExprs(origLbMap.getResults().begin(), - origLbMap.getResults().end()); - SmallVector origUbExprs(origUbMap.getResults().begin(), - origUbMap.getResults().end()); + SmallVector origLbExprs( + canonicalizedLbMap.getResults().begin(), + canonicalizedLbMap.getResults().end()); + SmallVector origUbExprs( + canonicalizedUbMap.getResults().begin(), + canonicalizedUbMap.getResults().end()); SmallVector newUbExprs; // The original upperBound can have more than one result. For the new @@ -605,14 +617,15 @@ // Construct newUbMap. AffineMap newUbMap = - AffineMap::get(origLbMap.getNumDims() + origUbMap.getNumDims(), - origLbMap.getNumSymbols() + origUbMap.getNumSymbols(), + AffineMap::get( + canonicalizedLbMap.getNumDims() + canonicalizedUbMap.getNumDims(), + canonicalizedLbMap.getNumSymbols() + canonicalizedUbMap.getNumSymbols(), newUbExprs, opBuilder.getContext()); canonicalizeMapAndOperands(&newUbMap, &ubOperands); SmallVector lbOperands(lb.getOperands().begin(), lb.getOperands().begin() + - lb.getMap().getNumDims()); + canonicalizedLbMap.getNumDims()); // Normalize the loop. op.setUpperBound(ubOperands, newUbMap); @@ -625,13 +638,13 @@ // Add an extra dim operand for loopIV. lbOperands.push_back(op.getInductionVar()); // Add symbol operands from lower bound. - for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j) - lbOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j)); + for (unsigned j = 0, e = canonicalizedLbMap.getNumSymbols(); j < e; ++j) + lbOperands.push_back(origLbOperands[canonicalizedLbMap.getNumDims() + j]); - AffineExpr origIVExpr = opBuilder.getAffineDimExpr(lb.getMap().getNumDims()); - AffineExpr newIVExpr = origIVExpr * origLoopStep + origLbMap.getResult(0); - AffineMap ivMap = AffineMap::get(origLbMap.getNumDims() + 1, - origLbMap.getNumSymbols(), newIVExpr); + AffineExpr origIVExpr = opBuilder.getAffineDimExpr(canonicalizedLbMap.getNumDims()); + AffineExpr newIVExpr = origIVExpr * origLoopStep + canonicalizedLbMap.getResult(0); + AffineMap ivMap = AffineMap::get(canonicalizedLbMap.getNumDims() + 1, + canonicalizedLbMap.getNumSymbols(), newIVExpr); canonicalizeMapAndOperands(&ivMap, &lbOperands); Operation *newIV = opBuilder.create(loc, ivMap, lbOperands); op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV); diff --git a/mlir/test/Dialect/Affine/affine-loop-normalize.mlir b/mlir/test/Dialect/Affine/affine-loop-normalize.mlir --- a/mlir/test/Dialect/Affine/affine-loop-normalize.mlir +++ b/mlir/test/Dialect/Affine/affine-loop-normalize.mlir @@ -213,3 +213,32 @@ } return } + +// ----- + +// CHECK-LABEL: func @constant_lower_bound +func.func @constant_lower_bound() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + scf.for %j = %c0 to %c1 step %c1 { + // CHECK: affine.for %[[ARG0:.*]] = 0 to 1 { + affine.for %i = %c0 to %c1 { + // CHECK-NEXT: %[[IV:.*]] = affine.apply #map(%[[ARG0]]) + } + } + return +} + +// ----- + +// CHECK-LABEL: func @ensure_canonicalize_constant_bound +func.func @ensure_canonicalize_constant_bound() { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + // CHECK: affine.for %[[ARG0:.*]] = 0 to 2 { + affine.for %i = %c0 to %c2 { + // CHECK-NEXT: %[[IV:.*]] = affine.apply #map(%[[ARG0]]) + "test.foo"(%i) : (index) -> () + } + return +}