diff --git a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h --- a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h +++ b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h @@ -12,10 +12,7 @@ #include "mlir/Support/LLVM.h" namespace mlir { -class AffineExpr; class AffineForOp; -class AffineMap; -class AffineParallelOp; class Location; struct LogicalResult; class OpBuilder; @@ -26,18 +23,6 @@ class RewritePatternSet; -/// Emit code that computes the given affine expression using standard -/// arithmetic operations applied to the provided dimension and symbol values. -Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, - ValueRange dimValues, ValueRange symbolValues); - -/// Create a sequence of operations that implement the `affineMap` applied to -/// the given `operands` (as it it were an AffineApplyOp). -Optional> expandAffineMap(OpBuilder &builder, - Location loc, - AffineMap affineMap, - ValueRange operands); - /// Collect a set of patterns to convert from the Affine dialect to the Standard /// dialect, in particular convert structured affine control flow into CFG /// branch-based control flow. diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -287,6 +287,18 @@ void createAffineComputationSlice(Operation *opInst, SmallVectorImpl *sliceOps); +/// Emit code that computes the given affine expression using standard +/// arithmetic operations applied to the provided dimension and symbol values. +Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, + ValueRange dimValues, ValueRange symbolValues); + +/// Create a sequence of operations that implement the `affineMap` applied to +/// the given `operands` (as it it were an AffineApplyOp). +Optional> expandAffineMap(OpBuilder &builder, + Location loc, + AffineMap affineMap, + ValueRange operands); + } // namespace mlir #endif // MLIR_DIALECT_AFFINE_UTILS_H diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -15,14 +15,12 @@ #include "../PassDetail.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/Pass.h" @@ -32,213 +30,6 @@ using namespace mlir; using namespace mlir::vector; -namespace { -/// Visit affine expressions recursively and build the sequence of operations -/// that correspond to it. Visitation functions return an Value of the -/// expression subtree they visited or `nullptr` on error. -class AffineApplyExpander - : public AffineExprVisitor { -public: - /// This internal class expects arguments to be non-null, checks must be - /// performed at the call site. - AffineApplyExpander(OpBuilder &builder, ValueRange dimValues, - ValueRange symbolValues, Location loc) - : builder(builder), dimValues(dimValues), symbolValues(symbolValues), - loc(loc) {} - - template - Value buildBinaryExpr(AffineBinaryOpExpr expr) { - auto lhs = visit(expr.getLHS()); - auto rhs = visit(expr.getRHS()); - if (!lhs || !rhs) - return nullptr; - auto op = builder.create(loc, lhs, rhs); - return op.getResult(); - } - - Value visitAddExpr(AffineBinaryOpExpr expr) { - return buildBinaryExpr(expr); - } - - Value visitMulExpr(AffineBinaryOpExpr expr) { - return buildBinaryExpr(expr); - } - - /// Euclidean modulo operation: negative RHS is not allowed. - /// Remainder of the euclidean integer division is always non-negative. - /// - /// Implemented as - /// - /// a mod b = - /// let remainder = srem a, b; - /// negative = a < 0 in - /// select negative, remainder + b, remainder. - Value visitModExpr(AffineBinaryOpExpr expr) { - auto rhsConst = expr.getRHS().dyn_cast(); - if (!rhsConst) { - emitError( - loc, - "semi-affine expressions (modulo by non-const) are not supported"); - return nullptr; - } - if (rhsConst.getValue() <= 0) { - emitError(loc, "modulo by non-positive value is not supported"); - return nullptr; - } - - auto lhs = visit(expr.getLHS()); - auto rhs = visit(expr.getRHS()); - assert(lhs && rhs && "unexpected affine expr lowering failure"); - - Value remainder = builder.create(loc, lhs, rhs); - Value zeroCst = builder.create(loc, 0); - Value isRemainderNegative = builder.create( - loc, arith::CmpIPredicate::slt, remainder, zeroCst); - Value correctedRemainder = - builder.create(loc, remainder, rhs); - Value result = builder.create( - loc, isRemainderNegative, correctedRemainder, remainder); - return result; - } - - /// Floor division operation (rounds towards negative infinity). - /// - /// For positive divisors, it can be implemented without branching and with a - /// single division operation as - /// - /// a floordiv b = - /// let negative = a < 0 in - /// let absolute = negative ? -a - 1 : a in - /// let quotient = absolute / b in - /// negative ? -quotient - 1 : quotient - Value visitFloorDivExpr(AffineBinaryOpExpr expr) { - auto rhsConst = expr.getRHS().dyn_cast(); - if (!rhsConst) { - emitError( - loc, - "semi-affine expressions (division by non-const) are not supported"); - return nullptr; - } - if (rhsConst.getValue() <= 0) { - emitError(loc, "division by non-positive value is not supported"); - return nullptr; - } - - auto lhs = visit(expr.getLHS()); - auto rhs = visit(expr.getRHS()); - assert(lhs && rhs && "unexpected affine expr lowering failure"); - - Value zeroCst = builder.create(loc, 0); - Value noneCst = builder.create(loc, -1); - Value negative = builder.create( - loc, arith::CmpIPredicate::slt, lhs, zeroCst); - Value negatedDecremented = builder.create(loc, noneCst, lhs); - Value dividend = - builder.create(loc, negative, negatedDecremented, lhs); - Value quotient = builder.create(loc, dividend, rhs); - Value correctedQuotient = - builder.create(loc, noneCst, quotient); - Value result = builder.create(loc, negative, - correctedQuotient, quotient); - return result; - } - - /// Ceiling division operation (rounds towards positive infinity). - /// - /// For positive divisors, it can be implemented without branching and with a - /// single division operation as - /// - /// a ceildiv b = - /// let negative = a <= 0 in - /// let absolute = negative ? -a : a - 1 in - /// let quotient = absolute / b in - /// negative ? -quotient : quotient + 1 - Value visitCeilDivExpr(AffineBinaryOpExpr expr) { - auto rhsConst = expr.getRHS().dyn_cast(); - if (!rhsConst) { - emitError(loc) << "semi-affine expressions (division by non-const) are " - "not supported"; - return nullptr; - } - if (rhsConst.getValue() <= 0) { - emitError(loc, "division by non-positive value is not supported"); - return nullptr; - } - auto lhs = visit(expr.getLHS()); - auto rhs = visit(expr.getRHS()); - assert(lhs && rhs && "unexpected affine expr lowering failure"); - - Value zeroCst = builder.create(loc, 0); - Value oneCst = builder.create(loc, 1); - Value nonPositive = builder.create( - loc, arith::CmpIPredicate::sle, lhs, zeroCst); - Value negated = builder.create(loc, zeroCst, lhs); - Value decremented = builder.create(loc, lhs, oneCst); - Value dividend = - builder.create(loc, nonPositive, negated, decremented); - Value quotient = builder.create(loc, dividend, rhs); - Value negatedQuotient = - builder.create(loc, zeroCst, quotient); - Value incrementedQuotient = - builder.create(loc, quotient, oneCst); - Value result = builder.create( - loc, nonPositive, negatedQuotient, incrementedQuotient); - return result; - } - - Value visitConstantExpr(AffineConstantExpr expr) { - auto op = builder.create(loc, expr.getValue()); - return op.getResult(); - } - - Value visitDimExpr(AffineDimExpr expr) { - assert(expr.getPosition() < dimValues.size() && - "affine dim position out of range"); - return dimValues[expr.getPosition()]; - } - - Value visitSymbolExpr(AffineSymbolExpr expr) { - assert(expr.getPosition() < symbolValues.size() && - "symbol dim position out of range"); - return symbolValues[expr.getPosition()]; - } - -private: - OpBuilder &builder; - ValueRange dimValues; - ValueRange symbolValues; - - Location loc; -}; -} // namespace - -/// Create a sequence of operations that implement the `expr` applied to the -/// given dimension and symbol values. -mlir::Value mlir::expandAffineExpr(OpBuilder &builder, Location loc, - AffineExpr expr, ValueRange dimValues, - ValueRange symbolValues) { - return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr); -} - -/// Create a sequence of operations that implement the `affineMap` applied to -/// the given `operands` (as it it were an AffineApplyOp). -Optional> mlir::expandAffineMap(OpBuilder &builder, - Location loc, - AffineMap affineMap, - ValueRange operands) { - auto numDims = affineMap.getNumDims(); - auto expanded = llvm::to_vector<8>( - llvm::map_range(affineMap.getResults(), - [numDims, &builder, loc, operands](AffineExpr expr) { - return expandAffineExpr(builder, loc, expr, - operands.take_front(numDims), - operands.drop_front(numDims)); - })); - if (llvm::all_of(expanded, [](Value v) { return v; })) - return expanded; - return None; -} - /// Given a range of values, emit the code that reduces them with "min" or "max" /// depending on the provided comparison predicate. The predicate defines which /// comparison to perform, "lt" for "min", "gt" for "max" and is used for the diff --git a/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt b/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt --- a/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/AffineToStandard/CMakeLists.txt @@ -12,12 +12,13 @@ LINK_LIBS PUBLIC MLIRAffine + MLIRAffineUtils MLIRArithmetic + MLIRIR MLIRMemRef MLIRSCF MLIRPass MLIRStandard MLIRTransforms - MLIRIR MLIRVector ) diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/IntegerSet.h" @@ -27,6 +28,213 @@ using namespace mlir; +namespace { +/// Visit affine expressions recursively and build the sequence of operations +/// that correspond to it. Visitation functions return an Value of the +/// expression subtree they visited or `nullptr` on error. +class AffineApplyExpander + : public AffineExprVisitor { +public: + /// This internal class expects arguments to be non-null, checks must be + /// performed at the call site. + AffineApplyExpander(OpBuilder &builder, ValueRange dimValues, + ValueRange symbolValues, Location loc) + : builder(builder), dimValues(dimValues), symbolValues(symbolValues), + loc(loc) {} + + template + Value buildBinaryExpr(AffineBinaryOpExpr expr) { + auto lhs = visit(expr.getLHS()); + auto rhs = visit(expr.getRHS()); + if (!lhs || !rhs) + return nullptr; + auto op = builder.create(loc, lhs, rhs); + return op.getResult(); + } + + Value visitAddExpr(AffineBinaryOpExpr expr) { + return buildBinaryExpr(expr); + } + + Value visitMulExpr(AffineBinaryOpExpr expr) { + return buildBinaryExpr(expr); + } + + /// Euclidean modulo operation: negative RHS is not allowed. + /// Remainder of the euclidean integer division is always non-negative. + /// + /// Implemented as + /// + /// a mod b = + /// let remainder = srem a, b; + /// negative = a < 0 in + /// select negative, remainder + b, remainder. + Value visitModExpr(AffineBinaryOpExpr expr) { + auto rhsConst = expr.getRHS().dyn_cast(); + if (!rhsConst) { + emitError( + loc, + "semi-affine expressions (modulo by non-const) are not supported"); + return nullptr; + } + if (rhsConst.getValue() <= 0) { + emitError(loc, "modulo by non-positive value is not supported"); + return nullptr; + } + + auto lhs = visit(expr.getLHS()); + auto rhs = visit(expr.getRHS()); + assert(lhs && rhs && "unexpected affine expr lowering failure"); + + Value remainder = builder.create(loc, lhs, rhs); + Value zeroCst = builder.create(loc, 0); + Value isRemainderNegative = builder.create( + loc, arith::CmpIPredicate::slt, remainder, zeroCst); + Value correctedRemainder = + builder.create(loc, remainder, rhs); + Value result = builder.create( + loc, isRemainderNegative, correctedRemainder, remainder); + return result; + } + + /// Floor division operation (rounds towards negative infinity). + /// + /// For positive divisors, it can be implemented without branching and with a + /// single division operation as + /// + /// a floordiv b = + /// let negative = a < 0 in + /// let absolute = negative ? -a - 1 : a in + /// let quotient = absolute / b in + /// negative ? -quotient - 1 : quotient + Value visitFloorDivExpr(AffineBinaryOpExpr expr) { + auto rhsConst = expr.getRHS().dyn_cast(); + if (!rhsConst) { + emitError( + loc, + "semi-affine expressions (division by non-const) are not supported"); + return nullptr; + } + if (rhsConst.getValue() <= 0) { + emitError(loc, "division by non-positive value is not supported"); + return nullptr; + } + + auto lhs = visit(expr.getLHS()); + auto rhs = visit(expr.getRHS()); + assert(lhs && rhs && "unexpected affine expr lowering failure"); + + Value zeroCst = builder.create(loc, 0); + Value noneCst = builder.create(loc, -1); + Value negative = builder.create( + loc, arith::CmpIPredicate::slt, lhs, zeroCst); + Value negatedDecremented = builder.create(loc, noneCst, lhs); + Value dividend = + builder.create(loc, negative, negatedDecremented, lhs); + Value quotient = builder.create(loc, dividend, rhs); + Value correctedQuotient = + builder.create(loc, noneCst, quotient); + Value result = builder.create(loc, negative, + correctedQuotient, quotient); + return result; + } + + /// Ceiling division operation (rounds towards positive infinity). + /// + /// For positive divisors, it can be implemented without branching and with a + /// single division operation as + /// + /// a ceildiv b = + /// let negative = a <= 0 in + /// let absolute = negative ? -a : a - 1 in + /// let quotient = absolute / b in + /// negative ? -quotient : quotient + 1 + Value visitCeilDivExpr(AffineBinaryOpExpr expr) { + auto rhsConst = expr.getRHS().dyn_cast(); + if (!rhsConst) { + emitError(loc) << "semi-affine expressions (division by non-const) are " + "not supported"; + return nullptr; + } + if (rhsConst.getValue() <= 0) { + emitError(loc, "division by non-positive value is not supported"); + return nullptr; + } + auto lhs = visit(expr.getLHS()); + auto rhs = visit(expr.getRHS()); + assert(lhs && rhs && "unexpected affine expr lowering failure"); + + Value zeroCst = builder.create(loc, 0); + Value oneCst = builder.create(loc, 1); + Value nonPositive = builder.create( + loc, arith::CmpIPredicate::sle, lhs, zeroCst); + Value negated = builder.create(loc, zeroCst, lhs); + Value decremented = builder.create(loc, lhs, oneCst); + Value dividend = + builder.create(loc, nonPositive, negated, decremented); + Value quotient = builder.create(loc, dividend, rhs); + Value negatedQuotient = + builder.create(loc, zeroCst, quotient); + Value incrementedQuotient = + builder.create(loc, quotient, oneCst); + Value result = builder.create( + loc, nonPositive, negatedQuotient, incrementedQuotient); + return result; + } + + Value visitConstantExpr(AffineConstantExpr expr) { + auto op = builder.create(loc, expr.getValue()); + return op.getResult(); + } + + Value visitDimExpr(AffineDimExpr expr) { + assert(expr.getPosition() < dimValues.size() && + "affine dim position out of range"); + return dimValues[expr.getPosition()]; + } + + Value visitSymbolExpr(AffineSymbolExpr expr) { + assert(expr.getPosition() < symbolValues.size() && + "symbol dim position out of range"); + return symbolValues[expr.getPosition()]; + } + +private: + OpBuilder &builder; + ValueRange dimValues; + ValueRange symbolValues; + + Location loc; +}; +} // namespace + +/// Create a sequence of operations that implement the `expr` applied to the +/// given dimension and symbol values. +mlir::Value mlir::expandAffineExpr(OpBuilder &builder, Location loc, + AffineExpr expr, ValueRange dimValues, + ValueRange symbolValues) { + return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr); +} + +/// Create a sequence of operations that implement the `affineMap` applied to +/// the given `operands` (as it it were an AffineApplyOp). +Optional> mlir::expandAffineMap(OpBuilder &builder, + Location loc, + AffineMap affineMap, + ValueRange operands) { + auto numDims = affineMap.getNumDims(); + auto expanded = llvm::to_vector<8>( + llvm::map_range(affineMap.getResults(), + [numDims, &builder, loc, operands](AffineExpr expr) { + return expandAffineExpr(builder, loc, expr, + operands.take_front(numDims), + operands.drop_front(numDims)); + })); + if (llvm::all_of(expanded, [](Value v) { return v; })) + return expanded; + return None; +} + /// Promotes the `then` or the `else` block of `ifOp` (depending on whether /// `elseBlock` is false or true) into `ifOp`'s containing block, and discards /// the rest of the op.