diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -245,6 +245,170 @@ return static_cast(expr)->position; } +/// Returns true if the expression is divisible by the given symbol with +/// position `symbolPos`. The argument `opKind` specifies here what kind of +/// division or mod operation called this division. It helps in implementing the +/// commutative property of the floordiv and ceildiv operations. If the argument +///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv +/// operation, then the commutative property can be used otherwise, the floordiv +/// operation is not divisible. The same argument holds for ceildiv operation. +static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos, + AffineExprKind opKind) { + // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only. + assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv || + opKind == AffineExprKind::CeilDiv) && + "unexpected opKind"); + switch (expr.getKind()) { + case AffineExprKind::Constant: + if (expr.cast().getValue()) + return false; + return true; + case AffineExprKind::DimId: + return false; + case AffineExprKind::SymbolId: + return (expr.cast().getPosition() == symbolPos); + // Checks divisibility by the given symbol for both operands. + case AffineExprKind::Add: { + AffineBinaryOpExpr binaryExpr = expr.cast(); + return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) && + isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind); + } + // Checks divisibility by the given symbol for both operands. Consider the + // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`, + // this is a division by s1 and both the operands of modulo are divisible by + // s1 but it is not divisible by s1 always. The third argument is + // `AffineExprKind::Mod` for this reason. + case AffineExprKind::Mod: { + AffineBinaryOpExpr binaryExpr = expr.cast(); + return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, + AffineExprKind::Mod) && + isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, + AffineExprKind::Mod); + } + // Checks if any of the operand divisible by the given symbol. + case AffineExprKind::Mul: { + AffineBinaryOpExpr binaryExpr = expr.cast(); + return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) || + isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind); + } + // Floordiv and ceildiv are divisible by the given symbol when the first + // operand is divisible, and the affine expression kind of the argument expr + // is same as the argument `opKind`. This can be inferred from commutative + // property of floordiv and ceildiv operations and are as follow: + // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2 + // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2 + // It will fail if operations are not same. For example: + // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. + case AffineExprKind::FloorDiv: + case AffineExprKind::CeilDiv: { + AffineBinaryOpExpr binaryExpr = expr.cast(); + if (opKind != expr.getKind()) + return false; + return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()); + } + } + llvm_unreachable("Unknown AffineExpr"); +} + +/// Divides the given expression by the given symbol at position `symbolPos`. It +/// considers the divisibility condition is checked before calling itself. A +/// null expression is returned whenever the divisibility condition fails. +static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos, + AffineExprKind opKind) { + // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only. + assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv || + opKind == AffineExprKind::CeilDiv) && + "unexpected opKind"); + switch (expr.getKind()) { + case AffineExprKind::Constant: + if (expr.cast().getValue() != 0) + return nullptr; + return getAffineConstantExpr(0, expr.getContext()); + case AffineExprKind::DimId: + return nullptr; + case AffineExprKind::SymbolId: + return getAffineConstantExpr(1, expr.getContext()); + // Dividing both operands by the given symbol. + case AffineExprKind::Add: { + AffineBinaryOpExpr binaryExpr = expr.cast(); + return getAffineBinaryOpExpr( + expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind), + symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind)); + } + // Dividing both operands by the given symbol. + case AffineExprKind::Mod: { + AffineBinaryOpExpr binaryExpr = expr.cast(); + return getAffineBinaryOpExpr( + expr.getKind(), + symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()), + symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind())); + } + // Dividing any of the operand by the given symbol. + case AffineExprKind::Mul: { + AffineBinaryOpExpr binaryExpr = expr.cast(); + if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind)) + return binaryExpr.getLHS() * + symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind); + return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) * + binaryExpr.getRHS(); + } + // Dividing first operand only by the given symbol. + case AffineExprKind::FloorDiv: + case AffineExprKind::CeilDiv: { + AffineBinaryOpExpr binaryExpr = expr.cast(); + return getAffineBinaryOpExpr( + expr.getKind(), + symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()), + binaryExpr.getRHS()); + } + } + llvm_unreachable("Unknown AffineExpr"); +} + +/// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv +/// operations when the second operand simplifies to a symbol and the first +/// operand is divisible by that symbol. It can be applied to any semi-affine +/// expression. Returned expression can either be a semi-affine or pure affine +/// expression. +static AffineExpr simplifySemiAffine(AffineExpr expr) { + switch (expr.getKind()) { + case AffineExprKind::Constant: + case AffineExprKind::DimId: + case AffineExprKind::SymbolId: + return expr; + case AffineExprKind::Add: + case AffineExprKind::Mul: { + AffineBinaryOpExpr binaryExpr = expr.cast(); + return getAffineBinaryOpExpr(expr.getKind(), + simplifySemiAffine(binaryExpr.getLHS()), + simplifySemiAffine(binaryExpr.getRHS())); + } + // Check if the simplification of the second operand is a symbol, and the + // first operand is divisible by it. If the operation is a modulo, a constant + // zero expression is returned. In the case of floordiv and ceildiv, the + // symbol from the simplification of the second operand divides the first + // operand. Otherwise, simplification is not possible. + case AffineExprKind::FloorDiv: + case AffineExprKind::CeilDiv: + case AffineExprKind::Mod: { + AffineBinaryOpExpr binaryExpr = expr.cast(); + AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS()); + AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS()); + AffineSymbolExpr symbolExpr = + simplifySemiAffine(binaryExpr.getRHS()).dyn_cast(); + if (!symbolExpr) + return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS); + unsigned symbolPos = symbolExpr.getPosition(); + if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind())) + return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS); + if (expr.getKind() == AffineExprKind::Mod) + return getAffineConstantExpr(0, expr.getContext()); + return symbolicDivide(sLHS, symbolPos, expr.getKind()); + } + } + llvm_unreachable("Unknown AffineExpr"); +} + static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position, MLIRContext *context) { auto assignCtx = [context](AffineDimExprStorage *storage) { @@ -878,8 +1042,9 @@ /// Simplify the affine expression by flattening it and reconstructing it. AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols) { - // TODO: only pure affine for now. The simplification here can - // be extended to semi-affine maps in the future. + // Simplify semi-affine expressions separately. + if (!expr.isPureAffine()) + expr = simplifySemiAffine(expr); if (!expr.isPureAffine()) return expr; diff --git a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir --- a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir +++ b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir @@ -281,3 +281,49 @@ %out = affine.load %in[] : memref return %out : f32 } + +// ----- + +// Tests the simplification of a semi-affine expression in various cases. +// CHECK-DAG: #[[$map0:.*]] = affine_map<()[s0, s1] -> (-(s1 floordiv s0) + 2)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<()[s0, s1] -> (-(s1 floordiv s0) + 42)> + +// Tests the simplification of a semi-affine expression with a modulo operartion on a floordiv and multiplication. +// CHECK-LABEL: func @semiaffine_mod +func @semiaffine_mod(%arg0: index, %arg1: index) -> index { + %a = affine.apply affine_map<(d0)[s0] ->((-((d0 floordiv s0) * s0) + s0 * s0) mod s0)> (%arg0)[%arg1] + // CHECK: %[[CST:.*]] = constant 0 + return %a : index +} + +// Tests the simplification of a semi-affine expression with a nested floordiv and a floordiv on modulo operation. +// CHECK-LABEL: func @semiaffine_floordiv +func @semiaffine_floordiv(%arg0: index, %arg1: index) -> index { + %a = affine.apply affine_map<(d0)[s0] ->((-((d0 floordiv s0) * s0) + ((2 * s0) mod (3 * s0))) floordiv s0)> (%arg0)[%arg1] + // CHECK: affine.apply #[[$map0]]()[%arg1, %arg0] + return %a : index +} + +// Tests the simplification of a semi-affine expression with a ceildiv operation and a division of constant 0 by a symbol. +// CHECK-LABEL: func @semiaffine_ceildiv +func @semiaffine_ceildiv(%arg0: index, %arg1: index) -> index { + %a = affine.apply affine_map<(d0)[s0] ->((-((d0 floordiv s0) * s0) + s0 * 42 + ((5-5) floordiv s0)) ceildiv s0)> (%arg0)[%arg1] + // CHECK: affine.apply #[[$map1]]()[%arg1, %arg0] + return %a : index +} + +// Tests the simplification of a semi-affine expression with a nested ceildiv operation and further simplifications after performing ceildiv. +// CHECK-LABEL: func @semiaffine_composite_floor +func @semiaffine_composite_floor(%arg0: index, %arg1: index) -> index { + %a = affine.apply affine_map<(d0)[s0] ->(((((s0 * 2) ceildiv 4) * 5) + s0 * 42) ceildiv s0)> (%arg0)[%arg1] + // CHECK: %[[CST:.*]] = constant 47 + return %a : index +} + +// Tests the simplification of a semi-affine expression with a modulo operation with a second operand that simplifies to symbol. +// CHECK-LABEL: func @semiaffine_unsimplified_symbol +func @semiaffine_unsimplified_symbol(%arg0: index, %arg1: index) -> index { + %a = affine.apply affine_map<(d0)[s0] ->(s0 mod (2 * s0 - s0))> (%arg0)[%arg1] + // CHECK: %[[CST:.*]] = constant 0 + return %a : index +}