diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -634,4 +634,48 @@ let hasVerifier = 1; } +def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">, + Results<(outs AnyType:$res)>, + Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> { + + let summary = "cooperative matrix load"; + + string llvmBuilder = [{ + auto operands = moduleTranslation.lookupValues(opInst.getOperands()); + auto intId = getLdMatrixIntrinsicId($layout, $num); + $res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()}); + }]; + + string baseDescription = [{ + The `nvvm.ldmatrix` operation collectively loads one or more matrices across + all threads in a warp from the location indicated by the address operand + `ptr` from shared memory. + + The attribute `num` indicates how many 8x8 16-bit matrices are to be loaded. + + All the threads in the warp must execute the same ldmatrix operations. + + Each row of 8 elements needs to be consecutive in memory. Each lane of the + warp contains the start address of a row of 8 elements laid out as below: + + ``` + num | lane 0--7 | Threads 8--15 | Threads 16--31 + 1 | addr0--addr7 | | + 2 | addr0--addr7 | addr8--addr15 | + 4 | addr0--addr7 | addr8--addr15 | addr16--addr31 + ``` + + Example: + ```mlir + %l1 = nvvm.ldmatrix %ptr {num = 1 : i32, layout = #nvvm.mma_layout} : + (!llvm.ptr) -> i32 + %l2 = nvvm.ldmatrix %ptr {num = 4 : i32, layout = #nvvm.mma_layout} : + (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32)> + ``` + }]; + + let assemblyFormat = "$ptr attr-dict `:` functional-type($ptr, $res)"; + let hasVerifier = 1; +} + #endif // NVVMIR_OPS diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -219,6 +219,28 @@ return success(); } +LogicalResult NVVM::LdMatrixOp::verify() { + unsigned addressSpace = + ptr().getType().cast().getAddressSpace(); + if (addressSpace != 3) + return emitOpError("expected source pointer in memory space 3"); + + if (num() != 1 && num() != 2 && num() != 4) + return emitOpError("expected num attribute to be 1, 2 or 4"); + + Type i32 = IntegerType::get(getContext(), 32); + if (num() == 1 && getType() != i32) + return emitOpError("expected destination type is i32"); + if (num() == 2 || num() == 4) { + Type dstType = LLVM::LLVMStructType::getLiteral( + getContext(), SmallVector(num(), i32)); + if (getType() != dstType) + return emitOpError("expected destination type is a structure of ") + << num() << " elements of type i32"; + } + return success(); +} + //===----------------------------------------------------------------------===// // NVVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -64,6 +64,35 @@ llvm_unreachable("unknown shuffle kind"); } +/// Return the intrinsic ID associated with ldmatrix for the given paramters. +static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, + int32_t num) { + if (layout == NVVM::MMALayout::col) { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16; + case 2: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16; + case 4: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16; + default: + llvm_unreachable("unsupported number of matrix"); + } + + } else { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16; + case 2: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16; + case 4: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16; + default: + llvm_unreachable("unsupported number of matrix"); + } + } +} + namespace { /// Implementation of the dialect interface that converts operations belonging /// to the NVVM dialect to LLVM IR. diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1191,6 +1191,38 @@ // ----- +llvm.func @wmmald_matrix(%arg0: !llvm.ptr) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr) -> i32 + llvm.return +} + +// ----- + +llvm.func @wmmald_matrix(%arg0: !llvm.ptr) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}} + %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr) -> i32 + llvm.return +} + +// ----- + +llvm.func @wmmald_matrix(%arg0: !llvm.ptr) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is i32}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr) -> !llvm.struct<(i32)> + llvm.return +} + +// ----- + +llvm.func @wmmald_matrix(%arg0: !llvm.ptr) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}} + %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr) -> !llvm.struct<(i32, i32)> + llvm.return +} + +// ----- + llvm.func @caller() { // expected-error @below {{expected function call to produce a value}} llvm.call @callee() : () -> () diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -105,6 +105,16 @@ llvm.return } +// CHECK-LABEL: llvm.func @ld_matrix +llvm.func @ld_matrix(%arg0: !llvm.ptr) { + // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout, num = 1 : i32} : (!llvm.ptr) -> i32 + %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr) -> i32 + // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout, num = 2 : i32} : (!llvm.ptr) -> !llvm.struct<(i32, i32)> + %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr) -> !llvm.struct<(i32, i32)> + // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout, num = 4 : i32} : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32)> + %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32)> + llvm.return +} // ----- // expected-error@below {{attribute attached to unexpected op}} diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -176,6 +176,17 @@ llvm.return } +// CHECK-LABEL: @ld_matrix( +llvm.func @ld_matrix(%arg0: !llvm.ptr) { + // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3i32(i32 addrspace(3)* %{{.*}}) + %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr) -> i32 + // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3i32(i32 addrspace(3)* %{{.*}}) + %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr) -> !llvm.struct<(i32, i32)> + // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3i32(i32 addrspace(3)* %{{.*}}) + %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32)> + llvm.return +} + // This function has the "kernel" attribute attached and should appear in the // NVVM annotations after conversion. llvm.func @kernel_func() attributes {nvvm.kernel} {