diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -561,31 +561,39 @@ OpBuilder opBuilder(op); int64_t origLoopStep = op.getStep(); - // Calculate upperBound for normalized loop. - SmallVector ubOperands; AffineBound lb = op.getLowerBound(); + AffineMap canonicalizedLbMap = lb.getMap(); + SmallVector origLbOperands; + llvm::append_range(origLbOperands, lb.getOperands()); + AffineBound ub = op.getUpperBound(); + AffineMap canonicalizedUbMap = ub.getMap(); + SmallVector origUbOperands; + llvm::append_range(origUbOperands, ub.getOperands()); + + // Calculate upperBound for normalized loop. + SmallVector ubOperands; ubOperands.reserve(ub.getNumOperands() + lb.getNumOperands()); - AffineMap origLbMap = lb.getMap(); - AffineMap origUbMap = ub.getMap(); // Add dimension operands from upper/lower bound. - for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j) + for (unsigned j = 0, e = canonicalizedUbMap.getNumDims(); j < e; ++j) ubOperands.push_back(ub.getOperand(j)); - for (unsigned j = 0, e = origLbMap.getNumDims(); j < e; ++j) + for (unsigned j = 0, e = canonicalizedLbMap.getNumDims(); j < e; ++j) ubOperands.push_back(lb.getOperand(j)); // Add symbol operands from upper/lower bound. - for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j) - ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j)); - for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j) - ubOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j)); + for (unsigned j = 0, e = canonicalizedUbMap.getNumSymbols(); j < e; ++j) + ubOperands.push_back(ub.getOperand(canonicalizedUbMap.getNumDims() + j)); + for (unsigned j = 0, e = canonicalizedLbMap.getNumSymbols(); j < e; ++j) + ubOperands.push_back(lb.getOperand(canonicalizedLbMap.getNumDims() + j)); // Add original result expressions from lower/upper bound map. - SmallVector origLbExprs(origLbMap.getResults().begin(), - origLbMap.getResults().end()); - SmallVector origUbExprs(origUbMap.getResults().begin(), - origUbMap.getResults().end()); + SmallVector origLbExprs( + canonicalizedLbMap.getResults().begin(), + canonicalizedLbMap.getResults().end()); + SmallVector origUbExprs( + canonicalizedUbMap.getResults().begin(), + canonicalizedUbMap.getResults().end()); SmallVector newUbExprs; // The original upperBound can have more than one result. For the new @@ -604,15 +612,15 @@ } // Construct newUbMap. - AffineMap newUbMap = - AffineMap::get(origLbMap.getNumDims() + origUbMap.getNumDims(), - origLbMap.getNumSymbols() + origUbMap.getNumSymbols(), - newUbExprs, opBuilder.getContext()); + AffineMap newUbMap = AffineMap::get( + canonicalizedLbMap.getNumDims() + canonicalizedUbMap.getNumDims(), + canonicalizedLbMap.getNumSymbols() + canonicalizedUbMap.getNumSymbols(), + newUbExprs, opBuilder.getContext()); canonicalizeMapAndOperands(&newUbMap, &ubOperands); SmallVector lbOperands(lb.getOperands().begin(), lb.getOperands().begin() + - lb.getMap().getNumDims()); + canonicalizedLbMap.getNumDims()); // Normalize the loop. op.setUpperBound(ubOperands, newUbMap); @@ -625,13 +633,16 @@ // Add an extra dim operand for loopIV. lbOperands.push_back(op.getInductionVar()); // Add symbol operands from lower bound. - for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j) - lbOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j)); - - AffineExpr origIVExpr = opBuilder.getAffineDimExpr(lb.getMap().getNumDims()); - AffineExpr newIVExpr = origIVExpr * origLoopStep + origLbMap.getResult(0); - AffineMap ivMap = AffineMap::get(origLbMap.getNumDims() + 1, - origLbMap.getNumSymbols(), newIVExpr); + for (unsigned j = 0, e = canonicalizedLbMap.getNumSymbols(); j < e; ++j) + lbOperands.push_back(origLbOperands[canonicalizedLbMap.getNumDims() + j]); + + AffineExpr origIVExpr = + opBuilder.getAffineDimExpr(canonicalizedLbMap.getNumDims()); + AffineExpr newIVExpr = + origIVExpr * origLoopStep + canonicalizedLbMap.getResult(0); + AffineMap ivMap = + AffineMap::get(canonicalizedLbMap.getNumDims() + 1, + canonicalizedLbMap.getNumSymbols(), newIVExpr); canonicalizeMapAndOperands(&ivMap, &lbOperands); Operation *newIV = opBuilder.create(loc, ivMap, lbOperands); op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV); diff --git a/mlir/test/Dialect/Affine/affine-loop-normalize.mlir b/mlir/test/Dialect/Affine/affine-loop-normalize.mlir --- a/mlir/test/Dialect/Affine/affine-loop-normalize.mlir +++ b/mlir/test/Dialect/Affine/affine-loop-normalize.mlir @@ -213,3 +213,82 @@ } return } + +// ----- + +// CHECK-LABEL: func @constant_lower_bound +func.func @constant_lower_bound() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + scf.for %j = %c0 to %c1 step %c1 { + // CHECK: affine.for %[[ARG0:.*]] = + affine.for %i = %c0 to %c1 { + // CHECK-NEXT: %{{.*}} = affine.apply #map{{.*}}(%[[ARG0]]) + } + } + return +} + +// ----- + +// CHECK-DAG: [[$UB_MAP:#map[0-9]*]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK-DAG: [[$IV_MAP:#map[0-9]*]] = affine_map<(d0) -> (d0 * 4)> + +// CHECK-LABEL: func @upper_bound_by_symbol +func.func @upper_bound_by_symbol(%arg0: index, %arg1: index) { + // CHECK: affine.for %[[ARG0:.*]] = 0 to [[$UB_MAP]]()[%arg{{.*}}] { + affine.for %i = 0 to affine_map<()[s0, s1] -> (s0)>()[%arg0, %arg1] step 4 { + // CHECK-NEXT: %[[IV:.*]] = affine.apply [[$IV_MAP]](%[[ARG0]]) + // CHECK-NEXT: "test.foo"(%[[IV]]) : (index) -> () + "test.foo"(%i) : (index) -> () + } + return +} + +// ----- + +// CHECK-DAG: [[$UB_MAP:#map[0-9]*]] = affine_map<()[s0] -> ((-s0 + 10) ceildiv 4)> +// CHECK-DAG: [[$IV_MAP:#map[0-9]*]] = affine_map<(d0)[s0] -> (d0 * 4 + s0)> + +// CHECK-LABEL: func @lower_bound_by_symbol +func.func @lower_bound_by_symbol(%arg0: index, %arg1: index) { + // CHECK: affine.for %[[ARG0:.*]] = 0 to [[$UB_MAP]]()[%arg{{.*}}] { + affine.for %i = affine_map<()[s0, s1] -> (s0)>()[%arg0, %arg1] to 10 step 4 { + // CHECK-NEXT: %[[IV:.*]] = affine.apply [[$IV_MAP]](%[[ARG0]])[%arg{{.*}}] + // CHECK-NEXT: "test.foo"(%[[IV]]) : (index) -> () + "test.foo"(%i) : (index) -> () + } + return +} + +// ----- + +// CHECK-DAG: [[$UB_MAP:#map[0-9]*]] = affine_map<()[s0] -> (s0 ceildiv 4)> +// CHECK-DAG: [[$IV_MAP:#map[0-9]*]] = affine_map<(d0) -> (d0 * 4)> + +// CHECK-LABEL: func @upper_bound_by_dim +func.func @upper_bound_by_dim(%arg0: index, %arg1: index) { + // CHECK: affine.for %[[ARG0:.*]] = 0 to [[$UB_MAP]]()[%arg{{.*}}] { + affine.for %i = 0 to affine_map<(d0, d1) -> (d0)>(%arg0, %arg1) step 4 { + // CHECK-NEXT: %[[IV:.*]] = affine.apply [[$IV_MAP]](%[[ARG0]]) + // CHECK-NEXT: "test.foo"(%[[IV]]) : (index) -> () + "test.foo"(%i) : (index) -> () + } + return +} + +// ----- + +// CHECK-DAG: [[$UB_MAP:#map[0-9]*]] = affine_map<()[s0] -> ((-s0 + 10) ceildiv 4)> +// CHECK-DAG: [[$IV_MAP:#map[0-9]*]] = affine_map<(d0)[s0] -> (d0 * 4 + s0)> + +// CHECK-LABEL: func @upper_bound_by_dim +func.func @upper_bound_by_dim(%arg0: index, %arg1: index) { + // CHECK: affine.for %[[ARG0:.*]] = 0 to [[$UB_MAP]]()[%arg{{.*}}] { + affine.for %i = affine_map<(d0, d1) -> (d0)>(%arg0, %arg1) to 10 step 4 { + // CHECK-NEXT: %[[IV:.*]] = affine.apply [[$IV_MAP]](%[[ARG0]])[%arg{{.*}}] + // CHECK-NEXT: "test.foo"(%[[IV]]) : (index) -> () + "test.foo"(%i) : (index) -> () + } + return +}