Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
Show First 20 Lines • Show All 3,552 Lines • ▼ Show 20 Lines | |||||
OpFoldResult TensorToMemrefOp::fold(ArrayRef<Attribute>) { | OpFoldResult TensorToMemrefOp::fold(ArrayRef<Attribute>) { | ||||
if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>()) | if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>()) | ||||
if (tensorLoad.memref().getType() == getType()) | if (tensorLoad.memref().getType() == getType()) | ||||
return tensorLoad.memref(); | return tensorLoad.memref(); | ||||
return {}; | return {}; | ||||
} | } | ||||
namespace { | |||||
/// Replace tensor_cast + tensor_to_memref by tensor_to_memref + memref_cast. | |||||
struct TensorCastToMemref : public OpRewritePattern<TensorToMemrefOp> { | |||||
using OpRewritePattern<TensorToMemrefOp>::OpRewritePattern; | |||||
LogicalResult matchAndRewrite(TensorToMemrefOp tensorToMemRef, | |||||
PatternRewriter &rewriter) const final { | |||||
auto tensorCastOperand = | |||||
tensorToMemRef.getOperand().getDefiningOp<tensor::CastOp>(); | |||||
if (!tensorCastOperand) | |||||
return failure(); | |||||
auto srcTensorType = | |||||
tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>(); | |||||
if (!srcTensorType) | |||||
return failure(); | |||||
auto memrefType = MemRefType::get(srcTensorType.getShape(), | |||||
srcTensorType.getElementType()); | |||||
Value memref = rewriter.create<TensorToMemrefOp>( | |||||
tensorToMemRef.getLoc(), memrefType, tensorCastOperand.getOperand()); | |||||
rewriter.replaceOpWithNewOp<MemRefCastOp>( | |||||
Lint: Pre-merge checks: clang-format: please reformat the code
```
- rewriter.replaceOpWithNewOp<MemRefCastOp>… | |||||
tensorToMemRef, tensorToMemRef.getType(), memref); | |||||
return success(); | |||||
} | |||||
}; | |||||
} // namespace | |||||
void TensorToMemrefOp::getCanonicalizationPatterns( | |||||
OwningRewritePatternList &results, MLIRContext *context) { | |||||
results.insert<TensorCastToMemref>(context); | |||||
} | |||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// TransposeOp | // TransposeOp | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
/// Build a strided memref type by applying `permutationMap` tp `memRefType`. | /// Build a strided memref type by applying `permutationMap` tp `memRefType`. | ||||
static MemRefType inferTransposeResultType(MemRefType memRefType, | static MemRefType inferTransposeResultType(MemRefType memRefType, | ||||
AffineMap permutationMap) { | AffineMap permutationMap) { | ||||
auto rank = memRefType.getRank(); | auto rank = memRefType.getRank(); | ||||
▲ Show 20 Lines • Show All 367 Lines • Show Last 20 Lines |
clang-format: please reformat the code