Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
Show First 20 Lines • Show All 1,478 Lines • ▼ Show 20 Lines | OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { | ||||
Operation *definingOp = memrefOrTensor().getDefiningOp(); | Operation *definingOp = memrefOrTensor().getDefiningOp(); | ||||
// dim(tensor_load(memref)) -> dim(memref) | // dim(tensor_load(memref)) -> dim(memref) | ||||
if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) { | if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) { | ||||
setOperand(0, tensorLoadOp.memref()); | setOperand(0, tensorLoadOp.memref()); | ||||
return getResult(); | return getResult(); | ||||
} | } | ||||
// Fold dim to the operand of dynamic_tensor_from_elements. | |||||
if (auto fromElements = | |||||
dyn_cast_or_null<DynamicTensorFromElementsOp>(definingOp)) { | |||||
auto resultType = | |||||
fromElements.getResult().getType().cast<RankedTensorType>(); | |||||
// The case where the type encodes the size of the dimension is handled | |||||
// above. | |||||
assert(resultType.getShape()[index.getInt()] == | |||||
RankedTensorType::kDynamicSize); | |||||
// Find the operand of the fromElements that corresponds to this index. | |||||
auto dynExtents = fromElements.dynamicExtents().begin(); | |||||
for (auto dim : resultType.getShape().take_front(index.getInt())) | |||||
if (dim == RankedTensorType::kDynamicSize) | |||||
dynExtents++; | |||||
return Value{*dynExtents}; | |||||
} | |||||
// Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. | // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. | ||||
auto memrefType = argTy.dyn_cast<MemRefType>(); | auto memrefType = argTy.dyn_cast<MemRefType>(); | ||||
if (!memrefType) | if (!memrefType) | ||||
return {}; | return {}; | ||||
// The size at the given index is now known to be a dynamic size of a memref. | // The size at the given index is now known to be a dynamic size of a memref. | ||||
unsigned unsignedIndex = index.getValue().getZExtValue(); | unsigned unsignedIndex = index.getValue().getZExtValue(); | ||||
if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) | if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp)) | ||||
▲ Show 20 Lines • Show All 2,938 Lines • Show Last 20 Lines |