diff --git a/mlir/include/mlir/IR/IntegerSet.h b/mlir/include/mlir/IR/IntegerSet.h --- a/mlir/include/mlir/IR/IntegerSet.h +++ b/mlir/include/mlir/IR/IntegerSet.h @@ -75,6 +75,7 @@ explicit operator bool() { return set; } bool operator==(IntegerSet other) const { return set == other.set; } + bool operator!=(IntegerSet other) const { return set != other.set; } unsigned getNumDims() const; unsigned getNumSymbols() const; diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2492,23 +2492,40 @@ withElseRegion); } +/// Compose any affine.apply ops feeding into `operands` of the integer set +/// `set` by composing the maps of such affine.apply ops with the integer +/// set constraints. +static void composeSetAndOperands(IntegerSet &set, + SmallVectorImpl &operands) { + // We will simply reuse the API of the map composition by viewing the LHSs of + // the equalities and inequalities of `set` as the affine exprs of an affine + // map. Convert to equivalent map, compose, and convert back to set. + auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(), + set.getConstraints(), set.getContext()); + // Check if any composition is possible. + if (llvm::none_of(operands, + [](Value v) { return v.getDefiningOp(); })) + return; + + composeAffineMapAndOperands(&map, &operands); + set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(), + set.getEqFlags()); +} + /// Canonicalize an affine if op's conditional (integer set + operands). LogicalResult AffineIfOp::fold(ArrayRef, SmallVectorImpl &) { auto set = getIntegerSet(); SmallVector operands(getOperands()); + composeSetAndOperands(set, operands); canonicalizeSetAndOperands(&set, &operands); - // Any canonicalization change always leads to either a reduction in the - // number of operands or a change in the number of symbolic operands - // (promotion of dims to symbols). - if (operands.size() < getIntegerSet().getNumInputs() || - set.getNumSymbols() > getIntegerSet().getNumSymbols()) { - setConditional(set, operands); - return success(); - } + // Check if the canonicalization or composition led to any change. + if (getIntegerSet() == set && llvm::equal(operands, getOperands())) + return failure(); - return failure(); + setConditional(set, operands); + return success(); } void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results, diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -622,6 +622,28 @@ // ----- +// CHECK-DAG: #[[$SET:.*]] = affine_set<(d0, d1)[s0] : (d0 - 1 >= 0, d1 - 1 == 0, -d0 + s0 + 10 >= 0)> + +// CHECK-LABEL: func @canonicalize_affine_if_compose_apply +// CHECK-SAME: %[[N:.*]]: index +func.func @canonicalize_affine_if_compose_apply(%N: index) { + %M = affine.apply affine_map<()[s0] -> (s0 + 10)> ()[%N] + // CHECK-NEXT: affine.for %[[I:.*]] = + affine.for %i = 0 to 1024 { + // CHECK-NEXT: affine.for %[[J:.*]] = + affine.for %j = 0 to 100 { + %j_ = affine.apply affine_map<(d0)[] -> (d0 + 1)> (%j) + // CHECK-NEXT: affine.if #[[$SET]](%[[I]], %[[J]])[%[[N]]] + affine.if affine_set<(d0, d1)[s0] : (d0 - 1 >= 0, d1 - 2 == 0, -d0 + s0 >= 0)>(%i, %j_)[%M] { + "test.foo"() : ()->() + } + } + } + return +} + +// ----- + // CHECK-DAG: #[[$LBMAP:.*]] = affine_map<()[s0] -> (0, s0)> // CHECK-DAG: #[[$UBMAP:.*]] = affine_map<()[s0] -> (1024, s0 * 2)>