diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -791,8 +791,47 @@ // LLVM Matrix operations. // -/// As specified in the LLVM MatrixBuilder: -/// Create a llvm.matrix.multiply call, multiplying matrices LHS and RHS. +/// Create a columnwise, strided 2-D matrix load, as specified in the LLVM +/// MatrixBuilder. +/// data - Start address of the matrix read +/// rows - Number of rows in matrix (must be a constant) +/// columns - Number of columns in matrix (must be a constant) +/// stride - Space between columns +def LLVM_MatrixColumnsWiseLoadOp + : LLVM_OneResultOp<"intr.matrix.columnwise.load">, + Arguments<(ins LLVM_Type:$data, LLVM_Type:$stride, + I32Attr:$rows, I32Attr:$columns)> { + string llvmBuilder = [{ + llvm::MatrixBuilder mb(builder); + $res = mb.CreateMatrixColumnwiseLoad( + $data, $rows.getZExtValue(), $columns.getZExtValue(), $stride); + }]; + let assemblyFormat = "$data `,` `<` `stride` `=` $stride `>` attr-dict" + "`:` type($res) `from` type($data) `stride` type($stride)"; +} + +/// Create a columnwise, strided 2-D matrix store, as specified in the LLVM +/// MatrixBuilder. +/// matrix - Matrix to store +/// ptr - Pointer to write back to +/// rows - Number of rows in matrix (must be a constant) +/// columns - Number of columns in matrix (must be a constant) +/// stride - Space between columns +def LLVM_MatrixColumnsWiseStoreOp + : LLVM_ZeroResultOp<"intr.matrix.columnwise.store">, + Arguments<(ins LLVM_Type:$matrix, LLVM_Type:$data, LLVM_Type:$stride, + I32Attr:$rows, I32Attr:$columns)> { + string llvmBuilder = [{ + llvm::MatrixBuilder mb(builder); + mb.CreateMatrixColumnwiseStore( + $matrix, $data, $stride, $rows.getZExtValue(), $columns.getZExtValue()); + }]; + let assemblyFormat = "$matrix `,` $data `,` `<` `stride` `=` $stride `>` " + "attr-dict`:` type($matrix) `to` type($data) `stride` type($stride)"; +} + +/// Create a llvm.matrix.multiply call, multiplying 2-D matrices LHS and RHS, as +/// specified in the LLVM MatrixBuilder. def LLVM_MatrixMultiplyOp : LLVM_OneResultOp<"intr.matrix.multiply">, Arguments<( @@ -808,6 +847,19 @@ "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)"; } +/// Create a llvm.matrix.transpose call, transposing a `rows` x `columns` 2-D +/// `matrix`, as specified in the LLVM MatrixBuilder. +def LLVM_MatrixTranposeOp + : LLVM_OneResultOp<"intr.matrix.transpose">, + Arguments<(ins LLVM_Type:$matrix, I32Attr:$rows, I32Attr:$columns)> { + string llvmBuilder = [{ + llvm::MatrixBuilder mb(builder); + $res = mb.CreateMatrixTranspose( + $matrix, $rows.getZExtValue(), $columns.getZExtValue()); + }]; + let assemblyFormat = "$matrix attr-dict `:` type($matrix) `into` type($res)"; +} + // // Atomic operations. // diff --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir --- a/mlir/test/Target/llvmir-intrinsics.mlir +++ b/mlir/test/Target/llvmir-intrinsics.mlir @@ -132,15 +132,24 @@ // CHECK-LABEL: @matrix_intrinsics // 4x16 16x3 -llvm.func @matrix_intrinsics(%A: !llvm<"<64 x float>">, %B: !llvm<"<48 x float>">) -// 4x3 - -> !llvm<"<12 x float>"> -{ +llvm.func @matrix_intrinsics(%A: !llvm<"<64 x float>">, %B: !llvm<"<48 x float>">, + %ptr: !llvm<"float*">, %stride: !llvm.i32) { // CHECK: call <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float> %0, <48 x float> %1, i32 4, i32 16, i32 3) %C = llvm.intr.matrix.multiply %A, %B { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_rows = 3: i32} : (!llvm<"<64 x float>">, !llvm<"<48 x float>">) -> !llvm<"<12 x float>"> - llvm.return %C: !llvm<"<12 x float>"> + // CHECK: call <48 x float> @llvm.matrix.transpose.v48f32(<48 x float> %1, i32 3, i32 16) + %D = llvm.intr.matrix.transpose %B { rows = 3: i32, columns = 16: i32} : + !llvm<"<48 x float>"> into !llvm<"<48 x float>"> + // CHECK: call <48 x float> @llvm.matrix.columnwise.load.v48f32.p0f32(float* %2, i32 %3, i32 3, i32 16) + %E = llvm.intr.matrix.columnwise.load %ptr, + { rows = 3: i32, columns = 16: i32} : + !llvm<"<48 x float>"> from !llvm<"float*"> stride !llvm.i32 + // CHECK: call void @llvm.matrix.columnwise.store.v48f32.p0f32(<48 x float> %7, float* %2, i32 %3, i32 3, i32 16) + llvm.intr.matrix.columnwise.store %E, %ptr, + { rows = 3: i32, columns = 16: i32} : + !llvm<"<48 x float>"> to !llvm<"float*"> stride !llvm.i32 + llvm.return } // Check that intrinsics are declared with appropriate types. @@ -167,3 +176,6 @@ // CHECK-DAG: declare <8 x float> @llvm.cos.v8f32(<8 x float>) #0 // CHECK-DAG: declare float @llvm.copysign.f32(float, float) // CHECK-DAG: declare <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float>, <48 x float>, i32 immarg, i32 immarg, i32 immarg) +// CHECK-DAG: declare <48 x float> @llvm.matrix.transpose.v48f32(<48 x float>, i32 immarg, i32 immarg) +// CHECK-DAG: declare <48 x float> @llvm.matrix.columnwise.load.v48f32.p0f32(float*, i32, i32 immarg, i32 immarg) +// CHECK-DAG: declare void @llvm.matrix.columnwise.store.v48f32.p0f32(<48 x float>, float* writeonly, i32, i32 immarg, i32 immarg)