diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -358,7 +358,10 @@ let options = [ Option<"reassociateFPReductions", "reassociate-fp-reductions", "bool", /*default=*/"false", - "Allows llvm to reassociate floating-point reductions for speed"> + "Allows llvm to reassociate floating-point reductions for speed">, + Option<"enableIndexOptimizations", "enable-index-optimizations", + "bool", /*default=*/"false", + "Allows compiler to assume indices fit in 32-bit if that yields faster code"> ]; } diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -22,8 +22,13 @@ /// ConvertVectorToLLVM pass in include/mlir/Conversion/Passes.td struct LowerVectorToLLVMOptions { bool reassociateFPReductions = false; - LowerVectorToLLVMOptions &setReassociateFPReductions(bool r) { - reassociateFPReductions = r; + bool enableIndexOptimizations = false; + LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) { + reassociateFPReductions = b; + return *this; + } + LowerVectorToLLVMOptions &setEnableIndexOptimizations(bool b) { + enableIndexOptimizations = b; return *this; } }; @@ -37,7 +42,8 @@ /// Collect a set of patterns to convert from the Vector dialect to LLVM. void populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool reassociateFPReductions = false); + bool reassociateFPReductions = false, + bool enableIndexOptimizations = false); /// Create a pass to convert vector operations to the LLVMIR dialect. std::unique_ptr> createConvertVectorToLLVMPass( diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -117,6 +117,49 @@ return res; } +// Helper that returns a vector comparison that constructs a mask: +// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] +// +// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, +// much more compact, IR for this operation, but LLVM eventually +// generates more elaborate instructions for this intrinsic since it +// is very conservative on the boundary conditions. +static Value buildVectorComparison(ConversionPatternRewriter &rewriter, + Operation *op, bool enableIndexOptimizations, + int64_t dim, Value b, Value *off = nullptr) { + auto loc = op->getLoc(); + // If we can assume all indices fit in 32-bit, we perform the vector + // comparison in 32-bit to get a higher degree of SIMD parallelism. + // Otherwise we perform the vector comparison using 64-bit indices. + Value indices; + Type idxType; + if (enableIndexOptimizations) { + SmallVector values(dim); + for (int64_t d = 0; d < dim; d++) + values[d] = d; + indices = + rewriter.create(loc, rewriter.getI32VectorAttr(values)); + idxType = rewriter.getI32Type(); + } else { + SmallVector values(dim); + for (int64_t d = 0; d < dim; d++) + values[d] = d; + indices = + rewriter.create(loc, rewriter.getI64VectorAttr(values)); + idxType = rewriter.getI64Type(); + } + // Add in an offset if requested. + if (off) { + Value o = rewriter.create(loc, idxType, *off); + Value ov = rewriter.create(loc, indices.getType(), o); + indices = rewriter.create(loc, ov, indices); + } + // Construct the vector comparison. + Value bound = rewriter.create(loc, idxType, b); + Value bounds = rewriter.create(loc, indices.getType(), bound); + return rewriter.create(loc, CmpIPredicate::slt, indices, bounds); +} + // Helper that returns data layout alignment of an operation with memref. template LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op, @@ -512,10 +555,10 @@ public: explicit VectorReductionOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter, - bool reassociateFP) + bool reassociateFPRed) : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context, typeConverter), - reassociateFPReductions(reassociateFP) {} + reassociateFPReductions(reassociateFPRed) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -589,6 +632,34 @@ const bool reassociateFPReductions; }; +/// Conversion pattern for a vector.create_mask (1-D only). +class VectorCreateMaskOpConversion : public ConvertToLLVMPattern { +public: + explicit VectorCreateMaskOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter, + bool enableIndexOpt) + : ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context, + typeConverter), + enableIndexOptimizations(enableIndexOpt) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto dstType = op->getResult(0).getType().cast(); + int64_t rank = dstType.getRank(); + if (rank == 1) { + rewriter.replaceOp( + op, buildVectorComparison(rewriter, op, enableIndexOptimizations, + dstType.getDimSize(0), operands[0])); + return success(); + } + return failure(); + } + +private: + const bool enableIndexOptimizations; +}; + class VectorShuffleOpConversion : public ConvertToLLVMPattern { public: explicit VectorShuffleOpConversion(MLIRContext *context, @@ -1121,17 +1192,19 @@ /// Conversion pattern that converts a 1-D vector transfer read/write op in a /// sequence of: -/// 1. Bitcast or addrspacecast to vector form. -/// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. -/// 3. Create a mask where offsetVector is compared against memref upper bound. -/// 4. Rewrite op as a masked read or write. +/// 1. Get the source/dst address as an LLVM vector pointer. +/// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. +/// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. +/// 4. Create a mask where offsetVector is compared against memref upper bound. +/// 5. Rewrite op as a masked read or write. template class VectorTransferConversion : public ConvertToLLVMPattern { public: explicit VectorTransferConversion(MLIRContext *context, - LLVMTypeConverter &typeConv) - : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, - typeConv) {} + LLVMTypeConverter &typeConv, + bool enableIndexOpt) + : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv), + enableIndexOptimizations(enableIndexOpt) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -1155,7 +1228,6 @@ auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; Location loc = op->getLoc(); - Type i64Type = rewriter.getIntegerType(64); MemRefType memRefType = xferOp.getMemRefType(); if (auto memrefVectorElementType = @@ -1202,41 +1274,26 @@ xferOp, operands, vectorDataPtr); // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. - unsigned vecWidth = vecTy.getVectorNumElements(); - VectorType vectorCmpType = VectorType::get(vecWidth, i64Type); - SmallVector indices; - indices.reserve(vecWidth); - for (unsigned i = 0; i < vecWidth; ++i) - indices.push_back(i); - Value linearIndices = rewriter.create( - loc, vectorCmpType, - DenseElementsAttr::get(vectorCmpType, ArrayRef(indices))); - linearIndices = rewriter.create( - loc, toLLVMTy(vectorCmpType), linearIndices); - // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. - // TODO: when the leaf transfer rank is k > 1 we need the last - // `k` dimensions here. - unsigned lastIndex = llvm::size(xferOp.indices()) - 1; - Value offsetIndex = *(xferOp.indices().begin() + lastIndex); - offsetIndex = rewriter.create(loc, i64Type, offsetIndex); - Value base = rewriter.create(loc, vectorCmpType, offsetIndex); - Value offsetVector = rewriter.create(loc, base, linearIndices); - // 4. Let dim the memref dimension, compute the vector comparison mask: // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] + // + // TODO: when the leaf transfer rank is k > 1, we need the last `k` + // dimensions here. + unsigned vecWidth = vecTy.getVectorNumElements(); + unsigned lastIndex = llvm::size(xferOp.indices()) - 1; + Value off = *(xferOp.indices().begin() + lastIndex); Value dim = rewriter.create(loc, xferOp.memref(), lastIndex); - dim = rewriter.create(loc, i64Type, dim); - dim = rewriter.create(loc, vectorCmpType, dim); - Value mask = - rewriter.create(loc, CmpIPredicate::slt, offsetVector, dim); - mask = rewriter.create(loc, toLLVMTy(mask.getType()), - mask); + Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations, + vecWidth, dim, &off); // 5. Rewrite as a masked read / write. return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp, operands, vectorDataPtr, mask); } + +private: + const bool enableIndexOptimizations; }; class VectorPrintOpConversion : public ConvertToLLVMPattern { @@ -1444,7 +1501,7 @@ /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool reassociateFPReductions) { + bool reassociateFPReductions, bool enableIndexOptimizations) { MLIRContext *ctx = converter.getDialect()->getContext(); // clang-format off patterns.insert(ctx); patterns.insert( ctx, converter, reassociateFPReductions); + patterns.insert, + VectorTransferConversion>( + ctx, converter, enableIndexOptimizations); patterns .insert, - VectorTransferConversion, VectorTypeCastOpConversion, VectorMaskedLoadOpConversion, VectorMaskedStoreOpConversion, @@ -1485,6 +1544,7 @@ : public ConvertVectorToLLVMBase { LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { this->reassociateFPReductions = options.reassociateFPReductions; + this->enableIndexOptimizations = options.enableIndexOptimizations; } void runOnOperation() override; }; @@ -1505,15 +1565,14 @@ LLVMTypeConverter converter(&getContext()); OwningRewritePatternList patterns; populateVectorToLLVMMatrixConversionPatterns(converter, patterns); - populateVectorToLLVMConversionPatterns(converter, patterns, - reassociateFPReductions); + populateVectorToLLVMConversionPatterns( + converter, patterns, reassociateFPReductions, enableIndexOptimizations); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); - if (failed(applyPartialConversion(getOperation(), target, patterns))) { + if (failed(applyPartialConversion(getOperation(), target, patterns))) signalPassFailure(); - } } std::unique_ptr> diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1347,7 +1347,8 @@ auto eltType = dstType.getElementType(); auto dimSizes = op.mask_dim_sizes(); int64_t rank = dimSizes.size(); - int64_t trueDim = dimSizes[0].cast().getInt(); + int64_t trueDim = std::min(dstType.getDimSize(0), + dimSizes[0].cast().getInt()); if (rank == 1) { // Express constant 1-D case in explicit vector form: @@ -1402,21 +1403,8 @@ int64_t rank = dstType.getRank(); Value idx = op.getOperand(0); - if (rank == 1) { - // Express dynamic 1-D case in explicit vector form: - // mask = [0,1,..,n-1] < [a,a,..,a] - SmallVector values(dim); - for (int64_t d = 0; d < dim; d++) - values[d] = d; - Value indices = - rewriter.create(loc, rewriter.getI64VectorAttr(values)); - Value bound = - rewriter.create(loc, rewriter.getI64Type(), idx); - Value bounds = rewriter.create(loc, indices.getType(), bound); - rewriter.replaceOpWithNewOp(op, CmpIPredicate::slt, indices, - bounds); - return success(); - } + if (rank == 1) + return failure(); // leave for lowering VectorType lowType = VectorType::get(dstType.getShape().drop_front(), eltType); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt %s --convert-vector-to-llvm='enable-index-optimizations=1' | FileCheck %s --check-prefix=CMP32 +// RUN: mlir-opt %s --convert-vector-to-llvm='enable-index-optimizations=0' | FileCheck %s --check-prefix=CMP64 + +// CMP32-LABEL: llvm.func @genbool_var_1d( +// CMP32-SAME: %[[A:.*]]: !llvm.i64) +// CMP32: %[[T0:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi32>) : !llvm.vec<11 x i32> +// CMP32: %[[T1:.*]] = llvm.trunc %[[A]] : !llvm.i64 to !llvm.i32 +// CMP32: %[[T2:.*]] = llvm.mlir.undef : !llvm.vec<11 x i32> +// CMP32: %[[T3:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 +// CMP32: %[[T4:.*]] = llvm.insertelement %[[T1]], %[[T2]][%[[T3]] : !llvm.i32] : !llvm.vec<11 x i32> +// CMP32: %[[T5:.*]] = llvm.shufflevector %[[T4]], %[[T2]] [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm.vec<11 x i32>, !llvm.vec<11 x i32> +// CMP32: %[[T6:.*]] = llvm.icmp "slt" %[[T0]], %[[T5]] : !llvm.vec<11 x i32> +// CMP32: llvm.return %[[T6]] : !llvm.vec<11 x i1> + +// CMP64-LABEL: llvm.func @genbool_var_1d( +// CMP64-SAME: %[[A:.*]]: !llvm.i64) +// CMP64: %[[T0:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi64>) : !llvm.vec<11 x i64> +// CMP64: %[[T1:.*]] = llvm.mlir.undef : !llvm.vec<11 x i64> +// CMP64: %[[T2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 +// CMP64: %[[T3:.*]] = llvm.insertelement %[[A]], %[[T1]][%[[T2]] : !llvm.i32] : !llvm.vec<11 x i64> +// CMP64: %[[T4:.*]] = llvm.shufflevector %[[T3]], %[[T1]] [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm.vec<11 x i64>, !llvm.vec<11 x i64> +// CMP64: %[[T5:.*]] = llvm.icmp "slt" %[[T0]], %[[T4]] : !llvm.vec<11 x i64> +// CMP64: llvm.return %[[T5]] : !llvm.vec<11 x i1> + +func @genbool_var_1d(%arg0: index) -> vector<11xi1> { + %0 = vector.create_mask %arg0 : vector<11xi1> + return %0 : vector<11xi1> +} + +// CMP32-LABEL: llvm.func @transfer_read_1d +// CMP32: %[[C:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>) : !llvm.vec<16 x i32> +// CMP32: %[[A:.*]] = llvm.add %{{.*}}, %[[C]] : !llvm.vec<16 x i32> +// CMP32: %[[M:.*]] = llvm.icmp "slt" %[[A]], %{{.*}} : !llvm.vec<16 x i32> +// CMP32: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %[[M]], %{{.*}} +// CMP32: llvm.return %[[L]] : !llvm.vec<16 x float> + +// CMP64-LABEL: llvm.func @transfer_read_1d +// CMP64: %[[C:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi64>) : !llvm.vec<16 x i64> +// CMP64: %[[A:.*]] = llvm.add %{{.*}}, %[[C]] : !llvm.vec<16 x i64> +// CMP64: %[[M:.*]] = llvm.icmp "slt" %[[A]], %{{.*}} : !llvm.vec<16 x i64> +// CMP64: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %[[M]], %{{.*}} +// CMP64: llvm.return %[[L]] : !llvm.vec<16 x float> + +func @transfer_read_1d(%A : memref, %i: index) -> vector<16xf32> { + %d = constant -1.0: f32 + %f = vector.transfer_read %A[%i], %d {permutation_map = affine_map<(d0) -> (d0)>} : memref, vector<16xf32> + return %f : vector<16xf32> +} diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -749,10 +749,12 @@ // CHECK-SAME: (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] : // CHECK-SAME: !llvm.ptr to !llvm.ptr> +// CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 0] : +// CHECK-SAME: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. -// CHECK: %[[linearIndex:.*]] = llvm.mlir.constant( -// CHECK-SAME: dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : +// CHECK: %[[linearIndex:.*]] = llvm.mlir.constant(dense +// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : // CHECK-SAME: vector<17xi64>) : !llvm.vec<17 x i64> // // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. @@ -770,8 +772,6 @@ // // 4. Let dim the memref dimension, compute the vector comparison mask: // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] -// CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 0] : -// CHECK-SAME: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[dimVec:.*]] = llvm.mlir.undef : !llvm.vec<17 x i64> // CHECK: %[[c01:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 // CHECK: %[[dimVec2:.*]] = llvm.insertelement %[[DIM]], %[[dimVec]][%[[c01]] : @@ -799,9 +799,9 @@ // CHECK-SAME: !llvm.ptr to !llvm.ptr> // // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. -// CHECK: %[[linearIndex_b:.*]] = llvm.mlir.constant( -// CHECK-SAME: dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : -// CHECK-SAME: vector<17xi64>) : !llvm.vec<17 x i64> +// CHECK: %[[linearIndex_b:.*]] = llvm.mlir.constant(dense +// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : +// CHECK-SAME: vector<17xi64>) : !llvm.vec<17 x i64> // // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. // CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, @@ -832,6 +832,8 @@ } // CHECK-LABEL: func @transfer_read_2d_to_1d // CHECK-SAME: %[[BASE_0:[a-zA-Z0-9]*]]: !llvm.i64, %[[BASE_1:[a-zA-Z0-9]*]]: !llvm.i64) -> !llvm.vec<17 x float> +// CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 1] : +// CHECK-SAME: !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // // Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. // CHECK: %[[offsetVec:.*]] = llvm.mlir.undef : !llvm.vec<17 x i64> @@ -847,8 +849,6 @@ // Let dim the memref dimension, compute the vector comparison mask: // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] // Here we check we properly use %DIM[1] -// CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 1] : -// CHECK-SAME: !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[dimVec:.*]] = llvm.mlir.undef : !llvm.vec<17 x i64> // CHECK: %[[c01:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 // CHECK: %[[dimVec2:.*]] = llvm.insertelement %[[DIM]], %[[dimVec]][%[[c01]] : diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -785,43 +785,63 @@ return %v: vector<2x3x4xi1> } -// CHECK-LABEL: func @genbool_var_1d -// CHECK-SAME: %[[A:.*]]: index -// CHECK: %[[C1:.*]] = constant dense<[0, 1, 2]> : vector<3xi64> -// CHECK: %[[T0:.*]] = index_cast %[[A]] : index to i64 -// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xi64> -// CHECK: %[[T2:.*]] = cmpi "slt", %[[C1]], %[[T1]] : vector<3xi64> -// CHECK: return %[[T2]] : vector<3xi1> +// CHECK-LABEL: func @genbool_var_1d( +// CHECK-SAME: %[[A:.*]]: index) +// CHECK: %[[T0:.*]] = vector.create_mask %[[A]] : vector<3xi1> +// CHECK: return %[[T0]] : vector<3xi1> func @genbool_var_1d(%arg0: index) -> vector<3xi1> { %0 = vector.create_mask %arg0 : vector<3xi1> return %0 : vector<3xi1> } -// CHECK-LABEL: func @genbool_var_2d -// CHECK-SAME: %[[A:.*0]]: index -// CHECK-SAME: %[[B:.*1]]: index -// CHECK: %[[CI:.*]] = constant dense<[0, 1, 2]> : vector<3xi64> -// CHECK: %[[CF:.*]] = constant dense : vector<3xi1> +// CHECK-LABEL: func @genbool_var_2d( +// CHECK-SAME: %[[A:.*0]]: index, +// CHECK-SAME: %[[B:.*1]]: index) +// CHECK: %[[C1:.*]] = constant dense : vector<3xi1> // CHECK: %[[C2:.*]] = constant dense : vector<2x3xi1> // CHECK: %[[c0:.*]] = constant 0 : index // CHECK: %[[c1:.*]] = constant 1 : index -// CHECK: %[[T0:.*]] = index_cast %[[B]] : index to i64 -// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xi64> -// CHECK: %[[T2:.*]] = cmpi "slt", %[[CI]], %[[T1]] : vector<3xi64> -// CHECK: %[[T3:.*]] = cmpi "slt", %[[c0]], %[[A]] : index -// CHECK: %[[T4:.*]] = select %[[T3]], %[[T2]], %[[CF]] : vector<3xi1> -// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1> -// CHECK: %[[T6:.*]] = cmpi "slt", %[[c1]], %[[A]] : index -// CHECK: %[[T7:.*]] = select %[[T6]], %[[T2]], %[[CF]] : vector<3xi1> -// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[T5]] [1] : vector<3xi1> into vector<2x3xi1> -// CHECK: return %[[T8]] : vector<2x3xi1> +// CHECK: %[[T0:.*]] = vector.create_mask %[[B]] : vector<3xi1> +// CHECK: %[[T1:.*]] = cmpi "slt", %[[c0]], %[[A]] : index +// CHECK: %[[T2:.*]] = select %[[T1]], %[[T0]], %[[C1]] : vector<3xi1> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1> +// CHECK: %[[T4:.*]] = cmpi "slt", %[[c1]], %[[A]] : index +// CHECK: %[[T5:.*]] = select %[[T4]], %[[T0]], %[[C1]] : vector<3xi1> +// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : vector<3xi1> into vector<2x3xi1> +// CHECK: return %[[T6]] : vector<2x3xi1> func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> { %0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1> return %0 : vector<2x3xi1> } +// CHECK-LABEL: func @genbool_var_3d( +// CHECK-SAME: %[[A:.*0]]: index, +// CHECK-SAME: %[[B:.*1]]: index, +// CHECK-SAME: %[[C:.*2]]: index) +// CHECK: %[[C1:.*]] = constant dense : vector<7xi1> +// CHECK: %[[C2:.*]] = constant dense : vector<1x7xi1> +// CHECK: %[[C3:.*]] = constant dense : vector<2x1x7xi1> +// CHECK: %[[c0:.*]] = constant 0 : index +// CHECK: %[[c1:.*]] = constant 1 : index +// CHECK: %[[T0:.*]] = vector.create_mask %[[C]] : vector<7xi1> +// CHECK: %[[T1:.*]] = cmpi "slt", %[[c0]], %[[B]] : index +// CHECK: %[[T2:.*]] = select %[[T1]], %[[T0]], %[[C1]] : vector<7xi1> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<7xi1> into vector<1x7xi1> +// CHECK: %[[T4:.*]] = cmpi "slt", %[[c0]], %[[A]] : index +// CHECK: %[[T5:.*]] = select %[[T4]], %[[T3]], %[[C2]] : vector<1x7xi1> +// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[C3]] [0] : vector<1x7xi1> into vector<2x1x7xi1> +// CHECK: %[[T7:.*]] = cmpi "slt", %[[c1]], %[[A]] : index +// CHECK: %[[T8:.*]] = select %[[T7]], %[[T3]], %[[C2]] : vector<1x7xi1> +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [1] : vector<1x7xi1> into vector<2x1x7xi1> +// CHECK: return %[[T9]] : vector<2x1x7xi1> + +func @genbool_var_3d(%arg0: index, %arg1: index, %arg2: index) -> vector<2x1x7xi1> { + %0 = vector.create_mask %arg0, %arg1, %arg2 : vector<2x1x7xi1> + return %0 : vector<2x1x7xi1> +} + #matmat_accesses_0 = [ affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (k, n)>,