package org.nio4r; import java.io.IOException; import java.io.Serializable; import java.nio.channels.Channel; import java.nio.channels.SelectableChannel; import java.nio.channels.ReadableByteChannel; import java.nio.channels.WritableByteChannel; import java.nio.BufferOverflowException; import java.nio.BufferUnderflowException; import java.nio.InvalidMarkException; import org.jruby.Ruby; import org.jruby.RubyClass; import org.jruby.RubyIO; import org.jruby.RubyNumeric; import org.jruby.RubyObject; import org.jruby.RubyString; import org.jruby.anno.JRubyMethod; import org.jruby.exceptions.RaiseException; import org.jruby.runtime.ThreadContext; import org.jruby.runtime.builtin.IRubyObject; import org.jruby.runtime.Block; /* created by Upekshej */ public class ByteBuffer extends RubyObject { private static final long serialVersionUID = -6903439483039149324L; private transient java.nio.ByteBuffer byteBuffer; public static RaiseException newOverflowError(ThreadContext context, String message) { RubyClass klass = context.runtime.getModule("NIO").getClass("ByteBuffer").getClass("OverflowError"); return context.runtime.newRaiseException(klass, message); } public static RaiseException newUnderflowError(ThreadContext context, String message) { RubyClass klass = context.runtime.getModule("NIO").getClass("ByteBuffer").getClass("UnderflowError"); return context.runtime.newRaiseException(klass, message); } public static RaiseException newMarkUnsetError(ThreadContext context, String message) { RubyClass klass = context.runtime.getModule("NIO").getClass("ByteBuffer").getClass("MarkUnsetError"); return context.runtime.newRaiseException(klass, message); } public ByteBuffer(final Ruby ruby, RubyClass rubyClass) { super(ruby, rubyClass); } @JRubyMethod public IRubyObject initialize(ThreadContext context, IRubyObject capacity) { this.byteBuffer = java.nio.ByteBuffer.allocate(RubyNumeric.num2int(capacity)); return this; } @JRubyMethod public IRubyObject clear(ThreadContext context) { this.byteBuffer.clear(); return this; } @JRubyMethod(name = "position") public IRubyObject getPosition(ThreadContext context) { return context.getRuntime().newFixnum(this.byteBuffer.position()); } @JRubyMethod(name = "position=") public IRubyObject setPosition(ThreadContext context, IRubyObject newPosition) { int pos = RubyNumeric.num2int(newPosition); if(pos < 0) { throw context.runtime.newArgumentError("negative position given"); } if(pos > this.byteBuffer.limit()) { throw context.runtime.newArgumentError("specified position exceeds limit"); } try { this.byteBuffer.position(pos); return newPosition; } catch(IllegalArgumentException e) { throw context.runtime.newArgumentError(e.getLocalizedMessage()); } } @JRubyMethod(name = "limit") public IRubyObject getLimit(ThreadContext context) { return context.getRuntime().newFixnum(this.byteBuffer.limit()); } @JRubyMethod(name = "limit=") public IRubyObject setLimit(ThreadContext context, IRubyObject newLimit) { int lim = RubyNumeric.num2int(newLimit); if(lim < 0) { throw context.runtime.newArgumentError("negative limit given"); } if(lim > this.byteBuffer.capacity()) { throw context.runtime.newArgumentError("specified limit exceeds capacity"); } try { this.byteBuffer.limit(lim); return newLimit; } catch(IllegalArgumentException e) { throw context.runtime.newArgumentError(e.getLocalizedMessage()); } } @JRubyMethod(name = {"capacity", "size"}) public IRubyObject capacity(ThreadContext context) { return context.getRuntime().newFixnum(this.byteBuffer.capacity()); } @JRubyMethod public IRubyObject remaining(ThreadContext context) { return context.getRuntime().newFixnum(this.byteBuffer.remaining()); } @JRubyMethod(name = "full?") public IRubyObject isFull(ThreadContext context) { if (this.byteBuffer.hasRemaining()) { return context.getRuntime().getFalse(); } else { return context.getRuntime().getTrue(); } } @JRubyMethod public IRubyObject get(ThreadContext context) { return this.get(context, context.getRuntime().newFixnum(this.byteBuffer.remaining())); } @JRubyMethod public IRubyObject get(ThreadContext context, IRubyObject length) { int len = RubyNumeric.num2int(length); byte[] bytes = new byte[len]; try { this.byteBuffer.get(bytes); } catch(BufferUnderflowException e) { throw ByteBuffer.newUnderflowError(context, "not enough data in buffer"); } return RubyString.newString(context.getRuntime(), bytes); } @JRubyMethod(name = "[]") public IRubyObject fetch(ThreadContext context, IRubyObject index) { int i = RubyNumeric.num2int(index); if(i < 0) { throw context.runtime.newArgumentError("negative index given"); } if(i >= this.byteBuffer.limit()) { throw context.runtime.newArgumentError("index exceeds limit"); } return context.getRuntime().newFixnum(this.byteBuffer.get(i)); } @JRubyMethod(name = "<<") public IRubyObject put(ThreadContext context, IRubyObject str) { try { this.byteBuffer.put(str.convertToString().getByteList().bytes()); } catch(BufferOverflowException e) { throw ByteBuffer.newOverflowError(context, "buffer is full"); } return this; } @JRubyMethod(name = "read_from") public IRubyObject readFrom(ThreadContext context, IRubyObject io) { Ruby runtime = context.runtime; Channel channel = RubyIO.convertToIO(context, io).getChannel(); if(!this.byteBuffer.hasRemaining()) { throw ByteBuffer.newOverflowError(context, "buffer is full"); } if(!(channel instanceof ReadableByteChannel) || !(channel instanceof SelectableChannel)) { throw runtime.newArgumentError("unsupported IO object: " + io.getType().toString()); } try { ((SelectableChannel)channel).configureBlocking(false); } catch(IOException ie) { throw runtime.newIOError(ie.getLocalizedMessage()); } try { int bytesRead = ((ReadableByteChannel)channel).read(this.byteBuffer); if(bytesRead >= 0) { return runtime.newFixnum(bytesRead); } else { throw runtime.newEOFError(); } } catch(IOException ie) { throw runtime.newIOError(ie.getLocalizedMessage()); } } @JRubyMethod(name = "write_to") public IRubyObject writeTo(ThreadContext context, IRubyObject io) { Ruby runtime = context.runtime; Channel channel = RubyIO.convertToIO(context, io).getChannel(); if(!this.byteBuffer.hasRemaining()) { throw ByteBuffer.newUnderflowError(context, "not enough data in buffer"); } if(!(channel instanceof WritableByteChannel) || !(channel instanceof SelectableChannel)) { throw runtime.newArgumentError("unsupported IO object: " + io.getType().toString()); } try { ((SelectableChannel)channel).configureBlocking(false); } catch(IOException ie) { throw runtime.newIOError(ie.getLocalizedMessage()); } try { int bytesWritten = ((WritableByteChannel)channel).write(this.byteBuffer); if(bytesWritten >= 0) { return runtime.newFixnum(bytesWritten); } else { throw runtime.newEOFError(); } } catch(IOException ie) { throw runtime.newIOError(ie.getLocalizedMessage()); } } @JRubyMethod public IRubyObject flip(ThreadContext context) { this.byteBuffer.flip(); return this; } @JRubyMethod public IRubyObject rewind(ThreadContext context) { this.byteBuffer.rewind(); return this; } @JRubyMethod public IRubyObject mark(ThreadContext context) { this.byteBuffer.mark(); return this; } @JRubyMethod public IRubyObject reset(ThreadContext context) { try { this.byteBuffer.reset(); return this; } catch(InvalidMarkException ie) { throw ByteBuffer.newMarkUnsetError(context, "mark has not been set"); } } @JRubyMethod public IRubyObject compact(ThreadContext context) { this.byteBuffer.compact(); return this; } @JRubyMethod public IRubyObject each(ThreadContext context, Block block) { for(int i = 0; i < this.byteBuffer.limit(); i++) { block.call(context, context.getRuntime().newFixnum(this.byteBuffer.get(i))); } return this; } @JRubyMethod public IRubyObject inspect(ThreadContext context) { return context.runtime.newString(String.format( "#<%s:0x%x @position=%d @limit=%d @capacity=%d>", this.getType().toString(), System.identityHashCode(this), this.byteBuffer.position(), this.byteBuffer.limit(), this.byteBuffer.capacity() )); } }