diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -364,6 +364,12 @@ /// comparison predicates. bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, const APFloat &rhs); + +/// Return true if ofr1 and ofr2 are the same integer constant attribute values +/// or the same SSA value. +/// Ignore integer bitwitdh and type mismatch that come from the fact there is +/// no IndexAttr and that IndexType have no bitwidth. +bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2); } // end namespace mlir #endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2928,6 +2928,7 @@ }]; let hasCanonicalizer = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -3026,6 +3027,7 @@ /// and `strides` operands. static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; } }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -59,6 +59,27 @@ dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); } +/// Return true if ofr1 and ofr2 are the same integer constant attribute values +/// or the same SSA value. +/// Ignore integer bitwitdh and type mismatch that come from the fact there is +/// no IndexAttr and that IndexType have no bitwidth. +bool mlir::isEqualConstantIntOrValue(OpFoldResult op1, OpFoldResult op2) { + auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional { + Attribute attr = ofr.dyn_cast(); + // Note: isa+cast-like pattern allows writing the condition below as 1 line. + if (!attr && ofr.get().getDefiningOp()) + attr = ofr.get().getDefiningOp().getValue(); + if (auto intAttr = attr.dyn_cast_or_null()) + return intAttr.getValue().getSExtValue(); + return llvm::None; + }; + auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2); + if (cst1 && cst2 && *cst1 == *cst2) + return true; + auto v1 = op1.dyn_cast(), v2 = op2.dyn_cast(); + return v1 && v2 && v1 == v2; +} + //===----------------------------------------------------------------------===// // StandardOpsDialect Interfaces //===----------------------------------------------------------------------===// @@ -3557,6 +3578,34 @@ context); } +// +static LogicalResult +foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, + ShapedType shapedType) { + OpBuilder b(op.getContext()); + for (OpFoldResult ofr : op.getMixedOffsets()) + if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(0))) + return failure(); + // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip + // is appropriate. + auto shape = shapedType.getShape(); + for (auto it : llvm::zip(op.getMixedSizes(), shape)) + if (!isEqualConstantIntOrValue(std::get<0>(it), + b.getIndexAttr(std::get<1>(it)))) + return failure(); + for (OpFoldResult ofr : op.getMixedStrides()) + if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(1))) + return failure(); + return success(); +} + +OpFoldResult SubTensorOp::fold(ArrayRef) { + if (getSourceType() == getType() && + succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) + return this->source(); + return OpFoldResult(); +} + //===----------------------------------------------------------------------===// // SubTensorInsertOp //===----------------------------------------------------------------------===// @@ -3597,6 +3646,13 @@ build(b, result, source, dest, offsetValues, sizeValues, strideValues); } +OpFoldResult SubTensorInsertOp::fold(ArrayRef) { + if (getSourceType() == getType() && + succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) + return this->source(); + return OpFoldResult(); +} + //===----------------------------------------------------------------------===// // TensorLoadOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -157,3 +157,22 @@ memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> } + +// CHECK-LABEL: func @trivial_subtensor +// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> +// CHECK-NOT: subtensor +// CHECK: return %[[ARG0]] : tensor<4x6x16x32xi8> +func @trivial_subtensor(%arg0 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> { + %0 = subtensor %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<4x6x16x32xi8> + return %0 : tensor<4x6x16x32xi8> +} + +// CHECK-LABEL: func @trivial_subtensor_insert +// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> +// CHECK-NOT: subtensor +// CHECK: return %[[ARG0]] : tensor<4x6x16x32xi8> +func @trivial_subtensor_insert(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> { + %0 = subtensor_insert %arg0 into %arg1[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> into tensor<4x6x16x32xi8> + return %0 : tensor<4x6x16x32xi8> +} +