Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Show First 20 Lines • Show All 606 Lines • ▼ Show 20 Lines | public: | ||||
explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv, | explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv, | ||||
bool enableIndexOpt) | bool enableIndexOpt) | ||||
: ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv), | : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv), | ||||
enableIndexOptimizations(enableIndexOpt) {} | enableIndexOptimizations(enableIndexOpt) {} | ||||
LogicalResult | LogicalResult | ||||
matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands, | matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands, | ||||
ConversionPatternRewriter &rewriter) const override { | ConversionPatternRewriter &rewriter) const override { | ||||
auto dstType = op->getResult(0).getType().cast<VectorType>(); | auto dstType = op.getType(); | ||||
int64_t rank = dstType.getRank(); | int64_t rank = dstType.getRank(); | ||||
if (rank == 1) { | if (rank == 1) { | ||||
rewriter.replaceOp( | rewriter.replaceOp( | ||||
op, buildVectorComparison(rewriter, op, enableIndexOptimizations, | op, buildVectorComparison(rewriter, op, enableIndexOptimizations, | ||||
dstType.getDimSize(0), operands[0])); | dstType.getDimSize(0), operands[0])); | ||||
return success(); | return success(); | ||||
} | } | ||||
return failure(); | return failure(); | ||||
▲ Show 20 Lines • Show All 462 Lines • ▼ Show 20 Lines | public: | ||||
using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; | using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; | ||||
LogicalResult | LogicalResult | ||||
matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands, | matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands, | ||||
ConversionPatternRewriter &rewriter) const override { | ConversionPatternRewriter &rewriter) const override { | ||||
auto loc = castOp->getLoc(); | auto loc = castOp->getLoc(); | ||||
MemRefType sourceMemRefType = | MemRefType sourceMemRefType = | ||||
castOp.getOperand().getType().cast<MemRefType>(); | castOp.getOperand().getType().cast<MemRefType>(); | ||||
MemRefType targetMemRefType = | MemRefType targetMemRefType = castOp.getType(); | ||||
castOp.getResult().getType().cast<MemRefType>(); | |||||
// Only static shape casts supported atm. | // Only static shape casts supported atm. | ||||
if (!sourceMemRefType.hasStaticShape() || | if (!sourceMemRefType.hasStaticShape() || | ||||
!targetMemRefType.hasStaticShape()) | !targetMemRefType.hasStaticShape()) | ||||
return failure(); | return failure(); | ||||
auto llvmSourceDescriptorTy = | auto llvmSourceDescriptorTy = | ||||
operands[0].getType().dyn_cast<LLVM::LLVMStructType>(); | operands[0].getType().dyn_cast<LLVM::LLVMStructType>(); | ||||
▲ Show 20 Lines • Show All 350 Lines • ▼ Show 20 Lines | VectorExtractStridedSliceOpConversion(MLIRContext *ctx) | ||||
: OpRewritePattern<ExtractStridedSliceOp>(ctx) { | : OpRewritePattern<ExtractStridedSliceOp>(ctx) { | ||||
// This pattern creates recursive ExtractStridedSliceOp, but the recursion | // This pattern creates recursive ExtractStridedSliceOp, but the recursion | ||||
// is bounded as the rank is strictly decreasing. | // is bounded as the rank is strictly decreasing. | ||||
setHasBoundedRewriteRecursion(); | setHasBoundedRewriteRecursion(); | ||||
} | } | ||||
LogicalResult matchAndRewrite(ExtractStridedSliceOp op, | LogicalResult matchAndRewrite(ExtractStridedSliceOp op, | ||||
PatternRewriter &rewriter) const override { | PatternRewriter &rewriter) const override { | ||||
auto dstType = op.getResult().getType().cast<VectorType>(); | auto dstType = op.getType(); | ||||
assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); | assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); | ||||
int64_t offset = | int64_t offset = | ||||
op.offsets().getValue().front().cast<IntegerAttr>().getInt(); | op.offsets().getValue().front().cast<IntegerAttr>().getInt(); | ||||
int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); | int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); | ||||
int64_t stride = | int64_t stride = | ||||
op.strides().getValue().front().cast<IntegerAttr>().getInt(); | op.strides().getValue().front().cast<IntegerAttr>().getInt(); | ||||
▲ Show 20 Lines • Show All 77 Lines • Show Last 20 Lines |