Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/Vector/VectorOps.cpp
Show First 20 Lines • Show All 61 Lines • ▼ Show 20 Lines | if (auto c = mask.getDefiningOp<ConstantOp>()) { | ||||
} | } | ||||
} else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) { | } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) { | ||||
// Inspect constant mask index. If the index exceeds the | // Inspect constant mask index. If the index exceeds the | ||||
// dimension size, all bits are set. If the index is zero | // dimension size, all bits are set. If the index is zero | ||||
// or less, no bits are set. | // or less, no bits are set. | ||||
ArrayAttr masks = m.mask_dim_sizes(); | ArrayAttr masks = m.mask_dim_sizes(); | ||||
assert(masks.size() == 1); | assert(masks.size() == 1); | ||||
int64_t i = masks[0].cast<IntegerAttr>().getInt(); | int64_t i = masks[0].cast<IntegerAttr>().getInt(); | ||||
int64_t u = m.getType().cast<VectorType>().getDimSize(0); | int64_t u = m.getType().getDimSize(0); | ||||
if (i >= u) | if (i >= u) | ||||
return MaskFormat::AllTrue; | return MaskFormat::AllTrue; | ||||
if (i <= 0) | if (i <= 0) | ||||
return MaskFormat::AllFalse; | return MaskFormat::AllFalse; | ||||
} | } | ||||
return MaskFormat::Unknown; | return MaskFormat::Unknown; | ||||
} | } | ||||
▲ Show 20 Lines • Show All 765 Lines • ▼ Show 20 Lines | |||||
// Fold extractOp with source coming from ShapeCast op. | // Fold extractOp with source coming from ShapeCast op. | ||||
static Value foldExtractFromShapeCast(ExtractOp extractOp) { | static Value foldExtractFromShapeCast(ExtractOp extractOp) { | ||||
auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>(); | auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>(); | ||||
if (!shapeCastOp) | if (!shapeCastOp) | ||||
return Value(); | return Value(); | ||||
// Get the nth dimension size starting from lowest dimension. | // Get the nth dimension size starting from lowest dimension. | ||||
auto getDimReverse = [](VectorType type, int64_t n) { | auto getDimReverse = [](VectorType type, int64_t n) { | ||||
return type.getShape().take_back(n+1).front(); | return type.getShape().take_back(n + 1).front(); | ||||
jpienaar: Nit: clang-format only changed parts to make these easier (you don't have to change, just for… | |||||
}; | }; | ||||
int64_t destinationRank = | int64_t destinationRank = | ||||
extractOp.getType().isa<VectorType>() | extractOp.getType().isa<VectorType>() | ||||
? extractOp.getType().cast<VectorType>().getRank() | ? extractOp.getType().cast<VectorType>().getRank() | ||||
: 0; | : 0; | ||||
if (destinationRank > shapeCastOp.getSourceVectorType().getRank()) | if (destinationRank > shapeCastOp.getSourceVectorType().getRank()) | ||||
return Value(); | return Value(); | ||||
if (destinationRank > 0) { | if (destinationRank > 0) { | ||||
▲ Show 20 Lines • Show All 1,004 Lines • ▼ Show 20 Lines | LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, | ||||
// ConstantOp. | // ConstantOp. | ||||
auto constantOp = | auto constantOp = | ||||
extractStridedSliceOp.vector().getDefiningOp<ConstantOp>(); | extractStridedSliceOp.vector().getDefiningOp<ConstantOp>(); | ||||
if (!constantOp) | if (!constantOp) | ||||
return failure(); | return failure(); | ||||
auto dense = constantOp.value().dyn_cast<SplatElementsAttr>(); | auto dense = constantOp.value().dyn_cast<SplatElementsAttr>(); | ||||
if (!dense) | if (!dense) | ||||
return failure(); | return failure(); | ||||
auto newAttr = DenseElementsAttr::get( | auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(), | ||||
extractStridedSliceOp.getType().cast<VectorType>(), | |||||
dense.getSplatValue()); | dense.getSplatValue()); | ||||
rewriter.replaceOpWithNewOp<ConstantOp>(extractStridedSliceOp, newAttr); | rewriter.replaceOpWithNewOp<ConstantOp>(extractStridedSliceOp, newAttr); | ||||
return success(); | return success(); | ||||
} | } | ||||
}; | }; | ||||
// Helper that returns a subset of `arrayAttr` as a vector of int64_t. | // Helper that returns a subset of `arrayAttr` as a vector of int64_t. | ||||
static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, | static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr, | ||||
unsigned dropFront = 0, | unsigned dropFront = 0, | ||||
▲ Show 20 Lines • Show All 1,263 Lines • Show Last 20 Lines |
Nit: clang-format only changed parts to make these easier (you don't have to change, just for next one)