diff --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp @@ -124,6 +124,20 @@ } } +/// Generates a boolean Value that is true if the iv-th bit in xferOp's mask +/// is set to true. Does not return a Value if the transfer op is not 1D or +/// if the transfer op does not have a mask. +template +static Value maybeGenerateMaskCheck(OpBuilder &builder, OpTy xferOp, Value iv) { + if (xferOp.getVectorType().getRank() != 1) + return Value(); + if (!xferOp.mask()) + return Value(); + + auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv); + return vector_extract_element(xferOp.mask(), ivI32).value; +} + /// Helper function TransferOpConversion and TransferOp1dConversion. /// Generate an in-bounds check if the transfer op may go out-of-bounds on the /// specified dimension `dim` with the loop iteration variable `iv`. @@ -141,6 +155,10 @@ /// (out-of-bounds case) /// } /// ``` +/// +/// If the transfer is 1D and has a mask, this function generates a more complex +/// check also accounts for potentially masked out elements. +/// /// This function variant returns the value returned by `inBoundsCase` or /// `outOfBoundsCase`. The MLIR type of the return value must be specified in /// `resultTypes`. @@ -151,13 +169,29 @@ function_ref inBoundsCase, function_ref outOfBoundsCase = nullptr) { bool hasRetVal = !resultTypes.empty(); + Value cond; // Condition to be built... + + // Condition check 1: Access in-bounds? bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts. if (!xferOp.isDimInBounds(0) && !isBroadcast) { auto memrefDim = memref_dim(xferOp.source(), std_constant_index(dim.getValue())); using edsc::op::operator+; auto memrefIdx = xferOp.indices()[dim.getValue()] + iv; - auto cond = std_cmpi_sgt(memrefDim.value, memrefIdx); + cond = std_cmpi_sgt(memrefDim.value, memrefIdx); + } + + // Condition check 2: Masked in? + if (auto maskCond = maybeGenerateMaskCheck(builder, xferOp, iv)) { + if (cond) { + cond = builder.create(xferOp.getLoc(), cond, maskCond); + } else { + cond = maskCond; + } + } + + // If the condition is non-empty, generate an SCF::IfOp. + if (cond) { auto check = builder.create( xferOp.getLoc(), resultTypes, cond, /*thenBuilder=*/[&](OpBuilder &builder, Location loc) { @@ -173,7 +207,7 @@ return hasRetVal ? check.getResult(0) : Value(); } - // No runtime check needed if dim is guaranteed to be in-bounds. + // Condition is empty, no need for an SCF::IfOp. return inBoundsCase(builder, xferOp.getLoc()); } @@ -664,8 +698,6 @@ return failure(); if (map.isMinorIdentity()) // Handled by ConvertVectorToLLVM return failure(); - if (xferOp.mask()) - return failure(); // Loop bounds, step, state... auto vecType = xferOp.getVectorType(); diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir @@ -1,8 +1,3 @@ -// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ -// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s - // RUN: mlir-opt %s -test-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ @@ -10,34 +5,57 @@ // Test for special cases of 1D vector transfer ops. -func @transfer_read_2d(%A : memref, %base1 : index, %base2 : index) { +func @transfer_read_1d(%A : memref, %base1 : index, %base2 : index) { %fm42 = constant -42.0: f32 %f = vector.transfer_read %A[%base1, %base2], %fm42 - {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} - : memref, vector<5x6xf32> - vector.print %f: vector<5x6xf32> + {permutation_map = affine_map<(d0, d1) -> (d0)>} + : memref, vector<9xf32> + vector.print %f: vector<9xf32> return } -func @transfer_read_1d(%A : memref, %base1 : index, %base2 : index) { +func @transfer_read_1d_broadcast( + %A : memref, %base1 : index, %base2 : index) { %fm42 = constant -42.0: f32 %f = vector.transfer_read %A[%base1, %base2], %fm42 - {permutation_map = affine_map<(d0, d1) -> (d0)>} + {permutation_map = affine_map<(d0, d1) -> (0)>} : memref, vector<9xf32> vector.print %f: vector<9xf32> return } -func @transfer_read_1d_broadcast( +func @transfer_read_1d_in_bounds( %A : memref, %base1 : index, %base2 : index) { %fm42 = constant -42.0: f32 %f = vector.transfer_read %A[%base1, %base2], %fm42 - {permutation_map = affine_map<(d0, d1) -> (0)>} + {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} + : memref, vector<3xf32> + vector.print %f: vector<3xf32> + return +} + +func @transfer_read_1d_mask( + %A : memref, %base1 : index, %base2 : index) { + %fm42 = constant -42.0: f32 + %mask = constant dense<[1, 0, 1, 0, 1, 1, 1, 0, 1]> : vector<9xi1> + %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask + {permutation_map = affine_map<(d0, d1) -> (d0)>} : memref, vector<9xf32> vector.print %f: vector<9xf32> return } +func @transfer_read_1d_mask_in_bounds( + %A : memref, %base1 : index, %base2 : index) { + %fm42 = constant -42.0: f32 + %mask = constant dense<[1, 0, 1]> : vector<3xi1> + %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask + {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} + : memref, vector<3xf32> + vector.print %f: vector<3xf32> + return +} + func @transfer_write_1d(%A : memref, %base1 : index, %base2 : index) { %fn1 = constant -1.0 : f32 %vf0 = splat %fn1 : vector<7xf32> @@ -69,14 +87,35 @@ } } + // Read from 2D memref on first dimension. Cannot be lowered to an LLVM + // vector load. Instead, generates scalar loads. call @transfer_read_1d(%A, %c1, %c2) : (memref, index, index) -> () + // Write to 2D memref on first dimension. Cannot be lowered to an LLVM + // vector store. Instead, generates scalar stores. call @transfer_write_1d(%A, %c3, %c2) : (memref, index, index) -> () + // (Same as above.) call @transfer_read_1d(%A, %c0, %c2) : (memref, index, index) -> () + // Read a scalar from a 2D memref and broadcast the value to a 1D vector. + // Generates a loop with vector.insertelement. call @transfer_read_1d_broadcast(%A, %c1, %c2) : (memref, index, index) -> () + // Read from 2D memref on first dimension. Accesses are in-bounds, so no + // if-check is generated inside the generated loop. + call @transfer_read_1d_in_bounds(%A, %c1, %c2) + : (memref, index, index) -> () + // Optional mask attribute is specified and, in addition, there may be + // out-of-bounds accesses. + call @transfer_read_1d_mask(%A, %c1, %c2) + : (memref, index, index) -> () + // Same as above, but accesses are in-bounds. + call @transfer_read_1d_mask_in_bounds(%A, %c1, %c2) + : (memref, index, index) -> () return } // CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 ) // CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 ) // CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ) +// CHECK: ( 12, 22, -1 ) +// CHECK: ( 12, -42, -1, -42, -42, -42, -42, -42, -42 ) +// CHECK: ( 12, -42, -1 )