diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2512,6 +2512,31 @@ return success(); } +/// Tensor-based RAW can bypass storage to tensor. +/// +static Value foldRAW(TransferReadOp readOp, ArrayRef, + SmallVectorImpl &results) { + auto writeOp = readOp.source().getDefiningOp(); + if (!writeOp) + return failure(); + // TODO: write.source is confusing naming, it should be dest. + // Only fold on tensors to take advantage of SSA values. + auto rankedTensorType = + writeOp.source().getType().dyn_cast(); + // If not operating on the same tensor types, bail. + if (!rankedTensorType || readOp.source().getType() != rankedTensorType) + return failure(); + // If anything disagrees, bail. + if (readOp.source() != writeOp.result() || + readOp.indices() != writeOp.indices() || + readOp.permutation_map() != writeOp.permutation_map() || + readOp.mask() != writeOp.mask() || + readOp.in_bounds() != writeOp.in_bounds()) + return failure(); + results.push_back(writeOp.vector()); + return success(); +} + OpFoldResult TransferReadOp::fold(ArrayRef) { /// transfer_read(memrefcast) -> transfer_read if (succeeded(foldTransferInBoundsAttribute(*this))) @@ -2520,6 +2545,9 @@ return getResult(); if (succeeded(foldTensorCast(*this))) return getResult(); + Value result; + if (succeeded(foldRAW(*this, result))) + return result; return OpFoldResult(); } @@ -2664,7 +2692,9 @@ [&op](Twine t) { return op.emitOpError(t); }); } -/// Fold: +/// Tensor-based WAR can avoid redundant writes. +/// Fold `%t2 = vector.transfer_write` when it appears in a pattern such as: +/// /// ``` /// %t1 = ... /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} : @@ -2681,52 +2711,65 @@ /// /// The producer of t1 may or may not be DCE'd depending on whether it is a /// block argument or has side effects. -static LogicalResult foldReadInitWrite(TransferWriteOp write, - ArrayRef, - SmallVectorImpl &results) { - auto rankedTensorType = write.source().getType().dyn_cast(); - // If not operating on tensors, bail. - if (!rankedTensorType) - return failure(); +static LogicalResult foldWAR(TransferWriteOp writeOp, ArrayRef, + SmallVectorImpl &results) { + auto readOp = writeOp.vector().getDefiningOp(); // If no read, bail. - auto read = write.vector().getDefiningOp(); - if (!read) + if (!readOp) return failure(); - // For now, only accept minor identity. Future: composition is minor identity. - if (!read.permutation_map().isMinorIdentity() || - !write.permutation_map().isMinorIdentity()) + auto rankedTensorType = + writeOp.source().getType().dyn_cast(); + // If not operating on the same tensor types, bail. + if (!rankedTensorType || !readOp.source().getType()) return failure(); - // Bail on mismatching ranks. - if (read.getTransferRank() != write.getTransferRank()) + // If anything disagrees, bail. + // TODO: write.source is confusing naming, it should be dest. + if (writeOp.source() != readOp.source() || + // Same position in the same tensor. + writeOp.indices() != readOp.indices() || + // Same permutation map. + writeOp.permutation_map() != earlierWriteOp.permutation_map()) return failure(); // Bail on potential out-of-bounds accesses. if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim()) return failure(); - // Tensor types must be the same. - if (read.source().getType() != rankedTensorType) - return failure(); // Vector types must be the same. if (read.getVectorType() != write.getVectorType()) return failure(); - // Vector and Tensor shapes must match. - if (read.getVectorType().getShape() != rankedTensorType.getShape()) - return failure(); - // If any index is nonzero. - auto isNotConstantZero = [](Value v) { - auto cstOp = v.getDefiningOp(); - return !cstOp || cstOp.getValue() != 0; - }; - if (llvm::any_of(read.indices(), isNotConstantZero) || - llvm::any_of(write.indices(), isNotConstantZero)) - return failure(); // Success. results.push_back(read.source()); return success(); } +/// Tensor-based WAW can bypass the first storage to tensor. +static LogicalResult foldWAW(TransferWriteOp writeOp) { + auto earlierWriteOp = writeOp.source().getDefiningOp(); + if (!earlierWriteOp) + return failure(); + // TODO: write.source is confusing naming, it should be dest. + // Only fold on tensors to take advantage of SSA values. + auto rankedTensorType = + writeOp.source().getType().dyn_cast(); + // If not operating on the same tensor types, bail. + if (!rankedTensorType || + earlierWriteOp.source().getType() != rankedTensorType) + return failure(); + // If anything disagrees, bail. + if (writeOp.source() != earlierWriteOp.result() || + writeOp.indices() != earlierWriteOp.indices() || + writeOp.permutation_map() != earlierWriteOp.permutation_map() || + writeOp.mask() != earlierWriteOp.mask() || + writeOp.in_bounds() != earlierWriteOp.in_bounds()) + return failure(); + writeOp.sourceMutable().assign(earlierWriteOp.source()); + return success(); +} + LogicalResult TransferWriteOp::fold(ArrayRef operands, SmallVectorImpl &results) { - if (succeeded(foldReadInitWrite(*this, operands, results))) + if (succeeded(foldWAR(*this, operands, results))) + return success(); + if (succeeded(foldWAW(*this))) return success(); if (succeeded(foldTransferInBoundsAttribute(*this))) return success();