diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -104,6 +104,10 @@ /// dimensional identifiers. bool isIdentity() const; + /// Returns true if this affine map is an identity affine map on the symbol + /// identifiers. + bool isSymbolIdentity() const; + /// Returns true if this affine map is a minor identity, i.e. an identity /// affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions. bool isMinorIdentity() const; diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -3042,6 +3042,9 @@ SmallVector results; auto foldedMap = op.getMap().partialConstantFold(operands, &results); + if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity()) + return op.getOperand(0); + // If some of the map results are not constant, try changing the map in-place. if (results.empty()) { // If the map is the same, report that folding did not happen. diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -276,6 +276,18 @@ return true; } +bool AffineMap::isSymbolIdentity() const { + if (getNumSymbols() != getNumResults()) + return false; + ArrayRef results = getResults(); + for (unsigned i = 0, numSymbols = getNumSymbols(); i < numSymbols; ++i) { + auto expr = results[i].dyn_cast(); + if (!expr || expr.getPosition() != i) + return false; + } + return true; +} + bool AffineMap::isEmpty() const { return getNumDims() == 0 && getNumSymbols() == 0 && getNumResults() == 0; } diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -1219,3 +1219,19 @@ "test.foo"(%1) : (index) -> () return } + +// ----- +// CHECK-LABEL: func @min.oneval(%arg0: index) +func.func @min.oneval(%arg0: index) -> index { + %min = affine.min affine_map<()[s0] -> (s0)> ()[%arg0] + // CHECK: return %arg0 : index + return %min: index +} + +// ----- +// CHECK-LABEL: func @max.oneval(%arg0: index) +func.func @max.oneval(%arg0: index) -> index { + %max = affine.max affine_map<()[s0] -> (s0)> ()[%arg0] + // CHECK: return %arg0 : index + return %max: index +}