diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -567,7 +567,7 @@ /// %2 = vector.extractelement %0[%b : index] : vector<1024xf32> /// Rewriting such IR (replacing one vector load with multiple scalar loads) may /// negatively affect performance. -class FoldScalarExtractOfTransferRead +class RewriteScalarExtractElementOfTransferRead : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -618,17 +618,79 @@ } }; -/// Rewrite scalar transfer_write(broadcast) to memref.store. -class FoldScalarTransferWriteOfBroadcast - : public OpRewritePattern { +/// Rewrite extract(transfer_read) to memref.load. +/// +/// Rewrite only if the extractelement op is the single user of the transfer op. +/// E.g., do not rewrite IR such as: +/// %0 = vector.transfer_read ... : vector<1024xf32> +/// %1 = vector.extract %0[0] : vector<1024xf32> +/// %2 = vector.extract %0[5] : vector<1024xf32> +/// Rewriting such IR (replacing one vector load with multiple scalar loads) may +/// negatively affect performance. +class RewriteScalarExtractOfTransferRead + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ExtractOp extractOp, + PatternRewriter &rewriter) const override { + // Only match scalar extracts. + if (extractOp.getType().isa()) + return failure(); + auto xferOp = extractOp.getVector().getDefiningOp(); + if (!xferOp) + return failure(); + // xfer result must have a single use. Otherwise, it may be better to + // perform a vector load. + if (!extractOp.getVector().hasOneUse()) + return failure(); + // Mask not supported. + if (xferOp.getMask()) + return failure(); + // Map not supported. + if (!xferOp.getPermutationMap().isMinorIdentity()) + return failure(); + // Cannot rewrite if the indices may be out of bounds. The starting point is + // always inbounds, so we don't care in case of 0d transfers. + if (xferOp.hasOutOfBoundsDim() && xferOp.getType().getRank() > 0) + return failure(); + // Construct scalar load. + SmallVector newIndices(xferOp.getIndices().begin(), + xferOp.getIndices().end()); + for (const auto &it : llvm::enumerate(extractOp.getPosition())) { + int64_t offset = it.value().cast().getInt(); + int64_t idx = + newIndices.size() - extractOp.getPosition().size() + it.index(); + OpFoldResult ofr = makeComposedFoldedAffineApply( + rewriter, extractOp.getLoc(), + rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]}); + if (ofr.is()) { + newIndices[idx] = ofr.get(); + } else { + newIndices[idx] = rewriter.create( + extractOp.getLoc(), *getConstantIntValue(ofr)); + } + } + if (xferOp.getSource().getType().isa()) { + rewriter.replaceOpWithNewOp(extractOp, xferOp.getSource(), + newIndices); + } else { + rewriter.replaceOpWithNewOp( + extractOp, xferOp.getSource(), newIndices); + } + return success(); + } +}; + +/// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>) +/// to memref.store. +class RewriteScalarWrite : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override { // Must be a scalar write. auto vecType = xferOp.getVectorType(); - if (vecType.getRank() != 0 && - (vecType.getRank() != 1 || vecType.getShape()[0] != 1)) + if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; })) return failure(); // Mask not supported. if (xferOp.getMask()) @@ -636,19 +698,25 @@ // Map not supported. if (!xferOp.getPermutationMap().isMinorIdentity()) return failure(); - // Must be a broadcast of a scalar. - auto broadcastOp = xferOp.getVector().getDefiningOp(); - if (!broadcastOp || broadcastOp.getSource().getType().isa()) - return failure(); + // Only float and integer element types are supported. + Value scalar; + if (vecType.getRank() == 0) { + // vector.extract does not support vector etc., so use + // vector.extractelement instead. + scalar = rewriter.create(xferOp.getLoc(), + xferOp.getVector()); + } else { + SmallVector pos(vecType.getRank(), 0); + scalar = rewriter.create(xferOp.getLoc(), + xferOp.getVector(), pos); + } // Construct a scalar store. if (xferOp.getSource().getType().isa()) { rewriter.replaceOpWithNewOp( - xferOp, broadcastOp.getSource(), xferOp.getSource(), - xferOp.getIndices()); + xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); } else { rewriter.replaceOpWithNewOp( - xferOp, broadcastOp.getSource(), xferOp.getSource(), - xferOp.getIndices()); + xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); } return success(); } @@ -673,9 +741,9 @@ void mlir::vector::populateScalarVectorTransferLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { - patterns - .add( - patterns.getContext(), benefit); + patterns.add( + patterns.getContext(), benefit); } void mlir::vector::populateVectorTransferDropUnitDimsPatterns( diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir --- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir +++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir @@ -44,7 +44,9 @@ // CHECK-LABEL: func @transfer_write_0d( // CHECK-SAME: %[[m:.*]]: memref, %[[idx:.*]]: index, %[[f:.*]]: f32 -// CHECK: memref.store %[[f]], %[[m]][%[[idx]], %[[idx]], %[[idx]]] +// CHECK: %[[bc:.*]] = vector.broadcast %[[f]] : f32 to vector +// CHECK: %[[extract:.*]] = vector.extractelement %[[bc]][] : vector +// CHECK: memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]] func.func @transfer_write_0d(%m: memref, %idx: index, %f: f32) { %0 = vector.broadcast %f : f32 to vector vector.transfer_write %0, %m[%idx, %idx, %idx] : vector, memref @@ -66,10 +68,43 @@ // CHECK-LABEL: func @tensor_transfer_write_0d( // CHECK-SAME: %[[t:.*]]: tensor, %[[idx:.*]]: index, %[[f:.*]]: f32 -// CHECK: %[[r:.*]] = tensor.insert %[[f]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]] +// CHECK: %[[bc:.*]] = vector.broadcast %[[f]] : f32 to vector +// CHECK: %[[extract:.*]] = vector.extractelement %[[bc]][] : vector +// CHECK: %[[r:.*]] = tensor.insert %[[extract]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]] // CHECK: return %[[r]] func.func @tensor_transfer_write_0d(%t: tensor, %idx: index, %f: f32) -> tensor { %0 = vector.broadcast %f : f32 to vector %1 = vector.transfer_write %0, %t[%idx, %idx, %idx] : vector, tensor return %1 : tensor } + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 8)> +// CHECK: #[[$map1:.*]] = affine_map<()[s0] -> (s0 + 1)> +// CHECK-LABEL: func @transfer_read_2d_extract( +// CHECK-SAME: %[[m:.*]]: memref, %[[idx:.*]]: index, %[[idx2:.*]]: index +// CHECK: %[[added:.*]] = affine.apply #[[$map]]()[%[[idx]]] +// CHECK: %[[added1:.*]] = affine.apply #[[$map1]]()[%[[idx]]] +// CHECK: %[[r:.*]] = memref.load %[[m]][%[[idx]], %[[idx]], %[[added]], %[[added1]]] +// CHECK: return %[[r]] +func.func @transfer_read_2d_extract(%m: memref, %idx: index, %idx2: index) -> f32 { + %cst = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %0 = vector.transfer_read %m[%idx, %idx, %idx, %idx], %cst {in_bounds = [true, true]} : memref, vector<10x5xf32> + %1 = vector.extract %0[8, 1] : vector<10x5xf32> + return %1 : f32 +} + +// ----- + +// CHECK-LABEL: func @transfer_write_arith_constant( +// CHECK-SAME: %[[m:.*]]: memref, %[[idx:.*]]: index +// CHECK: %[[cst:.*]] = arith.constant dense<5.000000e+00> : vector<1x1xf32> +// CHECK: %[[extract:.*]] = vector.extract %[[cst]][0, 0] : vector<1x1xf32> +// CHECK: memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]] +func.func @transfer_write_arith_constant(%m: memref, %idx: index) { + %cst = arith.constant dense<5.000000e+00> : vector<1x1xf32> + vector.transfer_write %cst, %m[%idx, %idx, %idx] : vector<1x1xf32>, memref + return +}