diff --git a/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp --- a/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp +++ b/mlir/lib/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.cpp @@ -131,22 +131,48 @@ LogicalResult IndexCastOpLowering::matchAndRewrite( arith::IndexCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto targetType = typeConverter->convertType(op.getResult().getType()); + auto resultType = op.getResult().getType(); auto targetElementType = - typeConverter->convertType(getElementTypeOrSelf(op.getResult())) - .cast(); + typeConverter->convertType(getElementTypeOrSelf(resultType)); auto sourceElementType = - getElementTypeOrSelf(adaptor.getIn()).cast(); - unsigned targetBits = targetElementType.getWidth(); - unsigned sourceBits = sourceElementType.getWidth(); + typeConverter->convertType(getElementTypeOrSelf(op.getIn())); + unsigned targetBits = targetElementType.getIntOrFloatBitWidth(); + unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth(); - if (targetBits == sourceBits) + if (targetBits == sourceBits) { rewriter.replaceOp(op, adaptor.getIn()); - else if (targetBits < sourceBits) - rewriter.replaceOpWithNewOp(op, targetType, adaptor.getIn()); - else - rewriter.replaceOpWithNewOp(op, targetType, adaptor.getIn()); - return success(); + return success(); + } + + // Handle the scalar and 1D vector cases. + auto operandType = adaptor.getIn().getType(); + if (!operandType.isa()) { + auto targetType = typeConverter->convertType(resultType); + if (targetBits < sourceBits) + rewriter.replaceOpWithNewOp(op, targetType, + adaptor.getIn()); + else + rewriter.replaceOpWithNewOp(op, targetType, + adaptor.getIn()); + return success(); + } + + auto vectorType = resultType.dyn_cast(); + if (!vectorType) + return rewriter.notifyMatchFailure(op, "expected vector result type"); + + return LLVM::detail::handleMultidimensionalVectors( + op.getOperation(), adaptor.getOperands(), *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) -> Value { + OpAdaptor adaptor(operands); + if (targetBits < sourceBits) { + return rewriter.create(op.getLoc(), llvm1DVectorTy, + adaptor.getIn()); + } + return rewriter.create(op.getLoc(), llvm1DVectorTy, + adaptor.getIn()); + }, + rewriter); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ArithmeticToLLVM/convert-nd-vector-to-llvmir.mlir b/mlir/test/Conversion/ArithmeticToLLVM/convert-nd-vector-to-llvmir.mlir --- a/mlir/test/Conversion/ArithmeticToLLVM/convert-nd-vector-to-llvmir.mlir +++ b/mlir/test/Conversion/ArithmeticToLLVM/convert-nd-vector-to-llvmir.mlir @@ -210,3 +210,24 @@ %0 = arith.select %arg0, %arg1, %arg2 : vector<4x3xi1>, vector<4x3xi32> func.return } + +// CHECK-LABEL: func @index_cast_2d( +// CHECK-SAME: %[[ARG0:.*]]: vector<1x2x3xi1>) +func.func @index_cast_2d(%arg0: vector<1x2x3xi1>) { + // CHECK: %[[SRC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] + // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %[[SRC]][0, 0] : !llvm.array<1 x array<2 x vector<3xi1>>> + // CHECK: %[[SEXT1:.*]] = llvm.sext %[[EXTRACT1]] : vector<3xi1> to vector<3xi{{.*}}> + // CHECK: %[[INSERT1:.*]] = llvm.insertvalue %[[SEXT1]], %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi{{.*}}>>> + // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %[[SRC]][0, 1] : !llvm.array<1 x array<2 x vector<3xi1>>> + // CHECK: %[[SEXT2:.*]] = llvm.sext %[[EXTRACT2]] : vector<3xi1> to vector<3xi{{.*}}> + // CHECK: %[[INSERT2:.*]] = llvm.insertvalue %[[SEXT2]], %[[INSERT1]][0, 1] : !llvm.array<1 x array<2 x vector<3xi{{.*}}>>> + %0 = arith.index_cast %arg0: vector<1x2x3xi1> to vector<1x2x3xindex> + // CHECK: %[[EXTRACT3:.*]] = llvm.extractvalue %[[INSERT2]][0, 0] : !llvm.array<1 x array<2 x vector<3xi{{.*}}>>> + // CHECK: %[[TRUNC1:.*]] = llvm.trunc %[[EXTRACT3]] : vector<3xi{{.*}}> to vector<3xi1> + // CHECK: %[[INSERT3:.*]] = llvm.insertvalue %[[TRUNC1]], %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi1>>> + // CHECK: %[[EXTRACT4:.*]] = llvm.extractvalue %[[INSERT2]][0, 1] : !llvm.array<1 x array<2 x vector<3xi{{.*}}>>> + // CHECK: %[[TRUNC2:.*]] = llvm.trunc %[[EXTRACT4]] : vector<3xi{{.*}}> to vector<3xi1> + // CHECK: %[[INSERT4:.*]] = llvm.insertvalue %[[TRUNC2]], %[[INSERT3]][0, 1] : !llvm.array<1 x array<2 x vector<3xi1>>> + %1 = arith.index_cast %0: vector<1x2x3xindex> to vector<1x2x3xi1> + return +}