diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp @@ -135,15 +135,16 @@ // multipled. auto *BO = dyn_cast(Start); if (!BO || (BO->getOpcode() != Instruction::Add && + BO->getOpcode() != Instruction::Shl && BO->getOpcode() != Instruction::Mul)) return std::make_pair(nullptr, nullptr); // Look for an operand that is splatted. - unsigned OtherIndex = 1; - Value *Splat = getSplatValue(BO->getOperand(0)); - if (!Splat) { - Splat = getSplatValue(BO->getOperand(1)); - OtherIndex = 0; + unsigned OtherIndex = 0; + Value *Splat = getSplatValue(BO->getOperand(1)); + if (!Splat && Instruction::isCommutative(BO->getOpcode())) { + Splat = getSplatValue(BO->getOperand(0)); + OtherIndex = 1; } if (!Splat) return std::make_pair(nullptr, nullptr); @@ -158,13 +159,22 @@ Builder.SetCurrentDebugLocation(DebugLoc()); // Add the splat value to the start or multiply the start and stride by the // splat. - if (BO->getOpcode() == Instruction::Add) { + switch (BO->getOpcode()) { + default: + llvm_unreachable("Unexpected opcode"); + case Instruction::Add: Start = Builder.CreateAdd(Start, Splat); - } else { - assert(BO->getOpcode() == Instruction::Mul && "Unexpected opcode"); + break; + case Instruction::Mul: Start = Builder.CreateMul(Start, Splat); Stride = Builder.CreateMul(Stride, Splat); + break; + case Instruction::Shl: + Start = Builder.CreateShl(Start, Splat); + Stride = Builder.CreateShl(Stride, Splat); + break; } + return std::make_pair(Start, Stride); } diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll --- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll +++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll @@ -112,6 +112,94 @@ ret %x } +define @straightline_offset_add(ptr %p, i64 %offset) { +; CHECK-LABEL: @straightline_offset_add( +; CHECK-NEXT: [[TMP1:%.*]] = add i64 0, [[OFFSET:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]] +; CHECK-NEXT: [[X:%.*]] = call @llvm.riscv.masked.strided.load.nxv1i64.p0.i64( poison, ptr [[TMP2]], i64 4, shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret [[X]] +; + %step = call @llvm.experimental.stepvector.nxv1i64() + %splat.insert = insertelement poison, i64 %offset, i64 0 + %splat = shufflevector %splat.insert, poison, zeroinitializer + %offsetv = add %step, %splat + %ptrs = getelementptr i32, ptr %p, %offsetv + %x = call @llvm.masked.gather.nxv1i64.nxv1p0( + %ptrs, + i32 8, + shufflevector ( insertelement ( poison, i1 1, i64 0), poison, zeroinitializer), + poison + ) + ret %x +} + +define @straightline_offset_shl(ptr %p) { +; CHECK-LABEL: @straightline_offset_shl( +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 0 +; CHECK-NEXT: [[X:%.*]] = call @llvm.riscv.masked.strided.load.nxv1i64.p0.i64( poison, ptr [[TMP1]], i64 32, shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret [[X]] +; + %step = call @llvm.experimental.stepvector.nxv1i64() + %splat.insert = insertelement poison, i64 3, i64 0 + %splat = shufflevector %splat.insert, poison, zeroinitializer + %offset = shl %step, %splat + %ptrs = getelementptr i32, ptr %p, %offset + %x = call @llvm.masked.gather.nxv1i64.nxv1p0( + %ptrs, + i32 8, + shufflevector ( insertelement ( poison, i1 1, i64 0), poison, zeroinitializer), + poison + ) + ret %x +} + +define @neg_shl_is_not_commutative(ptr %p) { +; CHECK-LABEL: @neg_shl_is_not_commutative( +; CHECK-NEXT: [[STEP:%.*]] = call @llvm.experimental.stepvector.nxv1i64() +; CHECK-NEXT: [[SPLAT_INSERT:%.*]] = insertelement poison, i64 3, i64 0 +; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector [[SPLAT_INSERT]], poison, zeroinitializer +; CHECK-NEXT: [[OFFSET:%.*]] = shl [[SPLAT]], [[STEP]] +; CHECK-NEXT: [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], [[OFFSET]] +; CHECK-NEXT: [[X:%.*]] = call @llvm.masked.gather.nxv1i64.nxv1p0( [[PTRS]], i32 8, shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer), poison) +; CHECK-NEXT: ret [[X]] +; + %step = call @llvm.experimental.stepvector.nxv1i64() + %splat.insert = insertelement poison, i64 3, i64 0 + %splat = shufflevector %splat.insert, poison, zeroinitializer + %offset = shl %splat, %step + %ptrs = getelementptr i32, ptr %p, %offset + %x = call @llvm.masked.gather.nxv1i64.nxv1p0( + %ptrs, + i32 8, + shufflevector ( insertelement ( poison, i1 1, i64 0), poison, zeroinitializer), + poison + ) + ret %x +} + +define @straightline_offset_shl_nonc(ptr %p, i64 %shift) { +; CHECK-LABEL: @straightline_offset_shl_nonc( +; CHECK-NEXT: [[TMP1:%.*]] = shl i64 0, [[SHIFT:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = shl i64 1, [[SHIFT]] +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]] +; CHECK-NEXT: [[TMP4:%.*]] = mul i64 [[TMP2]], 4 +; CHECK-NEXT: [[X:%.*]] = call @llvm.riscv.masked.strided.load.nxv1i64.p0.i64( poison, ptr [[TMP3]], i64 [[TMP4]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret [[X]] +; + %step = call @llvm.experimental.stepvector.nxv1i64() + %splat.insert = insertelement poison, i64 %shift, i64 0 + %splat = shufflevector %splat.insert, poison, zeroinitializer + %offset = shl %step, %splat + %ptrs = getelementptr i32, ptr %p, %offset + %x = call @llvm.masked.gather.nxv1i64.nxv1p0( + %ptrs, + i32 8, + shufflevector ( insertelement ( poison, i1 1, i64 0), poison, zeroinitializer), + poison + ) + ret %x +} + define void @scatter_loopless( %x, ptr %p, i64 %stride) { ; CHECK-LABEL: @scatter_loopless( ; CHECK-NEXT: [[TMP1:%.*]] = mul i64 0, [[STRIDE:%.*]]