diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1643,16 +1643,18 @@ Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, std::function createOperand, ConversionPatternRewriter &rewriter) { - auto operandNDVectorType = op->getOperand(0).getType().dyn_cast(); - auto resultNDVectorType = op->getResult(0).getType().dyn_cast(); - assert(operandNDVectorType && resultNDVectorType && "expected vector types"); - + auto resultNDVectorType = op->getResult(0).getType().cast(); + + SmallVector operand1DVectorTypes; + for (Value operand : op->getOperands()) { + auto operandNDVectorType = operand.getType().cast(); + auto operandTypeInfo = + extractNDVectorTypeInfo(operandNDVectorType, typeConverter); + operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy); + } auto resultTypeInfo = extractNDVectorTypeInfo(resultNDVectorType, typeConverter); - auto operandTypeInfo = - extractNDVectorTypeInfo(operandNDVectorType, typeConverter); auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy; - auto operand1DVectorTy = operandTypeInfo.llvm1DVectorTy; auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy; auto loc = op->getLoc(); Value desc = rewriter.create(loc, resultNDVectoryTy); @@ -1660,9 +1662,11 @@ // For this unrolled `position` corresponding to the `linearIndex`^th // element, extract operand vectors SmallVector extractedOperands; - for (auto operand : operands) + for (auto operand : llvm::enumerate(operands)) { extractedOperands.push_back(rewriter.create( - loc, operand1DVectorTy, operand, position)); + loc, operand1DVectorTypes[operand.index()], operand.value(), + position)); + } Value newVal = createOperand(result1DVectorTy, extractedOperands); desc = rewriter.create(loc, resultNDVectoryTy, desc, newVal, position); @@ -1723,7 +1727,7 @@ using OrOpLowering = VectorConvertToLLVMPattern; using PowFOpLowering = VectorConvertToLLVMPattern; using RemFOpLowering = VectorConvertToLLVMPattern; -using SelectOpLowering = OneToOneConvertToLLVMPattern; +using SelectOpLowering = VectorConvertToLLVMPattern; using SignExtendIOpLowering = VectorConvertToLLVMPattern; using ShiftLeftOpLowering = diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -297,3 +297,16 @@ %0 = cmpi ult, %arg0, %arg1 : vector<4x3xi32> std.return } + +// ----- + +// CHECK-LABEL: func @select_2dvector( +func @select_2dvector(%arg0 : vector<4x3xi1>, %arg1 : vector<4x3xi32>, %arg2 : vector<4x3xi32>) { + // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<3xi1>> + // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %arg1[0] : !llvm.array<4 x vector<3xi32>> + // CHECK: %[[EXTRACT3:.*]] = llvm.extractvalue %arg2[0] : !llvm.array<4 x vector<3xi32>> + // CHECK: %[[SELECT:.*]] = llvm.select %[[EXTRACT1]], %[[EXTRACT2]], %[[EXTRACT3]] : vector<3xi1>, vector<3xi32> + // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[SELECT]], %0[0] : !llvm.array<4 x vector<3xi32>> + %0 = select %arg0, %arg1, %arg2 : vector<4x3xi1>, vector<4x3xi32> + std.return +}