diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -135,5 +135,51 @@ def ROCDL_mfma_i32_32x32x8i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x8i8">; def ROCDL_mfma_i32_16x16x16i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x16i8">; +//===---------------------------------------------------------------------===// +// Vector buffer load/store intrinsics + +def ROCDL_MubufLoadOp : + ROCDL_Op<"buffer.load">, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_Type:$rsrc, + LLVM_Type:$vindex, + LLVM_Type:$offset, + LLVM_Type:$glc, + LLVM_Type:$slc)>{ + string llvmBuilder = [{ + $res = createIntrinsicCall(builder, + llvm::Intrinsic::amdgcn_buffer_load, {$rsrc, $vindex, $offset, $glc, + $slc}, {$_resultType}); + }]; + let parser = [{ return parseROCDLMubufLoadOp(parser, result); }]; + let printer = [{ + Operation *op = this->getOperation(); + p << op->getName() << " " << op->getOperands() + << " : " << op->getResultTypes(); + }]; +} + +def ROCDL_MubufStoreOp : + ROCDL_Op<"buffer.store">, + Arguments<(ins LLVM_Type:$vdata, + LLVM_Type:$rsrc, + LLVM_Type:$vindex, + LLVM_Type:$offset, + LLVM_Type:$glc, + LLVM_Type:$slc)>{ + string llvmBuilder = [{ + auto vdataType = op.vdata().getType().cast() + .getUnderlyingType(); + createIntrinsicCall(builder, + llvm::Intrinsic::amdgcn_buffer_store, {$vdata, $rsrc, $vindex, + $offset, $glc, $slc}, {vdataType}); + }]; + let parser = [{ return parseROCDLMubufStoreOp(parser, result); }]; + let printer = [{ + Operation *op = this->getOperation(); + p << op->getName() << " " << op->getOperands() + << " : " << vdata().getType(); + }]; +} #endif // ROCDLIR_OPS diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp @@ -31,6 +31,56 @@ using namespace ROCDL; //===----------------------------------------------------------------------===// +// Parsing for ROCDL ops +//===----------------------------------------------------------------------===// + +static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) { + return parser.getBuilder() + .getContext() + ->getRegisteredDialect(); +} + +// ::= +// `llvm.amdgcn.buffer.load.* %rsrc, %vindex, %offset, %glc, %slc : +// result_type` +static ParseResult parseROCDLMubufLoadOp(OpAsmParser &parser, + OperationState &result) { + SmallVector ops; + Type type; + if (parser.parseOperandList(ops, 5) || parser.parseColonType(type) || + parser.addTypeToList(type, result.types)) + return failure(); + + auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser)); + auto int1Ty = LLVM::LLVMType::getInt1Ty(getLlvmDialect(parser)); + auto i32x4Ty = LLVM::LLVMType::getVectorTy(int32Ty, 4); + return parser.resolveOperands(ops, + {i32x4Ty, int32Ty, int32Ty, int1Ty, int1Ty}, + parser.getNameLoc(), result.operands); +} + +// ::= +// `llvm.amdgcn.buffer.store.* %vdata, %rsrc, %vindex, %offset, %glc, %slc : +// result_type` +static ParseResult parseROCDLMubufStoreOp(OpAsmParser &parser, + OperationState &result) { + SmallVector ops; + Type type; + if (parser.parseOperandList(ops, 6) || parser.parseColonType(type)) + return failure(); + + auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser)); + auto int1Ty = LLVM::LLVMType::getInt1Ty(getLlvmDialect(parser)); + auto i32x4Ty = LLVM::LLVMType::getVectorTy(int32Ty, 4); + + if (parser.resolveOperands(ops, + {type, i32x4Ty, int32Ty, int32Ty, int1Ty, int1Ty}, + parser.getNameLoc(), result.operands)) + return failure(); + return success(); +} + +//===----------------------------------------------------------------------===// // ROCDLDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp @@ -31,9 +31,10 @@ // Create a call to llvm intrinsic static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder, llvm::Intrinsic::ID intrinsic, - ArrayRef args = {}) { + ArrayRef args = {}, + ArrayRef tys = {}) { llvm::Module *module = builder.GetInsertBlock()->getModule(); - llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic); + llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic, tys); return builder.CreateCall(fn, args); } diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -144,3 +144,26 @@ llvm.return %r0 : !llvm<"<32 x float>"> } + +llvm.func @rocdl.mubuf(%rsrc : !llvm<"<4 x i32>">, %vindex : !llvm.i32, + %offset : !llvm.i32, %glc : !llvm.i1, + %slc : !llvm.i1, %vdata1 : !llvm<"<1 x float>">, + %vdata2 : !llvm<"<2 x float>">, %vdata4 : !llvm<"<4 x float>">) { + // CHECK-LABEL: rocdl.mubuf + // CHECK: %{{.*}} = rocdl.buffer.load %{{.*}} %{{.*}} %{{.*}} %{{.*}} %{{.*}} : !llvm<"<1 x float>"> + %r1 = rocdl.buffer.load %rsrc, %vindex, %offset, %glc, %slc : !llvm<"<1 x float>"> + // CHECK: %{{.*}} = rocdl.buffer.load %{{.*}} %{{.*}} %{{.*}} %{{.*}} %{{.*}} : !llvm<"<2 x float>"> + %r2 = rocdl.buffer.load %rsrc, %vindex, %offset, %glc, %slc : !llvm<"<2 x float>"> + // CHECK: %{{.*}} = rocdl.buffer.load %{{.*}} %{{.*}} %{{.*}} %{{.*}} %{{.*}} : !llvm<"<4 x float>"> + %r4 = rocdl.buffer.load %rsrc, %vindex, %offset, %glc, %slc : !llvm<"<4 x float>"> + + // CHECK: rocdl.buffer.store %{{.*}} %{{.*}} %{{.*}} %{{.*}} %{{.*}} %{{.*}} : !llvm<"<1 x float>"> + rocdl.buffer.store %vdata1, %rsrc, %vindex, %offset, %glc, %slc : !llvm<"<1 x float>"> + // CHECK: rocdl.buffer.store %{{.*}} %{{.*}} %{{.*}} %{{.*}} %{{.*}} %{{.*}} : !llvm<"<2 x float>"> + rocdl.buffer.store %vdata2, %rsrc, %vindex, %offset, %glc, %slc : !llvm<"<2 x float>"> + // CHECK: rocdl.buffer.store %{{.*}} %{{.*}} %{{.*}} %{{.*}} %{{.*}} %{{.*}} : !llvm<"<4 x float>"> + rocdl.buffer.store %vdata4, %rsrc, %vindex, %offset, %glc, %slc : !llvm<"<4 x float>"> + + llvm.return +} + diff --git a/mlir/test/Target/rocdl.mlir b/mlir/test/Target/rocdl.mlir --- a/mlir/test/Target/rocdl.mlir +++ b/mlir/test/Target/rocdl.mlir @@ -151,3 +151,26 @@ llvm.return %r0 : !llvm<"<32 x float>"> } + +llvm.func @rocdl.mubuf(%rsrc : !llvm<"<4 x i32>">, %vindex : !llvm.i32, + %offset : !llvm.i32, %glc : !llvm.i1, + %slc : !llvm.i1, %vdata1 : !llvm<"<1 x float>">, + %vdata2 : !llvm<"<2 x float>">, %vdata4 : !llvm<"<4 x float>">) { + // CHECK-LABEL: rocdl.mubuf + // CHECK: call <1 x float> @llvm.amdgcn.buffer.load.v1f32(<4 x i32> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i1 %{{.*}}, i1 %{{.*}}) + %r1 = rocdl.buffer.load %rsrc, %vindex, %offset, %glc, %slc : !llvm<"<1 x float>"> + // CHECK: call <2 x float> @llvm.amdgcn.buffer.load.v2f32(<4 x i32> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i1 %{{.*}}, i1 %{{.*}}) + %r2 = rocdl.buffer.load %rsrc, %vindex, %offset, %glc, %slc : !llvm<"<2 x float>"> + // CHECK: call <4 x float> @llvm.amdgcn.buffer.load.v4f32(<4 x i32> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i1 %{{.*}}, i1 %{{.*}}) + %r4 = rocdl.buffer.load %rsrc, %vindex, %offset, %glc, %slc : !llvm<"<4 x float>"> + + // CHECK: call void @llvm.amdgcn.buffer.store.v1f32(<1 x float> %{{.*}}, <4 x i32> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i1 %{{.*}}, i1 %{{.*}}) + rocdl.buffer.store %vdata1, %rsrc, %vindex, %offset, %glc, %slc : !llvm<"<1 x float>"> + // CHECK: call void @llvm.amdgcn.buffer.store.v2f32(<2 x float> %{{.*}}, <4 x i32> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i1 %{{.*}}, i1 %{{.*}}) + rocdl.buffer.store %vdata2, %rsrc, %vindex, %offset, %glc, %slc : !llvm<"<2 x float>"> + // CHECK: call void @llvm.amdgcn.buffer.store.v4f32(<4 x float> %{{.*}}, <4 x i32> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i1 %{{.*}}, i1 %{{.*}}) + rocdl.buffer.store %vdata4, %rsrc, %vindex, %offset, %glc, %slc : !llvm<"<4 x float>"> + + llvm.return +} +