diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Matchers.h" @@ -25,6 +26,99 @@ #define DEBUG_TYPE "affine-analysis" +/// A utility function to check if a value is defined at the top level of +/// `region` or is an argument of `region`. A value of index type defined at the +/// top level of a `AffineScope` region is always a valid symbol for all +/// uses in that region. +static bool isTopLevelValue(Value value, Region *region) { + if (auto arg = value.dyn_cast()) + return arg.getParentRegion() == region; + return value.getDefiningOp()->getParentRegion() == region; +} + +/// Checks if `value` known to be a legal affine dimension or symbol in `src` +/// region remains legal if the operation that uses it is inlined into `dest` +/// with the given value mapping. `legalityCheck` is either `isValidDim` or +/// `isValidSymbol`, depending on the value being required to remain a valid +/// dimension or symbol. +static bool +remainsLegalAfterInline(Value value, Region *src, Region *dest, + const BlockAndValueMapping &mapping, + function_ref legalityCheck) { + // If the value is a valid dimension for any other reason than being + // a top-level value, it will remain valid: constants get inlined + // with the function, transitive affine applies also get inlined and + // will be checked themselves, etc. + if (!isTopLevelValue(value, src)) + return true; + + // If it's a top-level value because it's a block operand, i.e. a + // function argument, check whether the value replacing it after + // inlining is a valid dimension in the new region. + if (value.isa()) + return legalityCheck(mapping.lookup(value), dest); + + // If it's a top-level value beacuse it's defined in the region, + // it can only be inlined if the defining op is a constant or a + // `dim`, which can appear anywhere and be valid, since the defining + // op won't be top-level anymore after inlining. + Attribute operandCst; + return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) || + value.getDefiningOp(); +} + +/// Checks if all values known to be legal affine dimensions or symbols in `src` +/// remain so if their respective users are inlined into `dest`. +static bool +remainsLegalAfterInline(ValueRange values, Region *src, Region *dest, + const BlockAndValueMapping &mapping, + function_ref legalityCheck) { + return llvm::all_of(values, [&](Value v) { + return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck); + }); +} + +/// Checks if an affine read or write operation remains legal after inlining +/// from `src` to `dest`. +template +static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest, + const BlockAndValueMapping &mapping) { + static_assert(llvm::is_one_of::value, + "only ops with affine read/write interface are supported"); + + AffineMap map = op.getAffineMap(); + ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims()); + ValueRange symbolOperands = + op.getMapOperands().take_back(map.getNumSymbols()); + if (!remainsLegalAfterInline( + dimOperands, src, dest, mapping, + static_cast(isValidDim))) + return false; + if (!remainsLegalAfterInline( + symbolOperands, src, dest, mapping, + static_cast(isValidSymbol))) + return false; + return true; +} + +/// Checks if an affine apply operation remains legal after inlining from `src` +/// to `dest`. +template <> +bool remainsLegalAfterInline(AffineApplyOp op, Region *src, Region *dest, + const BlockAndValueMapping &mapping) { + // If it's a valid dimension, we need to check that it remains so. + if (isValidDim(op.getResult(), src)) + return remainsLegalAfterInline( + op.getMapOperands(), src, dest, mapping, + static_cast(isValidDim)); + + // Otherwise it must be a valid symbol, check that it remains so. + return remainsLegalAfterInline( + op.getMapOperands(), src, dest, mapping, + static_cast(isValidSymbol)); +} + //===----------------------------------------------------------------------===// // AffineDialect Interfaces //===----------------------------------------------------------------------===// @@ -41,22 +135,62 @@ /// Returns true if the given region 'src' can be inlined into the region /// 'dest' that is attached to an operation registered to the current dialect. + /// 'wouldBeCloned' is set if the region is cloned into its new location + /// rather than moved, indicating there may be other users. bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, BlockAndValueMapping &valueMapping) const final { - // Conservatively don't allow inlining into affine structures. - return false; + // We can inline into affine loops and conditionals if this doesn't break + // affine value categorization rules. + Operation *destOp = dest->getParentOp(); + if (!isa(destOp)) + return false; + + // Multi-block regions cannot be inlined into affine constructs, all of + // which require single-block regions. + if (!llvm::hasSingleElement(*src)) + return false; + + // Side-effecting operations that the affine dialect cannot understand + // should not be inlined. + Block &srcBlock = src->front(); + for (Operation &op : srcBlock) { + // Ops with no side effects are fine, + if (auto iface = dyn_cast(op)) { + if (iface.hasNoEffect()) + continue; + } + + // Assuming the inlined region is valid, we only need to check if the + // inlining would change it. + bool remainsValid = + llvm::TypeSwitch(&op) + .Case([&](auto op) { + return remainsLegalAfterInline(op, src, dest, valueMapping); + }) + .Default([](Operation *) { + // Conservatively disallow inlining ops we cannot reason about. + return false; + }); + + if (!remainsValid) + return false; + } + + return true; } /// Returns true if the given operation 'op', that is registered to this /// dialect, can be inlined into the given region, false otherwise. bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned, BlockAndValueMapping &valueMapping) const final { - // Always allow inlining affine operations into the top-level region of a - // function. There are some edge cases when inlining *into* affine - // structures, but that is handled in the other 'isLegalToInline' hook - // above. - // TODO: We should be able to inline into other regions than functions. - return isa(region->getParentOp()); + // Always allow inlining affine operations into a region that is marked as + // affine scope, or into affine loops and conditionals. There are some edge + // cases when inlining *into* affine structures, but that is handled in the + // other 'isLegalToInline' hook above. + Operation *parentOp = region->getParentOp(); + return parentOp->hasTrait() || + isa(parentOp); } /// Affine regions should be analyzed recursively. @@ -101,16 +235,6 @@ return parentOp && parentOp->hasTrait(); } -/// A utility function to check if a value is defined at the top level of -/// `region` or is an argument of `region`. A value of index type defined at the -/// top level of a `AffineScope` region is always a valid symbol for all -/// uses in that region. -static bool isTopLevelValue(Value value, Region *region) { - if (auto arg = value.dyn_cast()) - return arg.getParentRegion() == region; - return value.getDefiningOp()->getParentRegion() == region; -} - /// Returns the closest region enclosing `op` that is held by an operation with /// trait `AffineScope`; `nullptr` if there is no such region. // TODO: getAffineScope should be publicly exposed for affine passes/utilities. diff --git a/mlir/test/Dialect/Affine/inlining.mlir b/mlir/test/Dialect/Affine/inlining.mlir --- a/mlir/test/Dialect/Affine/inlining.mlir +++ b/mlir/test/Dialect/Affine/inlining.mlir @@ -54,16 +54,77 @@ // ----- -// Test that calls are not inlined into affine structures. +// Test that calls are inlined into affine structures. func @func_noop() { return } -// CHECK-LABEL: func @not_inline_into_affine_ops -func @not_inline_into_affine_ops() { - // CHECK: call @func_noop +// CHECK-LABEL: func @inline_into_affine_ops +func @inline_into_affine_ops() { + // CHECK-NOT: call @func_noop affine.for %i = 1 to 10 { call @func_noop() : () -> () } return } + +// ----- + +// Test that calls with dimension arguments are properly inlined. +func @func_dim(%arg0: index, %arg1: memref) { + affine.load %arg1[%arg0] : memref + return +} + +// CHECK-LABEL: @inline_dimension +// CHECK: (%[[ARG0:.*]]: memref) +func @inline_dimension(%arg0: memref) { + // CHECK: affine.for %[[IV:.*]] = + affine.for %i = 1 to 42 { + // CHECK-NOT: call @func_dim + // CHECK: affine.load %[[ARG0]][%[[IV]]] + call @func_dim(%i, %arg0) : (index, memref) -> () + } + return +} + +// ----- + +// Test that calls with vector operations are also inlined. +func @func_vector_dim(%arg0: index, %arg1: memref<32xf32>) { + affine.vector_load %arg1[%arg0] : memref<32xf32>, vector<4xf32> + return +} + +// CHECK-LABEL: @inline_dimension_vector +// CHECK: (%[[ARG0:.*]]: memref<32xf32>) +func @inline_dimension_vector(%arg0: memref<32xf32>) { + // CHECK: affine.for %[[IV:.*]] = + affine.for %i = 1 to 42 { + // CHECK-NOT: call @func_dim + // CHECK: affine.vector_load %[[ARG0]][%[[IV]]] + call @func_vector_dim(%i, %arg0) : (index, memref<32xf32>) -> () + } + return +} + +// ----- + +// Test that calls that would result in violation of affine value +// categorization (top-level value stop being top-level) are not inlined. +func private @get_index() -> index + +func @func_top_level(%arg0: memref) { + %0 = call @get_index() : () -> index + affine.load %arg0[%0] : memref + return +} + +// CHECK-LABEL: @no_inline_not_top_level +func @no_inline_not_top_level(%arg0: memref) { + affine.for %i = 1 to 42 { + // CHECK: call @func_top_level + call @func_top_level(%arg0) : (memref) -> () + } + return +}