diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -581,6 +581,9 @@ extractOp.getVector().template getDefiningOp(); if (!xferOp) return failure(); + // Check that we are extracting a scalar and not a sub-vector. + if (isa(extractOp.getResult().getType())) + return failure(); // If multiple uses are not allowed, check if xfer has a single use. if (!allowMultipleUses && !xferOp.getResult().hasOneUse()) return failure(); @@ -622,6 +625,7 @@ void rewrite(vector::ExtractElementOp extractOp, PatternRewriter &rewriter) const override { // Construct scalar load. + auto loc = extractOp.getLoc(); auto xferOp = extractOp.getVector().getDefiningOp(); SmallVector newIndices(xferOp.getIndices().begin(), xferOp.getIndices().end()); @@ -629,13 +633,13 @@ AffineExpr sym0, sym1; bindSymbols(extractOp.getContext(), sym0, sym1); OpFoldResult ofr = affine::makeComposedFoldedAffineApply( - rewriter, extractOp.getLoc(), sym0 + sym1, + rewriter, loc, sym0 + sym1, {newIndices[newIndices.size() - 1], extractOp.getPosition()}); if (ofr.is()) { newIndices[newIndices.size() - 1] = ofr.get(); } else { newIndices[newIndices.size() - 1] = - rewriter.create(extractOp.getLoc(), + rewriter.create(loc, *getConstantIntValue(ofr)); } } diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir --- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir +++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir @@ -136,3 +136,20 @@ return %1, %2 : f32, f32 } +// ----- + +// Check that patterns don't trigger for an sub-vector (not scalar) extraction. +// CHECK-LABEL: func @subvector_extract( +// CHECK-SAME: %[[m:.*]]: memref, %[[idx:.*]]: index +// CHECK-NOT: memref.load +// CHECK: %[[r:.*]] = vector.transfer_read %[[m]][%[[idx]], %[[idx]]] +// CHECK: %[[e0:.*]] = vector.extract %[[r]][0] +// CHECK: return %[[e0]] + +func.func @subvector_extract(%m: memref, %idx: index) -> vector<16xf32> { + %cst = arith.constant 0.0 : f32 + %0 = vector.transfer_read %m[%idx, %idx], %cst {in_bounds = [true, true]} : memref, vector<8x16xf32> + %1 = vector.extract %0[0] : vector<8x16xf32> + return %1 : vector<16xf32> +} +