Please use GitHub pull requests for new patches. Phabricator shutdown timeline
Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Show First 20 Lines • Show All 207 Lines • ▼ Show 20 Lines | auto resultCast = | ||||
rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc); | rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc); | ||||
rewriter.replaceOp(alloc, {resultCast}); | rewriter.replaceOp(alloc, {resultCast}); | ||||
return success(); | return success(); | ||||
} | } | ||||
}; | }; | ||||
/// Fold alloc operations with no users or only store and dealloc uses. | /// Fold alloc operations with no users or only store and dealloc uses. | ||||
template <typename T> | template <typename T> struct SimplifyDeadAlloc : public OpRewritePattern<T> { | ||||
struct SimplifyDeadAlloc : public OpRewritePattern<T> { | |||||
using OpRewritePattern<T>::OpRewritePattern; | using OpRewritePattern<T>::OpRewritePattern; | ||||
LogicalResult matchAndRewrite(T alloc, | LogicalResult matchAndRewrite(T alloc, | ||||
PatternRewriter &rewriter) const override { | PatternRewriter &rewriter) const override { | ||||
if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { | if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { | ||||
if (auto storeOp = dyn_cast<StoreOp>(op)) | if (auto storeOp = dyn_cast<StoreOp>(op)) | ||||
return storeOp.getValue() == alloc; | return storeOp.getValue() == alloc; | ||||
return !isa<DeallocOp>(op); | return !isa<DeallocOp>(op); | ||||
▲ Show 20 Lines • Show All 341 Lines • ▼ Show 20 Lines | if (inputs.size() != 1 || outputs.size() != 1) | ||||
return false; | return false; | ||||
Type a = inputs.front(), b = outputs.front(); | Type a = inputs.front(), b = outputs.front(); | ||||
auto aT = a.dyn_cast<MemRefType>(); | auto aT = a.dyn_cast<MemRefType>(); | ||||
auto bT = b.dyn_cast<MemRefType>(); | auto bT = b.dyn_cast<MemRefType>(); | ||||
auto uaT = a.dyn_cast<UnrankedMemRefType>(); | auto uaT = a.dyn_cast<UnrankedMemRefType>(); | ||||
auto ubT = b.dyn_cast<UnrankedMemRefType>(); | auto ubT = b.dyn_cast<UnrankedMemRefType>(); | ||||
// Strips signed/unsigned bit from integer types. | |||||
auto stripSign = [](Type type) -> Type { | |||||
if (auto integer = type.dyn_cast<IntegerType>()) | |||||
return IntegerType::get(type.getContext(), integer.getWidth()); | |||||
return type; | |||||
}; | |||||
if (aT && bT) { | if (aT && bT) { | ||||
if (aT.getElementType() != bT.getElementType()) | if (stripSign(aT.getElementType()) != stripSign(bT.getElementType())) | ||||
return false; | return false; | ||||
if (aT.getLayout() != bT.getLayout()) { | if (aT.getLayout() != bT.getLayout()) { | ||||
int64_t aOffset, bOffset; | int64_t aOffset, bOffset; | ||||
SmallVector<int64_t, 4> aStrides, bStrides; | SmallVector<int64_t, 4> aStrides, bStrides; | ||||
if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || | if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || | ||||
failed(getStridesAndOffset(bT, bStrides, bOffset)) || | failed(getStridesAndOffset(bT, bStrides, bOffset)) || | ||||
aStrides.size() != bStrides.size()) | aStrides.size() != bStrides.size()) | ||||
return false; | return false; | ||||
Show All 31 Lines | if (aT && bT) { | ||||
if (!bT && !ubT) | if (!bT && !ubT) | ||||
return false; | return false; | ||||
// Unranked to unranked casting is unsupported | // Unranked to unranked casting is unsupported | ||||
if (uaT && ubT) | if (uaT && ubT) | ||||
return false; | return false; | ||||
auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); | auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); | ||||
auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); | auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); | ||||
if (aEltType != bEltType) | if (stripSign(aEltType) != stripSign(bEltType)) | ||||
return false; | return false; | ||||
auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); | auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); | ||||
auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); | auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); | ||||
return aMemSpace == bMemSpace; | return aMemSpace == bMemSpace; | ||||
} | } | ||||
return false; | return false; | ||||
▲ Show 20 Lines • Show All 2,325 Lines • Show Last 20 Lines |